use crate::error::{KeyRejected, Unspecified};
use crate::pqdsa::signature::{PqdsaSigningAlgorithm, PublicKey};
use crate::pqdsa::AlgorithmID;
use crate::signature::KeyPair;
use crate::wolfcrypt_rs::{
wc_FreeRng, wc_InitRng, wc_dilithium_export_private, wc_dilithium_export_public,
wc_dilithium_free, wc_dilithium_import_key, wc_dilithium_init, wc_dilithium_key,
wc_dilithium_make_key, wc_dilithium_make_key_from_seed, wc_dilithium_set_level,
wc_dilithium_sign_ctx_msg, WC_RNG,
};
use core::fmt::{Debug, Formatter};
#[cfg(not(feature = "std"))]
use crate::prelude::*;
#[allow(clippy::module_name_repetitions)]
pub struct PqdsaKeyPair {
algorithm: &'static PqdsaSigningAlgorithm,
priv_key: Box<[u8]>,
pubkey: PublicKey,
}
#[allow(clippy::missing_fields_in_debug)]
impl Debug for PqdsaKeyPair {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PqdsaKeyPair")
.field("algorithm", &self.algorithm)
.finish()
}
}
impl KeyPair for PqdsaKeyPair {
type PublicKey = PublicKey;
fn public_key(&self) -> &Self::PublicKey {
&self.pubkey
}
}
pub struct PqdsaPrivateKey<'a>(pub(crate) &'a PqdsaKeyPair);
impl PqdsaPrivateKey<'_> {
pub fn as_raw_bytes(&self) -> Result<PqdsaPrivateKeyRaw, Unspecified> {
let mut combined = Vec::with_capacity(self.0.priv_key.len() + self.0.pubkey.octets.len());
combined.extend_from_slice(&self.0.priv_key);
combined.extend_from_slice(&self.0.pubkey.octets);
Ok(PqdsaPrivateKeyRaw(combined.into_boxed_slice()))
}
}
pub struct PqdsaPrivateKeyRaw(Box<[u8]>);
impl AsRef<[u8]> for PqdsaPrivateKeyRaw {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
unsafe fn init_dilithium_key(
key: &mut wc_dilithium_key,
id: &AlgorithmID,
) -> Result<(), Unspecified> {
let rc = wc_dilithium_init(key);
if rc != 0 {
return Err(Unspecified);
}
let rc = wc_dilithium_set_level(key, id.level());
if rc != 0 {
wc_dilithium_free(key);
return Err(Unspecified);
}
Ok(())
}
unsafe fn export_public_key(
key: &mut wc_dilithium_key,
id: &AlgorithmID,
) -> Result<Box<[u8]>, Unspecified> {
let pub_size = id.pub_key_size_bytes();
let mut pub_buf = vec![0u8; pub_size];
let mut pub_len = pub_size as u32;
let rc = wc_dilithium_export_public(key, pub_buf.as_mut_ptr(), &mut pub_len);
if rc != 0 || pub_len as usize != pub_size {
return Err(Unspecified);
}
Ok(pub_buf.into_boxed_slice())
}
unsafe fn export_private_key(
key: &mut wc_dilithium_key,
id: &AlgorithmID,
) -> Result<Box<[u8]>, Unspecified> {
let priv_size = id.priv_key_size_bytes();
let mut priv_buf = vec![0u8; priv_size];
let mut priv_len = priv_size as u32;
let rc = wc_dilithium_export_private(key, priv_buf.as_mut_ptr(), &mut priv_len);
if rc != 0 || priv_len as usize != priv_size {
return Err(Unspecified);
}
Ok(priv_buf.into_boxed_slice())
}
unsafe fn import_key_pair(
key: &mut wc_dilithium_key,
priv_bytes: &[u8],
pub_bytes: &[u8],
) -> Result<(), Unspecified> {
let rc = wc_dilithium_import_key(
priv_bytes.as_ptr(),
priv_bytes.len() as u32,
pub_bytes.as_ptr(),
pub_bytes.len() as u32,
key,
);
if rc != 0 {
return Err(Unspecified);
}
Ok(())
}
impl PqdsaKeyPair {
pub fn generate(algorithm: &'static PqdsaSigningAlgorithm) -> Result<Self, Unspecified> {
let id = algorithm.0.id;
unsafe {
let mut rng = WC_RNG::zeroed();
let rc = wc_InitRng(&mut rng);
if rc != 0 {
return Err(Unspecified);
}
let mut key = wc_dilithium_key::zeroed();
let result = (|| -> Result<Self, Unspecified> {
init_dilithium_key(&mut key, id)?;
let rc = wc_dilithium_make_key(&mut key, &mut rng);
if rc != 0 {
return Err(Unspecified);
}
let pub_bytes = export_public_key(&mut key, id)?;
let priv_bytes = export_private_key(&mut key, id)?;
Ok(Self {
algorithm,
priv_key: priv_bytes,
pubkey: PublicKey::new(pub_bytes),
})
})();
wc_dilithium_free(&mut key);
wc_FreeRng(&mut rng);
result
}
}
pub fn from_raw_private_key(
algorithm: &'static PqdsaSigningAlgorithm,
raw_private_key: &[u8],
) -> Result<Self, KeyRejected> {
let id = algorithm.0.id;
let priv_size = id.priv_key_size_bytes();
let pub_size = id.pub_key_size_bytes();
let combined_size = priv_size + pub_size;
if raw_private_key.len() != combined_size {
return Err(KeyRejected::wrong_algorithm());
}
let priv_bytes = &raw_private_key[..priv_size];
let pub_bytes = &raw_private_key[priv_size..];
unsafe {
let mut key = wc_dilithium_key::zeroed();
let result = (|| -> Result<Self, KeyRejected> {
init_dilithium_key(&mut key, id).map_err(|_| KeyRejected::unspecified())?;
import_key_pair(&mut key, priv_bytes, pub_bytes)
.map_err(|_| KeyRejected::unspecified())?;
Ok(Self {
algorithm,
priv_key: priv_bytes.to_vec().into_boxed_slice(),
pubkey: PublicKey::new(pub_bytes.to_vec().into_boxed_slice()),
})
})();
wc_dilithium_free(&mut key);
result
}
}
pub fn from_seed(
algorithm: &'static PqdsaSigningAlgorithm,
seed: &[u8],
) -> Result<Self, KeyRejected> {
let id = algorithm.0.id;
let expected_seed_len = id.seed_size_bytes();
match seed.len().cmp(&expected_seed_len) {
core::cmp::Ordering::Less => return Err(KeyRejected::too_small()),
core::cmp::Ordering::Greater => return Err(KeyRejected::too_large()),
core::cmp::Ordering::Equal => {}
}
unsafe {
let mut key = wc_dilithium_key::zeroed();
let result = (|| -> Result<Self, KeyRejected> {
init_dilithium_key(&mut key, id).map_err(|_| KeyRejected::unspecified())?;
let rc = wc_dilithium_make_key_from_seed(&mut key, seed.as_ptr());
if rc != 0 {
return Err(KeyRejected::unspecified());
}
let pub_bytes =
export_public_key(&mut key, id).map_err(|_| KeyRejected::unspecified())?;
let priv_bytes =
export_private_key(&mut key, id).map_err(|_| KeyRejected::unspecified())?;
Ok(Self {
algorithm,
priv_key: priv_bytes,
pubkey: PublicKey::new(pub_bytes),
})
})();
wc_dilithium_free(&mut key);
result
}
}
pub fn sign(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Unspecified> {
let id = self.algorithm.0.id;
let sig_length = id.signature_size_bytes();
if signature.len() < sig_length {
return Err(Unspecified);
}
unsafe {
let mut rng = WC_RNG::zeroed();
let rc = wc_InitRng(&mut rng);
if rc != 0 {
return Err(Unspecified);
}
let mut key = wc_dilithium_key::zeroed();
let result = (|| -> Result<usize, Unspecified> {
init_dilithium_key(&mut key, id)?;
import_key_pair(&mut key, &self.priv_key, self.pubkey.octets.as_ref())?;
let mut sig_len = sig_length as u32;
let rc = wc_dilithium_sign_ctx_msg(
core::ptr::null(), 0, msg.as_ptr(),
msg.len() as u32,
signature.as_mut_ptr(),
&mut sig_len,
&mut key,
&mut rng,
);
if rc != 0 {
return Err(Unspecified);
}
Ok(sig_len as usize)
})();
wc_dilithium_free(&mut key);
wc_FreeRng(&mut rng);
result
}
}
#[must_use]
pub fn algorithm(&self) -> &'static PqdsaSigningAlgorithm {
self.algorithm
}
#[must_use]
pub fn private_key(&self) -> PqdsaPrivateKey<'_> {
PqdsaPrivateKey(self)
}
}
unsafe impl Send for PqdsaKeyPair {}
unsafe impl Sync for PqdsaKeyPair {}
#[cfg(all(test, feature = "unstable"))]
mod tests {
use super::*;
use crate::signature::UnparsedPublicKey;
use crate::unstable::signature::{ML_DSA_44_SIGNING, ML_DSA_65_SIGNING, ML_DSA_87_SIGNING};
const TEST_ALGORITHMS: &[&PqdsaSigningAlgorithm] =
&[&ML_DSA_44_SIGNING, &ML_DSA_65_SIGNING, &ML_DSA_87_SIGNING];
#[test]
fn test_generate_sign_verify_roundtrip() {
for &alg in TEST_ALGORITHMS {
let keypair = PqdsaKeyPair::generate(alg).unwrap();
let message = b"Test message for ML-DSA";
let mut signature = vec![0u8; alg.signature_len()];
let sig_len = keypair.sign(message, &mut signature).unwrap();
assert_eq!(sig_len, alg.signature_len());
let verify_alg = alg.0;
let pk = UnparsedPublicKey::new(verify_alg, keypair.public_key().as_ref());
pk.verify(message, &signature).unwrap();
}
}
#[test]
fn test_sign_buffer_too_small() {
for &alg in TEST_ALGORITHMS {
let keypair = PqdsaKeyPair::generate(alg).unwrap();
let message = b"Test message";
let mut small_buf = vec![0u8; alg.signature_len() - 1];
assert!(keypair.sign(message, &mut small_buf).is_err());
}
}
#[test]
fn test_from_seed() {
for &alg in TEST_ALGORITHMS {
let seed = [1u8; 32];
let kp = PqdsaKeyPair::from_seed(alg, &seed).unwrap();
assert_eq!(kp.algorithm(), alg);
let msg = b"seed test";
let mut sig = vec![0; alg.signature_len()];
let sig_len = kp.sign(msg, &mut sig).unwrap();
assert_eq!(sig_len, alg.signature_len());
}
}
#[test]
fn test_from_seed_deterministic() {
for &alg in TEST_ALGORITHMS {
let seed = [42u8; 32];
let kp1 = PqdsaKeyPair::from_seed(alg, &seed).unwrap();
let kp2 = PqdsaKeyPair::from_seed(alg, &seed).unwrap();
assert_eq!(kp1.public_key().as_ref(), kp2.public_key().as_ref());
}
}
#[test]
fn test_from_seed_wrong_size() {
for &alg in TEST_ALGORITHMS {
assert_eq!(
PqdsaKeyPair::from_seed(alg, &[0u8; 31]).err(),
Some(KeyRejected::too_small())
);
assert_eq!(
PqdsaKeyPair::from_seed(alg, &[0u8; 33]).err(),
Some(KeyRejected::too_large())
);
assert_eq!(
PqdsaKeyPair::from_seed(alg, &[]).err(),
Some(KeyRejected::too_small())
);
}
}
#[test]
fn test_from_seed_different_seeds_different_keys() {
for &alg in TEST_ALGORITHMS {
let kp1 = PqdsaKeyPair::from_seed(alg, &[1u8; 32]).unwrap();
let kp2 = PqdsaKeyPair::from_seed(alg, &[2u8; 32]).unwrap();
assert_ne!(kp1.public_key().as_ref(), kp2.public_key().as_ref());
}
}
#[test]
fn test_from_seed_raw_private_key_roundtrip() {
for &alg in TEST_ALGORITHMS {
let seed = [55u8; 32];
let kp = PqdsaKeyPair::from_seed(alg, &seed).unwrap();
let raw_bytes = kp.private_key().as_raw_bytes().unwrap();
let kp2 = PqdsaKeyPair::from_raw_private_key(alg, raw_bytes.as_ref()).unwrap();
assert_eq!(kp.public_key().as_ref(), kp2.public_key().as_ref());
}
}
#[test]
fn test_from_seed_same_seed_different_algorithms() {
let seed = [42u8; 32];
let kp_44 = PqdsaKeyPair::from_seed(&ML_DSA_44_SIGNING, &seed).unwrap();
let kp_65 = PqdsaKeyPair::from_seed(&ML_DSA_65_SIGNING, &seed).unwrap();
let kp_87 = PqdsaKeyPair::from_seed(&ML_DSA_87_SIGNING, &seed).unwrap();
assert_ne!(
kp_44.public_key().as_ref().len(),
kp_65.public_key().as_ref().len()
);
assert_ne!(
kp_65.public_key().as_ref().len(),
kp_87.public_key().as_ref().len()
);
}
#[test]
fn test_algorithm_getter() {
for &alg in TEST_ALGORITHMS {
let keypair = PqdsaKeyPair::generate(alg).unwrap();
assert_eq!(keypair.algorithm(), alg);
}
}
#[test]
fn test_debug() {
for &alg in TEST_ALGORITHMS {
let keypair = PqdsaKeyPair::generate(alg).unwrap();
let debug_str = format!("{keypair:?}");
assert!(
debug_str.starts_with("PqdsaKeyPair { algorithm: PqdsaSigningAlgorithm(PqdsaVerificationAlgorithm { id:"),
"{debug_str}"
);
let pubkey = keypair.public_key();
let pk_debug = format!("{pubkey:?}");
assert!(pk_debug.starts_with("PqdsaPublicKey("), "{pk_debug}");
}
}
#[test]
fn test_negative_verify_wrong_key() {
for &alg in TEST_ALGORITHMS {
let kp1 = PqdsaKeyPair::generate(alg).unwrap();
let kp2 = PqdsaKeyPair::generate(alg).unwrap();
let msg = b"wrong key test";
let mut sig = vec![0u8; alg.signature_len()];
kp1.sign(msg, &mut sig).unwrap();
let wrong_pk = UnparsedPublicKey::new(alg.0, kp2.public_key().as_ref());
assert!(wrong_pk.verify(msg, &sig).is_err());
}
}
#[test]
fn test_negative_corrupted_signature() {
for &alg in TEST_ALGORITHMS {
let kp = PqdsaKeyPair::generate(alg).unwrap();
let msg = b"corrupted sig test";
let mut sig = vec![0u8; alg.signature_len()];
kp.sign(msg, &mut sig).unwrap();
sig[0] ^= 0xff;
let pk = UnparsedPublicKey::new(alg.0, kp.public_key().as_ref());
assert!(pk.verify(msg, &sig).is_err());
}
}
}