bc_components/mlkem/
mlkem_public_key.rs

1use anyhow::{Result, anyhow};
2use dcbor::prelude::*;
3use pqcrypto_mlkem::*;
4use pqcrypto_traits::kem::{PublicKey, SharedSecret};
5
6use super::{MLKEM, MLKEMCiphertext};
7use crate::{SymmetricKey, tags};
8
9/// A public key for the ML-KEM post-quantum key encapsulation mechanism.
10///
11/// `MLKEMPublicKey` represents a public key that can be used to encapsulate
12/// shared secrets using the ML-KEM (Module Lattice-based Key Encapsulation
13/// Mechanism) post-quantum algorithm. It supports multiple security levels
14/// through the variants:
15///
16/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 800
17///   bytes
18/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 1184
19///   bytes
20/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 1568
21///   bytes
22///
23/// # Examples
24///
25/// ```
26/// use bc_components::MLKEM;
27///
28/// // Generate a keypair
29/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
30///
31/// // Encapsulate a shared secret using the public key
32/// let (shared_secret, ciphertext) =
33///     public_key.encapsulate_new_shared_secret();
34/// ```
35#[derive(Clone)]
36pub enum MLKEMPublicKey {
37    /// An ML-KEM-512 public key (NIST security level 1)
38    MLKEM512(Box<mlkem512::PublicKey>),
39    /// An ML-KEM-768 public key (NIST security level 3)
40    MLKEM768(Box<mlkem768::PublicKey>),
41    /// An ML-KEM-1024 public key (NIST security level 5)
42    MLKEM1024(Box<mlkem1024::PublicKey>),
43}
44
45/// Implements equality comparison for ML-KEM public keys.
46impl PartialEq for MLKEMPublicKey {
47    /// Compares two ML-KEM public keys for equality.
48    ///
49    /// Two ML-KEM public keys are equal if they have the same security level
50    /// and the same raw byte representation.
51    fn eq(&self, other: &Self) -> bool {
52        self.level() == other.level() && self.as_bytes() == other.as_bytes()
53    }
54}
55
56impl Eq for MLKEMPublicKey {}
57
58/// Implements hashing for ML-KEM public keys.
59impl std::hash::Hash for MLKEMPublicKey {
60    /// Hashes both the security level and the raw bytes of the public key.
61    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62        self.level().hash(state);
63        self.as_bytes().hash(state);
64    }
65}
66
67impl MLKEMPublicKey {
68    /// Returns the security level of this ML-KEM public key.
69    pub fn level(&self) -> MLKEM {
70        match self {
71            MLKEMPublicKey::MLKEM512(_) => MLKEM::MLKEM512,
72            MLKEMPublicKey::MLKEM768(_) => MLKEM::MLKEM768,
73            MLKEMPublicKey::MLKEM1024(_) => MLKEM::MLKEM1024,
74        }
75    }
76
77    /// Returns the size of this ML-KEM public key in bytes.
78    pub fn size(&self) -> usize { self.level().public_key_size() }
79
80    /// Returns the raw bytes of this ML-KEM public key.
81    pub fn as_bytes(&self) -> &[u8] { self.as_ref() }
82
83    /// Creates an ML-KEM public key from raw bytes and a security level.
84    ///
85    /// # Parameters
86    ///
87    /// * `level` - The security level of the key.
88    /// * `bytes` - The raw bytes of the key.
89    ///
90    /// # Returns
91    ///
92    /// An `MLKEMPublicKey` if the bytes represent a valid key for the given
93    /// level, or an error otherwise.
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if the bytes do not represent a valid ML-KEM public key
98    /// for the specified security level.
99    pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
100        match level {
101            MLKEM::MLKEM512 => Ok(MLKEMPublicKey::MLKEM512(Box::new(
102                mlkem512::PublicKey::from_bytes(bytes)
103                    .map_err(|e| anyhow!(e))?,
104            ))),
105            MLKEM::MLKEM768 => Ok(MLKEMPublicKey::MLKEM768(Box::new(
106                mlkem768::PublicKey::from_bytes(bytes)
107                    .map_err(|e| anyhow!(e))?,
108            ))),
109            MLKEM::MLKEM1024 => Ok(MLKEMPublicKey::MLKEM1024(Box::new(
110                mlkem1024::PublicKey::from_bytes(bytes)
111                    .map_err(|e| anyhow!(e))?,
112            ))),
113        }
114    }
115
116    /// Encapsulates a new shared secret using this public key.
117    ///
118    /// This method generates a random shared secret and encapsulates it using
119    /// this public key, producing a ciphertext that can only be decapsulated
120    /// by the corresponding private key.
121    ///
122    /// # Returns
123    ///
124    /// A tuple containing:
125    /// - A `SymmetricKey` with the shared secret (32 bytes)
126    /// - An `MLKEMCiphertext` with the encapsulated shared secret
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use bc_components::MLKEM;
132    ///
133    /// // Generate a keypair
134    /// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
135    ///
136    /// // Encapsulate a shared secret
137    /// let (shared_secret, ciphertext) =
138    ///     public_key.encapsulate_new_shared_secret();
139    ///
140    /// // The private key holder can decapsulate the same shared secret
141    /// let decapsulated_secret =
142    ///     private_key.decapsulate_shared_secret(&ciphertext).unwrap();
143    /// assert_eq!(shared_secret, decapsulated_secret);
144    /// ```
145    pub fn encapsulate_new_shared_secret(
146        &self,
147    ) -> (SymmetricKey, MLKEMCiphertext) {
148        match self {
149            MLKEMPublicKey::MLKEM512(pk) => {
150                let (ss, ct) = mlkem512::encapsulate(pk.as_ref());
151                (
152                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
153                    MLKEMCiphertext::MLKEM512(ct.into()),
154                )
155            }
156            MLKEMPublicKey::MLKEM768(pk) => {
157                let (ss, ct) = mlkem768::encapsulate(pk.as_ref());
158                (
159                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
160                    MLKEMCiphertext::MLKEM768(ct.into()),
161                )
162            }
163            MLKEMPublicKey::MLKEM1024(pk) => {
164                let (ss, ct) = mlkem1024::encapsulate(pk.as_ref());
165                (
166                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
167                    MLKEMCiphertext::MLKEM1024(ct.into()),
168                )
169            }
170        }
171    }
172}
173
174impl AsRef<[u8]> for MLKEMPublicKey {
175    /// Returns the raw bytes of the public key.
176    fn as_ref(&self) -> &[u8] {
177        match self {
178            MLKEMPublicKey::MLKEM512(pk) => pk.as_ref().as_bytes(),
179            MLKEMPublicKey::MLKEM768(pk) => pk.as_ref().as_bytes(),
180            MLKEMPublicKey::MLKEM1024(pk) => pk.as_ref().as_bytes(),
181        }
182    }
183}
184
185/// Provides debug formatting for ML-KEM public keys.
186impl std::fmt::Debug for MLKEMPublicKey {
187    /// Formats the public key as a string for debugging purposes.
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            MLKEMPublicKey::MLKEM512(_) => f.write_str("MLKEM512PublicKey"),
191            MLKEMPublicKey::MLKEM768(_) => f.write_str("MLKEM768PublicKey"),
192            MLKEMPublicKey::MLKEM1024(_) => f.write_str("MLKEM1024PublicKey"),
193        }
194    }
195}
196
197/// Defines CBOR tags for ML-KEM public keys.
198impl CBORTagged for MLKEMPublicKey {
199    /// Returns the CBOR tag for ML-KEM public keys.
200    fn cbor_tags() -> Vec<Tag> {
201        tags_for_values(&[tags::TAG_MLKEM_PUBLIC_KEY])
202    }
203}
204
205/// Converts an `MLKEMPublicKey` to CBOR.
206impl From<MLKEMPublicKey> for CBOR {
207    /// Converts to tagged CBOR.
208    fn from(value: MLKEMPublicKey) -> Self { value.tagged_cbor() }
209}
210
211/// Implements CBOR encoding for ML-KEM public keys.
212impl CBORTaggedEncodable for MLKEMPublicKey {
213    /// Creates the untagged CBOR representation as an array with level and key
214    /// bytes.
215    fn untagged_cbor(&self) -> CBOR {
216        vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
217    }
218}
219
220/// Attempts to convert CBOR to an `MLKEMPublicKey`.
221impl TryFrom<CBOR> for MLKEMPublicKey {
222    type Error = dcbor::Error;
223
224    /// Converts from tagged CBOR.
225    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
226        Self::from_tagged_cbor(cbor)
227    }
228}
229
230/// Implements CBOR decoding for ML-KEM public keys.
231impl CBORTaggedDecodable for MLKEMPublicKey {
232    /// Creates an `MLKEMPublicKey` from untagged CBOR.
233    ///
234    /// # Errors
235    /// Returns an error if the CBOR value doesn't represent a valid ML-KEM
236    /// public key.
237    fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
238        match untagged_cbor.as_case() {
239            CBORCase::Array(elements) => {
240                if elements.len() != 2 {
241                    return Err("MLKEMPublicKey must have two elements".into());
242                }
243
244                let level = MLKEM::try_from(elements[0].clone())?;
245                let data = CBOR::try_into_byte_string(elements[1].clone())?;
246                Ok(MLKEMPublicKey::from_bytes(level, &data)?)
247            }
248            _ => Err("MLKEMPublicKey must be an array".into()),
249        }
250    }
251}