use crate::types::Z;
use crate::Q;
use sha3::digest::XofReader;
pub(crate) fn sample_ntt(mut xof_reader: impl XofReader) -> [Z; 256] {
let mut array_a_hat = [Z::default(); 256];
let mut c = [0u8; 3];
let mut j = 0usize;
while j < 256 {
xof_reader.read(&mut c);
let d1 = u16::from(c[0]) + 256 * (u16::from(c[1]) & 0x0F);
let d2 = (u16::from(c[1]) >> 4) + 16 * u16::from(c[2]);
if d1 < Q {
array_a_hat[j].set_u16(d1);
j += 1;
}
if (d2 < Q) & (j < 256) {
array_a_hat[j].set_u16(d2);
j += 1;
}
}
array_a_hat
}
#[must_use]
pub(crate) fn sample_poly_cbd(byte_array_b: &[u8]) -> [Z; 256] {
let eta = u32::try_from(byte_array_b.len()).unwrap() >> 6;
debug_assert_eq!(byte_array_b.len(), 64 * eta as usize, "Alg 8: byte array not 64 * eta");
let mut array_f: [Z; 256] = [Z::default(); 256];
let mut temp = 0;
let mut int_index = 0;
let mut bit_index = 0;
for byte in byte_array_b {
temp |= u32::from(*byte) << bit_index;
bit_index += 8;
while bit_index >= 2 * (eta as usize) {
let tmask_x = temp & ((1 << eta) - 1);
let x = count_ones(tmask_x);
let tmask_y = (temp >> eta) & ((1 << eta) - 1);
let y = count_ones(tmask_y);
let (mut xx, mut yy) = (Z::default(), Z::default());
xx.set_u16(x);
yy.set_u16(y);
array_f[int_index] = xx.sub(yy);
bit_index -= 2 * (eta as usize);
temp >>= 2 * (eta as usize);
int_index += 1;
}
}
array_f
}
#[allow(clippy::cast_possible_truncation)] fn count_ones(x: u32) -> u16 {
let x = (x & 0x5555_5555) + ((x >> 1) & 0x5555_5555);
let x = (x & 0x3333_3333) + ((x >> 2) & 0x3333_3333);
let x = (x & 0x0F0F_0F0F) + ((x >> 4) & 0x0F0F_0F0F);
x as u16
}