use core::cell::UnsafeCell;
use core::marker::PhantomData;
use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::{check, len_as_u32, WolfCryptError};
use wolfcrypt_rs::{
wc_dilithium_export_private, wc_dilithium_export_public, wc_dilithium_free,
wc_dilithium_import_public, wc_dilithium_init, wc_dilithium_make_key,
wc_dilithium_set_level, wc_dilithium_sign_msg, wc_dilithium_verify_msg,
wc_dilithium_key, wc_FreeRng, wc_InitRng, WC_RNG,
DILITHIUM_ML_DSA_44_KEY_SIZE, DILITHIUM_ML_DSA_44_PUB_KEY_SIZE,
DILITHIUM_ML_DSA_44_SIG_SIZE, DILITHIUM_ML_DSA_65_KEY_SIZE,
DILITHIUM_ML_DSA_65_PUB_KEY_SIZE, DILITHIUM_ML_DSA_65_SIG_SIZE,
DILITHIUM_ML_DSA_87_KEY_SIZE, DILITHIUM_ML_DSA_87_PUB_KEY_SIZE,
DILITHIUM_ML_DSA_87_SIG_SIZE, WC_ML_DSA_44, WC_ML_DSA_65, WC_ML_DSA_87,
};
#[derive(Debug)]
pub struct MlDsaSignature<L: MlDsaLevel> {
bytes: Vec<u8>,
_level: PhantomData<L>,
}
impl<L: MlDsaLevel> Clone for MlDsaSignature<L> {
fn clone(&self) -> Self {
Self { bytes: self.bytes.clone(), _level: PhantomData }
}
}
impl<L: MlDsaLevel> AsRef<[u8]> for MlDsaSignature<L> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<L: MlDsaLevel> signature_trait::SignatureEncoding for MlDsaSignature<L> {
type Repr = Box<[u8]>;
}
impl<L: MlDsaLevel> TryFrom<&[u8]> for MlDsaSignature<L> {
type Error = signature_trait::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() == L::SIG_SIZE {
Ok(Self { bytes: bytes.to_vec(), _level: PhantomData })
} else {
Err(signature_trait::Error::new())
}
}
}
impl<L: MlDsaLevel> From<MlDsaSignature<L>> for Box<[u8]> {
fn from(sig: MlDsaSignature<L>) -> Box<[u8]> {
sig.bytes.into_boxed_slice()
}
}
pub trait MlDsaLevel {
const LEVEL: u8;
const SIG_SIZE: usize;
const PUB_KEY_SIZE: usize;
const PRIV_KEY_SIZE: usize;
}
pub struct MlDsa44;
impl MlDsaLevel for MlDsa44 {
const LEVEL: u8 = WC_ML_DSA_44;
const SIG_SIZE: usize = DILITHIUM_ML_DSA_44_SIG_SIZE;
const PUB_KEY_SIZE: usize = DILITHIUM_ML_DSA_44_PUB_KEY_SIZE;
const PRIV_KEY_SIZE: usize = DILITHIUM_ML_DSA_44_KEY_SIZE;
}
pub struct MlDsa65;
impl MlDsaLevel for MlDsa65 {
const LEVEL: u8 = WC_ML_DSA_65;
const SIG_SIZE: usize = DILITHIUM_ML_DSA_65_SIG_SIZE;
const PUB_KEY_SIZE: usize = DILITHIUM_ML_DSA_65_PUB_KEY_SIZE;
const PRIV_KEY_SIZE: usize = DILITHIUM_ML_DSA_65_KEY_SIZE;
}
pub struct MlDsa87;
impl MlDsaLevel for MlDsa87 {
const LEVEL: u8 = WC_ML_DSA_87;
const SIG_SIZE: usize = DILITHIUM_ML_DSA_87_SIG_SIZE;
const PUB_KEY_SIZE: usize = DILITHIUM_ML_DSA_87_PUB_KEY_SIZE;
const PRIV_KEY_SIZE: usize = DILITHIUM_ML_DSA_87_KEY_SIZE;
}
pub struct MlDsaSigningKey<L: MlDsaLevel> {
key: UnsafeCell<wc_dilithium_key>,
rng: UnsafeCell<WC_RNG>,
_level: PhantomData<L>,
}
unsafe impl<L: MlDsaLevel> Send for MlDsaSigningKey<L> {}
impl<L: MlDsaLevel> MlDsaSigningKey<L> {
pub fn generate(rng: &mut crate::rand::WolfRng) -> Result<Self, WolfCryptError> {
let mut key = wc_dilithium_key::zeroed();
let rc = unsafe { wc_dilithium_init(&mut key) };
check(rc, "wc_dilithium_init")?;
let rc = unsafe { wc_dilithium_set_level(&mut key, L::LEVEL) };
check(rc, "wc_dilithium_set_level")?;
let rc = unsafe { wc_dilithium_make_key(&mut key, &mut rng.rng) };
check(rc, "wc_dilithium_make_key")?;
let mut own_rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut own_rng) };
check(rc, "wc_InitRng")?;
Ok(Self {
key: UnsafeCell::new(key),
rng: UnsafeCell::new(own_rng),
_level: PhantomData,
})
}
pub fn verifying_key(&self) -> MlDsaVerifyingKey<L> {
let mut pub_buf = vec![0u8; L::PUB_KEY_SIZE];
let mut pub_len: u32 = L::PUB_KEY_SIZE as u32;
let rc = unsafe {
wc_dilithium_export_public(
self.key.get(),
pub_buf.as_mut_ptr(),
&mut pub_len,
)
};
assert_eq!(rc, 0, "wc_dilithium_export_public failed (key not initialized)");
assert_eq!(pub_len as usize, L::PUB_KEY_SIZE);
MlDsaVerifyingKey::from_bytes(&pub_buf)
.expect("exported public key must be valid")
}
pub fn to_private_bytes(&self) -> zeroize::Zeroizing<Vec<u8>> {
let mut priv_buf = vec![0u8; L::PRIV_KEY_SIZE];
let mut priv_len: u32 = L::PRIV_KEY_SIZE as u32;
let rc = unsafe {
wc_dilithium_export_private(
self.key.get(),
priv_buf.as_mut_ptr(),
&mut priv_len,
)
};
assert_eq!(rc, 0, "wc_dilithium_export_private failed (key not initialized)");
priv_buf.truncate(priv_len as usize);
zeroize::Zeroizing::new(priv_buf)
}
}
impl<L: MlDsaLevel> Drop for MlDsaSigningKey<L> {
fn drop(&mut self) {
unsafe {
wc_dilithium_free(self.key.get_mut());
wc_FreeRng(self.rng.get_mut());
}
}
}
impl<L: MlDsaLevel> signature_trait::Signer<MlDsaSignature<L>> for MlDsaSigningKey<L> {
fn try_sign(&self, msg: &[u8]) -> Result<MlDsaSignature<L>, signature_trait::Error> {
let mut sig_buf = vec![0u8; L::SIG_SIZE];
let mut sig_len: u32 = L::SIG_SIZE as u32;
let rc = unsafe {
wc_dilithium_sign_msg(
msg.as_ptr(),
len_as_u32(msg.len()),
sig_buf.as_mut_ptr(),
&mut sig_len,
self.key.get(),
self.rng.get(),
)
};
if rc != 0 {
return Err(signature_trait::Error::new());
}
sig_buf.truncate(sig_len as usize);
Ok(MlDsaSignature { bytes: sig_buf, _level: PhantomData })
}
}
pub struct MlDsaVerifyingKey<L: MlDsaLevel> {
key: UnsafeCell<wc_dilithium_key>,
pub_bytes: Vec<u8>,
_level: PhantomData<L>,
}
unsafe impl<L: MlDsaLevel> Send for MlDsaVerifyingKey<L> {}
impl<L: MlDsaLevel> MlDsaVerifyingKey<L> {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, WolfCryptError> {
if bytes.len() != L::PUB_KEY_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let mut key = wc_dilithium_key::zeroed();
let rc = unsafe { wc_dilithium_init(&mut key) };
check(rc, "wc_dilithium_init")?;
let rc = unsafe { wc_dilithium_set_level(&mut key, L::LEVEL) };
check(rc, "wc_dilithium_set_level")?;
let rc = unsafe {
wc_dilithium_import_public(bytes.as_ptr(), len_as_u32(bytes.len()), &mut key)
};
check(rc, "wc_dilithium_import_public")?;
Ok(Self {
key: UnsafeCell::new(key),
pub_bytes: bytes.to_vec(),
_level: PhantomData,
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.pub_bytes
}
}
impl<L: MlDsaLevel> Drop for MlDsaVerifyingKey<L> {
fn drop(&mut self) {
unsafe {
wc_dilithium_free(self.key.get_mut());
}
}
}
impl<L: MlDsaLevel> signature_trait::Verifier<MlDsaSignature<L>> for MlDsaVerifyingKey<L> {
fn verify(
&self,
msg: &[u8],
signature: &MlDsaSignature<L>,
) -> Result<(), signature_trait::Error> {
let sig_bytes = signature.as_ref();
let mut result: i32 = 0;
let rc = unsafe {
wc_dilithium_verify_msg(
sig_bytes.as_ptr(),
len_as_u32(sig_bytes.len()),
msg.as_ptr(),
len_as_u32(msg.len()),
&mut result,
self.key.get(),
)
};
if rc != 0 || result != 1 {
return Err(signature_trait::Error::new());
}
Ok(())
}
}
pub type MlDsa44SigningKey = MlDsaSigningKey<MlDsa44>;
pub type MlDsa44VerifyingKey = MlDsaVerifyingKey<MlDsa44>;
pub type MlDsa44Signature = MlDsaSignature<MlDsa44>;
pub type MlDsa65SigningKey = MlDsaSigningKey<MlDsa65>;
pub type MlDsa65VerifyingKey = MlDsaVerifyingKey<MlDsa65>;
pub type MlDsa65Signature = MlDsaSignature<MlDsa65>;
pub type MlDsa87SigningKey = MlDsaSigningKey<MlDsa87>;
pub type MlDsa87VerifyingKey = MlDsaVerifyingKey<MlDsa87>;
pub type MlDsa87Signature = MlDsaSignature<MlDsa87>;