use crate::oracle::{query_bits_range, LpnOracle, Sample};
use itertools::Itertools;
use m4ri_rust::friendly::BinMatrix;
use m4ri_rust::friendly::BinVector;
use rayon::prelude::*;
use std::mem;
use std::ops;
#[inline]
fn usize_to_binvec(c: usize, size: usize) -> BinVector {
let bytes = unsafe { mem::transmute::<usize, [u8; mem::size_of::<usize>()]>(c.to_be()) };
let skip = (64 - size) / 8;
let mut binvec = BinVector::from_bytes(&bytes[skip..]);
let result = BinVector::from(binvec.split_off(((8 - skip) * 8) - size));
debug_assert_eq!(result.len(), size);
result
}
pub fn lf1_solve(oracle: LpnOracle) -> BinVector {
let n_prime = oracle.samples.len();
assert!(n_prime > 0, "What, no samples?");
let b = oracle.samples[0].a.len();
assert!(b < 21, "Don't use too large b! b = {}", b);
assert!(b > 0, "Wtf, b = 0?");
let (a_matrix, c) = {
let mut c = BinVector::with_capacity(n_prime);
(
BinMatrix::new(
oracle
.samples
.into_iter()
.map(|q| {
c.push(q.c);
q.a
})
.collect(),
),
c,
)
};
let computation = |candidate: usize| {
let candidate_vector = usize_to_binvec(candidate, b);
let mut matrix_vector_product: BinVector = &a_matrix * &candidate_vector;
matrix_vector_product += &c;
let hw = matrix_vector_product.count_ones();
n_prime as i32 - 2 * (hw as i32)
};
println!("Doing LF1 naively");
let max = 2usize.pow(b as u32);
let best_candidate = (1..max)
.into_par_iter()
.max_by_key(|candidate| computation(*candidate))
.expect("Can't work on an empty list");
println!("Best candidate weight: {}", best_candidate.count_ones());
usize_to_binvec(best_candidate, b)
}
pub fn xor_reduce(oracle: &mut LpnOracle, b: u32) {
let k = oracle.k;
assert!(b <= k);
let num_samples = oracle.samples.len();
println!("Xor-reduce iteration, {} samples", num_samples);
let maxj = 2usize.pow(b);
let mut vector_partitions = Vec::with_capacity(maxj);
let query_capacity = ((num_samples / maxj) * (num_samples / maxj - 1)) / 2;
for _ in 0..maxj {
vector_partitions.push(Vec::with_capacity(query_capacity));
}
let bitrange: ops::Range<usize> = ((k - b) as usize)..(k as usize);
for mut q in oracle.samples.drain(..) {
let idx = query_bits_range(&(q.a), &bitrange) as usize;
if vector_partitions[idx].capacity() == 0 {
println!(
"Vector {} is full, will need to resize from {}",
idx,
vector_partitions[idx].len()
);
}
q.a.truncate((k - b) as usize);
vector_partitions[idx].push(q);
}
vector_partitions.par_iter_mut().for_each(|partition| {
*partition = partition
.iter()
.tuple_combinations()
.map(|(v1, v2)| Sample {
a: &v1.a + &v2.a,
c: v1.c ^ v2.c,
e: v1.e ^ v2.e,
})
.collect();
});
oracle
.samples
.reserve(vector_partitions.iter().fold(0, |acc, x| acc + x.len()));
for partition in vector_partitions {
oracle.samples.extend(partition.into_iter());
}
oracle.k = k - b;
oracle.secret.truncate(oracle.k as usize);
oracle.delta = oracle.delta.powi(2);
println!(
"Xor-reduce iteration done, {} samples now, k' = {}",
oracle.samples.len(),
oracle.k
);
}
pub fn fwht_solve(oracle: LpnOracle) -> BinVector {
println!("FWHT solving...");
debug_assert_eq!(oracle.samples[0].a.len() as u32, oracle.k);
let mut majority_counter = vec![0i64; 2usize.pow(oracle.k)];
oracle.samples.into_iter().for_each(|q| {
majority_counter[q.a.as_u64() as usize] += if q.c { -1 } else { 1 };
});
fwht(majority_counter.as_mut_slice(), oracle.k);
let guess = (0..2usize.pow(oracle.k as u32))
.max_by_key(|x| majority_counter[*x])
.unwrap();
let mut result = BinVector::with_capacity(oracle.k as usize);
for i in 0..oracle.k {
result.push(guess >> i & 1 == 1);
}
result
}
#[inline]
#[allow(clippy::many_single_char_names)]
fn fwht(data: &mut [i64], bits: u32) {
let n = bits;
for i in 0..n {
let mut j = 0;
while j < (1 << n) {
let mut k = 0;
while k < (1 << i) {
let a = j + k;
let b = j + k + (1 << i);
let tmp = data[a];
data[a] += data[b];
data[b] = tmp - data[b];
k += 1;
}
j += 1 << (i + 1);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transmute_usize_to_u8s() {
assert_eq!(
usize_to_binvec(2, 4),
BinVector::from_bools(&[false, false, true, false])
);
let a = 0x0000_0000_0000_0001usize;
let binvec = usize_to_binvec(a, 50);
for i in 0..49 {
assert_eq!(binvec.get(i), Some(false), "bit {} isn't 0", i);
}
assert_eq!(binvec.get(49), Some(true));
}
}