use alloc::boxed::Box;
use super::common::{PcStatus, guard, out_write, slice};
use crate::mlkem::{
MlKem512Ciphertext, MlKem512DecapsKey, MlKem512EncapsKey, MlKem768Ciphertext,
MlKem768DecapsKey, MlKem768EncapsKey, MlKem1024Ciphertext, MlKem1024DecapsKey,
MlKem1024EncapsKey,
};
use crate::rng::OsRng;
pub mod set_id {
#![allow(missing_docs)]
pub const ML_KEM_512: i32 = 1;
pub const ML_KEM_768: i32 = 2;
pub const ML_KEM_1024: i32 = 3;
}
pub enum PcMlKem {
K512(Box<MlKem512DecapsKey>),
K768(Box<MlKem768DecapsKey>),
K1024(Box<MlKem1024DecapsKey>),
}
#[unsafe(no_mangle)]
pub extern "C" fn pc_mlkem_generate(set: i32) -> *mut PcMlKem {
crate::ffi::common::guard_ptr(|| {
let k = match set {
set_id::ML_KEM_512 => {
let (sk, _) = MlKem512DecapsKey::generate(&mut OsRng);
PcMlKem::K512(Box::new(sk))
}
set_id::ML_KEM_768 => {
let (sk, _) = MlKem768DecapsKey::generate(&mut OsRng);
PcMlKem::K768(Box::new(sk))
}
set_id::ML_KEM_1024 => {
let (sk, _) = MlKem1024DecapsKey::generate(&mut OsRng);
PcMlKem::K1024(Box::new(sk))
}
_ => return core::ptr::null_mut(),
};
Box::into_raw(Box::new(k))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_from_pkcs8_pem(pem: *const u8, len: usize) -> *mut PcMlKem {
crate::ffi::common::guard_ptr(|| {
let Some(bytes) = (unsafe { slice(pem, len) }) else {
return core::ptr::null_mut();
};
let Ok(s) = core::str::from_utf8(bytes) else {
return core::ptr::null_mut();
};
if let Ok(k) = MlKem768DecapsKey::from_pkcs8_pem(s) {
return Box::into_raw(Box::new(PcMlKem::K768(Box::new(k))));
}
if let Ok(k) = MlKem512DecapsKey::from_pkcs8_pem(s) {
return Box::into_raw(Box::new(PcMlKem::K512(Box::new(k))));
}
if let Ok(k) = MlKem1024DecapsKey::from_pkcs8_pem(s) {
return Box::into_raw(Box::new(PcMlKem::K1024(Box::new(k))));
}
core::ptr::null_mut()
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_private_to_pem(
k: *const PcMlKem,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if k.is_null() {
return PcStatus::NullPointer;
}
let pem = match unsafe { &*k } {
PcMlKem::K512(sk) => sk.to_pkcs8_pem(),
PcMlKem::K768(sk) => sk.to_pkcs8_pem(),
PcMlKem::K1024(sk) => sk.to_pkcs8_pem(),
};
unsafe { out_write(pem.as_bytes(), out, out_len) }
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_public_to_pem(
k: *const PcMlKem,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if k.is_null() {
return PcStatus::NullPointer;
}
let pem = match unsafe { &*k } {
PcMlKem::K512(sk) => sk.encapsulation_key().to_spki_pem(),
PcMlKem::K768(sk) => sk.encapsulation_key().to_spki_pem(),
PcMlKem::K1024(sk) => sk.encapsulation_key().to_spki_pem(),
};
unsafe { out_write(pem.as_bytes(), out, out_len) }
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_public_to_der(
k: *const PcMlKem,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if k.is_null() {
return PcStatus::NullPointer;
}
let der = match unsafe { &*k } {
PcMlKem::K512(sk) => sk.encapsulation_key().to_spki_der(),
PcMlKem::K768(sk) => sk.encapsulation_key().to_spki_der(),
PcMlKem::K1024(sk) => sk.encapsulation_key().to_spki_der(),
};
unsafe { out_write(&der, out, out_len) }
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_encaps(
set: i32,
ek_spki: *const u8,
ek_spki_len: usize,
ct: *mut u8,
ct_len: *mut usize,
ss: *mut u8,
) -> PcStatus {
guard(|| {
let Some(spki) = (unsafe { slice(ek_spki, ek_spki_len) }) else {
return PcStatus::NullPointer;
};
if ss.is_null() {
return PcStatus::NullPointer;
}
let (ct_bytes, secret): (alloc::vec::Vec<u8>, [u8; 32]) = match set {
set_id::ML_KEM_512 => {
let k = match MlKem512EncapsKey::from_spki_der(spki) {
Ok(k) => k,
Err(_) => return PcStatus::BadEncoding,
};
let bytes = k.to_bytes();
if MlKem512EncapsKey::from_bytes_validated(bytes).is_err() {
return PcStatus::BadEncoding;
}
let (c, s) = k.encapsulate(&mut OsRng);
(c.to_bytes().to_vec(), s)
}
set_id::ML_KEM_768 => {
let k = match MlKem768EncapsKey::from_spki_der(spki) {
Ok(k) => k,
Err(_) => return PcStatus::BadEncoding,
};
let bytes = k.to_bytes();
if MlKem768EncapsKey::from_bytes_validated(bytes).is_err() {
return PcStatus::BadEncoding;
}
let (c, s) = k.encapsulate(&mut OsRng);
(c.to_bytes().to_vec(), s)
}
set_id::ML_KEM_1024 => {
let k = match MlKem1024EncapsKey::from_spki_der(spki) {
Ok(k) => k,
Err(_) => return PcStatus::BadEncoding,
};
let bytes = k.to_bytes();
if MlKem1024EncapsKey::from_bytes_validated(bytes).is_err() {
return PcStatus::BadEncoding;
}
let (c, s) = k.encapsulate(&mut OsRng);
(c.to_bytes().to_vec(), s)
}
_ => return PcStatus::Unsupported,
};
let mut secret = secret;
let st = unsafe { out_write(&ct_bytes, ct, ct_len) };
if st != PcStatus::Ok {
wipe_array(&mut secret);
return st;
}
unsafe { core::ptr::copy_nonoverlapping(secret.as_ptr(), ss, 32) };
wipe_array(&mut secret);
PcStatus::Ok
})
}
fn wipe_array(buf: &mut [u8]) {
for b in buf.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&buf);
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_decaps(
k: *const PcMlKem,
ct: *const u8,
ct_len: usize,
ss: *mut u8,
) -> PcStatus {
guard(|| {
if k.is_null() {
return PcStatus::NullPointer;
}
let Some(c) = (unsafe { slice(ct, ct_len) }) else {
return PcStatus::NullPointer;
};
if ss.is_null() {
return PcStatus::NullPointer;
}
let mut secret = match unsafe { &*k } {
PcMlKem::K512(sk) => {
let arr: [u8; 768] = match c.try_into() {
Ok(a) => a,
Err(_) => return PcStatus::BadEncoding,
};
sk.decapsulate(&MlKem512Ciphertext::from_bytes(arr))
}
PcMlKem::K768(sk) => {
let arr: [u8; 1088] = match c.try_into() {
Ok(a) => a,
Err(_) => return PcStatus::BadEncoding,
};
sk.decapsulate(&MlKem768Ciphertext::from_bytes(arr))
}
PcMlKem::K1024(sk) => {
let arr: [u8; 1568] = match c.try_into() {
Ok(a) => a,
Err(_) => return PcStatus::BadEncoding,
};
sk.decapsulate(&MlKem1024Ciphertext::from_bytes(arr))
}
};
unsafe { core::ptr::copy_nonoverlapping(secret.as_ptr(), ss, 32) };
wipe_array(&mut secret);
PcStatus::Ok
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_mlkem_free(k: *mut PcMlKem) {
if !k.is_null() {
drop(unsafe { Box::from_raw(k) });
}
}