use alloc::vec;
use alloc::vec::Vec;
use core::ffi::c_int;
use core::marker::PhantomData;
use core::ptr;
use zeroize::ZeroizeOnDrop;
use crate::error::{check, len_as_u32, WolfCryptError};
use wolfcrypt_rs::{
wc_FreeRng, wc_InitRng, wc_MlKemKey_Decapsulate, wc_MlKemKey_DecodePrivateKey,
wc_MlKemKey_DecodePublicKey, wc_MlKemKey_Delete, wc_MlKemKey_Encapsulate,
wc_MlKemKey_EncodePrivateKey, wc_MlKemKey_EncodePublicKey, wc_MlKemKey_MakeKey,
wc_MlKemKey_New, MlKemKey, INVALID_DEVID, WC_ML_KEM_1024, WC_ML_KEM_1024_CIPHER_TEXT_SIZE,
WC_ML_KEM_1024_PRIVATE_KEY_SIZE, WC_ML_KEM_1024_PUBLIC_KEY_SIZE, WC_ML_KEM_512,
WC_ML_KEM_512_CIPHER_TEXT_SIZE, WC_ML_KEM_512_PRIVATE_KEY_SIZE, WC_ML_KEM_512_PUBLIC_KEY_SIZE,
WC_ML_KEM_768, WC_ML_KEM_768_CIPHER_TEXT_SIZE, WC_ML_KEM_768_PRIVATE_KEY_SIZE,
WC_ML_KEM_768_PUBLIC_KEY_SIZE, WC_ML_KEM_SS_SZ, WC_RNG,
};
mod sealed {
pub trait Sealed {}
}
pub trait MlKemLevel: sealed::Sealed + Send + 'static {
const TYPE: c_int;
const PK_SIZE: usize;
const SK_SIZE: usize;
const CT_SIZE: usize;
const SS_SIZE: usize;
}
pub struct MlKem512;
pub struct MlKem768;
pub struct MlKem1024;
impl sealed::Sealed for MlKem512 {}
impl sealed::Sealed for MlKem768 {}
impl sealed::Sealed for MlKem1024 {}
impl MlKemLevel for MlKem512 {
const TYPE: c_int = WC_ML_KEM_512;
const PK_SIZE: usize = WC_ML_KEM_512_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_512_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_512_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
impl MlKemLevel for MlKem768 {
const TYPE: c_int = WC_ML_KEM_768;
const PK_SIZE: usize = WC_ML_KEM_768_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_768_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_768_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
impl MlKemLevel for MlKem1024 {
const TYPE: c_int = WC_ML_KEM_1024;
const PK_SIZE: usize = WC_ML_KEM_1024_PUBLIC_KEY_SIZE;
const SK_SIZE: usize = WC_ML_KEM_1024_PRIVATE_KEY_SIZE;
const CT_SIZE: usize = WC_ML_KEM_1024_CIPHER_TEXT_SIZE;
const SS_SIZE: usize = WC_ML_KEM_SS_SZ;
}
#[derive(ZeroizeOnDrop)]
pub struct SharedSecret(#[zeroize(drop)] [u8; WC_ML_KEM_SS_SZ]);
impl SharedSecret {
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl PartialEq for SharedSecret {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for SharedSecret {}
pub struct MlKemDecapsulationKey<L: MlKemLevel> {
key: *mut MlKemKey,
rng: WC_RNG,
_level: PhantomData<L>,
}
unsafe impl<L: MlKemLevel> Send for MlKemDecapsulationKey<L> {}
impl<L: MlKemLevel> MlKemDecapsulationKey<L> {
pub fn generate() -> Result<Self, WolfCryptError> {
let key = unsafe { wc_MlKemKey_New(L::TYPE, ptr::null_mut(), INVALID_DEVID) };
if key.is_null() {
return Err(WolfCryptError::ALLOC_FAILED);
}
let mut rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut rng) };
if rc != 0 {
unsafe {
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_InitRng",
});
}
let rc = unsafe { wc_MlKemKey_MakeKey(key, &mut rng) };
if rc != 0 {
unsafe {
wc_FreeRng(&mut rng);
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_MlKemKey_MakeKey",
});
}
Ok(Self {
key,
rng,
_level: PhantomData,
})
}
pub fn encapsulation_key(&self) -> Result<MlKemEncapsulationKey<L>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
MlKemEncapsulationKey::from_bytes(&pk_buf)
}
pub fn public_key_bytes(&self) -> Result<Vec<u8>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
Ok(pk_buf)
}
pub fn private_key_bytes(&self) -> Result<zeroize::Zeroizing<Vec<u8>>, WolfCryptError> {
let mut sk_buf = vec![0u8; L::SK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePrivateKey(self.key, sk_buf.as_mut_ptr(), L::SK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePrivateKey")?;
Ok(zeroize::Zeroizing::new(sk_buf))
}
pub fn from_private_bytes(bytes: &[u8]) -> Result<Self, WolfCryptError> {
if bytes.len() != L::SK_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let key = unsafe { wc_MlKemKey_New(L::TYPE, ptr::null_mut(), INVALID_DEVID) };
if key.is_null() {
return Err(WolfCryptError::ALLOC_FAILED);
}
let rc =
unsafe { wc_MlKemKey_DecodePrivateKey(key, bytes.as_ptr(), len_as_u32(bytes.len())) };
if rc != 0 {
unsafe {
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_MlKemKey_DecodePrivateKey",
});
}
let mut rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut rng) };
if rc != 0 {
unsafe {
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_InitRng",
});
}
Ok(Self {
key,
rng,
_level: PhantomData,
})
}
pub fn decapsulate(&self, ct: &[u8]) -> Result<SharedSecret, WolfCryptError> {
if ct.len() != L::CT_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let mut ss = [0u8; WC_ML_KEM_SS_SZ];
let rc = unsafe {
wc_MlKemKey_Decapsulate(self.key, ss.as_mut_ptr(), ct.as_ptr(), len_as_u32(ct.len()))
};
check(rc, "wc_MlKemKey_Decapsulate")?;
Ok(SharedSecret(ss))
}
}
impl<L: MlKemLevel> Drop for MlKemDecapsulationKey<L> {
fn drop(&mut self) {
unsafe {
wc_MlKemKey_Delete(self.key, ptr::null_mut());
wc_FreeRng(&mut self.rng);
}
}
}
pub struct MlKemEncapsulationKey<L: MlKemLevel> {
key: *mut MlKemKey,
rng: WC_RNG,
_level: PhantomData<L>,
}
unsafe impl<L: MlKemLevel> Send for MlKemEncapsulationKey<L> {}
impl<L: MlKemLevel> MlKemEncapsulationKey<L> {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, WolfCryptError> {
if bytes.len() != L::PK_SIZE {
return Err(WolfCryptError::INVALID_INPUT);
}
let key = unsafe { wc_MlKemKey_New(L::TYPE, ptr::null_mut(), INVALID_DEVID) };
if key.is_null() {
return Err(WolfCryptError::ALLOC_FAILED);
}
let rc =
unsafe { wc_MlKemKey_DecodePublicKey(key, bytes.as_ptr(), len_as_u32(bytes.len())) };
if rc != 0 {
unsafe {
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_MlKemKey_DecodePublicKey",
});
}
let mut rng = WC_RNG::zeroed();
let rc = unsafe { wc_InitRng(&mut rng) };
if rc != 0 {
unsafe {
wc_MlKemKey_Delete(key, ptr::null_mut());
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_InitRng",
});
}
Ok(Self {
key,
rng,
_level: PhantomData,
})
}
pub fn as_bytes(&self) -> Result<Vec<u8>, WolfCryptError> {
let mut pk_buf = vec![0u8; L::PK_SIZE];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key, pk_buf.as_mut_ptr(), L::PK_SIZE as u32)
};
check(rc, "wc_MlKemKey_EncodePublicKey")?;
Ok(pk_buf)
}
pub fn encapsulate(&mut self) -> Result<(Vec<u8>, SharedSecret), WolfCryptError> {
let mut ct = vec![0u8; L::CT_SIZE];
let mut ss = [0u8; WC_ML_KEM_SS_SZ];
let rc = unsafe {
wc_MlKemKey_Encapsulate(self.key, ct.as_mut_ptr(), ss.as_mut_ptr(), &mut self.rng)
};
check(rc, "wc_MlKemKey_Encapsulate")?;
Ok((ct, SharedSecret(ss)))
}
}
impl<L: MlKemLevel> Drop for MlKemEncapsulationKey<L> {
fn drop(&mut self) {
unsafe {
wc_MlKemKey_Delete(self.key, ptr::null_mut());
wc_FreeRng(&mut self.rng);
}
}
}
pub type MlKem512DecapsulationKey = MlKemDecapsulationKey<MlKem512>;
pub type MlKem768DecapsulationKey = MlKemDecapsulationKey<MlKem768>;
pub type MlKem1024DecapsulationKey = MlKemDecapsulationKey<MlKem1024>;
pub type MlKem512EncapsulationKey = MlKemEncapsulationKey<MlKem512>;
pub type MlKem768EncapsulationKey = MlKemEncapsulationKey<MlKem768>;
pub type MlKem1024EncapsulationKey = MlKemEncapsulationKey<MlKem1024>;