bc_components/mlkem/
mlkem_private_key.rs

1use anyhow::{Result, anyhow};
2use dcbor::prelude::*;
3use pqcrypto_mlkem::*;
4use pqcrypto_traits::kem::{SecretKey, SharedSecret};
5
6use super::{MLKEM, MLKEMCiphertext};
7use crate::{Decrypter, EncapsulationPrivateKey, SymmetricKey, tags};
8
9/// A private key for the ML-KEM post-quantum key encapsulation mechanism.
10///
11/// `MLKEMPrivateKey` represents a private key that can be used to decapsulate
12/// shared secrets using the ML-KEM (Module Lattice-based Key Encapsulation
13/// Mechanism) post-quantum algorithm. It supports multiple security levels
14/// through the variants:
15///
16/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 1632
17///   bytes
18/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 2400
19///   bytes
20/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 3168
21///   bytes
22///
23/// # Security
24///
25/// ML-KEM private keys should be kept secure and never exposed. They provide
26/// resistance against attacks from both classical and quantum computers.
27///
28/// # Examples
29///
30/// ```
31/// use bc_components::MLKEM;
32///
33/// // Generate a keypair
34/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
35///
36/// // Party A encapsulates a shared secret using the public key
37/// let (shared_secret_a, ciphertext) = public_key.encapsulate_new_shared_secret();
38///
39/// // Party B decapsulates the shared secret using the private key and ciphertext
40/// let shared_secret_b = private_key.decapsulate_shared_secret(&ciphertext).unwrap();
41///
42/// // Both parties now have the same shared secret
43/// assert_eq!(shared_secret_a, shared_secret_b);
44/// ```
45#[derive(Clone, PartialEq)]
46pub enum MLKEMPrivateKey {
47    /// An ML-KEM-512 private key (NIST security level 1)
48    MLKEM512(Box<mlkem512::SecretKey>),
49    /// An ML-KEM-768 private key (NIST security level 3)
50    MLKEM768(Box<mlkem768::SecretKey>),
51    /// An ML-KEM-1024 private key (NIST security level 5)
52    MLKEM1024(Box<mlkem1024::SecretKey>),
53}
54
55impl Eq for MLKEMPrivateKey {}
56
57/// Implements hashing for ML-KEM private keys.
58impl std::hash::Hash for MLKEMPrivateKey {
59    /// Hashes the raw bytes of the private key.
60    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
61        match self {
62            MLKEMPrivateKey::MLKEM512(sk) => sk.as_bytes().hash(state),
63            MLKEMPrivateKey::MLKEM768(sk) => sk.as_bytes().hash(state),
64            MLKEMPrivateKey::MLKEM1024(sk) => sk.as_bytes().hash(state),
65        }
66    }
67}
68
69impl MLKEMPrivateKey {
70    /// Returns the security level of this ML-KEM private key.
71    pub fn level(&self) -> MLKEM {
72        match self {
73            MLKEMPrivateKey::MLKEM512(_) => MLKEM::MLKEM512,
74            MLKEMPrivateKey::MLKEM768(_) => MLKEM::MLKEM768,
75            MLKEMPrivateKey::MLKEM1024(_) => MLKEM::MLKEM1024,
76        }
77    }
78
79    /// Returns the size of this ML-KEM private key in bytes.
80    pub fn size(&self) -> usize { self.level().private_key_size() }
81
82    /// Returns the raw bytes of this ML-KEM private key.
83    pub fn as_bytes(&self) -> &[u8] { self.as_ref() }
84
85    /// Creates an ML-KEM private key from raw bytes and a security level.
86    ///
87    /// # Parameters
88    ///
89    /// * `level` - The security level of the key.
90    /// * `bytes` - The raw bytes of the key.
91    ///
92    /// # Returns
93    ///
94    /// An `MLKEMPrivateKey` if the bytes represent a valid key for the given
95    /// level, or an error otherwise.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if the bytes do not represent a valid ML-KEM private
100    /// key for the specified security level.
101    pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
102        match level {
103            MLKEM::MLKEM512 => Ok(MLKEMPrivateKey::MLKEM512(Box::new(
104                mlkem512::SecretKey::from_bytes(bytes)
105                    .map_err(|e| anyhow!(e))?,
106            ))),
107            MLKEM::MLKEM768 => Ok(MLKEMPrivateKey::MLKEM768(Box::new(
108                mlkem768::SecretKey::from_bytes(bytes)
109                    .map_err(|e| anyhow!(e))?,
110            ))),
111            MLKEM::MLKEM1024 => Ok(MLKEMPrivateKey::MLKEM1024(Box::new(
112                mlkem1024::SecretKey::from_bytes(bytes)
113                    .map_err(|e| anyhow!(e))?,
114            ))),
115        }
116    }
117
118    /// Decapsulates a shared secret from a ciphertext using this private key.
119    ///
120    /// # Parameters
121    ///
122    /// * `ciphertext` - The ciphertext containing the encapsulated shared
123    ///   secret.
124    ///
125    /// # Returns
126    ///
127    /// A `SymmetricKey` containing the decapsulated shared secret, or an error
128    /// if decapsulation fails.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the security level of the ciphertext doesn't match
133    /// the security level of this private key, or if decapsulation fails
134    /// for any other reason.
135    ///
136    /// # Panics
137    ///
138    /// Panics if the security level of the ciphertext doesn't match the
139    /// security level of this private key.
140    pub fn decapsulate_shared_secret(
141        &self,
142        ciphertext: &MLKEMCiphertext,
143    ) -> Result<SymmetricKey> {
144        match (self, ciphertext) {
145            (MLKEMPrivateKey::MLKEM512(sk), MLKEMCiphertext::MLKEM512(ct)) => {
146                let ss = mlkem512::decapsulate(ct.as_ref(), sk.as_ref());
147                SymmetricKey::from_data_ref(ss.as_bytes())
148            }
149            (MLKEMPrivateKey::MLKEM768(sk), MLKEMCiphertext::MLKEM768(ct)) => {
150                let ss = mlkem768::decapsulate(ct.as_ref(), sk.as_ref());
151                SymmetricKey::from_data_ref(ss.as_bytes())
152            }
153            (
154                MLKEMPrivateKey::MLKEM1024(sk),
155                MLKEMCiphertext::MLKEM1024(ct),
156            ) => {
157                let ss = mlkem1024::decapsulate(ct.as_ref(), sk.as_ref());
158                SymmetricKey::from_data_ref(ss.as_bytes())
159            }
160            _ => panic!("MLKEM level mismatch"),
161        }
162    }
163}
164
165impl AsRef<[u8]> for MLKEMPrivateKey {
166    /// Returns the raw bytes of the private key.
167    fn as_ref(&self) -> &[u8] {
168        match self {
169            MLKEMPrivateKey::MLKEM512(sk) => sk.as_ref().as_bytes(),
170            MLKEMPrivateKey::MLKEM768(sk) => sk.as_ref().as_bytes(),
171            MLKEMPrivateKey::MLKEM1024(sk) => sk.as_ref().as_bytes(),
172        }
173    }
174}
175
176/// Implements the `Decrypter` trait for ML-KEM private keys.
177impl Decrypter for MLKEMPrivateKey {
178    /// Returns this key as an `EncapsulationPrivateKey`.
179    fn encapsulation_private_key(&self) -> EncapsulationPrivateKey {
180        EncapsulationPrivateKey::MLKEM(self.clone())
181    }
182}
183
184/// Provides debug formatting for ML-KEM private keys.
185impl std::fmt::Debug for MLKEMPrivateKey {
186    /// Formats the private key as a string for debugging purposes.
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        match self {
189            MLKEMPrivateKey::MLKEM512(_) => f.write_str("MLKEM512PrivateKey"),
190            MLKEMPrivateKey::MLKEM768(_) => f.write_str("MLKEM768PrivateKey"),
191            MLKEMPrivateKey::MLKEM1024(_) => f.write_str("MLKEM1024PrivateKey"),
192        }
193    }
194}
195
196/// Defines CBOR tags for ML-KEM private keys.
197impl CBORTagged for MLKEMPrivateKey {
198    /// Returns the CBOR tag for ML-KEM private keys.
199    fn cbor_tags() -> Vec<Tag> {
200        tags_for_values(&[tags::TAG_MLKEM_PRIVATE_KEY])
201    }
202}
203
204/// Converts an `MLKEMPrivateKey` to CBOR.
205impl From<MLKEMPrivateKey> for CBOR {
206    /// Converts to tagged CBOR.
207    fn from(value: MLKEMPrivateKey) -> Self { value.tagged_cbor() }
208}
209
210/// Implements CBOR encoding for ML-KEM private keys.
211impl CBORTaggedEncodable for MLKEMPrivateKey {
212    /// Creates the untagged CBOR representation as an array with level and key
213    /// bytes.
214    fn untagged_cbor(&self) -> CBOR {
215        vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
216    }
217}
218
219/// Attempts to convert CBOR to an `MLKEMPrivateKey`.
220impl TryFrom<CBOR> for MLKEMPrivateKey {
221    type Error = dcbor::Error;
222
223    /// Converts from tagged CBOR.
224    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
225        Self::from_tagged_cbor(cbor)
226    }
227}
228
229/// Implements CBOR decoding for ML-KEM private keys.
230impl CBORTaggedDecodable for MLKEMPrivateKey {
231    /// Creates an `MLKEMPrivateKey` from untagged CBOR.
232    ///
233    /// # Errors
234    /// Returns an error if the CBOR value doesn't represent a valid ML-KEM
235    /// private key.
236    fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
237        match untagged_cbor.as_case() {
238            CBORCase::Array(elements) => {
239                if elements.len() != 2 {
240                    return Err("MLKEMPrivateKey must have two elements".into());
241                }
242
243                let level = MLKEM::try_from(elements[0].clone())?;
244                let data = CBOR::try_into_byte_string(elements[1].clone())?;
245                Ok(MLKEMPrivateKey::from_bytes(level, &data)?)
246            }
247            _ => Err("MLKEMPrivateKey must be an array".into()),
248        }
249    }
250}