use std::cmp::min;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Equation<const W: usize> {
pub s: usize, pub a: [u64; W], pub b: u8, }
impl<const W: usize> Equation<W> {
pub fn homogeneous(s: usize, a: [u64; W]) -> Equation<W> {
Equation::inhomogeneous(s, a, 0)
}
pub fn inhomogeneous(s: usize, a: [u64; W], b: u8) -> Equation<W> {
let mut eq = Equation { s: 0, a, b };
eq.add(&Equation::zero());
eq.s += s;
eq
}
pub fn zero() -> Self {
Equation {
s: 0,
a: [0u64; W],
b: 0,
}
}
pub fn is_zero(&self) -> bool {
self.a == [0u64; W]
}
pub fn add(&mut self, other: &Equation<W>) {
assert!(self.s == other.s);
for i in 0..W {
self.a[i] ^= other.a[i];
}
self.b ^= other.b;
if self.is_zero() {
return;
}
while self.a[0] == 0 {
self.a.rotate_left(1);
}
let k = self.a[0].trailing_zeros();
if k == 0 {
return;
}
for i in 0..W - 1 {
self.a[i] >>= k;
self.a[i] |= self.a[i + 1] << (64 - k);
}
self.a[W - 1] >>= k;
self.s += k as usize;
}
pub fn eval(&self, z: &[u64]) -> u8 {
let limb = self.s / 64;
let shift = self.s % 64;
let mut r = 0;
for i in limb..min(z.len(), limb + W) {
let mut tmp = z[i] >> shift;
if i + 1 < z.len() && shift != 0 {
tmp |= z[i + 1] << (64 - shift);
}
r ^= tmp & self.a[i - limb];
}
(r.count_ones() & 1) as u8
}
}
#[cfg(test)]
mod tests {
use crate::Equation;
#[test]
fn test_equation_add() {
let mut e1 = Equation {
s: 127,
a: [0b11],
b: 1,
};
let e2 = Equation {
s: 127,
a: [0b01],
b: 1,
};
e1.add(&e2);
assert!(e1.s == 128);
assert!(e1.a[0] == 0b1);
assert!(e1.b == 0);
let mut e1 = Equation {
s: 127,
a: [0b11, 0b1110, 0b1, 0],
b: 1,
};
let e2 = Equation {
s: 127,
a: [0b01, 0b0100, 0b0, 0],
b: 1,
};
e1.add(&e2);
assert!(e1.s == 128);
assert!(e1.a[0] == 0b1);
assert!(e1.a[1] == (1 << 63) | 0b101);
assert!(e1.a[2] == 0);
assert!(e1.a[3] == 0);
assert!(e1.b == 0);
}
#[test]
fn test_equation_eval() {
for s in 0..64 {
let eq = Equation {
s,
a: [0xffffffffffffffff, 0, 0, 0],
b: 0,
};
assert!(0 == eq.eval(&[]));
for i in 0..64 {
assert!(((i >= eq.s) as u8) == eq.eval(&[1 << i, 0]));
assert!(((i < eq.s) as u8) == eq.eval(&[0, 1 << i]));
assert!(0 == eq.eval(&[0, 0, 1 << i]));
}
}
}
}