#![cfg_attr(not(feature = "std"), no_std)]
#![warn(clippy::all, clippy::pedantic)]
#![warn(missing_debug_implementations)]
#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
use x25519_dalek::{PublicKey as XPub, StaticSecret};
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const X25519_BYTES: usize = 32;
pub const X25519_SS_BYTES: usize = 32;
pub const MLKEM_SS_BYTES: usize = 32;
pub const SHARED_SECRET_BYTES: usize = MLKEM_SS_BYTES + X25519_SS_BYTES;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LengthError {
pub expected: usize,
pub got: usize,
}
impl core::fmt::Display for LengthError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"wrong byte length: expected {}, got {}",
self.expected, self.got
)
}
}
#[cfg(feature = "std")]
impl std::error::Error for LengthError {}
fn x25519_keypair_from_seed(seed: [u8; 32]) -> (StaticSecret, XPub) {
let sk = StaticSecret::from(seed);
let pk = XPub::from(&sk);
(sk, pk)
}
fn fill_seed_pair<R: RngCore + CryptoRng>(rng: &mut R) -> ([u8; 32], [u8; 64]) {
let mut x = [0u8; 32];
let mut m = [0u8; 64];
rng.fill_bytes(&mut x);
rng.fill_bytes(&mut m);
(x, m)
}
macro_rules! hybrid_kem {
($name:ident, $pq:ident, $pq_pk_ty:ident, $pq_sk_ty:ident, $pq_ct_ty:ident,
$ek_ty:ident, $dk_ty:ident, $ct_ty:ident, $ss_ty:ident,
$pq_pk:expr, $pq_sk:expr, $pq_ct:expr,
$ek_size:expr, $dk_size:expr, $ct_size:expr) => {
#[derive(Debug)]
pub struct $name;
impl $name {
pub const ENCAPSULATION_KEY_SIZE: usize = $ek_size;
pub const DECAPSULATION_KEY_SIZE: usize = $dk_size;
pub const CIPHERTEXT_SIZE: usize = $ct_size;
pub const SHARED_SECRET_SIZE: usize = SHARED_SECRET_BYTES;
pub fn keygen<R: RngCore + CryptoRng>(rng: &mut R) -> ($ek_ty, $dk_ty) {
let (x_seed, m_seed) = fill_seed_pair(rng);
let (xsk, xpk) = x25519_keypair_from_seed(x_seed);
let (mpk, msk) = mlkem::$pq::keygen_deterministic(&m_seed);
let mut ek = [0u8; $ek_size];
ek[..$pq_pk].copy_from_slice(mpk.as_bytes());
ek[$pq_pk..].copy_from_slice(xpk.as_bytes());
let mut dk = [0u8; $dk_size];
dk[..$pq_sk].copy_from_slice(msk.as_bytes());
dk[$pq_sk..].copy_from_slice(&xsk.to_bytes());
($ek_ty(ek), $dk_ty(dk))
}
pub fn encapsulate<R: RngCore + CryptoRng>(
ek: &$ek_ty,
rng: &mut R,
) -> ($ct_ty, $ss_ty) {
let mpk_bytes: &[u8; $pq_pk] =
(&ek.0[..$pq_pk]).try_into().expect("ek length checked");
let xpk_bytes: &[u8; X25519_BYTES] =
(&ek.0[$pq_pk..]).try_into().expect("ek length checked");
let mpk = mlkem::$pq_pk_ty::from_bytes(mpk_bytes);
let xpk = XPub::from(*xpk_bytes);
let (mct, mss) = mlkem::$pq::encapsulate(&mpk, rng);
let mut x_seed = [0u8; 32];
rng.fill_bytes(&mut x_seed);
let xsk = ReusableSecretWrapper::from(x_seed);
let xpk_eph = XPub::from(&xsk.0);
let xss = xsk.0.diffie_hellman(&xpk);
let mut ct = [0u8; $ct_size];
ct[..$pq_ct].copy_from_slice(mct.as_bytes());
ct[$pq_ct..].copy_from_slice(xpk_eph.as_bytes());
let mut ss = [0u8; SHARED_SECRET_BYTES];
ss[..MLKEM_SS_BYTES].copy_from_slice(mss.as_bytes());
ss[MLKEM_SS_BYTES..].copy_from_slice(xss.as_bytes());
($ct_ty(ct), $ss_ty(ss))
}
pub fn decapsulate(dk: &$dk_ty, ct: &$ct_ty) -> $ss_ty {
let msk_bytes: &[u8; $pq_sk] =
(&dk.0[..$pq_sk]).try_into().expect("dk length checked");
let xsk_bytes: &[u8; X25519_BYTES] =
(&dk.0[$pq_sk..]).try_into().expect("dk length checked");
let msk = mlkem::$pq_sk_ty::from_bytes(msk_bytes);
let xsk = StaticSecret::from(*xsk_bytes);
let mct_bytes: &[u8; $pq_ct] =
(&ct.0[..$pq_ct]).try_into().expect("ct length checked");
let xpk_bytes: &[u8; X25519_BYTES] =
(&ct.0[$pq_ct..]).try_into().expect("ct length checked");
let mct = mlkem::$pq_ct_ty::from_bytes(mct_bytes);
let xpk = XPub::from(*xpk_bytes);
let mss = mlkem::$pq::decapsulate(&msk, &mct);
let xss = xsk.diffie_hellman(&xpk);
let mut ss = [0u8; SHARED_SECRET_BYTES];
ss[..MLKEM_SS_BYTES].copy_from_slice(mss.as_bytes());
ss[MLKEM_SS_BYTES..].copy_from_slice(xss.as_bytes());
$ss_ty(ss)
}
}
#[derive(Clone)]
pub struct $ek_ty(pub(crate) [u8; $ek_size]);
#[derive(Clone, ZeroizeOnDrop)]
pub struct $dk_ty(pub(crate) [u8; $dk_size]);
#[derive(Clone)]
pub struct $ct_ty(pub(crate) [u8; $ct_size]);
#[derive(Clone, ZeroizeOnDrop)]
pub struct $ss_ty(pub(crate) [u8; SHARED_SECRET_BYTES]);
impl $ek_ty {
pub fn as_bytes(&self) -> &[u8; $ek_size] {
&self.0
}
pub fn from_bytes(b: &[u8; $ek_size]) -> Self {
Self(*b)
}
}
impl $dk_ty {
pub fn as_bytes(&self) -> &[u8; $dk_size] {
&self.0
}
pub fn from_bytes(b: &[u8; $dk_size]) -> Self {
Self(*b)
}
}
impl $ct_ty {
pub fn as_bytes(&self) -> &[u8; $ct_size] {
&self.0
}
pub fn from_bytes(b: &[u8; $ct_size]) -> Self {
Self(*b)
}
}
impl $ss_ty {
pub fn as_bytes(&self) -> &[u8; SHARED_SECRET_BYTES] {
&self.0
}
}
impl AsRef<[u8]> for $ek_ty {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsRef<[u8]> for $ct_ty {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsRef<[u8]> for $ss_ty {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsRef<[u8]> for $dk_ty {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl TryFrom<&[u8]> for $ek_ty {
type Error = LengthError;
fn try_from(b: &[u8]) -> Result<Self, LengthError> {
if b.len() != $ek_size {
return Err(LengthError {
expected: $ek_size,
got: b.len(),
});
}
let mut a = [0u8; $ek_size];
a.copy_from_slice(b);
Ok(Self(a))
}
}
impl TryFrom<&[u8]> for $ct_ty {
type Error = LengthError;
fn try_from(b: &[u8]) -> Result<Self, LengthError> {
if b.len() != $ct_size {
return Err(LengthError {
expected: $ct_size,
got: b.len(),
});
}
let mut a = [0u8; $ct_size];
a.copy_from_slice(b);
Ok(Self(a))
}
}
impl TryFrom<&[u8]> for $dk_ty {
type Error = LengthError;
fn try_from(b: &[u8]) -> Result<Self, LengthError> {
if b.len() != $dk_size {
return Err(LengthError {
expected: $dk_size,
got: b.len(),
});
}
let mut a = [0u8; $dk_size];
a.copy_from_slice(b);
Ok(Self(a))
}
}
impl PartialEq for $ek_ty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $ek_ty {}
impl PartialEq for $ct_ty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $ct_ty {}
impl PartialEq for $ss_ty {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for $ss_ty {}
impl PartialEq for $dk_ty {
fn eq(&self, other: &Self) -> bool {
self.0.as_slice().ct_eq(other.0.as_slice()).into()
}
}
impl Eq for $dk_ty {}
impl core::fmt::Debug for $ek_ty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
concat!(stringify!($ek_ty), "(..{} bytes..)"),
self.0.len()
)
}
}
impl core::fmt::Debug for $dk_ty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, concat!(stringify!($dk_ty), "(..REDACTED..)"))
}
}
impl core::fmt::Debug for $ct_ty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
concat!(stringify!($ct_ty), "(..{} bytes..)"),
self.0.len()
)
}
}
impl core::fmt::Debug for $ss_ty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, concat!(stringify!($ss_ty), "(..REDACTED..)"))
}
}
impl Zeroize for $dk_ty {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl Zeroize for $ss_ty {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
};
}
struct ReusableSecretWrapper(StaticSecret);
impl From<[u8; 32]> for ReusableSecretWrapper {
fn from(b: [u8; 32]) -> Self {
Self(StaticSecret::from(b))
}
}
hybrid_kem!(
X25519MlKem768,
MlKem768,
PublicKey768,
SecretKey768,
Ciphertext768,
EncapsKey768,
DecapsKey768,
Ciphertext768Hybrid,
SharedSecret768Hybrid,
1184,
2400,
1088,
1216,
2432,
1120
);
hybrid_kem!(
X25519MlKem1024,
MlKem1024,
PublicKey1024,
SecretKey1024,
Ciphertext1024,
EncapsKey1024,
DecapsKey1024,
Ciphertext1024Hybrid,
SharedSecret1024Hybrid,
1568,
3168,
1568,
1600,
3200,
1600
);