1use crate::{gf2::GF2, linalg, sparse::SparseMatrix};
27use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, s};
28use num_traits::One;
29use thiserror::Error;
30
31mod staircase;
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Error)]
35pub enum Error {
36 #[error("the square matrix formed by the last columns of the parity check is not invertible")]
39 SubmatrixNotInvertible,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct Encoder {
45 encoder: EncoderType,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49enum EncoderType {
50 DenseGenerator { gen_matrix: Array2<GF2> },
52 Staircase { gen_matrix: SparseMatrix },
55}
56
57impl Encoder {
58 pub fn from_h(h: &SparseMatrix) -> Result<Encoder, Error> {
60 let n = h.num_rows();
61 let m = h.num_cols();
62
63 let encoder = if staircase::is_staircase(h) {
64 let mut gen_matrix = SparseMatrix::new(n, m - n);
69 for (j, k) in h.iter_all() {
70 if k < m - n {
71 gen_matrix.insert(j, k);
72 }
73 }
74 EncoderType::Staircase { gen_matrix }
75 } else {
76 let mut a = Array2::zeros((n, m));
82 for (j, k) in h.iter_all() {
83 let t = if k < m - n { k + n } else { k - (m - n) };
84 a[[j, t]] = GF2::one();
85 }
86
87 match linalg::gauss_reduction(&mut a) {
88 Ok(()) => (),
89 Err(linalg::Error::NotInvertible) => return Err(Error::SubmatrixNotInvertible),
90 };
91
92 let gen_matrix = a.slice(s![.., n..]).to_owned();
93 EncoderType::DenseGenerator { gen_matrix }
94 };
95 Ok(Encoder { encoder })
96 }
97
98 pub fn encode<S>(&self, message: &ArrayBase<S, Ix1>) -> Array1<GF2>
100 where
101 S: Data<Elem = GF2>,
102 {
103 let parity = match &self.encoder {
104 EncoderType::DenseGenerator { gen_matrix } => gen_matrix.dot(message),
105 EncoderType::Staircase { gen_matrix } => {
106 let mut parity = Array1::from_iter(
108 (0..gen_matrix.num_rows())
109 .map(|j| gen_matrix.iter_row(j).map(|&k| message[k]).sum()),
110 );
111 for j in 1..parity.len() {
113 let previous = parity[j - 1];
114 parity[j] += previous;
115 }
116 parity
117 }
118 };
119 ndarray::concatenate(ndarray::Axis(0), &[message.view(), parity.view()]).unwrap()
120 }
121}
122
123#[cfg(test)]
124mod test {
125 use super::*;
126 use num_traits::Zero;
127
128 #[test]
129 fn encode() {
130 let alist = "12 4
1313 9
1323 3 3 3 3 3 3 3 3 3 3 3
1339 9 9 9
1341 2 3
1351 3 4
1362 3 4
1372 3 4
1381 2 4
1391 2 3
1401 3 4
1411 2 4
1421 2 3
1432 3 4
1441 2 4
1451 3 4
1461 2 5 6 7 8 9 11 12
1471 3 4 5 6 8 9 10 11
1481 2 3 4 6 7 9 10 12
1492 3 4 5 7 8 10 11 12
150";
151 let h = SparseMatrix::from_alist(alist).unwrap();
152 let encoder = Encoder::from_h(&h).unwrap();
153 let i = GF2::one();
154 let o = GF2::zero();
155
156 let message = [i, o, i, i, o, o, i, o];
157 let codeword = encoder.encode(&ndarray::arr1(&message));
158 let expected = [i, o, i, i, o, o, i, o, i, o, o, i];
159 assert_eq!(&codeword.as_slice().unwrap(), &expected);
160
161 let message = [o, i, o, o, i, i, i, o];
162 let codeword = encoder.encode(&ndarray::arr1(&message));
163 let expected = [o, i, o, o, i, i, i, o, i, o, i, o];
164 assert_eq!(&codeword.as_slice().unwrap(), &expected);
165 }
166
167 #[test]
168 fn encode_staircase() {
169 let alist = "5 3
1702 4
1712 2 2 2 1
1722 4 4
1731 3
1742 3
1751 2
1762 3
1773
1781 3
1792 3 4
1801 2 4 5
181";
182 let h = SparseMatrix::from_alist(alist).unwrap();
183 let encoder = Encoder::from_h(&h).unwrap();
184 assert!(matches!(encoder.encoder, EncoderType::Staircase { .. }));
185 let i = GF2::one();
186 let o = GF2::zero();
187
188 let message = [i, o];
189 let codeword = encoder.encode(&ndarray::arr1(&message));
190 let expected = [i, o, i, i, o];
191 assert_eq!(&codeword.as_slice().unwrap(), &expected);
192
193 let message = [o, i];
194 let codeword = encoder.encode(&ndarray::arr1(&message));
195 let expected = [o, i, o, i, o];
196 assert_eq!(&codeword.as_slice().unwrap(), &expected);
197 }
198}