#![allow(clippy::explicit_iter_loop)]
use hybrid_array::Array;
use hybrid_array::typenum::Unsigned;
use rand_core::Rng;
use subtle::{
Choice,
ConditionallySelectable,
ConstantTimeEq,
};
use crate::algebra::{
FieldElement,
NttPolynomial,
NttVector,
Polynomial,
PolynomialVector,
};
use crate::crypto::PrfOutput;
use crate::encode::Encode;
pub(crate) use crate::hardened_rng::OsRngFill;
use crate::param::{
ArraySize,
CbdSamplingSize,
EncodedCiphertext,
KemParams,
};
use crate::util::B32;
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn cbd_table_lookup_ct<Eta: CbdSamplingSize>(v: u8) -> FieldElement {
let mut acc = FieldElement(0);
let max = Eta::OnesSize::USIZE as u8;
for i in 0..max {
let pick = v.ct_eq(&i);
let cand = Eta::ONES[usize::from(i)].0;
acc.0 = u16::conditional_select(&acc.0, &cand, pick);
}
acc
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn sample_poly_cbd_ct<Eta: CbdSamplingSize>(b: &PrfOutput<Eta>) -> Polynomial {
let vals: Polynomial = Encode::<Eta::SampleSize>::decode(b);
Polynomial(
vals.0
.iter()
.map(|val| cbd_table_lookup_ct::<Eta>(val.0 as u8))
.collect(),
)
}
pub(crate) fn ntt_vector_dot_masked<K: ArraySize, R: Rng>(
s_hat: &NttVector<K>,
u_hat: &NttVector<K>,
rng: &mut R,
) -> NttPolynomial {
let k = K::USIZE;
debug_assert_eq!(k, s_hat.0.len());
debug_assert_eq!(k, u_hat.0.len());
debug_assert!(k <= 4);
let rho = FieldElement::random_nonzero(rng);
let rho_inv = rho
.inv()
.expect("nonzero rho must have an inverse mod FIELD_MODULUS");
let mut s_masked = NttVector::<K>(Array::default());
let mut u_masked = NttVector::<K>(Array::default());
for i in 0..k {
s_masked.0[i] = scale_ntt_polynomial(&s_hat.0[i], rho);
u_masked.0[i] = scale_ntt_polynomial(&u_hat.0[i], rho_inv);
}
let mut s0 = NttVector::<K>(Array::default());
let mut s1 = NttVector::<K>(Array::default());
for i in 0..k {
split_ntt_polynomial_share(&s_masked.0[i], rng, &mut s0.0[i], &mut s1.0[i]);
}
let mut perm = [0usize; 4];
crate::hardened_rng::shuffle_indices(rng, k, &mut perm[..k]);
let mut acc = NttPolynomial::default();
for &idx in perm[..k].iter() {
let t0 = &s0.0[idx] * &u_masked.0[idx];
let t1 = &s1.0[idx] * &u_masked.0[idx];
acc = &(&acc + &t0) + &t1;
}
acc
}
pub(crate) fn ntt_matrix_vector_masked<K: ArraySize, R: Rng>(
matrix: &crate::algebra::NttMatrix<K>,
r_hat: &NttVector<K>,
rng: &mut R,
) -> NttVector<K> {
let mut out = NttVector::<K>(Array::default());
for i in 0..K::USIZE {
out.0[i] = ntt_vector_dot_masked(matrix.row(i), r_hat, rng);
}
out
}
fn scale_ntt_polynomial(p: &NttPolynomial, s: FieldElement) -> NttPolynomial {
NttPolynomial(p.0.iter().map(|&x| x * s).collect())
}
fn split_ntt_polynomial_share<R: Rng>(
s: &NttPolynomial,
rng: &mut R,
s0: &mut NttPolynomial,
s1: &mut NttPolynomial,
) {
for i in 0..256 {
let mask = random_fe(rng);
s1.0[i] = mask;
s0.0[i] = s.0[i] - mask;
}
}
fn random_fe<R: Rng>(rng: &mut R) -> FieldElement {
let mut buf = [0u8; 2];
rng.fill_bytes(&mut buf);
FieldElement::from_u16_reduced(u16::from_le_bytes(buf))
}
pub(crate) fn ciphertexts_equal_ct(a: &[u8], b: &[u8]) -> Choice {
debug_assert_eq!(a.len(), b.len());
let mut eq = Choice::from(1u8);
for (&x, &y) in a.iter().zip(b.iter()) {
eq &= x.ct_eq(&y);
}
eq
}
fn polynomial_ct_eq(a: &Polynomial, b: &Polynomial) -> Choice {
let mut eq = Choice::from(1u8);
for i in 0..256 {
eq &= a.0[i].0.ct_eq(&b.0[i].0);
}
eq
}
fn polynomial_vector_ct_eq<K: ArraySize>(
a: &PolynomialVector<K>,
b: &PolynomialVector<K>,
) -> Choice {
let mut eq = Choice::from(1u8);
for i in 0..K::USIZE {
eq &= polynomial_ct_eq(&a.0[i], &b.0[i]);
}
eq
}
pub(crate) fn ciphertexts_equal_arithmetic_domain_ct<P: KemParams>(
a: &EncodedCiphertext<P>,
b: &EncodedCiphertext<P>,
) -> Choice {
let (a_u, a_v) = P::split_ct(a);
let (b_u, b_v) = P::split_ct(b);
let ua: PolynomialVector<P::K> = Encode::<P::Du>::decode(a_u);
let ub: PolynomialVector<P::K> = Encode::<P::Du>::decode(b_u);
let va: Polynomial = Encode::<P::Dv>::decode(a_v);
let vb: Polynomial = Encode::<P::Dv>::decode(b_v);
polynomial_vector_ct_eq(&ua, &ub) & polynomial_ct_eq(&va, &vb)
}
pub(crate) fn select_shared_key_bytes_ct(eq: Choice, lhs: &B32, rhs: &B32) -> B32 {
let mut out = B32::default();
for i in 0..32 {
out[i] = u8::conditional_select(&rhs[i], &lhs[i], eq);
}
out
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation)]
mod tests {
use rand_core::{
Infallible,
TryRng,
};
use super::*;
use crate::MlKem768Params;
use crate::algebra::PolynomialVector;
use crate::param::ParameterSet;
use crate::pke::DecryptionKey;
struct CountingRng(u64);
impl TryRng for CountingRng {
type Error = Infallible;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
self.0 = self.0.wrapping_add(1);
Ok(self.0 as u32)
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(u64::from(self.try_next_u32()?))
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
for b in dest.iter_mut() {
*b = self.try_next_u32()? as u8;
}
Ok(())
}
}
#[test]
fn ntt_dot_masked_matches_plain() {
let d = B32::default();
let (dk, _ek) = DecryptionKey::<MlKem768Params>::generate(&d);
let mut rng = CountingRng(0xDEAD_BEEF);
let u = PolynomialVector::<<MlKem768Params as ParameterSet>::K>::sample_cbd::<
<MlKem768Params as ParameterSet>::Eta1,
>(&d, 0);
let u_hat = u.ntt();
let plain = dk.test_s_hat() * &u_hat;
let masked = ntt_vector_dot_masked(dk.test_s_hat(), &u_hat, &mut rng);
assert_eq!(plain, masked);
}
}