#![allow(clippy::mutex_atomic)]
use crate::codes::BinaryCode;
use m4ri_rust::friendly::BinMatrix;
use m4ri_rust::friendly::BinVector;
use std::cell::UnsafeCell;
use std::ptr;
use std::sync::Mutex;
#[derive(Serialize)]
pub struct ConcatenatedCode<'a> {
codes: Vec<&'a dyn BinaryCode>,
#[serde(skip, default = "default_mutex")]
init: Mutex<bool>,
#[serde(skip, default = "default_generator")]
generator: UnsafeCell<*mut BinMatrix>,
}
#[allow(dead_code)]
fn default_mutex() -> Mutex<bool> {
Mutex::new(false)
}
#[allow(dead_code)]
fn default_generator() -> UnsafeCell<*mut BinMatrix> {
UnsafeCell::new(ptr::null_mut())
}
unsafe impl<'a> Sync for ConcatenatedCode<'a> {}
impl<'codes> Clone for ConcatenatedCode<'codes> {
fn clone(&self) -> Self {
ConcatenatedCode {
codes: self.codes.clone(),
init: Mutex::new(false),
generator: UnsafeCell::new(ptr::null_mut()),
}
}
}
impl<'codes> ConcatenatedCode<'codes> {
pub fn new(codes: Vec<&'codes dyn BinaryCode>) -> ConcatenatedCode<'codes> {
ConcatenatedCode {
codes,
init: Mutex::new(false),
generator: UnsafeCell::new(ptr::null_mut()),
}
}
}
impl<'codes> BinaryCode for ConcatenatedCode<'codes> {
fn name(&self) -> String {
let names = self.codes.iter().fold(
String::with_capacity(self.codes.iter().fold(0, |acc, c| acc + 2 + c.name().len())),
|mut s, code| {
s.push_str(&code.name());
s.push_str(", ");
s
},
);
format!(
"[{}, {}] Concatenated code codes=[{}]",
self.length(),
self.dimension(),
names,
)
}
fn length(&self) -> usize {
self.codes.iter().fold(0usize, |a, c| a + c.length())
}
fn dimension(&self) -> usize {
self.codes.iter().fold(0usize, |a, c| a + c.dimension())
}
fn generator_matrix(&self) -> &BinMatrix {
debug_assert_ne!(
self.codes.len(),
0,
"We need at least one code for this to work"
);
{
let mut initialized = self.init.lock().unwrap();
if !*initialized {
let mut gen = Box::new(self.codes[0].generator_matrix().clone());
for code in self.codes.iter().skip(1) {
let corner = (gen.nrows(), gen.ncols());
gen = Box::new(gen.augmented(&BinMatrix::zero(gen.nrows(), code.length())));
gen = Box::new(gen.stacked(&BinMatrix::zero(code.dimension(), gen.ncols())));
gen.set_window(corner.0, corner.1, code.generator_matrix());
}
unsafe {
*(self.generator.get()) = Box::into_raw(gen);
}
*initialized = true;
};
}
unsafe { (*(self.generator.get())).as_ref().unwrap() }
}
fn parity_check_matrix(&self) -> &BinMatrix {
panic!("Not yet implemented");
}
fn encode(&self, c: &BinVector) -> BinVector {
let mut encoded = BinVector::with_capacity(self.dimension());
let mut slice = c.clone();
for code in self.codes.iter() {
let next_encode = BinVector::from(slice.split_off(code.dimension()));
debug_assert_eq!(slice.len(), code.dimension(), "split dimension incorrect");
encoded.extend_from_binvec(&code.encode(&slice));
slice = next_encode;
}
encoded
}
fn decode_to_message(&self, c: &BinVector) -> Result<BinVector, &str> {
let mut decoded = BinVector::with_capacity(self.dimension());
let mut slice = c.clone();
for code in self.codes.iter() {
let next_decode = BinVector::from(slice.split_off(code.length()));
debug_assert_eq!(slice.len(), code.length(), "Split length incorrect");
let decoded_slice = code.decode_to_message(&slice)?;
decoded.extend_from_binvec(&decoded_slice);
slice = next_decode;
}
Ok(decoded)
}
fn bias(&self, delta: f64) -> f64 {
self.codes
.iter()
.fold(1f64, |acc, code| acc * code.bias(delta))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codes::hamming::*;
use m4ri_rust::friendly::BinVector;
fn get_code() -> ConcatenatedCode<'static> {
let codes: Vec<&dyn BinaryCode> = vec![&HammingCode7_4, &HammingCode3_1];
ConcatenatedCode::new(codes)
}
#[test]
fn test_concatenated_code() {
let code = get_code();
assert_eq!(code.length(), 7 + 3, "Length wrong");
assert_eq!(code.dimension(), 4 + 1, "Dimension wrong");
let mut input = BinVector::from_bytes(&[0b1010_1000]);
input.truncate(5);
assert_eq!(
input,
BinVector::from_bools(&[true, false, true, false, true])
);
let mut encoded = BinVector::from_bytes(&[0b1010_1011, 0b1100_0000]);
encoded.truncate(10);
assert_eq!(
*encoded,
*code.encode(&input),
"encode(input) doesn't match encoded"
);
assert_eq!(
input,
code.decode_to_message(&code.encode(&input)).unwrap(),
"not idempotent"
);
}
#[test]
fn test_random_samples() {
let code = ConcatenatedCode::new(vec![&HammingCode15_11, &HammingCode7_4, &HammingCode3_1]);
for _ in 0..100 {
let v = BinVector::random(code.length());
let x = code.decode_to_message(&v).unwrap();
let cw = code.decode_to_code(&v).unwrap();
assert_eq!(&code.encode(&x), &cw);
assert!((v + cw).count_ones() < 5);
}
}
}