bc_components/mlkem/
mlkem_public_key.rs

1use dcbor::prelude::*;
2use pqcrypto_mlkem::*;
3use pqcrypto_traits::kem::{PublicKey, SharedSecret};
4
5use super::{MLKEM, MLKEMCiphertext};
6use crate::{
7    Digest, Error, Reference, ReferenceProvider, Result, SymmetricKey, tags,
8};
9
10/// A public key for the ML-KEM post-quantum key encapsulation mechanism.
11///
12/// `MLKEMPublicKey` represents a public key that can be used to encapsulate
13/// shared secrets using the ML-KEM (Module Lattice-based Key Encapsulation
14/// Mechanism) post-quantum algorithm. It supports multiple security levels
15/// through the variants:
16///
17/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 800
18///   bytes
19/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 1184
20///   bytes
21/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 1568
22///   bytes
23///
24/// # Examples
25///
26/// ```
27/// use bc_components::MLKEM;
28///
29/// // Generate a keypair
30/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
31///
32/// // Encapsulate a shared secret using the public key
33/// let (shared_secret, ciphertext) =
34///     public_key.encapsulate_new_shared_secret();
35/// ```
36#[derive(Clone)]
37pub enum MLKEMPublicKey {
38    /// An ML-KEM-512 public key (NIST security level 1)
39    MLKEM512(Box<mlkem512::PublicKey>),
40    /// An ML-KEM-768 public key (NIST security level 3)
41    MLKEM768(Box<mlkem768::PublicKey>),
42    /// An ML-KEM-1024 public key (NIST security level 5)
43    MLKEM1024(Box<mlkem1024::PublicKey>),
44}
45
46/// Implements equality comparison for ML-KEM public keys.
47impl PartialEq for MLKEMPublicKey {
48    /// Compares two ML-KEM public keys for equality.
49    ///
50    /// Two ML-KEM public keys are equal if they have the same security level
51    /// and the same raw byte representation.
52    fn eq(&self, other: &Self) -> bool {
53        self.level() == other.level() && self.as_bytes() == other.as_bytes()
54    }
55}
56
57impl Eq for MLKEMPublicKey {}
58
59/// Implements hashing for ML-KEM public keys.
60impl std::hash::Hash for MLKEMPublicKey {
61    /// Hashes both the security level and the raw bytes of the public key.
62    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
63        self.level().hash(state);
64        self.as_bytes().hash(state);
65    }
66}
67
68impl MLKEMPublicKey {
69    /// Returns the security level of this ML-KEM public key.
70    pub fn level(&self) -> MLKEM {
71        match self {
72            MLKEMPublicKey::MLKEM512(_) => MLKEM::MLKEM512,
73            MLKEMPublicKey::MLKEM768(_) => MLKEM::MLKEM768,
74            MLKEMPublicKey::MLKEM1024(_) => MLKEM::MLKEM1024,
75        }
76    }
77
78    /// Returns the size of this ML-KEM public key in bytes.
79    pub fn size(&self) -> usize { self.level().public_key_size() }
80
81    /// Returns the raw bytes of this ML-KEM public key.
82    pub fn as_bytes(&self) -> &[u8] { self.as_ref() }
83
84    /// Creates an ML-KEM public key from raw bytes and a security level.
85    ///
86    /// # Parameters
87    ///
88    /// * `level` - The security level of the key.
89    /// * `bytes` - The raw bytes of the key.
90    ///
91    /// # Returns
92    ///
93    /// An `MLKEMPublicKey` if the bytes represent a valid key for the given
94    /// level, or an error otherwise.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the bytes do not represent a valid ML-KEM public key
99    /// for the specified security level.
100    pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
101        match level {
102            MLKEM::MLKEM512 => Ok(MLKEMPublicKey::MLKEM512(Box::new(
103                mlkem512::PublicKey::from_bytes(bytes)
104                    .map_err(|e| Error::post_quantum(e.to_string()))?,
105            ))),
106            MLKEM::MLKEM768 => Ok(MLKEMPublicKey::MLKEM768(Box::new(
107                mlkem768::PublicKey::from_bytes(bytes)
108                    .map_err(|e| Error::post_quantum(e.to_string()))?,
109            ))),
110            MLKEM::MLKEM1024 => Ok(MLKEMPublicKey::MLKEM1024(Box::new(
111                mlkem1024::PublicKey::from_bytes(bytes)
112                    .map_err(|e| Error::post_quantum(e.to_string()))?,
113            ))),
114        }
115    }
116
117    /// Encapsulates a new shared secret using this public key.
118    ///
119    /// This method generates a random shared secret and encapsulates it using
120    /// this public key, producing a ciphertext that can only be decapsulated
121    /// by the corresponding private key.
122    ///
123    /// # Returns
124    ///
125    /// A tuple containing:
126    /// - A `SymmetricKey` with the shared secret (32 bytes)
127    /// - An `MLKEMCiphertext` with the encapsulated shared secret
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use bc_components::MLKEM;
133    ///
134    /// // Generate a keypair
135    /// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
136    ///
137    /// // Encapsulate a shared secret
138    /// let (shared_secret, ciphertext) =
139    ///     public_key.encapsulate_new_shared_secret();
140    ///
141    /// // The private key holder can decapsulate the same shared secret
142    /// let decapsulated_secret =
143    ///     private_key.decapsulate_shared_secret(&ciphertext).unwrap();
144    /// assert_eq!(shared_secret, decapsulated_secret);
145    /// ```
146    pub fn encapsulate_new_shared_secret(
147        &self,
148    ) -> (SymmetricKey, MLKEMCiphertext) {
149        match self {
150            MLKEMPublicKey::MLKEM512(pk) => {
151                let (ss, ct) = mlkem512::encapsulate(pk.as_ref());
152                (
153                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
154                    MLKEMCiphertext::MLKEM512(ct.into()),
155                )
156            }
157            MLKEMPublicKey::MLKEM768(pk) => {
158                let (ss, ct) = mlkem768::encapsulate(pk.as_ref());
159                (
160                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
161                    MLKEMCiphertext::MLKEM768(ct.into()),
162                )
163            }
164            MLKEMPublicKey::MLKEM1024(pk) => {
165                let (ss, ct) = mlkem1024::encapsulate(pk.as_ref());
166                (
167                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
168                    MLKEMCiphertext::MLKEM1024(ct.into()),
169                )
170            }
171        }
172    }
173}
174
175impl AsRef<[u8]> for MLKEMPublicKey {
176    /// Returns the raw bytes of the public key.
177    fn as_ref(&self) -> &[u8] {
178        match self {
179            MLKEMPublicKey::MLKEM512(pk) => pk.as_ref().as_bytes(),
180            MLKEMPublicKey::MLKEM768(pk) => pk.as_ref().as_bytes(),
181            MLKEMPublicKey::MLKEM1024(pk) => pk.as_ref().as_bytes(),
182        }
183    }
184}
185
186/// Provides debug formatting for ML-KEM public keys.
187impl std::fmt::Debug for MLKEMPublicKey {
188    /// Formats the public key as a string for debugging purposes.
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match self {
191            MLKEMPublicKey::MLKEM512(_) => f.write_str("MLKEM512PublicKey"),
192            MLKEMPublicKey::MLKEM768(_) => f.write_str("MLKEM768PublicKey"),
193            MLKEMPublicKey::MLKEM1024(_) => f.write_str("MLKEM1024PublicKey"),
194        }
195    }
196}
197
198/// Defines CBOR tags for ML-KEM public keys.
199impl CBORTagged for MLKEMPublicKey {
200    /// Returns the CBOR tag for ML-KEM public keys.
201    fn cbor_tags() -> Vec<Tag> {
202        tags_for_values(&[tags::TAG_MLKEM_PUBLIC_KEY])
203    }
204}
205
206/// Converts an `MLKEMPublicKey` to CBOR.
207impl From<MLKEMPublicKey> for CBOR {
208    /// Converts to tagged CBOR.
209    fn from(value: MLKEMPublicKey) -> Self { value.tagged_cbor() }
210}
211
212/// Implements CBOR encoding for ML-KEM public keys.
213impl CBORTaggedEncodable for MLKEMPublicKey {
214    /// Creates the untagged CBOR representation as an array with level and key
215    /// bytes.
216    fn untagged_cbor(&self) -> CBOR {
217        vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
218    }
219}
220
221/// Attempts to convert CBOR to an `MLKEMPublicKey`.
222impl TryFrom<CBOR> for MLKEMPublicKey {
223    type Error = dcbor::Error;
224
225    /// Converts from tagged CBOR.
226    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
227        Self::from_tagged_cbor(cbor)
228    }
229}
230
231/// Implements CBOR decoding for ML-KEM public keys.
232impl CBORTaggedDecodable for MLKEMPublicKey {
233    /// Creates an `MLKEMPublicKey` from untagged CBOR.
234    ///
235    /// # Errors
236    /// Returns an error if the CBOR value doesn't represent a valid ML-KEM
237    /// public key.
238    fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
239        match untagged_cbor.as_case() {
240            CBORCase::Array(elements) => {
241                if elements.len() != 2 {
242                    return Err("MLKEMPublicKey must have two elements".into());
243                }
244
245                let level = MLKEM::try_from(elements[0].clone())?;
246                let data = CBOR::try_into_byte_string(elements[1].clone())?;
247                Ok(MLKEMPublicKey::from_bytes(level, &data)?)
248            }
249            _ => Err("MLKEMPublicKey must be an array".into()),
250        }
251    }
252}
253
254impl ReferenceProvider for MLKEMPublicKey {
255    fn reference(&self) -> Reference {
256        Reference::from_digest(Digest::from_image(
257            self.tagged_cbor().to_cbor_data(),
258        ))
259    }
260}
261
262impl std::fmt::Display for MLKEMPublicKey {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        match self {
265            MLKEMPublicKey::MLKEM512(_) => {
266                write!(f, "MLKEM512PublicKey({})", self.ref_hex_short())
267            }
268            MLKEMPublicKey::MLKEM768(_) => {
269                write!(f, "MLKEM768PublicKey({})", self.ref_hex_short())
270            }
271            MLKEMPublicKey::MLKEM1024(_) => {
272                write!(f, "MLKEM1024PublicKey({})", self.ref_hex_short())
273            }
274        }
275    }
276}