use core::marker::PhantomData;
use hybrid_array::typenum::U32;
use rand_core::{
CryptoRng,
Rng,
};
use zeroize::{
Zeroize,
ZeroizeOnDrop,
Zeroizing,
};
use crate::crypto::{
G,
H,
J,
rand,
};
use crate::param::{
DecapsulationKeySize,
EncapsulationKeySize,
EncodedCiphertext,
KemParams,
};
use crate::pke::{
DecryptionKey,
EncryptionKey,
};
use crate::util::{
B32,
SecretB32,
};
pub use crate::{
Decapsulate,
Encapsulate,
};
use crate::{
Encoded,
EncodedSizeUser,
};
pub(crate) type SharedKey = B32;
#[derive(Clone, Debug, PartialEq)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
dk_pke: DecryptionKey<P>,
ek: EncapsulationKey<P>,
z: B32,
}
impl<P> Drop for DecapsulationKey<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.dk_pke.zeroize();
self.z.zeroize();
}
}
impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
impl<P> EncodedSizeUser for DecapsulationKey<P>
where
P: KemParams,
{
type EncodedSize = DecapsulationKeySize<P>;
#[allow(clippy::similar_names)] fn from_bytes(enc: &Encoded<Self>) -> Self {
let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
let ek_pke = EncryptionKey::from_bytes(ek_pke);
Self {
dk_pke: DecryptionKey::from_bytes(dk_pke),
ek: EncapsulationKey {
ek_pke,
h: h.clone(),
},
z: z.clone(),
}
}
fn as_bytes(&self) -> Zeroizing<Encoded<Self>> {
Zeroizing::new(P::concat_dk(
self.dk_pke.as_bytes(),
self.ek.ek_pke.as_bytes(),
self.ek.h.clone(),
self.z.clone(),
))
}
}
#[cfg(not(feature = "hardened"))]
fn constant_time_eq(x: u8, y: u8) -> u8 {
let diff = x ^ y;
let is_zero = !diff & diff.wrapping_sub(1);
0u8.wrapping_sub(is_zero >> 7)
}
impl<P> crate::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
where
P: KemParams,
{
type Error = core::convert::Infallible;
fn decapsulate(
&self,
encapsulated_key: &EncodedCiphertext<P>,
) -> Result<SharedKey, Self::Error> {
let mp = SecretB32::new(self.dk_pke.decrypt(encapsulated_key));
let (kp, rp_raw) = G(&[&*mp, &self.ek.h]);
let Kp = SecretB32::new(kp);
let rp = SecretB32::new(rp_raw);
let Kbar = SecretB32::new(J(&[self.z.as_slice(), encapsulated_key.as_ref()]));
let cp = self.ek.ek_pke.encrypt(&mp, &rp);
#[cfg(feature = "hardened")]
{
let eq_bytes =
crate::masking::ciphertexts_equal_ct(cp.as_ref(), encapsulated_key.as_ref());
let eq_ring =
crate::masking::ciphertexts_equal_arithmetic_domain_ct::<P>(&cp, encapsulated_key);
let eq = eq_bytes & eq_ring;
Ok(crate::masking::select_shared_key_bytes_ct(eq, &Kp, &Kbar))
}
#[cfg(not(feature = "hardened"))]
{
let equal = cp
.iter()
.zip(encapsulated_key.iter())
.map(|(&x, &y)| constant_time_eq(x, y))
.fold(0xFF, |x, y| x & y);
Ok(Kp
.iter()
.zip(Kbar.iter())
.map(|(x, y)| (equal & x) | (!equal & y))
.collect())
}
}
}
impl<P> DecapsulationKey<P>
where
P: KemParams,
{
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
}
pub(crate) fn generate<R: CryptoRng + Rng + ?Sized>(rng: &mut R) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
Self::generate_deterministic(&d, &z)
}
#[must_use]
#[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
let (dk_pke, ek_pke) = DecryptionKey::generate(d);
let ek = EncapsulationKey::new(ek_pke);
let z = z.clone();
Self { dk_pke, ek, z }
}
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct EncapsulationKey<P>
where
P: KemParams,
{
ek_pke: EncryptionKey<P>,
h: B32,
}
impl<P> EncapsulationKey<P>
where
P: KemParams,
{
fn new(ek_pke: EncryptionKey<P>) -> Self {
let h = H(ek_pke.as_bytes().as_slice());
Self { ek_pke, h }
}
fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
let (K, r) = G(&[m, &self.h]);
let r = SecretB32::new(r);
let c = self.ek_pke.encrypt(m, &r);
(c, K)
}
}
impl<P> EncodedSizeUser for EncapsulationKey<P>
where
P: KemParams,
{
type EncodedSize = EncapsulationKeySize<P>;
fn from_bytes(enc: &Encoded<Self>) -> Self {
Self::new(EncryptionKey::from_bytes(enc))
}
fn as_bytes(&self) -> Zeroizing<Encoded<Self>> {
Zeroizing::new(self.ek_pke.as_bytes())
}
}
impl<P> crate::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
where
P: KemParams,
{
type Error = core::convert::Infallible;
fn encapsulate<R: CryptoRng + Rng + ?Sized>(
&self,
rng: &mut R,
) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
let m = SecretB32::new(rand(rng));
Ok(self.encapsulate_deterministic_inner(&m))
}
}
#[cfg(feature = "deterministic")]
impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
where
P: KemParams,
{
type Error = core::convert::Infallible;
fn encapsulate_deterministic(
&self,
m: &B32,
) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
Ok(self.encapsulate_deterministic_inner(m))
}
}
#[derive(Clone)]
pub struct Kem<P>
where
P: KemParams,
{
_phantom: PhantomData<P>,
}
impl<P> crate::KemCore for Kem<P>
where
P: KemParams,
{
type SharedKeySize = U32;
type CiphertextSize = P::CiphertextSize;
type DecapsulationKey = DecapsulationKey<P>;
type EncapsulationKey = EncapsulationKey<P>;
fn generate<R: CryptoRng + Rng + ?Sized>(
rng: &mut R,
) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate(rng);
let ek = dk.encapsulation_key().clone();
(dk, ek)
}
#[cfg(feature = "deterministic")]
fn generate_deterministic(
d: &B32,
z: &B32,
) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate_deterministic(d, z);
let ek = dk.encapsulation_key().clone();
(dk, ek)
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "random")]
use super::*;
#[cfg(feature = "random")]
use crate::{
Decapsulate,
Encapsulate,
MlKem512Params,
MlKem768Params,
MlKem1024Params,
};
#[cfg(feature = "random")]
fn round_trip_test<P>()
where
P: KemParams,
{
let mut rng = lib_q_random::LibQRng::new_secure().expect("Failed to create secure RNG");
let dk = DecapsulationKey::<P>::generate(&mut rng);
let ek = dk.encapsulation_key();
let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
let k_recv = dk.decapsulate(&ct).unwrap();
assert_eq!(k_send, k_recv);
}
#[test]
#[cfg(feature = "random")]
fn round_trip() {
round_trip_test::<MlKem512Params>();
round_trip_test::<MlKem768Params>();
round_trip_test::<MlKem1024Params>();
}
#[cfg(feature = "random")]
fn codec_test<P>()
where
P: KemParams,
{
let mut rng = lib_q_random::LibQRng::new_secure().expect("Failed to create secure RNG");
let dk_original = DecapsulationKey::<P>::generate(&mut rng);
let ek_original = dk_original.encapsulation_key().clone();
let dk_encoded = dk_original.as_bytes();
let dk_decoded = DecapsulationKey::from_bytes(&*dk_encoded);
assert_eq!(dk_original, dk_decoded);
let ek_encoded = ek_original.as_bytes();
let ek_decoded = EncapsulationKey::from_bytes(&*ek_encoded);
assert_eq!(ek_original, ek_decoded);
}
#[test]
#[cfg(feature = "random")]
fn codec() {
codec_test::<MlKem512Params>();
codec_test::<MlKem768Params>();
codec_test::<MlKem1024Params>();
}
}