use ml_kem::kem::{Decapsulate, Encapsulate};
use ml_kem::{Ciphertext, Encoded, EncodedSizeUser, KemCore, MlKem768};
use rand_core::CryptoRngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::RatchetError;
pub const PQ_EK_LEN: usize = 1184;
pub const PQ_CT_LEN: usize = 1088;
pub const PQ_SS_LEN: usize = 32;
pub const PQ_DK_LEN: usize = 2400;
const _: () = assert!(
core::mem::size_of::<Encoded<DK768>>() == PQ_DK_LEN,
"PQ_DK_LEN mismatch -- update the constant if ml-kem changes the DK size"
);
const _: () = assert!(
core::mem::size_of::<Encoded<EK768>>() == PQ_EK_LEN,
"PQ_EK_LEN mismatch -- update the constant if ml-kem changes the EK size"
);
const _: () = assert!(
core::mem::size_of::<CT768>() == PQ_CT_LEN,
"PQ_CT_LEN mismatch -- update the constant if ml-kem changes the CT size"
);
type EK768 = <MlKem768 as KemCore>::EncapsulationKey;
type DK768 = <MlKem768 as KemCore>::DecapsulationKey;
type CT768 = Ciphertext<MlKem768>;
#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
pub struct PqEk(pub [u8; PQ_EK_LEN]);
#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
pub struct PqCt(pub [u8; PQ_CT_LEN]);
pub struct SckaState {
our_dk: DK768,
our_ek: PqEk,
pending_ct: Option<PqCt>,
}
impl Drop for SckaState {
fn drop(&mut self) {
self.our_ek.0.zeroize();
}
}
impl SckaState {
pub fn new(rng: &mut impl CryptoRngCore) -> Self {
let (dk, ek) = MlKem768::generate(rng);
Self {
our_ek: PqEk(ek_to_bytes(&ek)),
our_dk: dk,
pending_ct: None,
}
}
pub fn encap_to(
&self,
peer_ek: &PqEk,
rng: &mut impl CryptoRngCore,
) -> Result<([u8; PQ_SS_LEN], PqCt), RatchetError> {
let ek = ek_from_bytes(&peer_ek.0).ok_or(RatchetError::InvalidPqEk)?;
let (ct, ss) = ek.encapsulate(rng).map_err(|_| RatchetError::EncapFailed)?;
let mut ct_bytes = [0u8; PQ_CT_LEN];
ct_bytes.copy_from_slice(ct.as_slice());
let mut ss_bytes = [0u8; PQ_SS_LEN];
ss_bytes.copy_from_slice(ss.as_slice());
Ok((ss_bytes, PqCt(ct_bytes)))
}
pub fn decap(&self, peer_ct: &PqCt) -> Result<[u8; PQ_SS_LEN], RatchetError> {
let ct = ct_from_bytes(&peer_ct.0).ok_or(RatchetError::InvalidPqCt)?;
let ss = self
.our_dk
.decapsulate(&ct)
.map_err(|_| RatchetError::DecapFailed)?;
let mut ss_bytes = [0u8; PQ_SS_LEN];
ss_bytes.copy_from_slice(ss.as_slice());
Ok(ss_bytes)
}
pub fn our_ek(&self) -> &PqEk {
&self.our_ek
}
pub(crate) fn dk_bytes(&self) -> [u8; PQ_DK_LEN] {
let encoded = self.our_dk.as_bytes();
let mut buf = [0u8; PQ_DK_LEN];
buf.copy_from_slice(encoded.as_slice());
buf
}
pub(crate) fn ek_bytes_raw(&self) -> &[u8; PQ_EK_LEN] {
&self.our_ek.0
}
pub(crate) fn pending_ct_ref(&self) -> Option<&PqCt> {
self.pending_ct.as_ref()
}
pub(crate) fn from_parts(
dk_bytes: &[u8; PQ_DK_LEN],
ek_bytes: [u8; PQ_EK_LEN],
pending_ct: Option<PqCt>,
) -> Option<Self> {
let arr = Encoded::<DK768>::try_from(dk_bytes.as_slice()).ok()?;
let dk = DK768::from_bytes(&arr);
Some(Self {
our_dk: dk,
our_ek: PqEk(ek_bytes),
pending_ct,
})
}
pub fn set_pending_ct(&mut self, ct: PqCt) {
self.pending_ct = Some(ct);
}
}
fn ek_to_bytes(ek: &EK768) -> [u8; PQ_EK_LEN] {
let encoded = ek.as_bytes(); let mut buf = [0u8; PQ_EK_LEN];
buf.copy_from_slice(encoded.as_slice()); buf
}
fn ek_from_bytes(bytes: &[u8; PQ_EK_LEN]) -> Option<EK768> {
let arr = Encoded::<EK768>::try_from(bytes.as_slice()).ok()?;
Some(EK768::from_bytes(&arr))
}
fn ct_from_bytes(bytes: &[u8; PQ_CT_LEN]) -> Option<CT768> {
CT768::try_from(bytes.as_slice()).ok()
}