use super::fe::*;
use super::re::*;
use super::serialization::bytes_to_bits;
use crate::errors::UnknownCryptoError;
use crate::hazardous::hash::sha3::{shake128, shake256};
pub fn sample_ntt(seed: &[u8; 32], ij: &[u8; 2]) -> Result<RingElementNTT, UnknownCryptoError> {
let mut xof = shake128::Shake128::new();
xof.absorb(seed)?;
xof.absorb(ij)?;
let mut a_hat = RingElementNTT::zero();
let mut j = 0;
while j < 256 {
let mut c = [0u8; 3];
xof.squeeze(&mut c)?;
let d1: i16 = (c[0] as i16) + 256 * ((c[1] as i16) & 15);
debug_assert!(d1 >= 0 || d1 < 2i16.pow(12));
let d2: i16 = ((c[1] as i16) >> 4u16) + 16i16 * (c[2] as i16);
debug_assert!(d2 >= 0 || d2 < 2i16.pow(12));
if d1 < KYBER_Q as i16 {
a_hat[j] = FieldElement(d1 as u32);
j += 1;
}
if d2 < KYBER_Q as i16 && j < 256 {
a_hat[j] = FieldElement(d2 as u32);
j += 1;
}
}
Ok(a_hat)
}
pub fn sample_poly_cbd(
seed: &[u8],
b: u8,
prf_out: &mut [u8],
bits: &mut [u8],
eta: usize,
) -> Result<RingElement, UnknownCryptoError> {
debug_assert_eq!(seed.len(), 32);
debug_assert!(eta == 2 || eta == 3);
let mut prf = shake256::Shake256::new();
prf.absorb(seed)?;
prf.absorb(&[b])?;
prf.squeeze(prf_out)?;
bytes_to_bits(prf_out, bits);
let mut f: RingElement = RingElement::zero();
for i in 0..256 {
let mut x: u8 = 0;
let mut y: u8 = 0;
for j in 0..eta {
x += bits[(2 * i * eta) + j];
y += bits[(2 * i * eta) + eta + j];
}
debug_assert!(x <= eta as u8);
debug_assert!(y <= eta as u8);
f[i] = FieldElement::new(x as u32) - FieldElement::new(y as u32);
debug_assert!(
(f[i].0 <= eta as u32) || (KYBER_Q - (eta as u32) <= f[i].0 && f[i].0 < KYBER_Q)
);
}
Ok(f)
}