use crate::ntt::multiply_ntts;
use crate::types::Z;
use crate::Q;
use sha3::digest::{ExtendableOutput, Update, XofReader};
use sha3::{Digest, Sha3_256, Sha3_512, Shake128, Shake256};
macro_rules! ensure {
($cond:expr, $msg:literal $(,)?) => {
if !$cond {
return Err($msg);
}
};
}
pub(crate) use ensure;
#[must_use]
pub(crate) fn add_vecs<const K: usize>(
vec_a: &[[Z; 256]; K], vec_b: &[[Z; 256]; K],
) -> [[Z; 256]; K] {
core::array::from_fn(|k| core::array::from_fn(|n| vec_a[k][n].add(vec_b[k][n])))
}
#[must_use]
pub(crate) fn mul_mat_vec<const K: usize>(
a_hat: &[[[Z; 256]; K]; K], u_hat: &[[Z; 256]; K],
) -> [[Z; 256]; K] {
let mut w_hat = [[Z::default(); 256]; K];
for i in 0..K {
#[allow(clippy::needless_range_loop)] for j in 0..K {
let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]);
w_hat[i] = add_vecs(&[w_hat[i]], &[tmp])[0];
}
}
w_hat
}
#[must_use]
pub(crate) fn mul_mat_t_vec<const K: usize>(
a_hat: &[[[Z; 256]; K]; K], u_hat: &[[Z; 256]; K],
) -> [[Z; 256]; K] {
let mut y_hat = [[Z::default(); 256]; K];
#[allow(clippy::needless_range_loop)] for i in 0..K {
#[allow(clippy::needless_range_loop)] for j in 0..K {
let tmp = multiply_ntts(&a_hat[j][i], &u_hat[j]); y_hat[i] = add_vecs(&[y_hat[i]], &[tmp])[0];
}
}
y_hat
}
#[must_use]
pub(crate) fn dot_t_prod<const K: usize>(u_hat: &[[Z; 256]; K], v_hat: &[[Z; 256]; K]) -> [Z; 256] {
let mut result = [Z::default(); 256];
for j in 0..K {
let tmp = multiply_ntts(&u_hat[j], &v_hat[j]);
result = add_vecs(&[result], &[tmp])[0];
}
result
}
#[must_use]
pub(crate) fn prf<const ETA_64: usize>(s: &[u8; 32], b: u8) -> [u8; ETA_64] {
let mut hasher = Shake256::default();
hasher.update(s);
hasher.update(&[b]);
let mut reader = hasher.finalize_xof();
let mut result = [0u8; ETA_64];
reader.read(&mut result);
result
}
#[must_use]
pub(crate) fn xof(rho: &[u8; 32], i: u8, j: u8) -> impl XofReader {
let mut hasher = Shake128::default();
hasher.update(rho);
hasher.update(&[i]);
hasher.update(&[j]);
hasher.finalize_xof()
}
pub(crate) fn g(bytes: &[&[u8]]) -> ([u8; 32], [u8; 32]) {
let mut hasher = Sha3_512::new();
bytes.iter().for_each(|b| Digest::update(&mut hasher, b));
let digest = hasher.finalize();
let a = digest[0..32].try_into().expect("g_a fail");
let b = digest[32..64].try_into().expect("g_b fail");
(a, b)
}
#[must_use]
pub(crate) fn h(bytes: &[u8]) -> [u8; 32] {
let mut hasher = Sha3_256::new();
Digest::update(&mut hasher, bytes);
let digest = hasher.finalize();
digest.into()
}
#[must_use]
pub(crate) fn j(z: &[u8; 32], ct: &[u8]) -> [u8; 32] {
let mut hasher = Shake256::default();
hasher.update(z);
hasher.update(ct);
let mut reader = hasher.finalize_xof();
let mut result = [0u8; 32];
reader.read(&mut result);
result
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn compress_vector(d: u32, inout: &mut [Z]) {
const M: u32 = (((1u64 << 36) + Q as u64 - 1) / Q as u64) as u32;
for x_ref in &mut *inout {
let y = (x_ref.get_u32() << d) + (u32::from(Q) >> 1);
let result = (u64::from(y) * u64::from(M)) >> 36;
x_ref.set_u16(result as u16);
}
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn decompress_vector(d: u32, inout: &mut [Z]) {
for y_ref in &mut *inout {
let qy = u32::from(Q) * y_ref.get_u32() + (1 << d) - 1;
y_ref.set_u16((qy >> d) as u16);
}
}