Skip to main content

ml_kem/
decapsulation_key.rs

1use crate::{
2    B32, EncapsulationKey, Seed, SharedKey,
3    crypto::{G, J},
4    param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
5    pke::{DecryptionKey, EncryptionKey},
6};
7use array::{
8    Array, ArraySize,
9    sizes::{U32, U64},
10};
11use kem::{
12    Ciphertext, Decapsulate, Decapsulator, Generate, InvalidKey, Kem, KeyExport, KeyInit,
13    KeySizeUser,
14};
15use rand_core::{TryCryptoRng, TryRng};
16use subtle::{ConditionallySelectable, ConstantTimeEq};
17
18#[cfg(feature = "zeroize")]
19use zeroize::{Zeroize, ZeroizeOnDrop};
20
21/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
22/// encapsulated shared key.
23#[derive(Clone, Debug)]
24pub struct DecapsulationKey<P>
25where
26    P: KemParams,
27{
28    dk_pke: DecryptionKey<P>,
29    ek: EncapsulationKey<P>,
30    d: Option<B32>,
31    z: B32,
32}
33
34impl<P> DecapsulationKey<P>
35where
36    P: KemParams,
37{
38    /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
39    #[inline]
40    #[must_use]
41    pub fn from_seed(seed: Seed) -> Self {
42        let (d, z) = seed.split();
43        Self::generate_deterministic(d, z)
44    }
45
46    /// Initialize a [`DecapsulationKey`] from the serialized expanded key form.
47    ///
48    /// Note that this form is deprecated in practice; prefer to use
49    /// [`DecapsulationKey::from_seed`]. See [`ExpandedKeyEncoding`] for more information.
50    ///
51    /// # Errors
52    /// - Returns [`InvalidKey`] in the event the expanded key failed validation
53    #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
54    pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
55        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
56        let dk_pke = DecryptionKey::from_bytes(dk_pke);
57        let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
58
59        let ek = EncapsulationKey::from_encryption_key(ek_pke);
60        if ek.h() != *h {
61            return Err(InvalidKey);
62        }
63
64        Ok(Self {
65            dk_pke,
66            ek,
67            d: None,
68            z: z.clone(),
69        })
70    }
71
72    /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
73    /// [`DecapsulationKey`].
74    ///
75    /// <div class="warning">
76    /// <b>Warning!</B>
77    ///
78    /// This value is key material. Please treat it with care.
79    /// </div>
80    ///
81    /// # Returns
82    /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
83    /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
84    #[inline]
85    pub fn to_seed(&self) -> Option<Seed> {
86        self.d.map(|d| d.concat(self.z))
87    }
88
89    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
90    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
91        &self.ek
92    }
93
94    #[inline]
95    pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
96    where
97        R: TryCryptoRng + ?Sized,
98    {
99        let d = B32::try_generate_from_rng(rng)?;
100        let z = B32::try_generate_from_rng(rng)?;
101        Ok(Self::generate_deterministic(d, z))
102    }
103
104    #[inline]
105    #[must_use]
106    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
107    pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
108        let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
109        let ek = EncapsulationKey::from_encryption_key(ek_pke);
110        let d = Some(d);
111        Self { dk_pke, ek, d, z }
112    }
113}
114
115// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
116// keys initialized from the expanded form
117impl<P> PartialEq for DecapsulationKey<P>
118where
119    P: KemParams,
120{
121    fn eq(&self, other: &Self) -> bool {
122        self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
123    }
124}
125
126#[cfg(feature = "zeroize")]
127impl<P> Drop for DecapsulationKey<P>
128where
129    P: KemParams,
130{
131    fn drop(&mut self) {
132        self.dk_pke.zeroize();
133        self.d.zeroize();
134        self.z.zeroize();
135    }
136}
137
138#[cfg(feature = "zeroize")]
139impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
140
141impl<P> From<Seed> for DecapsulationKey<P>
142where
143    P: KemParams,
144{
145    fn from(seed: Seed) -> Self {
146        Self::from_seed(seed)
147    }
148}
149
150impl<P> Decapsulate for DecapsulationKey<P>
151where
152    P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
153{
154    fn decapsulate(&self, encapsulated_key: &Ciphertext<P>) -> SharedKey {
155        let mp = self.dk_pke.decrypt(encapsulated_key);
156        let (Kp, rp) = G(&[&mp, &self.ek.h()]);
157        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
158        let cp = self.ek.ek_pke().encrypt(&mp, &rp);
159        B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
160    }
161}
162
163impl<P> Decapsulator for DecapsulationKey<P>
164where
165    P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
166{
167    type Kem = P;
168
169    fn encapsulation_key(&self) -> &EncapsulationKey<P> {
170        &self.ek
171    }
172}
173
174impl<P> Generate for DecapsulationKey<P>
175where
176    P: KemParams,
177{
178    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
179    where
180        R: TryCryptoRng + ?Sized,
181    {
182        Self::try_generate_from_rng(rng)
183    }
184}
185
186impl<P> KeySizeUser for DecapsulationKey<P>
187where
188    P: KemParams,
189{
190    type KeySize = U64;
191}
192
193/// Initialize [`DecapsulationKey`] from a 64-byte uniformly random [`Seed`] value.
194impl<P> KeyInit for DecapsulationKey<P>
195where
196    P: KemParams,
197{
198    #[inline]
199    fn new(seed: &Seed) -> Self {
200        Self::from_seed(*seed)
201    }
202}
203
204/// Serialize the 64-byte [`Seed`] value used to initialize this [`DecapsulationKey`].
205///
206/// # Panics
207/// If this [`DecapsulationKey`] was initialized using legacy expanded key support
208/// (see [`ExpandedKeyEncoding`]).
209impl<P> KeyExport for DecapsulationKey<P>
210where
211    P: KemParams,
212{
213    fn to_bytes(&self) -> Seed {
214        self.to_seed().expect("should be initialized from a seed")
215    }
216}
217
218/// DEPRECATED: support for encoding and decoding [`DecapsulationKey`]s in the legacy expanded form,
219/// as opposed to the more widely adopted [`Seed`] form.
220///
221/// The expanded encoding format is problematic for several reasons, notably they need to validated
222/// whereas generation from seeds is always correct, meaning there is no performance advantage to
223/// using them, only additional complexity.
224///
225/// They are significantly larger than seeds (which are 64-bytes) and their sizes vary depending on
226/// security level whereas the size of a seed is constant:
227/// - ML-KEM-512: 1632 bytes
228/// - ML-KEM-768: 2400 bytes
229/// - ML-KEM-1024: 3168 bytes
230///
231/// Many ML-KEM libraries have dropped support for this format entirely.
232#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
233pub trait ExpandedKeyEncoding: Sized {
234    /// The size of an expanded decapsulation key.
235    type EncodedSize: ArraySize;
236
237    /// Parse a [`DecapsulationKey`] from its legacy expanded form.
238    ///
239    /// # Errors
240    /// - If the key fails to validate successfully.
241    fn from_expanded_bytes(enc: &Array<u8, Self::EncodedSize>) -> Result<Self, InvalidKey>;
242
243    /// Serialize a [`DecapsulationKey`] to its legacy expanded form.
244    fn to_expanded_bytes(&self) -> Array<u8, Self::EncodedSize>;
245}
246
247#[allow(deprecated)]
248impl<P> ExpandedKeyEncoding for DecapsulationKey<P>
249where
250    P: KemParams,
251{
252    type EncodedSize = DecapsulationKeySize<P>;
253
254    fn from_expanded_bytes(expanded: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
255        Self::from_expanded(expanded)
256    }
257
258    fn to_expanded_bytes(&self) -> ExpandedDecapsulationKey<P> {
259        let dk_pke = self.dk_pke.to_bytes();
260        let ek = self.ek.to_bytes();
261        P::concat_dk(dk_pke, ek, self.ek.h(), self.z.clone())
262    }
263}