use ndarray::{Array1, ArrayBase, Data, Ix1, s};
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct Puncturer {
pattern: Box<[bool]>,
num_trues: usize,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Error)]
pub enum Error {
#[error("codeword size not divisible by puncturing pattern length")]
CodewordSizeNotDivisible,
}
impl Puncturer {
pub fn new(pattern: &[bool]) -> Puncturer {
assert!(!pattern.is_empty());
Puncturer {
pattern: pattern.into(),
num_trues: pattern.iter().filter(|&&b| b).count(),
}
}
pub fn puncture<S, A>(&self, codeword: &ArrayBase<S, Ix1>) -> Result<Array1<A>, Error>
where
S: Data<Elem = A>,
A: Clone,
{
let pattern_len = self.pattern.len();
let codeword_len = codeword.shape()[0];
if !codeword_len.is_multiple_of(pattern_len) {
return Err(Error::CodewordSizeNotDivisible);
}
let block_size = codeword_len / pattern_len;
let output_size = block_size * self.num_trues;
let mut out = Array1::uninit(output_size);
for (j, k) in self
.pattern
.iter()
.enumerate()
.filter_map(|(k, &b)| if b { Some(k) } else { None })
.enumerate()
{
codeword
.slice(s![k * block_size..(k + 1) * block_size])
.assign_to(out.slice_mut(s![j * block_size..(j + 1) * block_size]));
}
Ok(unsafe { out.assume_init() })
}
pub fn depuncture<T: Copy + Default>(&self, llrs: &[T]) -> Result<Vec<T>, Error> {
if !llrs.len().is_multiple_of(self.num_trues) {
return Err(Error::CodewordSizeNotDivisible);
}
let block_size = llrs.len() / self.num_trues;
let output_size = self.pattern.len() * block_size;
let mut output = vec![T::default(); output_size];
for (j, k) in self
.pattern
.iter()
.enumerate()
.filter_map(|(k, &b)| if b { Some(k) } else { None })
.enumerate()
{
output[k * block_size..(k + 1) * block_size]
.copy_from_slice(&llrs[j * block_size..(j + 1) * block_size]);
}
Ok(output)
}
pub fn rate(&self) -> f64 {
self.pattern.len() as f64 / self.num_trues as f64
}
}
#[cfg(test)]
mod test {
use super::*;
use ndarray::array;
#[test]
fn puncturing() {
let puncturer = Puncturer::new(&[true, true, false, true, false]);
let codeword = array![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let punctured = puncturer.puncture(&codeword).unwrap();
let expected = array![0, 1, 2, 3, 6, 7];
assert_eq!(&punctured, &expected);
let llrs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let llrs_out = puncturer.depuncture(&llrs).unwrap();
let expected = [1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
assert_eq!(&llrs_out, &expected);
}
}