use alloc::vec;
use alloc::vec::Vec;
use core::marker::PhantomData;
use core::ffi::c_int;
use core::ptr;
use zeroize::ZeroizeOnDrop;
use crate::error::{check, len_as_u32, WolfCryptError};
use wolfcrypt_rs::{
wc_FreeRng, wc_InitRng,
wc_MlKemKey_Decapsulate, wc_MlKemKey_Delete, wc_MlKemKey_Encapsulate,
wc_MlKemKey_DecodePublicKey, wc_MlKemKey_EncodePublicKey,
wc_MlKemKey_EncodePrivateKey,
wc_MlKemKey_MakeKey, wc_MlKemKey_New,
MlKemKey, WC_RNG, INVALID_DEVID,
WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
WC_ML_KEM_SS_SZ,
WC_ML_KEM_512_PUBLIC_KEY_SIZE, WC_ML_KEM_512_PRIVATE_KEY_SIZE, WC_ML_KEM_512_CIPHER_TEXT_SIZE,
WC_ML_KEM_768_PUBLIC_KEY_SIZE, WC_ML_KEM_768_PRIVATE_KEY_SIZE, WC_ML_KEM_768_CIPHER_TEXT_SIZE,
WC_ML_KEM_1024_PUBLIC_KEY_SIZE, WC_ML_KEM_1024_PRIVATE_KEY_SIZE, WC_ML_KEM_1024_CIPHER_TEXT_SIZE,
};
pub trait MlKemLevel: Send + 'static {
const TYPE: c_int;
const PK_SIZE: usize;
const SK_SIZE: usize;
const CT_SIZE: usize;
const SS_SIZE: usize;
}
pub struct MlKem512;
pub struct MlKem768;
pub struct MlKem1024;
impl MlKemLevel for MlKem512 {
const TYPE: c_int = WC_ML_KEM_512;
const PK_SIZE: usize = WC_ML_KEM_512_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_512_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_512_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
impl MlKemLevel for MlKem768 {
const TYPE: c_int = WC_ML_KEM_768;
const PK_SIZE: usize = WC_ML_KEM_768_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_768_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_768_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
impl MlKemLevel for MlKem1024 {
const TYPE: c_int = WC_ML_KEM_1024;
const PK_SIZE: usize = WC_ML_KEM_1024_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_1024_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_1024_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
#[derive(ZeroizeOnDrop)]
pub struct SharedSecret(#[zeroize(drop)] [u8; WC_ML_KEM_SS_SZ]);
impl SharedSecret {
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl PartialEq for SharedSecret {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for SharedSecret {}
pub struct MlKemDecapsulationKey<L: MlKemLevel> {
key: *mut MlKemKey,
rng: WC_RNG,
_level: PhantomData<L>,
}
unsafe impl<L: MlKemLevel> Send for MlKemDecapsulationKey<L> {}
impl<L: MlKemLevel> MlKemDecapsulationKey<L> {
pub fn generate() -> Result<Self, WolfCryptError> {
let key = unsafe {
wc_MlKemKey_New(L::TYPE, ptr::null_mut(), INVALID_DEVID)
};
if key.is_null() {
return Err(WolfCryptError::ALLOC_FAILED);
}
let mut rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut rng) };
if rc != 0 {
unsafe { wc_MlKemKey_Delete(key, ptr::null_mut()); }
return Err(WolfCryptError::Ffi { code: rc, func: "wc_InitRng" });
}
let rc = unsafe { wc_MlKemKey_MakeKey(key, &mut rng) };
if rc != 0 {
unsafe {
wc_FreeRng(&mut rng);
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi { code: rc, func: "wc_MlKemKey_MakeKey" });
}
Ok(Self { key, rng, _level: PhantomData })
}
pub fn encapsulation_key(&self) -> Result<MlKemEncapsulationKey<L>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
MlKemEncapsulationKey::from_bytes(&pk_buf)
}
pub fn public_key_bytes(&self) -> Result<Vec<u8>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
Ok(pk_buf)
}
pub fn private_key_bytes(&self) -> Result<zeroize::Zeroizing<Vec<u8>>, WolfCryptError> {
let mut sk_buf = vec![0u8; L::SK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePrivateKey(self.key, sk_buf.as_mut_ptr(), L::SK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePrivateKey")?;
Ok(zeroize::Zeroizing::new(sk_buf))
}
pub fn decapsulate(&self, ct: &[u8]) -> Result<SharedSecret, WolfCryptError> {
if ct.len() != L::CT_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let mut ss = [0u8; WC_ML_KEM_SS_SZ];
let rc = unsafe {
wc_MlKemKey_Decapsulate(self.key, ss.as_mut_ptr(), ct.as_ptr(), len_as_u32(ct.len()))
};
check(rc, "wc_MlKemKey_Decapsulate")?;
Ok(SharedSecret(ss))
}
}
impl<L: MlKemLevel> Drop for MlKemDecapsulationKey<L> {
fn drop(&mut self) {
unsafe {
wc_MlKemKey_Delete(self.key, ptr::null_mut());
wc_FreeRng(&mut self.rng);
}
}
}
pub struct MlKemEncapsulationKey<L: MlKemLevel> {
key: *mut MlKemKey,
rng: WC_RNG,
_level: PhantomData<L>,
}
unsafe impl<L: MlKemLevel> Send for MlKemEncapsulationKey<L> {}
impl<L: MlKemLevel> MlKemEncapsulationKey<L> {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, WolfCryptError> {
if bytes.len() != L::PK_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let key = unsafe {
wc_MlKemKey_New(L::TYPE, ptr::null_mut(), INVALID_DEVID)
};
if key.is_null() {
return Err(WolfCryptError::ALLOC_FAILED);
}
let rc = unsafe {
wc_MlKemKey_DecodePublicKey(key, bytes.as_ptr(), len_as_u32(bytes.len()))
};
if rc != 0 {
unsafe { wc_MlKemKey_Delete(key, ptr::null_mut()); }
return Err(WolfCryptError::Ffi { code: rc, func: "wc_MlKemKey_DecodePublicKey" });
}
let mut rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut rng) };
if rc != 0 {
unsafe { wc_MlKemKey_Delete(key, ptr::null_mut()); }
return Err(WolfCryptError::Ffi { code: rc, func: "wc_InitRng" });
}
Ok(Self { key, rng, _level: PhantomData })
}
pub fn as_bytes(&self) -> Result<Vec<u8>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
Ok(pk_buf)
}
pub fn encapsulate(&mut self) -> Result<(Vec<u8>, SharedSecret), WolfCryptError> {
let mut ct = vec![0u8; L::CT_SIZE];
let mut ss = [0u8; WC_ML_KEM_SS_SZ];
let rc = unsafe {
wc_MlKemKey_Encapsulate(
self.key,
ct.as_mut_ptr(),
ss.as_mut_ptr(),
&mut self.rng,
)
};
check(rc, "wc_MlKemKey_Encapsulate")?;
Ok((ct, SharedSecret(ss)))
}
}
impl<L: MlKemLevel> Drop for MlKemEncapsulationKey<L> {
fn drop(&mut self) {
unsafe {
wc_MlKemKey_Delete(self.key, ptr::null_mut());
wc_FreeRng(&mut self.rng);
}
}
}
pub type MlKem512DecapsulationKey = MlKemDecapsulationKey<MlKem512>;
pub type MlKem768DecapsulationKey = MlKemDecapsulationKey<MlKem768>;
pub type MlKem1024DecapsulationKey = MlKemDecapsulationKey<MlKem1024>;
pub type MlKem512EncapsulationKey = MlKemEncapsulationKey<MlKem512>;
pub type MlKem768EncapsulationKey = MlKemEncapsulationKey<MlKem768>;
pub type MlKem1024EncapsulationKey = MlKemEncapsulationKey<MlKem1024>;