use crate::ntt32::Ntt32Context;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PqScheme {
MlDsa44,
MlDsa65,
MlDsa87,
}
impl PqScheme {
#[inline]
pub const fn n(self) -> usize {
match self {
Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 256,
}
}
#[inline]
pub const fn q(self) -> u32 {
match self {
Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 8380417,
}
}
#[inline]
pub const fn k(self) -> usize {
match self {
Self::MlDsa44 => 4,
Self::MlDsa65 => 6,
Self::MlDsa87 => 8,
}
}
#[inline]
pub const fn security_level(self) -> u8 {
match self {
Self::MlDsa44 => 2,
Self::MlDsa65 => 3,
Self::MlDsa87 => 5,
}
}
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::MlDsa44 => "ML-DSA-44",
Self::MlDsa65 => "ML-DSA-65",
Self::MlDsa87 => "ML-DSA-87",
}
}
#[inline]
pub const fn fips(self) -> &'static str {
match self {
Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => "FIPS 204",
}
}
}
pub struct PqNtt {
ctx: Ntt32Context,
scheme: PqScheme,
}
impl PqNtt {
#[inline]
pub fn new(scheme: PqScheme) -> Self {
let ctx = Ntt32Context::new(scheme.n(), scheme.q());
Self { ctx, scheme }
}
#[inline]
pub fn scheme(&self) -> PqScheme {
self.scheme
}
#[inline]
pub fn n(&self) -> usize {
self.ctx.n
}
#[inline]
pub fn q(&self) -> u32 {
self.ctx.q
}
#[inline]
pub fn security_level(&self) -> u8 {
self.scheme.security_level()
}
#[inline]
pub fn context(&self) -> &Ntt32Context {
&self.ctx
}
#[inline]
pub fn forward(&self, data: &mut [u32]) {
self.ctx.forward(data);
}
#[inline]
pub fn inverse(&self, data: &mut [u32]) {
self.ctx.inverse(data);
}
#[inline]
pub fn multiply(&self, a: &[u32], b: &[u32]) -> alloc::vec::Vec<u32> {
self.ctx.negacyclic_mul(a, b)
}
#[inline]
pub fn multiply_into(&self, a: &mut [u32], b: &mut [u32], result: &mut [u32]) {
self.ctx.negacyclic_mul_into(a, b, result);
}
}
impl core::fmt::Debug for PqNtt {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PqNtt")
.field("scheme", &self.scheme.name())
.field("n", &self.n())
.field("q", &self.q())
.field("security_level", &self.security_level())
.field("fips", &self.scheme.fips())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mldsa_44_roundtrip() {
let ntt = PqNtt::new(PqScheme::MlDsa44);
assert_eq!(ntt.q(), 8380417);
assert_eq!(ntt.n(), 256);
assert_eq!(ntt.security_level(), 2);
let mut data: alloc::vec::Vec<u32> = (0..256).map(|i| i * 1000 % 8380417).collect();
let original = data.clone();
ntt.forward(&mut data);
assert_ne!(data, original, "NTT forward did nothing");
ntt.inverse(&mut data);
assert_eq!(data, original);
}
#[test]
fn test_mldsa_65_roundtrip() {
let ntt = PqNtt::new(PqScheme::MlDsa65);
assert_eq!(ntt.security_level(), 3);
let mut data = alloc::vec![8380416u32; 256]; let original = data.clone();
ntt.forward(&mut data);
ntt.inverse(&mut data);
assert_eq!(data, original);
}
#[test]
fn test_mldsa_87_roundtrip() {
let ntt = PqNtt::new(PqScheme::MlDsa87);
assert_eq!(ntt.security_level(), 5);
let mut data = alloc::vec![0u32; 256];
data[0] = 1;
let original = data.clone();
ntt.forward(&mut data);
ntt.inverse(&mut data);
assert_eq!(data, original);
}
#[test]
fn test_multiply() {
let ntt = PqNtt::new(PqScheme::MlDsa44);
let q = ntt.q();
let mut a = alloc::vec![0u32; 256];
a[0] = 1;
a[1] = 1;
let result = ntt.multiply(&a, &a);
assert_eq!(result[0], 1);
assert_eq!(result[1], 2);
assert_eq!(result[2], 1);
for i in 3..256 {
assert_eq!(result[i], 0, "unexpected non-zero at index {i}");
}
}
#[test]
fn test_scheme_metadata() {
assert_eq!(PqScheme::MlDsa44.name(), "ML-DSA-44");
assert_eq!(PqScheme::MlDsa44.fips(), "FIPS 204");
assert_eq!(PqScheme::MlDsa44.k(), 4);
assert_eq!(PqScheme::MlDsa65.name(), "ML-DSA-65");
assert_eq!(PqScheme::MlDsa65.k(), 6);
assert_eq!(PqScheme::MlDsa87.name(), "ML-DSA-87");
assert_eq!(PqScheme::MlDsa87.k(), 8);
}
#[test]
fn test_output_fully_reduced() {
let ntt = PqNtt::new(PqScheme::MlDsa65);
let mut data: alloc::vec::Vec<u32> =
(0..ntt.n()).map(|i| (i as u32 * 7 + 13) % ntt.q()).collect();
ntt.forward(&mut data);
assert!(
data.iter().all(|&x| x < ntt.q()),
"Output not fully reduced for ML-DSA-65"
);
}
#[test]
fn test_debug_display() {
let ntt = PqNtt::new(PqScheme::MlDsa65);
let debug = alloc::format!("{:?}", ntt);
assert!(debug.contains("ML-DSA-65"));
assert!(debug.contains("FIPS 204"));
}
}