bc_components/mlkem/
mlkem_private_key.rs

1use anyhow::{ anyhow, Result };
2use pqcrypto_mlkem::*;
3use pqcrypto_traits::kem::{ SecretKey, SharedSecret };
4use dcbor::prelude::*;
5
6use crate::{ tags, Decrypter, EncapsulationPrivateKey, SymmetricKey };
7
8use super::{ MLKEMCiphertext, MLKEM };
9
10/// A private key for the ML-KEM post-quantum key encapsulation mechanism.
11///
12/// `MLKEMPrivateKey` represents a private key that can be used to decapsulate shared secrets
13/// using the ML-KEM (Module Lattice-based Key Encapsulation Mechanism) post-quantum algorithm.
14/// It supports multiple security levels through the variants:
15///
16/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 1632 bytes
17/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 2400 bytes
18/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 3168 bytes
19///
20/// # Security
21///
22/// ML-KEM private keys should be kept secure and never exposed. They provide
23/// resistance against attacks from both classical and quantum computers.
24///
25/// # Examples
26///
27/// ```
28/// use bc_components::MLKEM;
29///
30/// // Generate a keypair
31/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
32///
33/// // Party A encapsulates a shared secret using the public key
34/// let (shared_secret_a, ciphertext) = public_key.encapsulate_new_shared_secret();
35///
36/// // Party B decapsulates the shared secret using the private key and ciphertext
37/// let shared_secret_b = private_key.decapsulate_shared_secret(&ciphertext).unwrap();
38///
39/// // Both parties now have the same shared secret
40/// assert_eq!(shared_secret_a, shared_secret_b);
41/// ```
42#[derive(Clone, PartialEq)]
43pub enum MLKEMPrivateKey {
44    /// An ML-KEM-512 private key (NIST security level 1)
45    MLKEM512(Box<mlkem512::SecretKey>),
46    /// An ML-KEM-768 private key (NIST security level 3)
47    MLKEM768(Box<mlkem768::SecretKey>),
48    /// An ML-KEM-1024 private key (NIST security level 5)
49    MLKEM1024(Box<mlkem1024::SecretKey>),
50}
51
52impl Eq for MLKEMPrivateKey {}
53
54/// Implements hashing for ML-KEM private keys.
55impl std::hash::Hash for MLKEMPrivateKey {
56    /// Hashes the raw bytes of the private key.
57    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58        match self {
59            MLKEMPrivateKey::MLKEM512(sk) => sk.as_bytes().hash(state),
60            MLKEMPrivateKey::MLKEM768(sk) => sk.as_bytes().hash(state),
61            MLKEMPrivateKey::MLKEM1024(sk) => sk.as_bytes().hash(state),
62        }
63    }
64}
65
66impl MLKEMPrivateKey {
67    /// Returns the security level of this ML-KEM private key.
68    pub fn level(&self) -> MLKEM {
69        match self {
70            MLKEMPrivateKey::MLKEM512(_) => MLKEM::MLKEM512,
71            MLKEMPrivateKey::MLKEM768(_) => MLKEM::MLKEM768,
72            MLKEMPrivateKey::MLKEM1024(_) => MLKEM::MLKEM1024,
73        }
74    }
75
76    /// Returns the size of this ML-KEM private key in bytes.
77    pub fn size(&self) -> usize {
78        self.level().private_key_size()
79    }
80
81    /// Returns the raw bytes of this ML-KEM private key.
82    pub fn as_bytes(&self) -> &[u8] {
83        match self {
84            MLKEMPrivateKey::MLKEM512(sk) => sk.as_ref().as_bytes(),
85            MLKEMPrivateKey::MLKEM768(sk) => sk.as_ref().as_bytes(),
86            MLKEMPrivateKey::MLKEM1024(sk) => sk.as_ref().as_bytes(),
87        }
88    }
89
90    /// Creates an ML-KEM private key from raw bytes and a security level.
91    ///
92    /// # Parameters
93    ///
94    /// * `level` - The security level of the key.
95    /// * `bytes` - The raw bytes of the key.
96    ///
97    /// # Returns
98    ///
99    /// An `MLKEMPrivateKey` if the bytes represent a valid key for the given level,
100    /// or an error otherwise.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the bytes do not represent a valid ML-KEM private key
105    /// for the specified security level.
106    pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
107        match level {
108            MLKEM::MLKEM512 =>
109                Ok(
110                    MLKEMPrivateKey::MLKEM512(
111                        Box::new(mlkem512::SecretKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
112                    )
113                ),
114            MLKEM::MLKEM768 =>
115                Ok(
116                    MLKEMPrivateKey::MLKEM768(
117                        Box::new(mlkem768::SecretKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
118                    )
119                ),
120            MLKEM::MLKEM1024 =>
121                Ok(
122                    MLKEMPrivateKey::MLKEM1024(
123                        Box::new(mlkem1024::SecretKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
124                    )
125                ),
126        }
127    }
128
129    /// Decapsulates a shared secret from a ciphertext using this private key.
130    ///
131    /// # Parameters
132    ///
133    /// * `ciphertext` - The ciphertext containing the encapsulated shared secret.
134    ///
135    /// # Returns
136    ///
137    /// A `SymmetricKey` containing the decapsulated shared secret, or an error
138    /// if decapsulation fails.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if the security level of the ciphertext doesn't match the
143    /// security level of this private key, or if decapsulation fails for any other reason.
144    ///
145    /// # Panics
146    ///
147    /// Panics if the security level of the ciphertext doesn't match the security
148    /// level of this private key.
149    pub fn decapsulate_shared_secret(&self, ciphertext: &MLKEMCiphertext) -> Result<SymmetricKey> {
150        match (self, ciphertext) {
151            (MLKEMPrivateKey::MLKEM512(sk), MLKEMCiphertext::MLKEM512(ct)) => {
152                let ss = mlkem512::decapsulate(ct.as_ref(), sk.as_ref());
153                SymmetricKey::from_data_ref(ss.as_bytes())
154            }
155            (MLKEMPrivateKey::MLKEM768(sk), MLKEMCiphertext::MLKEM768(ct)) => {
156                let ss = mlkem768::decapsulate(ct.as_ref(), sk.as_ref());
157                SymmetricKey::from_data_ref(ss.as_bytes())
158            }
159            (MLKEMPrivateKey::MLKEM1024(sk), MLKEMCiphertext::MLKEM1024(ct)) => {
160                let ss = mlkem1024::decapsulate(ct.as_ref(), sk.as_ref());
161                SymmetricKey::from_data_ref(ss.as_bytes())
162            }
163            _ => panic!("MLKEM level mismatch"),
164        }
165    }
166}
167
168/// Implements the `Decrypter` trait for ML-KEM private keys.
169impl Decrypter for MLKEMPrivateKey {
170    /// Returns this key as an `EncapsulationPrivateKey`.
171    fn encapsulation_private_key(&self) -> EncapsulationPrivateKey {
172        EncapsulationPrivateKey::MLKEM(self.clone())
173    }
174}
175
176/// Provides debug formatting for ML-KEM private keys.
177impl std::fmt::Debug for MLKEMPrivateKey {
178    /// Formats the private key as a string for debugging purposes.
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self {
181            MLKEMPrivateKey::MLKEM512(_) => f.write_str("MLKEM512PrivateKey"),
182            MLKEMPrivateKey::MLKEM768(_) => f.write_str("MLKEM768PrivateKey"),
183            MLKEMPrivateKey::MLKEM1024(_) => f.write_str("MLKEM1024PrivateKey"),
184        }
185    }
186}
187
188/// Defines CBOR tags for ML-KEM private keys.
189impl CBORTagged for MLKEMPrivateKey {
190    /// Returns the CBOR tag for ML-KEM private keys.
191    fn cbor_tags() -> Vec<Tag> {
192        tags_for_values(&[tags::TAG_MLKEM_PRIVATE_KEY])
193    }
194}
195
196/// Converts an `MLKEMPrivateKey` to CBOR.
197impl From<MLKEMPrivateKey> for CBOR {
198    /// Converts to tagged CBOR.
199    fn from(value: MLKEMPrivateKey) -> Self {
200        value.tagged_cbor()
201    }
202}
203
204/// Implements CBOR encoding for ML-KEM private keys.
205impl CBORTaggedEncodable for MLKEMPrivateKey {
206    /// Creates the untagged CBOR representation as an array with level and key bytes.
207    fn untagged_cbor(&self) -> CBOR {
208        vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
209    }
210}
211
212/// Attempts to convert CBOR to an `MLKEMPrivateKey`.
213impl TryFrom<CBOR> for MLKEMPrivateKey {
214    type Error = dcbor::Error;
215
216    /// Converts from tagged CBOR.
217    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
218        Self::from_tagged_cbor(cbor)
219    }
220}
221
222/// Implements CBOR decoding for ML-KEM private keys.
223impl CBORTaggedDecodable for MLKEMPrivateKey {
224    /// Creates an `MLKEMPrivateKey` from untagged CBOR.
225    ///
226    /// # Errors
227    /// Returns an error if the CBOR value doesn't represent a valid ML-KEM private key.
228    fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
229        match untagged_cbor.as_case() {
230            CBORCase::Array(elements) => {
231                if elements.len() != 2 {
232                    return Err("MLKEMPrivateKey must have two elements".into());
233                }
234
235                let level = MLKEM::try_from(elements[0].clone())?;
236                let data = CBOR::try_into_byte_string(elements[1].clone())?;
237                Ok(MLKEMPrivateKey::from_bytes(level, &data)?)
238            }
239            _ => {
240                return Err("MLKEMPrivateKey must be an array".into());
241            }
242        }
243    }
244}