use crate::{
B32, EncapsulationKey, Seed, SharedKey,
crypto::{G, J},
param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
pke::{DecryptionKey, EncryptionKey},
};
use array::{
Array, ArraySize,
sizes::{U32, U64},
};
use kem::{
Ciphertext, Decapsulate, Decapsulator, Generate, InvalidKey, Kem, KeyExport, KeyInit,
KeySizeUser,
};
use module_lattice::ctutils::{CtEq, CtSelect};
use rand_core::{TryCryptoRng, TryRng};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Debug)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
dk_pke: DecryptionKey<P>,
ek: EncapsulationKey<P>,
d: Option<B32>,
z: B32,
}
impl<P> DecapsulationKey<P>
where
P: KemParams,
{
#[inline]
#[must_use]
pub fn from_seed(seed: Seed) -> Self {
let (d, z) = seed.split();
Self::generate_deterministic(d, z)
}
#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
let dk_pke = DecryptionKey::from_bytes(dk_pke);
let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
let ek = EncapsulationKey::from_encryption_key(ek_pke);
if ek.h() != *h {
return Err(InvalidKey);
}
Ok(Self {
dk_pke,
ek,
d: None,
z: z.clone(),
})
}
#[inline]
pub fn to_seed(&self) -> Option<Seed> {
self.d.map(|d| d.concat(self.z))
}
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
}
#[inline]
pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
where
R: TryCryptoRng + ?Sized,
{
let d = B32::try_generate_from_rng(rng)?;
let z = B32::try_generate_from_rng(rng)?;
Ok(Self::generate_deterministic(d, z))
}
#[inline]
#[must_use]
#[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
let ek = EncapsulationKey::from_encryption_key(ek_pke);
let d = Some(d);
Self { dk_pke, ek, d, z }
}
}
impl<P> PartialEq for DecapsulationKey<P>
where
P: KemParams,
{
fn eq(&self, other: &Self) -> bool {
self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
}
}
#[cfg(feature = "zeroize")]
impl<P> Drop for DecapsulationKey<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.dk_pke.zeroize();
self.d.zeroize();
self.z.zeroize();
}
}
#[cfg(feature = "zeroize")]
impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
impl<P> From<Seed> for DecapsulationKey<P>
where
P: KemParams,
{
fn from(seed: Seed) -> Self {
Self::from_seed(seed)
}
}
impl<P> Decapsulate for DecapsulationKey<P>
where
P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
{
fn decapsulate(&self, encapsulated_key: &Ciphertext<P>) -> SharedKey {
let mp = self.dk_pke.decrypt(encapsulated_key);
let (Kp, rp) = G(&[&mp, &self.ek.h()]);
let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
let cp = self.ek.ek_pke().encrypt(&mp, &rp);
Kbar.ct_select(&Kp, cp.ct_eq(encapsulated_key))
}
}
impl<P> Decapsulator for DecapsulationKey<P>
where
P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
{
type Kem = P;
fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
}
}
impl<P> Generate for DecapsulationKey<P>
where
P: KemParams,
{
fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
where
R: TryCryptoRng + ?Sized,
{
Self::try_generate_from_rng(rng)
}
}
impl<P> KeySizeUser for DecapsulationKey<P>
where
P: KemParams,
{
type KeySize = U64;
}
impl<P> KeyInit for DecapsulationKey<P>
where
P: KemParams,
{
#[inline]
fn new(seed: &Seed) -> Self {
Self::from_seed(*seed)
}
}
impl<P> KeyExport for DecapsulationKey<P>
where
P: KemParams,
{
fn to_bytes(&self) -> Seed {
self.to_seed().expect("should be initialized from a seed")
}
}
#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
pub trait ExpandedKeyEncoding: Sized {
type EncodedSize: ArraySize;
fn from_expanded_bytes(enc: &Array<u8, Self::EncodedSize>) -> Result<Self, InvalidKey>;
fn to_expanded_bytes(&self) -> Array<u8, Self::EncodedSize>;
}
#[allow(deprecated)]
impl<P> ExpandedKeyEncoding for DecapsulationKey<P>
where
P: KemParams,
{
type EncodedSize = DecapsulationKeySize<P>;
fn from_expanded_bytes(expanded: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
Self::from_expanded(expanded)
}
fn to_expanded_bytes(&self) -> ExpandedDecapsulationKey<P> {
let dk_pke = self.dk_pke.to_bytes();
let ek = self.ek.to_bytes();
P::concat_dk(dk_pke, ek, self.ek.h(), self.z.clone())
}
}