ml_kem/
decapsulation_key.rs1use crate::{
2 B32, EncapsulationKey, Seed, SharedKey,
3 crypto::{G, J},
4 param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
5 pke::{DecryptionKey, EncryptionKey},
6};
7use array::{
8 Array, ArraySize,
9 sizes::{U32, U64},
10};
11use kem::{
12 Ciphertext, Decapsulate, Decapsulator, Generate, InvalidKey, Kem, KeyExport, KeyInit,
13 KeySizeUser,
14};
15use rand_core::{TryCryptoRng, TryRng};
16use subtle::{ConditionallySelectable, ConstantTimeEq};
17
18#[cfg(feature = "zeroize")]
19use zeroize::{Zeroize, ZeroizeOnDrop};
20
21#[derive(Clone, Debug)]
24pub struct DecapsulationKey<P>
25where
26 P: KemParams,
27{
28 dk_pke: DecryptionKey<P>,
29 ek: EncapsulationKey<P>,
30 d: Option<B32>,
31 z: B32,
32}
33
34impl<P> DecapsulationKey<P>
35where
36 P: KemParams,
37{
38 #[inline]
40 #[must_use]
41 pub fn from_seed(seed: Seed) -> Self {
42 let (d, z) = seed.split();
43 Self::generate_deterministic(d, z)
44 }
45
46 #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
54 pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
55 let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
56 let dk_pke = DecryptionKey::from_bytes(dk_pke);
57 let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
58
59 let ek = EncapsulationKey::from_encryption_key(ek_pke);
60 if ek.h() != *h {
61 return Err(InvalidKey);
62 }
63
64 Ok(Self {
65 dk_pke,
66 ek,
67 d: None,
68 z: z.clone(),
69 })
70 }
71
72 #[inline]
85 pub fn to_seed(&self) -> Option<Seed> {
86 self.d.map(|d| d.concat(self.z))
87 }
88
89 pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
91 &self.ek
92 }
93
94 #[inline]
95 pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
96 where
97 R: TryCryptoRng + ?Sized,
98 {
99 let d = B32::try_generate_from_rng(rng)?;
100 let z = B32::try_generate_from_rng(rng)?;
101 Ok(Self::generate_deterministic(d, z))
102 }
103
104 #[inline]
105 #[must_use]
106 #[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
108 let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
109 let ek = EncapsulationKey::from_encryption_key(ek_pke);
110 let d = Some(d);
111 Self { dk_pke, ek, d, z }
112 }
113}
114
115impl<P> PartialEq for DecapsulationKey<P>
118where
119 P: KemParams,
120{
121 fn eq(&self, other: &Self) -> bool {
122 self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
123 }
124}
125
126#[cfg(feature = "zeroize")]
127impl<P> Drop for DecapsulationKey<P>
128where
129 P: KemParams,
130{
131 fn drop(&mut self) {
132 self.dk_pke.zeroize();
133 self.d.zeroize();
134 self.z.zeroize();
135 }
136}
137
138#[cfg(feature = "zeroize")]
139impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
140
141impl<P> From<Seed> for DecapsulationKey<P>
142where
143 P: KemParams,
144{
145 fn from(seed: Seed) -> Self {
146 Self::from_seed(seed)
147 }
148}
149
150impl<P> Decapsulate for DecapsulationKey<P>
151where
152 P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
153{
154 fn decapsulate(&self, encapsulated_key: &Ciphertext<P>) -> SharedKey {
155 let mp = self.dk_pke.decrypt(encapsulated_key);
156 let (Kp, rp) = G(&[&mp, &self.ek.h()]);
157 let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
158 let cp = self.ek.ek_pke().encrypt(&mp, &rp);
159 B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
160 }
161}
162
163impl<P> Decapsulator for DecapsulationKey<P>
164where
165 P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
166{
167 type Kem = P;
168
169 fn encapsulation_key(&self) -> &EncapsulationKey<P> {
170 &self.ek
171 }
172}
173
174impl<P> Generate for DecapsulationKey<P>
175where
176 P: KemParams,
177{
178 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
179 where
180 R: TryCryptoRng + ?Sized,
181 {
182 Self::try_generate_from_rng(rng)
183 }
184}
185
186impl<P> KeySizeUser for DecapsulationKey<P>
187where
188 P: KemParams,
189{
190 type KeySize = U64;
191}
192
193impl<P> KeyInit for DecapsulationKey<P>
195where
196 P: KemParams,
197{
198 #[inline]
199 fn new(seed: &Seed) -> Self {
200 Self::from_seed(*seed)
201 }
202}
203
204impl<P> KeyExport for DecapsulationKey<P>
210where
211 P: KemParams,
212{
213 fn to_bytes(&self) -> Seed {
214 self.to_seed().expect("should be initialized from a seed")
215 }
216}
217
218#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
233pub trait ExpandedKeyEncoding: Sized {
234 type EncodedSize: ArraySize;
236
237 fn from_expanded_bytes(enc: &Array<u8, Self::EncodedSize>) -> Result<Self, InvalidKey>;
242
243 fn to_expanded_bytes(&self) -> Array<u8, Self::EncodedSize>;
245}
246
247#[allow(deprecated)]
248impl<P> ExpandedKeyEncoding for DecapsulationKey<P>
249where
250 P: KemParams,
251{
252 type EncodedSize = DecapsulationKeySize<P>;
253
254 fn from_expanded_bytes(expanded: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
255 Self::from_expanded(expanded)
256 }
257
258 fn to_expanded_bytes(&self) -> ExpandedDecapsulationKey<P> {
259 let dk_pke = self.dk_pke.to_bytes();
260 let ek = self.ek.to_bytes();
261 P::concat_dk(dk_pke, ek, self.ek.h(), self.z.clone())
262 }
263}