use m4ri_rust::friendly::BinMatrix;
use m4ri_rust::friendly::BinVector;
use std::collections::HashSet;
use std::fmt;
use std::mem;
use crate::oracle::Sample;
pub(crate) static N: usize = 10000;
fn usize_to_binvec(c: usize, size: usize) -> BinVector {
let bytes = unsafe { mem::transmute::<usize, [u8; mem::size_of::<usize>()]>(c.to_be()) };
let skip = (64 - size) / 8;
let mut binvec = BinVector::from_bytes(&bytes[skip..]);
let result = BinVector::from(binvec.split_off(((8 - skip) * 8) - size));
debug_assert_eq!(result.len(), size);
result
}
pub trait BinaryCode {
fn name(&self) -> String;
fn length(&self) -> usize;
fn dimension(&self) -> usize;
fn generator_matrix(&self) -> &BinMatrix;
fn parity_check_matrix(&self) -> &BinMatrix;
fn decode_to_code(&self, c: &BinVector) -> Result<BinVector, &str> {
Ok(self.encode(&self.decode_to_message(c)?))
}
fn decode_to_message(&self, c: &BinVector) -> Result<BinVector, &str>;
fn encode(&self, c: &BinVector) -> BinVector {
debug_assert_eq!(
c.len(),
self.dimension(),
"Vector to encode should be of length {}",
self.dimension()
);
let result = c * self.generator_matrix();
debug_assert_eq!(
result.len(),
self.length(),
"wtf, product should be of length"
);
result
}
fn decode_sample(&self, c: &mut Sample) {
use crate::oracle::{NOISE_BIT_BLOCK, NOISE_BIT_MASK};
let slice = c.get_sample_mut();
if NOISE_BIT_BLOCK == self.length() / 64 {
let noise_bit = slice[NOISE_BIT_BLOCK] & NOISE_BIT_MASK;
slice[NOISE_BIT_BLOCK] &= !NOISE_BIT_MASK;
self.decode_slice(slice);
slice[NOISE_BIT_BLOCK] &= (1 << self.dimension()) - 1;
slice[NOISE_BIT_BLOCK] |= noise_bit;
} else {
self.decode_slice(&mut slice[..=self.length() / 64]);
c.truncate(self.dimension(), false)
}
}
fn decode_slice(&self, c: &mut [u64]) {
let mut v = BinVector::with_capacity(self.length());
let stor = unsafe { v.get_storage_mut() };
stor.extend(c.iter().copied().map(|b| b as usize));
let v = self.decode_to_message(&v).unwrap();
c.iter_mut()
.zip(v.get_storage().iter().copied())
.for_each(|(b, d)| {
*b = d as u64;
});
}
fn bias(&self, delta: f64) -> f64 {
let mut distances = Vec::with_capacity(N);
if 2f64.powi(self.length() as i32) > 1.5 * N as f64 {
let mut seen = HashSet::with_capacity(N);
while seen.len() < N {
let v = BinVector::random(self.length());
if seen.contains(&v) {
continue;
}
let decoded = self.decode_to_code(&v);
if let Ok(decoded) = decoded {
distances.push((&v + &decoded).count_ones() as i32);
seen.insert(v);
} else {
println!("Decoding something failed");
return 0.0;
}
}
} else {
for i in 0..2usize.pow(self.length() as u32) {
let v = usize_to_binvec(i, self.length());
let decoded = self.decode_to_code(&v);
if let Ok(decoded) = decoded {
distances.push((&v + &decoded).count_ones() as i32);
} else {
println!("Decoding something failed");
return 0.0;
}
}
}
let count = distances.len();
let sum = distances
.into_iter()
.fold(0f64, |acc, dist| acc + delta.powi(dist));
sum / (count as f64)
}
}
impl fmt::Debug for dyn BinaryCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[{}, {}] Binary Code", self.length(), self.dimension())
}
}
impl serde::Serialize for &dyn BinaryCode {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
ser.serialize_str(&self.name())
}
}
#[cfg(feature = "hamming")]
mod hamming;
#[cfg(feature = "hamming")]
pub use self::hamming::*;
#[cfg(feature = "golay")]
mod golay;
#[cfg(feature = "golay")]
pub use self::golay::*;
mod concatenated;
pub use self::concatenated::*;
#[cfg(feature = "stgen")]
mod stgen;
#[cfg(feature = "stgen")]
pub use self::stgen::*;
mod identity;
pub use self::identity::*;
mod repetition;
pub use self::repetition::*;
mod bogosrnd;
pub use self::bogosrnd::*;
#[cfg(feature = "mds")]
mod mds;
#[cfg(feature = "mds")]
pub use self::mds::*;
#[cfg(feature = "custom")]
mod custom;
#[cfg(feature = "custom")]
pub use self::custom::*;
mod wagner;
pub use self::wagner::*;
mod guava;
pub use self::guava::*;