use std::boxed::Box;
use std::default::Default;
use std::sync::Once;
use fnv::FnvHashMap;
use m4ri_rust::friendly::BinMatrix;
use m4ri_rust::friendly::BinVector;
use crate::codes::BinaryCode;
/// ``[{{n}}, {{k}}]`` {{ name }} code
///{% if comment %}
/// {{ comment }}{% endif %}
///
/// Decodes using Syndrome decoding
#[derive(Clone, Serialize)]
pub struct {{ name }}Code{{n}}_{{k}};
static INIT: Once = Once::new();
static mut GENERATOR_MATRIX: *const BinMatrix = 0 as *const BinMatrix;
static mut PARITY_MATRIX: *const BinMatrix = 0 as *const BinMatrix;
static mut PARITY_MATRIX_T: *const BinMatrix = 0 as *const BinMatrix;
static mut SYNDROME_MAP: *const FnvHashMap<u64, &'static [usize; {{ syndrome_map_itemlen }}]> = 0 as *const FnvHashMap<u64, &'static [usize; {{ syndrome_map_itemlen }}]>;
fn init() {
INIT.call_once(|| {
unsafe {
let matrix = Box::new(BinMatrix::from_slices(&[
{% for row in generator %}&[ {{ row|intlist }} ],
{% endfor %}
], {{ n }}));
GENERATOR_MATRIX = Box::into_raw(matrix);
let matrix = Box::new(BinMatrix::from_slices(&[
{% for row in parity_matrix %}&[ {{ row|intlist }} ],
{% endfor %}
], {{ n }}));
let matrix_t = Box::new(matrix.transposed());
PARITY_MATRIX = Box::into_raw(matrix);
PARITY_MATRIX_T = Box::into_raw(matrix_t);
let mut map = Box::new(FnvHashMap::with_capacity_and_hasher({{ syndrome_map|length }}, Default::default()));
{% for (he, e) in syndrome_map.items() %}map.insert({{ he }}, &[{{ e|intlist }}]); // {{ he }} => {{ e }}
{% endfor %}
SYNDROME_MAP = Box::into_raw(map);
}
});
}
impl {{ name }}Code{{n}}_{{k}} {
fn parity_check_matrix_transposed(&self) -> &BinMatrix {
init();
unsafe {
PARITY_MATRIX_T.as_ref().unwrap()
}
}
}
impl BinaryCode for {{ name }}Code{{n}}_{{k}} {
fn name(&self) -> String {
"[{{ n }}, {{ k }}] {{ name }} code".to_owned()
}
fn length(&self) -> usize {
{{ n }}
}
fn dimension(&self) -> usize {
{{ k }}
}
fn generator_matrix(&self) -> &BinMatrix {
init();
unsafe {
GENERATOR_MATRIX.as_ref().unwrap()
}
}
fn parity_check_matrix(&self) -> &BinMatrix {
init();
unsafe {
PARITY_MATRIX.as_ref().unwrap()
}
}
fn decode_to_code(&self, c: &BinVector) -> Result<BinVector, &str> {
init();
let map = unsafe {
SYNDROME_MAP.as_ref().unwrap()
};
debug_assert_eq!(c.len(), self.length(), "the length doesn't match the expected length (length of the code)");
let he = c * self.parity_check_matrix_transposed();
let mut error = BinVector::with_capacity({{ n }});
let stor = unsafe { error.get_storage_mut() };
let errbytes = map[&he.as_u64()];
debug_assert_eq!(errbytes.len(), {{ n }} / 64 + if {{ n }} % 64 != 0 { 1 } else { 0 });
stor.clear();
stor.extend_from_slice(&errbytes[..]);
unsafe { error.set_len({{ n }}) };
debug_assert_eq!(error.len(), self.length(), "internal: the error vector is of the wrong length");
let result = c + &error;
debug_assert_eq!(result.len(), self.length(), "internal: the result vector is of the wrong length");
debug_assert_eq!((&result * self.parity_check_matrix_transposed()).count_ones(), 0);
Ok(result)
}
fn decode_to_message(&self, c: &BinVector) -> Result<BinVector, &str> {
{% if info_set|max == k - 1 %}
let mut codeword = self.decode_to_code(c)?;
codeword.truncate({{ k }});
Ok(codeword)
{% else %}
let codeword = self.decode_to_code(c)?;
let mut new_codeword = BinVector::with_capacity({{k}});
{% for idx in info_set %}new_codeword.push(codeword[{{idx}}]);
{% endfor %}
Ok(new_codeword)
{% endif %}
}
fn decode_slice(&self, c: &mut [u64]) {
init();
debug_assert_eq!(c[{{ n }} / 64] & !((1 << {{ n % 64 }}) - 1), 0, "this message has excess bits");
let map = unsafe {
SYNDROME_MAP.as_ref().unwrap()
};
let he = &BinMatrix::from_slices(&[&c[..]], self.length()) * self.parity_check_matrix_transposed();
let error = map[unsafe { &he.get_word_unchecked(0, 0) }];
c.iter_mut().zip(error.iter().copied()).for_each(|(sample, error)| *sample ^= error as u64);
}
{% if name == "Hamming" or (name == "Golay" and n == 23) %}
/// We know how to give the bias directly for this code
fn bias(&self, delta: f64) -> f64 {
{% if name == "Hamming" %}
(1f64 + f64::from({{ n }}) * delta) / f64::from({{ n }} + 1)
{% elif name == "Golay" and n == 23%}
(1.0 + 23.0 * delta + 253.0 * delta.powi(2) + 1771.0 * delta.powi(3))/2f64.powi(11)
{% endif %}
}
{% endif %}
}
#[cfg(test)]
mod tests {
use super::*;
use m4ri_rust::friendly::BinVector;
use crate::oracle::Sample;
#[test]
fn size() {
let code = {{ name }}Code{{n}}_{{k}}.generator_matrix();
assert_eq!(code.ncols(), {{n}});
assert_eq!(code.nrows(), {{k}});
}
#[test]
fn test_decode_sample() {
let code = {{ name }}Code{{n}}_{{k}};
for _ in 0..1000 {
// setup
let vec = BinVector::random(code.length());
let mut sample_a = Sample::from_binvector(&vec, false);
let mut sample_b = Sample::from_binvector(&vec, true);
let decoded_vec = code.decode_to_message(&vec).unwrap();
println!("decoded_vec: {:?}", decoded_vec);
// test vectors
let decoded_vec_sample_a = Sample::from_binvector(&decoded_vec, false);
let decoded_vec_sample_b = Sample::from_binvector(&decoded_vec, true);
code.decode_sample(&mut sample_a);
code.decode_sample(&mut sample_b);
assert_eq!(sample_a.get_product(), false);
assert_eq!(sample_b.get_product(), true);
assert_eq!(sample_a, decoded_vec_sample_a);
assert_eq!(sample_b, decoded_vec_sample_b);
}
}
#[test]
fn random_decode_tests() {
{% for testcase in testcases %}
{
let code = {{ name}}Code{{n}}_{{ k }};
let randvec = BinVector::from_bools(&[{{ testcase.randvec|boollist }}]);
let codeword = BinVector::from_bools(&[{{ testcase.codeword|boollist }}]);
assert_eq!(code.decode_to_code(&randvec), Ok(codeword));
}
{% endfor %}
}
#[test]
fn test_generator_representation() {
init();
let generator_matrix = unsafe { GENERATOR_MATRIX.as_ref().unwrap() };
let first_row = generator_matrix.get_window(0, 0, 1, generator_matrix.ncols());
let vector = BinVector::from_bools(&[ {{ generator_bools[0]|boollist }} ]);
assert_eq!(vector, first_row.as_vector());
}
}