#![allow(dead_code)]
pub mod keygen;
pub mod ntt;
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[allow(unsafe_code)]
pub mod ntt_avx2;
pub mod params;
pub mod poly;
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[allow(unsafe_code)]
pub mod poly_simd;
pub mod rounding;
pub mod sampling;
pub mod sign;
pub mod verify;
#[cfg(test)]
mod tests;
pub use params::{MlDsaParams, Params44, Params65, Params87};
pub use poly::{Poly, PolyMatrix, PolyVecK};
use core::marker::PhantomData;
#[derive(Clone)]
pub struct MlDsaSigningKey<P: MlDsaParams> {
rho: [u8; 32],
key: [u8; 32],
tr: [u8; 64],
bytes: alloc::vec::Vec<u8>,
_params: PhantomData<P>,
}
impl<P: MlDsaParams> MlDsaSigningKey<P> {
pub const SIZE: usize = P::SK_SIZE;
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != Self::SIZE {
return None;
}
let mut rho = [0u8; 32];
rho.copy_from_slice(&bytes[0..32]);
let mut key = [0u8; 32];
key.copy_from_slice(&bytes[32..64]);
let mut tr = [0u8; 64];
tr.copy_from_slice(&bytes[64..128]);
Some(Self {
rho,
key,
tr,
bytes: bytes.to_vec(),
_params: PhantomData,
})
}
pub fn to_bytes(&self) -> alloc::vec::Vec<u8> {
self.bytes.clone()
}
}
impl<P: MlDsaParams> core::fmt::Debug for MlDsaSigningKey<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "MlDsaSigningKey<{}>([REDACTED])", P::ALGORITHM)
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct MlDsaVerifyingKey<P: MlDsaParams> {
rho: [u8; 32],
bytes: alloc::vec::Vec<u8>,
_params: PhantomData<P>,
}
impl<P: MlDsaParams> MlDsaVerifyingKey<P> {
pub const SIZE: usize = P::PK_SIZE;
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != Self::SIZE {
return None;
}
let mut rho = [0u8; 32];
rho.copy_from_slice(&bytes[0..32]);
Some(Self {
rho,
bytes: bytes.to_vec(),
_params: PhantomData,
})
}
pub fn to_bytes(&self) -> alloc::vec::Vec<u8> {
self.bytes.clone()
}
}
impl<P: MlDsaParams> core::fmt::Debug for MlDsaVerifyingKey<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"MlDsaVerifyingKey<{}>({} bytes)",
P::ALGORITHM,
self.bytes.len()
)
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct MlDsaSignature<P: MlDsaParams> {
bytes: alloc::vec::Vec<u8>,
_params: PhantomData<P>,
}
impl<P: MlDsaParams> MlDsaSignature<P> {
pub const SIZE: usize = P::SIG_SIZE;
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != Self::SIZE {
return None;
}
Some(Self {
bytes: bytes.to_vec(),
_params: PhantomData,
})
}
pub fn to_bytes(&self) -> alloc::vec::Vec<u8> {
self.bytes.clone()
}
}
impl<P: MlDsaParams> core::fmt::Debug for MlDsaSignature<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"MlDsaSignature<{}>({} bytes)",
P::ALGORITHM,
self.bytes.len()
)
}
}
pub trait MlDsa<P: MlDsaParams> {
fn generate_keypair() -> (MlDsaSigningKey<P>, MlDsaVerifyingKey<P>);
fn sign(sk: &MlDsaSigningKey<P>, message: &[u8]) -> MlDsaSignature<P>;
fn verify(
vk: &MlDsaVerifyingKey<P>,
message: &[u8],
signature: &MlDsaSignature<P>,
) -> Result<(), MlDsaError>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MlDsaError {
VerificationFailed,
InvalidKey,
InvalidSignature,
InternalError,
}
impl core::fmt::Display for MlDsaError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::VerificationFailed => write!(f, "signature verification failed"),
Self::InvalidKey => write!(f, "invalid key format"),
Self::InvalidSignature => write!(f, "invalid signature format"),
Self::InternalError => write!(f, "internal error"),
}
}
}
pub type MlDsa44 = MlDsaNative<Params44>;
pub type MlDsa65 = MlDsaNative<Params65>;
pub type MlDsa87 = MlDsaNative<Params87>;
pub struct MlDsaNative<P: MlDsaParams> {
_params: PhantomData<P>,
}
impl<P: MlDsaParams> MlDsa<P> for MlDsaNative<P> {
fn generate_keypair() -> (MlDsaSigningKey<P>, MlDsaVerifyingKey<P>) {
let mut seed = [0u8; 32];
getrandom::getrandom(&mut seed).expect("Failed to generate random seed");
let kp = keygen::generate_keypair_internal::<P>(&seed);
let pk_bytes = keygen::pack_pk::<P>(&kp.rho, &kp.t1);
let sk_bytes = keygen::pack_sk::<P>(&kp.rho, &kp.key, &kp.tr, &kp.s1, &kp.s2, &kp.t0);
let sk = MlDsaSigningKey {
rho: kp.rho,
key: kp.key,
tr: kp.tr,
bytes: sk_bytes,
_params: PhantomData,
};
let vk = MlDsaVerifyingKey {
rho: kp.rho,
bytes: pk_bytes,
_params: PhantomData,
};
(sk, vk)
}
fn sign(sk: &MlDsaSigningKey<P>, message: &[u8]) -> MlDsaSignature<P> {
let sig_bytes = sign::sign_internal::<P>(&sk.bytes, message)
.expect("Signing failed - this should not happen with valid keys");
MlDsaSignature {
bytes: sig_bytes,
_params: PhantomData,
}
}
fn verify(
vk: &MlDsaVerifyingKey<P>,
message: &[u8],
signature: &MlDsaSignature<P>,
) -> Result<(), MlDsaError> {
if verify::verify_internal::<P>(&vk.bytes, message, &signature.bytes) {
Ok(())
} else {
Err(MlDsaError::VerificationFailed)
}
}
}
extern crate alloc;
#[cfg(test)]
mod api_tests {
use super::*;
#[test]
fn test_signing_key_size() {
assert_eq!(MlDsaSigningKey::<Params44>::SIZE, 2560);
assert_eq!(MlDsaSigningKey::<Params65>::SIZE, 4032);
assert_eq!(MlDsaSigningKey::<Params87>::SIZE, 4896);
}
#[test]
fn test_verifying_key_size() {
assert_eq!(MlDsaVerifyingKey::<Params44>::SIZE, 1312);
assert_eq!(MlDsaVerifyingKey::<Params65>::SIZE, 1952);
assert_eq!(MlDsaVerifyingKey::<Params87>::SIZE, 2592);
}
#[test]
fn test_signature_size() {
assert_eq!(MlDsaSignature::<Params44>::SIZE, 2420);
assert_eq!(MlDsaSignature::<Params65>::SIZE, 3309);
assert_eq!(MlDsaSignature::<Params87>::SIZE, 4627);
}
#[test]
fn test_key_from_bytes_wrong_size() {
let bytes = vec![0u8; 100];
assert!(MlDsaSigningKey::<Params65>::from_bytes(&bytes).is_none());
assert!(MlDsaVerifyingKey::<Params65>::from_bytes(&bytes).is_none());
}
#[test]
fn test_signature_from_bytes_wrong_size() {
let bytes = vec![0u8; 100];
assert!(MlDsaSignature::<Params65>::from_bytes(&bytes).is_none());
}
#[test]
fn test_key_roundtrip() {
let bytes = vec![0u8; Params65::SK_SIZE];
let sk = MlDsaSigningKey::<Params65>::from_bytes(&bytes).unwrap();
assert_eq!(sk.to_bytes(), bytes);
}
#[test]
fn test_generate_keypair_44() {
let (sk, vk) = MlDsa44::generate_keypair();
assert_eq!(sk.to_bytes().len(), Params44::SK_SIZE);
assert_eq!(vk.to_bytes().len(), Params44::PK_SIZE);
}
#[test]
fn test_generate_keypair_65() {
let (sk, vk) = MlDsa65::generate_keypair();
assert_eq!(sk.to_bytes().len(), Params65::SK_SIZE);
assert_eq!(vk.to_bytes().len(), Params65::PK_SIZE);
}
#[test]
fn test_generate_keypair_87() {
let (sk, vk) = MlDsa87::generate_keypair();
assert_eq!(sk.to_bytes().len(), Params87::SK_SIZE);
assert_eq!(vk.to_bytes().len(), Params87::PK_SIZE);
}
#[test]
fn test_keypair_keys_are_different() {
let (sk1, vk1) = MlDsa65::generate_keypair();
let (sk2, vk2) = MlDsa65::generate_keypair();
assert_ne!(sk1.to_bytes(), sk2.to_bytes());
assert_ne!(vk1.to_bytes(), vk2.to_bytes());
}
}