ldpc_toolbox/simulation/
puncturing.rs

1//! Code puncturing.
2
3use ndarray::{Array1, ArrayBase, Data, Ix1, s};
4use thiserror::Error;
5
6/// Puncturer.
7///
8/// This struct is used to perform puncturing on codewords to be transmitted,
9/// and "depuncturing" on demodulated LLRs.
10#[derive(Debug, Clone)]
11pub struct Puncturer {
12    pattern: Box<[bool]>,
13    num_trues: usize,
14}
15
16/// Puncturer error.
17#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Error)]
18pub enum Error {
19    /// The codeword size is not divisible by the puncturing pattern length
20    #[error("codeword size not divisible by puncturing pattern length")]
21    CodewordSizeNotDivisible,
22}
23
24impl Puncturer {
25    /// Creates a new puncturer.
26    ///
27    /// The puncturing pattern is defined by blocks. For example `[true, true,
28    /// true, false]` means that the first 3/4 of the codeword bits are
29    /// preserved, and the last 1/4 is punctured.
30    ///
31    /// # Panics
32    ///
33    /// This function panics if the pattern is empty.
34    pub fn new(pattern: &[bool]) -> Puncturer {
35        assert!(!pattern.is_empty());
36        Puncturer {
37            pattern: pattern.into(),
38            num_trues: pattern.iter().filter(|&&b| b).count(),
39        }
40    }
41
42    /// Puncture a codeword.
43    ///
44    /// Given a codeword, returns the punctured codeword. An error is returned
45    /// if the length of the codeword is not divisible by the length of the
46    /// puncturing pattern.
47    pub fn puncture<S, A>(&self, codeword: &ArrayBase<S, Ix1>) -> Result<Array1<A>, Error>
48    where
49        S: Data<Elem = A>,
50        A: Clone,
51    {
52        let pattern_len = self.pattern.len();
53        let codeword_len = codeword.shape()[0];
54        if codeword_len % pattern_len != 0 {
55            return Err(Error::CodewordSizeNotDivisible);
56        }
57        let block_size = codeword_len / pattern_len;
58        let output_size = block_size * self.num_trues;
59        let mut out = Array1::uninit(output_size);
60        for (j, k) in self
61            .pattern
62            .iter()
63            .enumerate()
64            .filter_map(|(k, &b)| if b { Some(k) } else { None })
65            .enumerate()
66        {
67            codeword
68                .slice(s![k * block_size..(k + 1) * block_size])
69                .assign_to(out.slice_mut(s![j * block_size..(j + 1) * block_size]));
70        }
71        // Safety: all the elements of out have been assigned by the loop above.
72        Ok(unsafe { out.assume_init() })
73    }
74
75    /// Depuncture LLRs.
76    ///
77    /// This function depunctures demodulated LLRs by inserting zeros (which
78    /// indicate erasures) in the positions of the codeword that were
79    /// punctured. The input length must correspond to the punctured codeword,
80    /// while the output length is equal to the codeword length. An error is
81    /// returned if the length of input is not divisible by the number of `true`
82    /// elements in the pattern.
83    pub fn depuncture<T: Copy + Default>(&self, llrs: &[T]) -> Result<Vec<T>, Error> {
84        if llrs.len() % self.num_trues != 0 {
85            return Err(Error::CodewordSizeNotDivisible);
86        }
87        let block_size = llrs.len() / self.num_trues;
88        let output_size = self.pattern.len() * block_size;
89        let mut output = vec![T::default(); output_size];
90        for (j, k) in self
91            .pattern
92            .iter()
93            .enumerate()
94            .filter_map(|(k, &b)| if b { Some(k) } else { None })
95            .enumerate()
96        {
97            output[k * block_size..(k + 1) * block_size]
98                .copy_from_slice(&llrs[j * block_size..(j + 1) * block_size]);
99        }
100        Ok(output)
101    }
102
103    /// Returns the rate of the puncturer.
104    ///
105    /// The rate is defined as the length of the original codeword divided by
106    /// the length of the punctured codeword, and so it is always greater or
107    /// equal to one.
108    pub fn rate(&self) -> f64 {
109        self.pattern.len() as f64 / self.num_trues as f64
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116    use ndarray::array;
117
118    #[test]
119    fn puncturing() {
120        let puncturer = Puncturer::new(&[true, true, false, true, false]);
121        let codeword = array![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
122        let punctured = puncturer.puncture(&codeword).unwrap();
123        let expected = array![0, 1, 2, 3, 6, 7];
124        assert_eq!(&punctured, &expected);
125        let llrs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
126        let llrs_out = puncturer.depuncture(&llrs).unwrap();
127        let expected = [1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
128        assert_eq!(&llrs_out, &expected);
129    }
130}