#![cfg_attr(not(feature = "std"), no_std)]
#![warn(clippy::all)]
#![allow(clippy::needless_range_loop)]
extern crate alloc;
mod compress;
mod field;
mod hash;
mod kpke;
mod mlkem;
mod ntt;
mod params;
mod poly;
mod sample;
mod serialize;
use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub use params::{Params, Params1024, Params512, Params768};
macro_rules! mlkem_api {
($name:ident, $params:ty, $pkty:ident, $skty:ident, $ctty:ident, $ssty:ident,
$pk:expr, $sk:expr, $ct:expr) => {
pub struct $name;
impl $name {
pub const PUBLIC_KEY_SIZE: usize = $pk;
pub const SECRET_KEY_SIZE: usize = $sk;
pub const CIPHERTEXT_SIZE: usize = $ct;
pub const SHARED_SECRET_SIZE: usize = 32;
pub fn keygen_deterministic(seed: &[u8; 64]) -> ($pkty, $skty) {
let mut d = [0u8; 32];
let mut z = [0u8; 32];
d.copy_from_slice(&seed[..32]);
z.copy_from_slice(&seed[32..]);
let (pk, sk) = mlkem::MlKem::<$params>::keygen(&d, &z);
let mut pk_arr = [0u8; $pk];
let mut sk_arr = [0u8; $sk];
pk_arr.copy_from_slice(&pk);
sk_arr.copy_from_slice(&sk);
($pkty(pk_arr), $skty(sk_arr))
}
pub fn keygen<R: RngCore + CryptoRng>(rng: &mut R) -> ($pkty, $skty) {
let mut seed = [0u8; 64];
rng.fill_bytes(&mut seed);
Self::keygen_deterministic(&seed)
}
pub fn encapsulate_deterministic(pk: &$pkty, m: &[u8; 32]) -> ($ctty, $ssty) {
let (ct, ss) = mlkem::MlKem::<$params>::encapsulate(&pk.0, m);
let mut ct_arr = [0u8; $ct];
ct_arr.copy_from_slice(&ct);
($ctty(ct_arr), $ssty(ss))
}
pub fn encapsulate<R: RngCore + CryptoRng>(pk: &$pkty, rng: &mut R) -> ($ctty, $ssty) {
let mut m = [0u8; 32];
rng.fill_bytes(&mut m);
Self::encapsulate_deterministic(pk, &m)
}
pub fn decapsulate(sk: &$skty, ct: &$ctty) -> $ssty {
$ssty(mlkem::MlKem::<$params>::decapsulate(&sk.0, &ct.0))
}
}
#[derive(Clone)]
pub struct $pkty(pub(crate) [u8; $pk]);
#[derive(Clone, ZeroizeOnDrop)]
pub struct $skty(pub(crate) [u8; $sk]);
#[derive(Clone)]
pub struct $ctty(pub(crate) [u8; $ct]);
#[derive(Clone, ZeroizeOnDrop)]
pub struct $ssty(pub(crate) [u8; 32]);
impl $pkty {
pub fn as_bytes(&self) -> &[u8; $pk] {
&self.0
}
pub fn from_bytes(b: &[u8; $pk]) -> Self {
Self(*b)
}
}
impl $skty {
pub fn as_bytes(&self) -> &[u8; $sk] {
&self.0
}
pub fn from_bytes(b: &[u8; $sk]) -> Self {
Self(*b)
}
}
impl $ctty {
pub fn as_bytes(&self) -> &[u8; $ct] {
&self.0
}
pub fn from_bytes(b: &[u8; $ct]) -> Self {
Self(*b)
}
}
impl $ssty {
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl PartialEq for $pkty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $pkty {}
impl PartialEq for $skty {
fn eq(&self, other: &Self) -> bool {
self.0.as_slice().ct_eq(other.0.as_slice()).into()
}
}
impl Eq for $skty {}
impl PartialEq for $ctty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $ctty {}
impl PartialEq for $ssty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $ssty {}
impl core::fmt::Debug for $pkty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
concat!(stringify!($pkty), "(..{} bytes..)"),
self.0.len()
)
}
}
impl core::fmt::Debug for $skty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, concat!(stringify!($skty), "(..REDACTED..)"))
}
}
impl core::fmt::Debug for $ctty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
concat!(stringify!($ctty), "(..{} bytes..)"),
self.0.len()
)
}
}
impl core::fmt::Debug for $ssty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, concat!(stringify!($ssty), "(..REDACTED..)"))
}
}
impl Zeroize for $skty {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl Zeroize for $ssty {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
};
}
mlkem_api!(
MlKem512,
Params512,
PublicKey512,
SecretKey512,
Ciphertext512,
SharedSecret512,
800,
1632,
768
);
mlkem_api!(
MlKem768,
Params768,
PublicKey768,
SecretKey768,
Ciphertext768,
SharedSecret768,
1184,
2400,
1088
);
mlkem_api!(
MlKem1024,
Params1024,
PublicKey1024,
SecretKey1024,
Ciphertext1024,
SharedSecret1024,
1568,
3168,
1568
);
pub type PublicKey = PublicKey768;
pub type SecretKey = SecretKey768;
pub type Ciphertext = Ciphertext768;
pub type SharedSecret = SharedSecret768;