bc_components/mlkem/
mlkem_public_key.rs

1use anyhow::{anyhow, bail, Error, Result};
2use dcbor::prelude::*;
3use pqcrypto_mlkem::*;
4use pqcrypto_traits::kem::{PublicKey, SharedSecret};
5
6use crate::{tags, SymmetricKey};
7
8use super::{MLKEMCiphertext, MLKEM};
9
10/// A public key for the ML-KEM post-quantum key encapsulation mechanism.
11///
12/// `MLKEMPublicKey` represents a public key that can be used to encapsulate shared secrets
13/// using the ML-KEM (Module Lattice-based Key Encapsulation Mechanism) post-quantum algorithm.
14/// It supports multiple security levels through the variants:
15///
16/// - `MLKEM512`: NIST security level 1 (roughly equivalent to AES-128), 800 bytes
17/// - `MLKEM768`: NIST security level 3 (roughly equivalent to AES-192), 1184 bytes
18/// - `MLKEM1024`: NIST security level 5 (roughly equivalent to AES-256), 1568 bytes
19///
20/// # Examples
21///
22/// ```
23/// use bc_components::MLKEM;
24///
25/// // Generate a keypair
26/// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
27///
28/// // Encapsulate a shared secret using the public key
29/// let (shared_secret, ciphertext) = public_key.encapsulate_new_shared_secret();
30/// ```
31#[derive(Clone)]
32pub enum MLKEMPublicKey {
33    /// An ML-KEM-512 public key (NIST security level 1)
34    MLKEM512(Box<mlkem512::PublicKey>),
35    /// An ML-KEM-768 public key (NIST security level 3)
36    MLKEM768(Box<mlkem768::PublicKey>),
37    /// An ML-KEM-1024 public key (NIST security level 5)
38    MLKEM1024(Box<mlkem1024::PublicKey>),
39}
40
41/// Implements equality comparison for ML-KEM public keys.
42impl PartialEq for MLKEMPublicKey {
43    /// Compares two ML-KEM public keys for equality.
44    ///
45    /// Two ML-KEM public keys are equal if they have the same security level
46    /// and the same raw byte representation.
47    fn eq(&self, other: &Self) -> bool {
48        self.level() == other.level() && self.as_bytes() == other.as_bytes()
49    }
50}
51
52impl Eq for MLKEMPublicKey {}
53
54/// Implements hashing for ML-KEM public keys.
55impl std::hash::Hash for MLKEMPublicKey {
56    /// Hashes both the security level and the raw bytes of the public key.
57    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58        self.level().hash(state);
59        self.as_bytes().hash(state);
60    }
61}
62
63impl MLKEMPublicKey {
64    /// Returns the security level of this ML-KEM public key.
65    pub fn level(&self) -> MLKEM {
66        match self {
67            MLKEMPublicKey::MLKEM512(_) => MLKEM::MLKEM512,
68            MLKEMPublicKey::MLKEM768(_) => MLKEM::MLKEM768,
69            MLKEMPublicKey::MLKEM1024(_) => MLKEM::MLKEM1024,
70        }
71    }
72
73    /// Returns the size of this ML-KEM public key in bytes.
74    pub fn size(&self) -> usize {
75        self.level().public_key_size()
76    }
77
78    /// Returns the raw bytes of this ML-KEM public key.
79    pub fn as_bytes(&self) -> &[u8] {
80        match self {
81            MLKEMPublicKey::MLKEM512(pk) => pk.as_ref().as_bytes(),
82            MLKEMPublicKey::MLKEM768(pk) => pk.as_ref().as_bytes(),
83            MLKEMPublicKey::MLKEM1024(pk) => pk.as_ref().as_bytes(),
84        }
85    }
86
87    /// Creates an ML-KEM public key from raw bytes and a security level.
88    ///
89    /// # Parameters
90    ///
91    /// * `level` - The security level of the key.
92    /// * `bytes` - The raw bytes of the key.
93    ///
94    /// # Returns
95    ///
96    /// An `MLKEMPublicKey` if the bytes represent a valid key for the given level,
97    /// or an error otherwise.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the bytes do not represent a valid ML-KEM public key
102    /// for the specified security level.
103    pub fn from_bytes(level: MLKEM, bytes: &[u8]) -> Result<Self> {
104        match level {
105            MLKEM::MLKEM512 => Ok(MLKEMPublicKey::MLKEM512(Box::new(
106                mlkem512::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?,
107            ))),
108            MLKEM::MLKEM768 => Ok(MLKEMPublicKey::MLKEM768(Box::new(
109                mlkem768::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?,
110            ))),
111            MLKEM::MLKEM1024 => Ok(MLKEMPublicKey::MLKEM1024(Box::new(
112                mlkem1024::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?,
113            ))),
114        }
115    }
116
117    /// Encapsulates a new shared secret using this public key.
118    ///
119    /// This method generates a random shared secret and encapsulates it using
120    /// this public key, producing a ciphertext that can only be decapsulated
121    /// by the corresponding private key.
122    ///
123    /// # Returns
124    ///
125    /// A tuple containing:
126    /// - A `SymmetricKey` with the shared secret (32 bytes)
127    /// - An `MLKEMCiphertext` with the encapsulated shared secret
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use bc_components::MLKEM;
133    ///
134    /// // Generate a keypair
135    /// let (private_key, public_key) = MLKEM::MLKEM512.keypair();
136    ///
137    /// // Encapsulate a shared secret
138    /// let (shared_secret, ciphertext) = public_key.encapsulate_new_shared_secret();
139    ///
140    /// // The private key holder can decapsulate the same shared secret
141    /// let decapsulated_secret = private_key.decapsulate_shared_secret(&ciphertext).unwrap();
142    /// assert_eq!(shared_secret, decapsulated_secret);
143    /// ```
144    pub fn encapsulate_new_shared_secret(&self) -> (SymmetricKey, MLKEMCiphertext) {
145        match self {
146            MLKEMPublicKey::MLKEM512(pk) => {
147                let (ss, ct) = mlkem512::encapsulate(pk.as_ref());
148                (
149                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
150                    MLKEMCiphertext::MLKEM512(ct.into()),
151                )
152            }
153            MLKEMPublicKey::MLKEM768(pk) => {
154                let (ss, ct) = mlkem768::encapsulate(pk.as_ref());
155                (
156                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
157                    MLKEMCiphertext::MLKEM768(ct.into()),
158                )
159            }
160            MLKEMPublicKey::MLKEM1024(pk) => {
161                let (ss, ct) = mlkem1024::encapsulate(pk.as_ref());
162                (
163                    SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(),
164                    MLKEMCiphertext::MLKEM1024(ct.into()),
165                )
166            }
167        }
168    }
169}
170
171/// Provides debug formatting for ML-KEM public keys.
172impl std::fmt::Debug for MLKEMPublicKey {
173    /// Formats the public key as a string for debugging purposes.
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        match self {
176            MLKEMPublicKey::MLKEM512(_) => f.write_str("MLKEM512PublicKey"),
177            MLKEMPublicKey::MLKEM768(_) => f.write_str("MLKEM768PublicKey"),
178            MLKEMPublicKey::MLKEM1024(_) => f.write_str("MLKEM1024PublicKey"),
179        }
180    }
181}
182
183/// Defines CBOR tags for ML-KEM public keys.
184impl CBORTagged for MLKEMPublicKey {
185    /// Returns the CBOR tag for ML-KEM public keys.
186    fn cbor_tags() -> Vec<Tag> {
187        tags_for_values(&[tags::TAG_MLKEM_PUBLIC_KEY])
188    }
189}
190
191/// Converts an `MLKEMPublicKey` to CBOR.
192impl From<MLKEMPublicKey> for CBOR {
193    /// Converts to tagged CBOR.
194    fn from(value: MLKEMPublicKey) -> Self {
195        value.tagged_cbor()
196    }
197}
198
199/// Implements CBOR encoding for ML-KEM public keys.
200impl CBORTaggedEncodable for MLKEMPublicKey {
201    /// Creates the untagged CBOR representation as an array with level and key bytes.
202    fn untagged_cbor(&self) -> CBOR {
203        vec![self.level().into(), CBOR::to_byte_string(self.as_bytes())].into()
204    }
205}
206
207/// Attempts to convert CBOR to an `MLKEMPublicKey`.
208impl TryFrom<CBOR> for MLKEMPublicKey {
209    type Error = Error;
210
211    /// Converts from tagged CBOR.
212    fn try_from(cbor: CBOR) -> Result<Self, Self::Error> {
213        Self::from_tagged_cbor(cbor)
214    }
215}
216
217/// Implements CBOR decoding for ML-KEM public keys.
218impl CBORTaggedDecodable for MLKEMPublicKey {
219    /// Creates an `MLKEMPublicKey` from untagged CBOR.
220    ///
221    /// # Errors
222    /// Returns an error if the CBOR value doesn't represent a valid ML-KEM public key.
223    fn from_untagged_cbor(untagged_cbor: CBOR) -> Result<Self> {
224        match untagged_cbor.as_case() {
225            CBORCase::Array(elements) => {
226                if elements.len() != 2 {
227                    bail!("MLKEMPublicKey must have two elements");
228                }
229
230                let level = MLKEM::try_from(elements[0].clone())?;
231                let data = CBOR::try_into_byte_string(elements[1].clone())?;
232                MLKEMPublicKey::from_bytes(level, &data)
233            }
234            _ => bail!("MLKEMPublicKey must be an array"),
235        }
236    }
237}