use crate::error::PqRascvError;
use sha3::{Digest, Sha3_256};
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const ML_DSA_65_SEED_SIZE: usize = 32;
pub const ML_DSA_65_VERIFYING_KEY_SIZE: usize = 1952;
pub const ML_DSA_65_SIGNATURE_SIZE: usize = 3309;
#[derive(Clone)]
pub struct SignatureBytes(pub [u8; ML_DSA_65_SIGNATURE_SIZE]);
impl AsRef<[u8]> for SignatureBytes {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct SigningKeySeed(pub [u8; ML_DSA_65_SEED_SIZE]);
impl SigningKeySeed {
#[must_use]
pub fn new(bytes: [u8; ML_DSA_65_SEED_SIZE]) -> Self {
Self(bytes)
}
#[must_use]
pub fn as_bytes(&self) -> &[u8; ML_DSA_65_SEED_SIZE] {
&self.0
}
}
pub trait CryptoBackend {
fn sign(&self, message: &[u8], signing_seed: &[u8]) -> Result<SignatureBytes, PqRascvError>;
fn verify(
&self,
message: &[u8],
verifying_key: &[u8],
signature: &[u8],
) -> Result<(), PqRascvError>;
fn pub_key_id(verifying_key: &[u8]) -> [u8; 32] {
let mut h = Sha3_256::new();
h.update(verifying_key);
h.finalize().into()
}
}
pub struct MlDsaBackend;
impl CryptoBackend for MlDsaBackend {
fn sign(&self, message: &[u8], signing_seed: &[u8]) -> Result<SignatureBytes, PqRascvError> {
use ml_dsa::{KeyGen, MlDsa65};
let seed_array: &[u8; ML_DSA_65_SEED_SIZE] = signing_seed
.try_into()
.map_err(|_| PqRascvError::SigningFailed)?;
let seed = ml_dsa::B32::from(*seed_array);
let sk = MlDsa65::from_seed(&seed);
let sig = sk
.signing_key()
.sign_deterministic(message, b"")
.map_err(|_| PqRascvError::SigningFailed)?;
let encoded = sig.encode();
let sig_bytes: [u8; ML_DSA_65_SIGNATURE_SIZE] = (*encoded).try_into()
.map_err(|_| PqRascvError::SigningFailed)?;
Ok(SignatureBytes(sig_bytes))
}
fn verify(
&self,
message: &[u8],
verifying_key: &[u8],
signature: &[u8],
) -> Result<(), PqRascvError> {
use ml_dsa::{EncodedVerifyingKey, MlDsa65, Signature, VerifyingKey};
if verifying_key.len() != ML_DSA_65_VERIFYING_KEY_SIZE {
return Err(PqRascvError::VerificationFailed);
}
if signature.len() != ML_DSA_65_SIGNATURE_SIZE {
return Err(PqRascvError::VerificationFailed);
}
let vk_array: [u8; ML_DSA_65_VERIFYING_KEY_SIZE] = verifying_key
.try_into()
.map_err(|_| PqRascvError::VerificationFailed)?;
let encoded_vk = EncodedVerifyingKey::<MlDsa65>::from(vk_array);
let vk = VerifyingKey::<MlDsa65>::decode(&encoded_vk);
let sig_array: [u8; ML_DSA_65_SIGNATURE_SIZE] = signature
.try_into()
.map_err(|_| PqRascvError::VerificationFailed)?;
let encoded_sig = ml_dsa::EncodedSignature::<MlDsa65>::from(sig_array);
let sig = Signature::<MlDsa65>::decode(&encoded_sig)
.ok_or(PqRascvError::VerificationFailed)?;
if vk.verify_with_context(message, b"", &sig) {
Ok(())
} else {
Err(PqRascvError::VerificationFailed)
}
}
}
pub fn generate_ml_dsa_keypair(
) -> Result<
(
SigningKeySeed,
[u8; ML_DSA_65_VERIFYING_KEY_SIZE],
),
PqRascvError,
> {
use getrandom::rand_core::UnwrapErr;
use getrandom::SysRng;
use ml_dsa::{KeyGen, MlDsa65};
use ml_dsa::signature::Keypair;
let mut rng = UnwrapErr(SysRng);
let sk = MlDsa65::key_gen(&mut rng);
let seed = sk.to_seed();
let vk_encoded = sk.verifying_key().encode();
let vk_bytes: [u8; ML_DSA_65_VERIFYING_KEY_SIZE] = (*vk_encoded)
.try_into()
.map_err(|_| PqRascvError::KeyGenerationFailed)?;
let seed_array: [u8; ML_DSA_65_SEED_SIZE] = (*seed)
.try_into()
.map_err(|_| PqRascvError::KeyGenerationFailed)?;
Ok((SigningKeySeed::new(seed_array), vk_bytes))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sign_and_verify_roundtrip() {
let (seed, vk) = generate_ml_dsa_keypair().expect("keygen failed");
let backend = MlDsaBackend;
let message = b"hello pqrascv-core";
let sig = backend.sign(message, seed.as_bytes()).expect("sign failed");
backend.verify(message, &vk, sig.as_ref()).expect("verify failed");
}
#[test]
fn verify_rejects_tampered_message() {
let (seed, vk) = generate_ml_dsa_keypair().expect("keygen failed");
let backend = MlDsaBackend;
let sig = backend.sign(b"original", seed.as_bytes()).expect("sign failed");
assert!(backend.verify(b"tampered", &vk, sig.as_ref()).is_err());
}
#[test]
fn verify_rejects_wrong_key() {
let (seed1, _vk1) = generate_ml_dsa_keypair().expect("keygen 1 failed");
let (_seed2, vk2) = generate_ml_dsa_keypair().expect("keygen 2 failed");
let backend = MlDsaBackend;
let sig = backend
.sign(b"cross-key test", seed1.as_bytes())
.expect("sign failed");
assert!(backend.verify(b"cross-key test", &vk2, sig.as_ref()).is_err());
}
#[test]
fn pub_key_id_is_deterministic() {
let vk = [0u8; ML_DSA_65_VERIFYING_KEY_SIZE];
assert_eq!(MlDsaBackend::pub_key_id(&vk), MlDsaBackend::pub_key_id(&vk));
}
#[test]
fn signing_is_deterministic() {
let (seed, _vk) = generate_ml_dsa_keypair().expect("keygen failed");
let backend = MlDsaBackend;
let message = b"determinism test";
let sig1 = backend.sign(message, seed.as_bytes()).unwrap();
let sig2 = backend.sign(message, seed.as_bytes()).unwrap();
assert_eq!(sig1.0, sig2.0);
}
}