use crate::conversion::{bit_unpack, coeff_from_half_byte, coeff_from_three_bytes};
use crate::helpers::{bit_length, is_in_range};
use crate::types::{Ph, R, R0, T, T0};
use sha2::{Digest, Sha256, Sha512};
use sha3::digest::{ExtendableOutput, Update, XofReader};
use sha3::{Shake128, Shake256};
pub(crate) fn h256_xof(v: &[&[u8]]) -> impl XofReader {
let mut hasher = Shake256::default();
v.iter().for_each(|b| hasher.update(b));
hasher.finalize_xof()
}
pub(crate) fn g128_xof(v: &[&[u8]]) -> impl XofReader {
let mut hasher = Shake128::default();
v.iter().for_each(|b| hasher.update(b));
hasher.finalize_xof()
}
pub(crate) fn sample_in_ball<const CTEST: bool>(tau: i32, rho: &[u8]) -> R {
let tau = usize::try_from(tau).expect("Alg 29: try_from fail");
let mut c = R0;
let mut h_ctx = h256_xof(&[rho]);
let mut h = [0u8; 8];
h_ctx.read(&mut h);
for i in (256 - tau)..=255 {
let mut j = [i.to_le_bytes()[0]]; if !CTEST {
h_ctx.read(&mut j);
};
while usize::from(j[0]) > i {
h_ctx.read(&mut j);
}
c.0[i] = c.0[usize::from(j[0])];
let index = i + tau - 256;
let bite = h[index / 8];
let shifted = bite >> (index & 0x07);
c.0[usize::from(j[0])] = 1 - 2 * i32::from(shifted & 0x01);
}
debug_assert!(
c.0.iter().map(|&e| usize::from(e != 0)).sum::<usize>() == tau,
"Alg 29: bad hamming weight (a)"
);
debug_assert!(
c.0.iter().map(|&e| e & 1).sum::<i32>() == tau.try_into().expect("cannot fail"),
"Alg 29: bad hamming weight (b)"
);
c
}
pub(crate) fn rej_ntt_poly<const CTEST: bool>(rhos: &[&[u8]]) -> T {
debug_assert_eq!(rhos.iter().map(|&i| i.len()).sum::<usize>(), 272 / 8, "Alg 30: bad rho size");
let mut a_hat = T0;
let mut j = 0;
let mut xof = g128_xof(rhos);
while j < 256 {
let mut h5 = [0u8; 3];
xof.read(&mut h5); let a_hat_j = coeff_from_three_bytes::<CTEST>(h5);
if let Ok(res) = a_hat_j {
a_hat.0[j] = res;
j += 1;
}
}
a_hat
}
pub(crate) fn rej_bounded_poly<const CTEST: bool>(eta: i32, rhos: &[&[u8]]) -> R {
debug_assert_eq!(rhos.iter().map(|&i| i.len()).sum::<usize>(), 528 / 8, "Alg 31: bad rho size");
let mut z = [0u8];
let mut a = R0;
let mut j = 0;
let mut xof = h256_xof(rhos);
while j < 256 {
xof.read(&mut z);
let z0 = coeff_from_half_byte::<CTEST>(eta, z[0] & 0x0f);
let z1 = coeff_from_half_byte::<CTEST>(eta, z[0] >> 4);
if let Ok(z0) = z0 {
a.0[j] = z0;
j += 1;
}
if let Ok(z1) = z1 {
if j < 256 {
a.0[j] = z1;
j += 1;
}
}
}
a
}
#[allow(clippy::cast_possible_truncation)] pub(crate) fn expand_a<const CTEST: bool, const K: usize, const L: usize>(
rho: &[u8; 32],
) -> [[T; L]; K] {
let cap_a_hat: [[T; L]; K] = core::array::from_fn(|r| {
core::array::from_fn(|s| rej_ntt_poly::<CTEST>(&[&rho[..], &[s as u8], &[r as u8]]))
});
cap_a_hat
}
#[allow(clippy::cast_possible_truncation)] pub(crate) fn expand_s<const CTEST: bool, const K: usize, const L: usize>(
eta: i32, rho: &[u8; 64],
) -> ([R; L], [R; K]) {
let s1: [R; L] =
core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[r as u8], &[0]]));
let s2: [R; K] =
core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[(r + L) as u8], &[0]]));
debug_assert!(s1.iter().all(|r| is_in_range(r, eta, eta)), "Alg 33: s1 out of range");
debug_assert!(s2.iter().all(|r| is_in_range(r, eta, eta)), "Alg 33: s2 out of range");
(s1, s2)
}
pub(crate) fn expand_mask<const L: usize>(gamma1: i32, rho: &[u8; 64], mu: u16) -> [R; L] {
let mut y = [R0; L];
let mut v = [0u8; 32 * 20];
let c = 1 + bit_length(gamma1 - 1); debug_assert!((c == 18) || (c == 20), "Alg 34: illegal c");
for r in 0..u16::try_from(L).expect("Alg 34: try_from1 fail") {
let n = mu + r;
let mut xof = h256_xof(&[rho, &n.to_le_bytes()]);
xof.read(&mut v);
y[r as usize] =
bit_unpack(&v[0..32 * c], gamma1 - 1, gamma1).expect("Alg 34: try_from2 fail");
}
debug_assert!(
y.iter().all(|r| is_in_range(r, gamma1 - 1, gamma1)),
"Alg 34: s coeff out of range"
);
y
}
pub(crate) fn hash_message(message: &[u8], ph: &Ph, phm: &mut [u8; 64]) -> ([u8; 11], usize) {
match ph {
Ph::SHA256 => (
[
0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01,
],
{
let mut hasher = Sha256::new();
Digest::update(&mut hasher, message);
phm[0..32].copy_from_slice(&hasher.finalize());
32
},
),
Ph::SHA512 => (
[
0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03,
],
{
let mut hasher = Sha512::new();
Digest::update(&mut hasher, message);
phm.copy_from_slice(&hasher.finalize());
64
},
),
Ph::SHAKE128 => (
[
0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0B,
],
{
let mut hasher = Shake128::default();
hasher.update(message);
let mut reader = hasher.finalize_xof();
reader.read(&mut phm[0..32]);
32
},
),
}
}