quantcrypt/kem/
ml_kem.rs

1use crate::kem::common::kem_info::KemInfo;
2use crate::kem::common::kem_trait::Kem;
3use crate::kem::common::kem_type::KemType;
4use crate::QuantCryptError;
5use ml_kem::kem::Decapsulate;
6use ml_kem::kem::Encapsulate;
7use ml_kem::*;
8use rand_chacha::ChaCha20Rng;
9use rand_core::CryptoRngCore;
10use rand_core::SeedableRng;
11
12macro_rules! key_gen_ml {
13    ($rng:expr, $curve:ident) => {{
14        let (dk, ek) = $curve::generate($rng);
15        (ek.as_bytes().to_vec(), dk.as_bytes().to_vec())
16    }};
17}
18
19macro_rules! encapsulate_ml {
20    ($rng:expr, $curve:ident, $pk:expr) => {{
21        let ek = get_encapsulation_key_obj::<$curve>($pk.to_vec())?;
22        let (ct, ss) = ek.encapsulate(&mut $rng).unwrap();
23        let ct = ct.as_slice().to_vec();
24        let ss = ss.as_slice().to_vec();
25        Ok((ss, ct))
26    }};
27}
28
29type Result<T> = std::result::Result<T, QuantCryptError>;
30
31// Get the encapsulated key object for the post quantum key encapsulation mechanism
32///
33/// # Arguments
34///
35/// * `pk` - The public key
36///
37/// # Returns
38///
39/// The encapsulated key object
40fn get_encapsulation_key_obj<K: KemCore>(pk: Vec<u8>) -> Result<K::EncapsulationKey> {
41    // Deserialize the public key
42    let pk = Encoded::<K::EncapsulationKey>::try_from(pk.as_slice())
43        .map_err(|_| QuantCryptError::InvalidPublicKey)?;
44    Ok(K::EncapsulationKey::from_bytes(&pk))
45}
46
47/// Get the decapsulation key object for the post quantum key encapsulation mechanism
48///
49/// # Arguments
50///
51/// * `sk` - The secret key
52///
53/// # Returns
54///
55/// The decapsulation key object
56fn get_decapsulation_key_obj<K: KemCore>(sk: &[u8]) -> Result<K::DecapsulationKey> {
57    // Deserialize the public key
58    let sk = Encoded::<K::DecapsulationKey>::try_from(sk)
59        .map_err(|_| QuantCryptError::InvalidPrivateKey)?;
60    Ok(K::DecapsulationKey::from_bytes(&sk))
61}
62
63/// Decapsulate a ciphertext
64///
65/// # Arguments
66///
67/// * `sk` - The secret key to decapsulate with
68/// * `ct` - The encapsulated key to decapsulate
69///
70/// # Returns
71///
72/// The shared secret (ss)
73fn decapsulate<K: KemCore>(sk: &[u8], ct: &[u8]) -> Result<Vec<u8>> {
74    let c = Ciphertext::<K>::try_from(ct).map_err(|_| QuantCryptError::InvalidCiphertext)?;
75    let dk = get_decapsulation_key_obj::<K>(sk)?;
76    let session_key = dk
77        .decapsulate(&c)
78        .map_err(|_| QuantCryptError::DecapFailed)?;
79    Ok(session_key.as_slice().to_vec())
80}
81
82/// A KEM manager for the MlKem method
83pub struct MlKemManager {
84    kem_info: KemInfo,
85}
86
87impl Kem for MlKemManager {
88    /// Create a new KEM instance
89    ///
90    /// # Arguments
91    ///
92    /// * `kem_type` - The type of KEM to create
93    ///
94    /// # Returns
95    ///
96    /// A new KEM instance
97    fn new(kem_type: KemType) -> Result<Self> {
98        let kem_info = KemInfo::new(kem_type);
99        Ok(Self { kem_info })
100    }
101
102    /// Generate a keypair
103    ///
104    /// # Arguments
105    ///
106    /// * `rng` - A random number generator
107    ///
108    /// # Returns
109    ///
110    /// A tuple containing the public and secret keys (pk, sk)
111    fn key_gen_with_rng(&mut self, rng: &mut impl CryptoRngCore) -> Result<(Vec<u8>, Vec<u8>)> {
112        match self.kem_info.kem_type {
113            KemType::MlKem512 => Ok(key_gen_ml!(rng, MlKem512)),
114            KemType::MlKem768 => Ok(key_gen_ml!(rng, MlKem768)),
115            KemType::MlKem1024 => Ok(key_gen_ml!(rng, MlKem1024)),
116            _ => {
117                panic!("Not implemented");
118            }
119        }
120    }
121
122    /// Generate a keypair using the default RNG ChaCha20Rng
123    ///
124    /// # Returns
125    ///
126    /// A tuple containing the public and secret keys (pk, sk)
127    fn key_gen(&mut self) -> Result<(Vec<u8>, Vec<u8>)> {
128        let mut rng = ChaCha20Rng::from_entropy();
129        self.key_gen_with_rng(&mut rng)
130    }
131
132    /// Encapsulate a public key
133    ///
134    /// # Arguments
135    ///
136    /// * `pk` - The public key to encapsulate
137    ///
138    /// # Returns
139    ///
140    /// A tuple containing the shared secret and ciphertext (ss, ct)
141    fn encap(&mut self, pk: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
142        let mut rng = ChaCha20Rng::from_entropy();
143        match self.kem_info.kem_type {
144            KemType::MlKem512 => {
145                encapsulate_ml!(rng, MlKem512, pk)
146            }
147            KemType::MlKem768 => {
148                encapsulate_ml!(rng, MlKem768, pk)
149            }
150            KemType::MlKem1024 => {
151                encapsulate_ml!(rng, MlKem1024, pk)
152            }
153            _ => {
154                panic!("Not implemented");
155            }
156        }
157    }
158
159    /// Decapsulate a ciphertext
160    ///
161    /// # Arguments
162    ///
163    /// * `sk` - The secret key to decapsulate with
164    /// * `ct` - The ciphertext to decapsulate
165    ///
166    /// # Returns
167    ///
168    /// The shared secret
169    fn decap(&self, sk: &[u8], ct: &[u8]) -> Result<Vec<u8>> {
170        match self.kem_info.kem_type {
171            KemType::MlKem512 => decapsulate::<MlKem512>(sk, ct),
172            KemType::MlKem768 => decapsulate::<MlKem768>(sk, ct),
173            KemType::MlKem1024 => decapsulate::<MlKem1024>(sk, ct),
174            _ => Err(QuantCryptError::NotImplemented),
175        }
176    }
177
178    /// Get KEM metadata information such as the key lengths,
179    /// size of ciphertext, etc.
180    ///
181    /// These values are also used to test the correctness of the KEM
182    ///
183    /// # Returns
184    ///
185    /// A structure containing metadata about the KEM
186    fn get_kem_info(&self) -> KemInfo {
187        self.kem_info.clone()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::kem::common::kem_type::KemType;
195    use crate::kem::common::macros::test_kem;
196
197    #[test]
198    fn test_ml_kem_512() {
199        let kem = MlKemManager::new(KemType::MlKem512);
200        test_kem!(kem);
201    }
202
203    #[test]
204    fn test_ml_kem_768() {
205        let kem = MlKemManager::new(KemType::MlKem768);
206        test_kem!(kem);
207    }
208
209    #[test]
210    fn test_ml_kem_1024() {
211        let kem = MlKemManager::new(KemType::MlKem1024);
212        test_kem!(kem);
213    }
214}