bc_components/mlkem/mlkem_public_key.rs
1use anyhow::{ anyhow, Result };
2use pqcrypto_mlkem::*;
3use pqcrypto_traits::kem::{ PublicKey, SharedSecret };
4use dcbor::prelude::*;
5
6use crate::{ tags, SymmetricKey };
7
8use super::{ MLKEMCiphertext, MLKEM };
9
10/// A public key for the ML-KEM post-quantum key encapsulation mechanism.
11///
12/// `MLKEMPublicKey` represents a public key that can be used to encapsulate shared secrets
13/// using the ML-KEM (Module Lattice-based Key Encapsulation Mechanism) post-quantum algorithm.
14/// It supports multiple security levels through the variants:
15///
16/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 800 bytes
17/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 1184 bytes
18/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 1568 bytes
19///
20/// # Examples
21///
22/// ```
23/// use bc_components::MLKEM;
24///
25/// // Generate a keypair
26/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
27///
28/// // Encapsulate a shared secret using the public key
29/// let (shared_secret, ciphertext) = public_key.encapsulate_new_shared_secret();
30/// ```
31#[derive(Clone)]
32pub enum MLKEMPublicKey {
33 /// An ML-KEM-512 public key (NIST security level 1)
34 MLKEM512(Box<mlkem512::PublicKey>),
35 /// An ML-KEM-768 public key (NIST security level 3)
36 MLKEM768(Box<mlkem768::PublicKey>),
37 /// An ML-KEM-1024 public key (NIST security level 5)
38 MLKEM1024(Box<mlkem1024::PublicKey>),
39}
40
41/// Implements equality comparison for ML-KEM public keys.
42impl PartialEq for MLKEMPublicKey {
43 /// Compares two ML-KEM public keys for equality.
44 ///
45 /// Two ML-KEM public keys are equal if they have the same security level
46 /// and the same raw byte representation.
47 fn eq(&self, other: &Self) -> bool {
48 self.level() == other.level() && self.as_bytes() == other.as_bytes()
49 }
50}
51
52impl Eq for MLKEMPublicKey {}
53
54/// Implements hashing for ML-KEM public keys.
55impl std::hash::Hash for MLKEMPublicKey {
56 /// Hashes both the security level and the raw bytes of the public key.
57 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58 self.level().hash(state);
59 self.as_bytes().hash(state);
60 }
61}
62
63impl MLKEMPublicKey {
64 /// Returns the security level of this ML-KEM public key.
65 pub fn level(&self) -> MLKEM {
66 match self {
67 MLKEMPublicKey::MLKEM512(_) => MLKEM::MLKEM512,
68 MLKEMPublicKey::MLKEM768(_) => MLKEM::MLKEM768,
69 MLKEMPublicKey::MLKEM1024(_) => MLKEM::MLKEM1024,
70 }
71 }
72
73 /// Returns the size of this ML-KEM public key in bytes.
74 pub fn size(&self) -> usize {
75 self.level().public_key_size()
76 }
77
78 /// Returns the raw bytes of this ML-KEM public key.
79 pub fn as_bytes(&self) -> &[u8] {
80 match self {
81 MLKEMPublicKey::MLKEM512(pk) => pk.as_ref().as_bytes(),
82 MLKEMPublicKey::MLKEM768(pk) => pk.as_ref().as_bytes(),
83 MLKEMPublicKey::MLKEM1024(pk) => pk.as_ref().as_bytes(),
84 }
85 }
86
87 /// Creates an ML-KEM public key from raw bytes and a security level.
88 ///
89 /// # Parameters
90 ///
91 /// * `level` - The security level of the key.
92 /// * `bytes` - The raw bytes of the key.
93 ///
94 /// # Returns
95 ///
96 /// An `MLKEMPublicKey` if the bytes represent a valid key for the given level,
97 /// or an error otherwise.
98 ///
99 /// # Errors
100 ///
101 /// Returns an error if the bytes do not represent a valid ML-KEM public key
102 /// for the specified security level.
103 pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
104 match level {
105 MLKEM::MLKEM512 =>
106 Ok(
107 MLKEMPublicKey::MLKEM512(
108 Box::new(mlkem512::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
109 )
110 ),
111 MLKEM::MLKEM768 =>
112 Ok(
113 MLKEMPublicKey::MLKEM768(
114 Box::new(mlkem768::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
115 )
116 ),
117 MLKEM::MLKEM1024 =>
118 Ok(
119 MLKEMPublicKey::MLKEM1024(
120 Box::new(mlkem1024::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?)
121 )
122 ),
123 }
124 }
125
126 /// Encapsulates a new shared secret using this public key.
127 ///
128 /// This method generates a random shared secret and encapsulates it using
129 /// this public key, producing a ciphertext that can only be decapsulated
130 /// by the corresponding private key.
131 ///
132 /// # Returns
133 ///
134 /// A tuple containing:
135 /// - A `SymmetricKey` with the shared secret (32 bytes)
136 /// - An `MLKEMCiphertext` with the encapsulated shared secret
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use bc_components::MLKEM;
142 ///
143 /// // Generate a keypair
144 /// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
145 ///
146 /// // Encapsulate a shared secret
147 /// let (shared_secret, ciphertext) = public_key.encapsulate_new_shared_secret();
148 ///
149 /// // The private key holder can decapsulate the same shared secret
150 /// let decapsulated_secret = private_key.decapsulate_shared_secret(&ciphertext).unwrap();
151 /// assert_eq!(shared_secret, decapsulated_secret);
152 /// ```
153 pub fn encapsulate_new_shared_secret(&self) -> (SymmetricKey, MLKEMCiphertext) {
154 match self {
155 MLKEMPublicKey::MLKEM512(pk) => {
156 let (ss, ct) = mlkem512::encapsulate(pk.as_ref());
157 (
158 SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
159 MLKEMCiphertext::MLKEM512(ct.into()),
160 )
161 }
162 MLKEMPublicKey::MLKEM768(pk) => {
163 let (ss, ct) = mlkem768::encapsulate(pk.as_ref());
164 (
165 SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
166 MLKEMCiphertext::MLKEM768(ct.into()),
167 )
168 }
169 MLKEMPublicKey::MLKEM1024(pk) => {
170 let (ss, ct) = mlkem1024::encapsulate(pk.as_ref());
171 (
172 SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
173 MLKEMCiphertext::MLKEM1024(ct.into()),
174 )
175 }
176 }
177 }
178}
179
180/// Provides debug formatting for ML-KEM public keys.
181impl std::fmt::Debug for MLKEMPublicKey {
182 /// Formats the public key as a string for debugging purposes.
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 match self {
185 MLKEMPublicKey::MLKEM512(_) => f.write_str("MLKEM512PublicKey"),
186 MLKEMPublicKey::MLKEM768(_) => f.write_str("MLKEM768PublicKey"),
187 MLKEMPublicKey::MLKEM1024(_) => f.write_str("MLKEM1024PublicKey"),
188 }
189 }
190}
191
192/// Defines CBOR tags for ML-KEM public keys.
193impl CBORTagged for MLKEMPublicKey {
194 /// Returns the CBOR tag for ML-KEM public keys.
195 fn cbor_tags() -> Vec<Tag> {
196 tags_for_values(&[tags::TAG_MLKEM_PUBLIC_KEY])
197 }
198}
199
200/// Converts an `MLKEMPublicKey` to CBOR.
201impl From<MLKEMPublicKey> for CBOR {
202 /// Converts to tagged CBOR.
203 fn from(value: MLKEMPublicKey) -> Self {
204 value.tagged_cbor()
205 }
206}
207
208/// Implements CBOR encoding for ML-KEM public keys.
209impl CBORTaggedEncodable for MLKEMPublicKey {
210 /// Creates the untagged CBOR representation as an array with level and key bytes.
211 fn untagged_cbor(&self) -> CBOR {
212 vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
213 }
214}
215
216/// Attempts to convert CBOR to an `MLKEMPublicKey`.
217impl TryFrom<CBOR> for MLKEMPublicKey {
218 type Error = dcbor::Error;
219
220 /// Converts from tagged CBOR.
221 fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
222 Self::from_tagged_cbor(cbor)
223 }
224}
225
226/// Implements CBOR decoding for ML-KEM public keys.
227impl CBORTaggedDecodable for MLKEMPublicKey {
228 /// Creates an `MLKEMPublicKey` from untagged CBOR.
229 ///
230 /// # Errors
231 /// Returns an error if the CBOR value doesn't represent a valid ML-KEM public key.
232 fn from_untagged_cbor(untagged_cbor: CBOR) -> dcbor::Result<Self> {
233 match untagged_cbor.as_case() {
234 CBORCase::Array(elements) => {
235 if elements.len() != 2 {
236 return Err("MLKEMPublicKey must have two elements".into());
237 }
238
239 let level = MLKEM::try_from(elements[0].clone())?;
240 let data = CBOR::try_into_byte_string(elements[1].clone())?;
241 Ok(MLKEMPublicKey::from_bytes(level, &data)?)
242 }
243 _ => {
244 return Err("MLKEMPublicKey must be an array".into());
245 }
246 }
247 }
248}