use crate::{gf2::GF2, linalg, sparse::SparseMatrix};
use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, s};
use num_traits::One;
use thiserror::Error;
mod staircase;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Error)]
pub enum Error {
#[error("the square matrix formed by the last columns of the parity check is not invertible")]
SubmatrixNotInvertible,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Encoder {
encoder: EncoderType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum EncoderType {
DenseGenerator { gen_matrix: Array2<GF2> },
Staircase { gen_matrix: SparseMatrix },
}
impl Encoder {
pub fn from_h(h: &SparseMatrix) -> Result<Encoder, Error> {
let n = h.num_rows();
let m = h.num_cols();
let encoder = if staircase::is_staircase(h) {
let mut gen_matrix = SparseMatrix::new(n, m - n);
for (j, k) in h.iter_all() {
if k < m - n {
gen_matrix.insert(j, k);
}
}
EncoderType::Staircase { gen_matrix }
} else {
let mut a = Array2::zeros((n, m));
for (j, k) in h.iter_all() {
let t = if k < m - n { k + n } else { k - (m - n) };
a[[j, t]] = GF2::one();
}
match linalg::gauss_reduction(&mut a) {
Ok(()) => (),
Err(linalg::Error::NotInvertible) => return Err(Error::SubmatrixNotInvertible),
};
let gen_matrix = a.slice(s![.., n..]).to_owned();
EncoderType::DenseGenerator { gen_matrix }
};
Ok(Encoder { encoder })
}
pub fn encode<S>(&self, message: &ArrayBase<S, Ix1>) -> Array1<GF2>
where
S: Data<Elem = GF2>,
{
let parity = match &self.encoder {
EncoderType::DenseGenerator { gen_matrix } => gen_matrix.dot(message),
EncoderType::Staircase { gen_matrix } => {
let mut parity = Array1::from_iter(
(0..gen_matrix.num_rows())
.map(|j| gen_matrix.iter_row(j).map(|&k| message[k]).sum()),
);
for j in 1..parity.len() {
let previous = parity[j - 1];
parity[j] += previous;
}
parity
}
};
ndarray::concatenate(ndarray::Axis(0), &[message.view(), parity.view()]).unwrap()
}
}
#[cfg(test)]
mod test {
use super::*;
use num_traits::Zero;
#[test]
fn encode() {
let alist = "12 4
3 9
3 3 3 3 3 3 3 3 3 3 3 3
9 9 9 9
1 2 3
1 3 4
2 3 4
2 3 4
1 2 4
1 2 3
1 3 4
1 2 4
1 2 3
2 3 4
1 2 4
1 3 4
1 2 5 6 7 8 9 11 12
1 3 4 5 6 8 9 10 11
1 2 3 4 6 7 9 10 12
2 3 4 5 7 8 10 11 12
";
let h = SparseMatrix::from_alist(alist).unwrap();
let encoder = Encoder::from_h(&h).unwrap();
let i = GF2::one();
let o = GF2::zero();
let message = [i, o, i, i, o, o, i, o];
let codeword = encoder.encode(&ndarray::arr1(&message));
let expected = [i, o, i, i, o, o, i, o, i, o, o, i];
assert_eq!(&codeword.as_slice().unwrap(), &expected);
let message = [o, i, o, o, i, i, i, o];
let codeword = encoder.encode(&ndarray::arr1(&message));
let expected = [o, i, o, o, i, i, i, o, i, o, i, o];
assert_eq!(&codeword.as_slice().unwrap(), &expected);
}
#[test]
fn encode_staircase() {
let alist = "5 3
2 4
2 2 2 2 1
2 4 4
1 3
2 3
1 2
2 3
3
1 3
2 3 4
1 2 4 5
";
let h = SparseMatrix::from_alist(alist).unwrap();
let encoder = Encoder::from_h(&h).unwrap();
assert!(matches!(encoder.encoder, EncoderType::Staircase { .. }));
let i = GF2::one();
let o = GF2::zero();
let message = [i, o];
let codeword = encoder.encode(&ndarray::arr1(&message));
let expected = [i, o, i, i, o];
assert_eq!(&codeword.as_slice().unwrap(), &expected);
let message = [o, i];
let codeword = encoder.encode(&ndarray::arr1(&message));
let expected = [o, i, o, i, o];
assert_eq!(&codeword.as_slice().unwrap(), &expected);
}
}