use crate::B32;
use crate::algebra::{
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
sample_poly_vec_cbd,
};
use crate::compress::Compress;
use crate::crypto::{G, PRF};
use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
use array::typenum::{U1, Unsigned};
use kem::{Ciphertext, InvalidKey};
use module_lattice::{
Encode,
ctutils::{Choice, CtEq},
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
#[derive(Clone, Default, Debug)]
pub(crate) struct DecryptionKey<P>
where
P: PkeParams,
{
s_hat: NttVector<P::K>,
}
impl<P> CtEq for DecryptionKey<P>
where
P: PkeParams,
{
fn ct_eq(&self, other: &Self) -> Choice {
self.s_hat.ct_eq(&other.s_hat)
}
}
impl<P> Eq for DecryptionKey<P> where P: PkeParams {}
impl<P> PartialEq for DecryptionKey<P>
where
P: PkeParams,
{
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
#[cfg(feature = "zeroize")]
impl<P> Zeroize for DecryptionKey<P>
where
P: PkeParams,
{
fn zeroize(&mut self) {
self.s_hat.zeroize();
}
}
impl<P> DecryptionKey<P>
where
P: PkeParams,
{
pub(crate) fn generate(d: &B32) -> (Self, EncryptionKey<P>) {
let k = P::K::U8;
let (rho, sigma) = G(&[&d[..], &[k]]);
let A_hat: NttMatrix<P::K> = matrix_sample_ntt(&rho, false);
let s: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, 0);
let e: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, P::K::U8);
let s_hat = s.ntt();
let e_hat = e.ntt();
let t_hat = &(&A_hat * &s_hat) + &e_hat;
let dk = DecryptionKey { s_hat };
let ek = EncryptionKey { t_hat, rho };
(dk, ek)
}
pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
let (c1, c2) = P::split_ct(ciphertext);
let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
u.decompress::<P::Du>();
let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
v.decompress::<P::Dv>();
let u_hat = u.ntt();
let sTu = (&self.s_hat * &u_hat).ntt_inverse();
let mut w = &v - &sTu;
Encode::<U1>::encode(w.compress::<U1>())
}
pub(crate) fn to_bytes(&self) -> EncodedDecryptionKey<P> {
P::encode_u12(&self.s_hat)
}
pub(crate) fn from_bytes(enc: &EncodedDecryptionKey<P>) -> Self {
let s_hat = P::decode_u12(enc);
Self { s_hat }
}
}
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub(crate) struct EncryptionKey<P>
where
P: PkeParams,
{
t_hat: NttVector<P::K>,
rho: B32,
}
impl<P> EncryptionKey<P>
where
P: PkeParams,
{
pub(crate) fn encrypt(&self, message: &B32, randomness: &B32) -> Ciphertext<P> {
let r = sample_poly_vec_cbd::<P::Eta1, P::K>(randomness, 0);
let e1 = sample_poly_vec_cbd::<P::Eta2, P::K>(randomness, P::K::U8);
let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);
let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
let r_hat: NttVector<P::K> = r.ntt();
let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
let mut u = ATr + e1;
let mut mu: Polynomial = Encode::<U1>::decode(message);
mu.decompress::<U1>();
let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
let mut v = &(&tTr + &e2) + μ
let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
P::concat_ct(c1, c2)
}
pub(crate) fn to_bytes(&self) -> EncodedEncryptionKey<P> {
let t_hat = P::encode_u12(&self.t_hat);
P::concat_ek(t_hat, self.rho.clone())
}
pub(crate) fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Result<Self, InvalidKey> {
let (t_hat, rho) = P::split_ek(enc);
let t_hat = P::decode_u12(t_hat);
let ret = Self {
t_hat,
rho: rho.clone(),
};
if &ret.to_bytes() == enc {
Ok(ret)
} else {
Err(InvalidKey)
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{MlKem512, MlKem768, MlKem1024};
use ::kem::Generate;
use getrandom::{SysRng, rand_core::UnwrapErr};
fn round_trip_test<P>()
where
P: PkeParams,
{
let mut rng = UnwrapErr(SysRng);
let d = B32::generate_from_rng(&mut rng);
let original = B32::default();
let randomness = B32::default();
let (dk, ek) = DecryptionKey::<P>::generate(&d);
let encrypted = ek.encrypt(&original, &randomness);
let decrypted = dk.decrypt(&encrypted);
assert_eq!(original, decrypted);
}
#[test]
fn round_trip() {
round_trip_test::<MlKem512>();
round_trip_test::<MlKem768>();
round_trip_test::<MlKem1024>();
}
fn codec_test<P>()
where
P: PkeParams,
{
let mut rng = UnwrapErr(SysRng);
let d = B32::generate_from_rng(&mut rng);
let (dk_original, ek_original) = DecryptionKey::<P>::generate(&d);
let dk_encoded = dk_original.to_bytes();
let dk_decoded = DecryptionKey::from_bytes(&dk_encoded);
assert_eq!(dk_original, dk_decoded);
let ek_encoded = ek_original.to_bytes();
let ek_decoded = EncryptionKey::from_bytes(&ek_encoded).unwrap();
assert_eq!(ek_original, ek_decoded);
}
#[test]
fn codec() {
codec_test::<MlKem512>();
codec_test::<MlKem768>();
codec_test::<MlKem1024>();
}
#[test]
fn reject_invalid_encryption_keys() {
let invalid_key = [0xFF; 1184];
assert!(EncryptionKey::<MlKem768>::from_bytes(&invalid_key.into()).is_err());
}
fn key_inequality_test<P>()
where
P: PkeParams,
{
let mut rng = UnwrapErr(SysRng);
let d1 = B32::generate_from_rng(&mut rng);
let d2 = B32::generate_from_rng(&mut rng);
let (dk1, _) = DecryptionKey::<P>::generate(&d1);
let (dk2, _) = DecryptionKey::<P>::generate(&d2);
assert_ne!(dk1, dk2);
}
#[test]
fn key_inequality() {
key_inequality_test::<MlKem512>();
key_inequality_test::<MlKem768>();
key_inequality_test::<MlKem1024>();
}
}