#![allow(clippy::result_unit_err)]
use alloc::vec::Vec;
pub use super::Params;
pub use super::field::Poly;
pub const N: usize = super::field::N;
pub const Q: u32 = super::field::Q;
pub const D: u32 = super::field::D;
#[derive(Clone, Copy)]
pub struct MlDsaLevel {
pub params: Params,
pub k: usize,
pub l: usize,
}
pub const ML_DSA_44: MlDsaLevel = MlDsaLevel {
params: super::P44,
k: 4,
l: 4,
};
pub const ML_DSA_65: MlDsaLevel = MlDsaLevel {
params: super::P65,
k: 6,
l: 5,
};
pub const ML_DSA_87: MlDsaLevel = MlDsaLevel {
params: super::P87,
k: 8,
l: 7,
};
pub fn reduce_once(a: u32) -> u32 {
super::field::reduce_once(a)
}
pub fn add(a: u32, b: u32) -> u32 {
super::field::add(a, b)
}
pub fn sub(a: u32, b: u32) -> u32 {
super::field::sub(a, b)
}
pub fn mul(a: u32, b: u32) -> u32 {
super::field::mul(a, b)
}
pub fn ntt_mul(a: &Poly, b: &Poly) -> Poly {
super::field::ntt_mul(a, b)
}
pub fn zeta(i: usize) -> u32 {
super::field::zeta(i)
}
pub const GAMMA2_32: u32 = super::reduce::GAMMA2_32;
pub const GAMMA2_88: u32 = super::reduce::GAMMA2_88;
pub fn power2_round(r: u32) -> (u32, u32) {
super::reduce::power2_round(r)
}
pub fn high_bits(r: u32, gamma2: u32) -> u32 {
super::reduce::high_bits(r, gamma2)
}
pub fn decompose(r: u32, gamma2: u32) -> (u32, i32) {
super::reduce::decompose(r, gamma2)
}
pub fn make_hint(z: u32, r: u32, gamma2: u32) -> u32 {
super::reduce::make_hint(z, r, gamma2)
}
pub fn use_hint(hint: u32, r: u32, gamma2: u32) -> u32 {
super::reduce::use_hint(hint, r, gamma2)
}
pub fn inf_norm(a: u32) -> u32 {
super::reduce::inf_norm(a)
}
pub fn sample_ntt_poly(rho: &[u8], s: u8, r: u8) -> Poly {
super::sample::sample_ntt_poly(rho, s, r)
}
pub fn sample_bounded_poly(seed: &[u8], eta: u32, nonce: u16) -> Poly {
super::sample::sample_bounded_poly(seed, eta, nonce)
}
pub fn sample_challenge(seed: &[u8], tau: usize) -> Poly {
super::sample::sample_challenge(seed, tau)
}
pub fn expand_mask(seed: &[u8], gamma1_bits: u32) -> Poly {
super::sample::expand_mask(seed, gamma1_bits)
}
pub fn pack_t1(f: &Poly) -> Vec<u8> {
super::encode::pack_t1(f)
}
pub fn unpack_t1(b: &[u8]) -> Poly {
super::encode::unpack_t1(b)
}
pub fn pack_t0(f: &Poly) -> Vec<u8> {
super::encode::pack_t0(f)
}
pub fn unpack_t0(b: &[u8]) -> Poly {
super::encode::unpack_t0(b)
}
pub fn pack_eta2(f: &Poly) -> Vec<u8> {
super::encode::pack_eta2(f)
}
pub fn unpack_eta2(b: &[u8]) -> Result<Poly, ()> {
super::encode::unpack_eta2(b)
}
pub fn pack_eta4(f: &Poly) -> Vec<u8> {
super::encode::pack_eta4(f)
}
pub fn unpack_eta4(b: &[u8]) -> Result<Poly, ()> {
super::encode::unpack_eta4(b)
}
pub fn pack_z17(f: &Poly) -> Vec<u8> {
super::encode::pack_z17(f)
}
pub fn pack_z19(f: &Poly) -> Vec<u8> {
super::encode::pack_z19(f)
}
pub fn unpack_z17(b: &[u8]) -> Poly {
super::encode::unpack_z17(b)
}
pub fn unpack_z19(b: &[u8]) -> Poly {
super::encode::unpack_z19(b)
}
pub fn pack_w1_4(f: &Poly) -> Vec<u8> {
super::encode::pack_w1_4(f)
}
pub fn pack_w1_6(f: &Poly) -> Vec<u8> {
super::encode::pack_w1_6(f)
}
pub fn pack_hint(hints: &[Poly], omega: usize) -> Vec<u8> {
super::encode::pack_hint(hints, omega)
}
pub fn unpack_hint(b: &[u8], hints: &mut [Poly], omega: usize) -> bool {
super::encode::unpack_hint(b, hints, omega)
}
pub fn pack_eta(f: &Poly, p: &Params) -> Vec<u8> {
super::pack_eta(f, p)
}
pub fn unpack_eta(b: &[u8], p: &Params) -> Result<Poly, super::Error> {
super::unpack_eta(b, p)
}
pub fn pack_z(f: &Poly, p: &Params) -> Vec<u8> {
super::pack_z(f, p)
}
pub fn unpack_z(b: &[u8], p: &Params) -> Poly {
super::unpack_z(b, p)
}
pub fn pack_w1(f: &Poly, p: &Params) -> Vec<u8> {
super::pack_w1(f, p)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_poly() -> Poly {
let mut p = Poly::zero();
for i in 0..N {
p.c[i] = ((i as u32).wrapping_mul(2_654_435_761)) % Q;
}
p
}
fn negacyclic_mul(a: &Poly, b: &Poly) -> Poly {
let mut acc = [0i128; N];
for (i, &ai) in a.c.iter().enumerate() {
for (j, &bj) in b.c.iter().enumerate() {
let prod = (ai as i128) * (bj as i128);
let k = i + j;
if k < N {
acc[k] += prod;
} else {
acc[k - N] -= prod;
}
}
}
let mut r = Poly::zero();
for (dst, &a) in r.c.iter_mut().zip(acc.iter()) {
*dst = a.rem_euclid(Q as i128) as u32;
}
r
}
#[test]
fn ntt_multiply_matches_schoolbook() {
let a = sample_poly();
let mut b = Poly::zero();
for i in 0..N {
b.c[i] = ((i as u32).wrapping_mul(40_503).wrapping_add(7)) % Q;
}
let want = negacyclic_mul(&a, &b);
let mut a_ntt = a;
a_ntt.ntt();
let mut b_ntt = b;
b_ntt.ntt();
let mut got = ntt_mul(&a_ntt, &b_ntt);
got.inv_ntt();
assert_eq!(got, want, "NTT product != schoolbook negacyclic product");
}
#[test]
fn ntt_mul_commutes() {
let mut a = sample_poly();
a.ntt();
let mut b = Poly::zero();
for i in 0..N {
b.c[i] = ((i as u32).wrapping_mul(40_503).wrapping_add(7)) % Q;
}
b.ntt();
assert_eq!(
ntt_mul(&a, &b),
ntt_mul(&b, &a),
"ntt_mul is not commutative"
);
}
#[test]
fn pack_unpack_t1_roundtrip() {
let mut p = Poly::zero();
for i in 0..N {
p.c[i] = (i as u32) & 0x3ff;
}
let bytes = pack_t1(&p);
assert_eq!(bytes.len(), N * 10 / 8);
assert_eq!(unpack_t1(&bytes), p);
}
#[test]
fn pack_unpack_z_roundtrip() {
let mut p = Poly::zero();
for i in 0..N {
p.c[i] = sub(ML_DSA_44.params.gamma1, (i as u32) % 7);
}
let b44 = pack_z(&p, &ML_DSA_44.params);
assert_eq!(unpack_z(&b44, &ML_DSA_44.params), p);
let mut p2 = Poly::zero();
for i in 0..N {
p2.c[i] = sub(ML_DSA_65.params.gamma1, (i as u32) % 11);
}
let b65 = pack_z(&p2, &ML_DSA_65.params);
assert_eq!(unpack_z(&b65, &ML_DSA_65.params), p2);
}
#[test]
fn sample_ntt_poly_deterministic() {
let rho = [7u8; 32];
let a = sample_ntt_poly(&rho, 1, 2);
let b = sample_ntt_poly(&rho, 1, 2);
assert_eq!(a, b, "sample_ntt_poly is not deterministic");
let c = sample_ntt_poly(&rho, 2, 1);
assert_ne!(a, c, "domain separation had no effect");
}
#[test]
fn level_dimensions() {
assert_eq!((ML_DSA_44.k, ML_DSA_44.l), (4, 4));
assert_eq!((ML_DSA_65.k, ML_DSA_65.l), (6, 5));
assert_eq!((ML_DSA_87.k, ML_DSA_87.l), (8, 7));
}
}