#![allow(clippy::cast_possible_truncation, reason = "tests")]
#![allow(clippy::integer_division_remainder_used, reason = "tests")]
use array::typenum::U2;
use module_lattice::{Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector};
module_lattice::define_field!(KyberField, u16, u32, u64, 3329);
module_lattice::define_field!(DilithiumField, u32, u64, u128, 8_380_417);
#[test]
fn small_reduce() {
assert_eq!(KyberField::small_reduce(3328), 3328);
assert_eq!(KyberField::small_reduce(3329), 0);
assert_eq!(DilithiumField::small_reduce(8_380_416), 8_380_416);
assert_eq!(DilithiumField::small_reduce(8_380_417), 0);
}
#[test]
fn barrett_reduce() {
assert_eq!(KyberField::barrett_reduce(0), 0);
assert_eq!(KyberField::barrett_reduce(3329), 0);
assert_eq!(KyberField::barrett_reduce(3328), 3328);
assert_eq!(KyberField::barrett_reduce(6658), 0);
let product: u32 = 3000 * 3000; let reduced = KyberField::barrett_reduce(product);
assert!(reduced < 3329);
assert_eq!(reduced, (product % 3329) as u16);
assert_eq!(DilithiumField::barrett_reduce(0), 0);
assert_eq!(DilithiumField::barrett_reduce(8_380_417), 0);
}
#[test]
fn elem_negation() {
let a: Elem<KyberField> = Elem::new(100);
let neg_a = -a;
assert_eq!(neg_a.0, 3229);
assert_eq!((-neg_a).0, 100);
let zero: Elem<KyberField> = Elem::new(0);
assert_eq!((-zero).0, 0);
}
#[test]
fn elem_addition() {
let a: Elem<KyberField> = Elem::new(100);
let b: Elem<KyberField> = Elem::new(200);
let sum = a + b;
assert_eq!(sum.0, 300);
let c: Elem<KyberField> = Elem::new(3300);
let d: Elem<KyberField> = Elem::new(100);
let wrapped = c + d;
assert_eq!(wrapped.0, 71);
let zero: Elem<KyberField> = Elem::new(0);
assert_eq!((a + zero).0, 100);
}
#[test]
fn elem_subtraction() {
let a: Elem<KyberField> = Elem::new(300);
let b: Elem<KyberField> = Elem::new(100);
let diff = a - b;
assert_eq!(diff.0, 200);
let c: Elem<KyberField> = Elem::new(100);
let d: Elem<KyberField> = Elem::new(300);
let wrapped = c - d;
assert_eq!(wrapped.0, 3129);
let zero: Elem<KyberField> = Elem::new(0);
assert_eq!((a - zero).0, 300);
assert_eq!((a - a).0, 0);
}
#[test]
fn elem_multiplication() {
let a: Elem<KyberField> = Elem::new(100);
let b: Elem<KyberField> = Elem::new(200);
let prod = a * b;
assert_eq!(prod.0, (100 * 200) % 3329);
let one: Elem<KyberField> = Elem::new(1);
assert_eq!((a * one).0, 100);
let zero: Elem<KyberField> = Elem::new(0);
assert_eq!((a * zero).0, 0);
let c: Elem<KyberField> = Elem::new(3000);
let d: Elem<KyberField> = Elem::new(3000);
let large_prod = c * d;
assert_eq!(large_prod.0, ((3000u32 * 3000u32) % 3329) as u16);
}
#[test]
fn elem_arithmetic_consistency() {
let a: Elem<KyberField> = Elem::new(1234);
let b: Elem<KyberField> = Elem::new(5678 % 3329);
assert_eq!((a + b - b).0, a.0);
assert_eq!((a - b + b).0, a.0);
assert_eq!((a + (-a)).0, 0);
}
fn make_test_polynomial<F: Field>(base: F::Int) -> Polynomial<F>
where
F::Int: From<u8>,
{
let mut coeffs = [Elem::new(F::Int::from(0u8)); 256];
for (i, c) in coeffs.iter_mut().enumerate().take(10) {
*c = Elem::new(base + F::Int::from(i as u8));
}
Polynomial::new(coeffs.into())
}
#[test]
fn polynomial_addition() {
let p1 = make_test_polynomial::<KyberField>(100);
let p2 = make_test_polynomial::<KyberField>(200);
let sum = &p1 + &p2;
assert_eq!(sum.0[0].0, 300); assert_eq!(sum.0[1].0, 302); assert_eq!(sum.0[9].0, 318);
assert_eq!(sum.0[10].0, 0);
}
#[test]
fn polynomial_subtraction() {
let p1 = make_test_polynomial::<KyberField>(300);
let p2 = make_test_polynomial::<KyberField>(100);
let diff = &p1 - &p2;
assert_eq!(diff.0[0].0, 200); assert_eq!(diff.0[1].0, 200); }
#[test]
fn polynomial_negation() {
let p = make_test_polynomial::<KyberField>(100);
let neg_p = -&p;
assert_eq!(neg_p.0[0].0, 3229);
assert_eq!(neg_p.0[1].0, 3228);
let double_neg = -&neg_p;
assert_eq!(double_neg.0[0].0, p.0[0].0);
}
#[test]
fn polynomial_scalar_multiplication() {
let p = make_test_polynomial::<KyberField>(100);
let scalar: Elem<KyberField> = Elem::new(3);
let scaled = scalar * &p;
assert_eq!(scaled.0[0].0, 300); assert_eq!(scaled.0[1].0, 303); }
fn make_test_vector<F: Field>(base: F::Int) -> Vector<F, U2>
where
F::Int: From<u8>,
{
let p1 = make_test_polynomial::<F>(base);
let p2 = make_test_polynomial::<F>(base + F::Int::from(50u8));
Vector::new([p1, p2].into())
}
#[test]
fn vector_addition() {
let v1 = make_test_vector::<KyberField>(100);
let v2 = make_test_vector::<KyberField>(200);
let sum = &v1 + &v2;
assert_eq!(sum.0[0].0[0].0, 300);
assert_eq!(sum.0[1].0[0].0, 400);
}
#[test]
fn vector_addition_owned() {
let v1 = make_test_vector::<KyberField>(100);
let v2 = make_test_vector::<KyberField>(200);
let sum = v1 + v2;
assert_eq!(sum.0[0].0[0].0, 300);
assert_eq!(sum.0[1].0[0].0, 400);
}
#[test]
fn vector_subtraction() {
let v1 = make_test_vector::<KyberField>(300);
let v2 = make_test_vector::<KyberField>(100);
let diff = &v1 - &v2;
assert_eq!(diff.0[0].0[0].0, 200);
assert_eq!(diff.0[1].0[0].0, 200);
}
#[test]
fn vector_negation() {
let v = make_test_vector::<KyberField>(100);
let neg_v = -&v;
assert_eq!(neg_v.0[0].0[0].0, 3229);
}
#[test]
fn vector_scalar_multiplication() {
let v = make_test_vector::<KyberField>(100);
let scalar: Elem<KyberField> = Elem::new(2);
let scaled = scalar * &v;
assert_eq!(scaled.0[0].0[0].0, 200); assert_eq!(scaled.0[1].0[0].0, 300); }
fn make_test_ntt_polynomial<F: Field>(base: F::Int) -> NttPolynomial<F>
where
F::Int: From<u8>,
{
let mut coeffs = [Elem::new(F::Int::from(0u8)); 256];
for (i, c) in coeffs.iter_mut().enumerate().take(10) {
*c = Elem::new(base + F::Int::from(i as u8));
}
NttPolynomial::new(coeffs.into())
}
#[test]
fn ntt_polynomial_addition() {
let p1 = make_test_ntt_polynomial::<KyberField>(100);
let p2 = make_test_ntt_polynomial::<KyberField>(200);
let sum = &p1 + &p2;
assert_eq!(sum.0[0].0, 300);
assert_eq!(sum.0[1].0, 302);
}
#[test]
fn ntt_polynomial_subtraction() {
let p1 = make_test_ntt_polynomial::<KyberField>(300);
let p2 = make_test_ntt_polynomial::<KyberField>(100);
let diff = &p1 - &p2;
assert_eq!(diff.0[0].0, 200);
}
#[test]
fn ntt_polynomial_negation() {
let p = make_test_ntt_polynomial::<KyberField>(100);
let neg_p = -&p;
assert_eq!(neg_p.0[0].0, 3229); }
#[test]
fn ntt_polynomial_scalar_multiplication() {
let p = make_test_ntt_polynomial::<KyberField>(100);
let scalar: Elem<KyberField> = Elem::new(3);
let scaled = scalar * &p;
assert_eq!(scaled.0[0].0, 300);
}
#[test]
fn ntt_polynomial_from_array() {
use array::Array;
let coeffs: Array<Elem<KyberField>, array::typenum::U256> =
core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into();
let p: NttPolynomial<KyberField> = coeffs.into();
assert_eq!(p.0[0].0, 0);
assert_eq!(p.0[1].0, 1);
let arr: Array<Elem<KyberField>, array::typenum::U256> = p.into();
assert_eq!(arr[0].0, coeffs[0].0);
}
fn make_test_ntt_vector<F: Field>(base: F::Int) -> NttVector<F, U2>
where
F::Int: From<u8>,
{
let p1 = make_test_ntt_polynomial::<F>(base);
let p2 = make_test_ntt_polynomial::<F>(base + F::Int::from(50u8));
NttVector::new([p1, p2].into())
}
#[test]
fn ntt_vector_addition() {
let v1 = make_test_ntt_vector::<KyberField>(100);
let v2 = make_test_ntt_vector::<KyberField>(200);
let sum = &v1 + &v2;
assert_eq!(sum.0[0].0[0].0, 300);
assert_eq!(sum.0[1].0[0].0, 400);
}
#[test]
fn ntt_vector_subtraction() {
let v1 = make_test_ntt_vector::<KyberField>(300);
let v2 = make_test_ntt_vector::<KyberField>(100);
let diff = &v1 - &v2;
assert_eq!(diff.0[0].0[0].0, 200);
assert_eq!(diff.0[1].0[0].0, 200);
}
#[test]
fn elem_equality() {
let a: Elem<KyberField> = Elem::new(100);
let b: Elem<KyberField> = Elem::new(100);
let c: Elem<KyberField> = Elem::new(200);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn polynomial_equality() {
let p1 = make_test_polynomial::<KyberField>(100);
let p2 = make_test_polynomial::<KyberField>(100);
let p3 = make_test_polynomial::<KyberField>(200);
assert_eq!(p1, p2);
assert_ne!(p1, p3);
}
#[test]
fn vector_equality() {
let v1 = make_test_vector::<KyberField>(100);
let v2 = make_test_vector::<KyberField>(100);
let v3 = make_test_vector::<KyberField>(200);
assert_eq!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn ntt_polynomial_equality() {
let p1 = make_test_ntt_polynomial::<KyberField>(100);
let p2 = make_test_ntt_polynomial::<KyberField>(100);
let p3 = make_test_ntt_polynomial::<KyberField>(200);
assert_eq!(p1, p2);
assert_ne!(p1, p3);
}
#[test]
fn ntt_vector_equality() {
let v1 = make_test_ntt_vector::<KyberField>(100);
let v2 = make_test_ntt_vector::<KyberField>(100);
let v3 = make_test_ntt_vector::<KyberField>(200);
assert_eq!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn ntt_matrix_equality() {
let v1 = make_test_ntt_vector::<KyberField>(100);
let v2 = make_test_ntt_vector::<KyberField>(150);
let m1: NttMatrix<KyberField, U2, U2> = NttMatrix::new([v1.clone(), v2.clone()].into());
let m2: NttMatrix<KyberField, U2, U2> = NttMatrix::new([v1.clone(), v2.clone()].into());
let v3 = make_test_ntt_vector::<KyberField>(200);
let m3: NttMatrix<KyberField, U2, U2> = NttMatrix::new([v1, v3].into());
assert_eq!(m1, m2);
assert_ne!(m1, m3);
}
#[test]
fn ntt_polynomial_into_array() {
use array::Array;
use array::typenum::U256;
let p = make_test_ntt_polynomial::<KyberField>(100);
let arr: Array<Elem<KyberField>, U256> = p.clone().into();
assert_eq!(arr[0].0, 100);
assert_eq!(arr[1].0, 101);
assert_eq!(arr[9].0, 109);
assert_eq!(arr[10].0, 0);
for i in 0..256 {
assert_eq!(arr[i].0, p.0[i].0);
}
}
#[cfg(feature = "zeroize")]
mod zeroize_tests {
use super::*;
use zeroize::Zeroize;
#[test]
fn elem_zeroize() {
let mut a: Elem<KyberField> = Elem::new(1234);
assert_ne!(a.0, 0);
a.zeroize();
assert_eq!(a.0, 0);
}
#[test]
fn polynomial_zeroize() {
let mut p = make_test_polynomial::<KyberField>(100);
assert_ne!(p.0[0].0, 0);
p.zeroize();
for i in 0..256 {
assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i);
}
}
#[test]
fn vector_zeroize() {
let mut v = make_test_vector::<KyberField>(100);
assert_ne!(v.0[0].0[0].0, 0);
v.zeroize();
for i in 0..2 {
for j in 0..256 {
assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed");
}
}
}
#[test]
fn ntt_polynomial_zeroize() {
let mut p = make_test_ntt_polynomial::<KyberField>(100);
assert_ne!(p.0[0].0, 0);
p.zeroize();
for i in 0..256 {
assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i);
}
}
#[test]
fn ntt_vector_zeroize() {
let mut v = make_test_ntt_vector::<KyberField>(100);
assert_ne!(v.0[0].0[0].0, 0);
v.zeroize();
for i in 0..2 {
for j in 0..256 {
assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed");
}
}
}
}
#[cfg(feature = "ctutils")]
mod ctutils_tests {
use super::*;
use ctutils::CtEq;
#[test]
fn elem_ct_eq() {
let a: Elem<KyberField> = Elem::new(100);
let b: Elem<KyberField> = Elem::new(100);
let c: Elem<KyberField> = Elem::new(200);
assert!(bool::from(a.ct_eq(&b)));
assert!(!bool::from(a.ct_eq(&c)));
}
#[test]
fn ntt_polynomial_ct_eq() {
let p1 = make_test_ntt_polynomial::<KyberField>(100);
let p2 = make_test_ntt_polynomial::<KyberField>(100);
let p3 = make_test_ntt_polynomial::<KyberField>(200);
assert!(bool::from(p1.ct_eq(&p2)));
assert!(!bool::from(p1.ct_eq(&p3)));
}
}