1use crate::kem::common::kem_info::KemInfo;
2use crate::kem::common::kem_trait::Kem;
3use crate::kem::common::kem_type::KemType;
4use crate::QuantCryptError;
5use ml_kem::kem::Decapsulate;
6use ml_kem::kem::Encapsulate;
7use ml_kem::*;
8use rand_chacha::ChaCha20Rng;
9use rand_core::CryptoRngCore;
10use rand_core::SeedableRng;
11
12macro_rules! key_gen_ml {
13 ($rng:expr, $curve:ident) => {{
14 let (dk, ek) = $curve::generate($rng);
15 (ek.as_bytes().to_vec(), dk.as_bytes().to_vec())
16 }};
17}
18
19macro_rules! encapsulate_ml {
20 ($rng:expr, $curve:ident, $pk:expr) => {{
21 let ek = get_encapsulation_key_obj::<$curve>($pk.to_vec())?;
22 let (ct, ss) = ek.encapsulate(&mut $rng).unwrap();
23 let ct = ct.as_slice().to_vec();
24 let ss = ss.as_slice().to_vec();
25 Ok((ss, ct))
26 }};
27}
28
29type Result<T> = std::result::Result<T, QuantCryptError>;
30
31fn get_encapsulation_key_obj<K: KemCore>(pk: Vec<u8>) -> Result<K::EncapsulationKey> {
41 let pk = Encoded::<K::EncapsulationKey>::try_from(pk.as_slice())
43 .map_err(|_| QuantCryptError::InvalidPublicKey)?;
44 Ok(K::EncapsulationKey::from_bytes(&pk))
45}
46
47fn get_decapsulation_key_obj<K: KemCore>(sk: &[u8]) -> Result<K::DecapsulationKey> {
57 let sk = Encoded::<K::DecapsulationKey>::try_from(sk)
59 .map_err(|_| QuantCryptError::InvalidPrivateKey)?;
60 Ok(K::DecapsulationKey::from_bytes(&sk))
61}
62
63fn decapsulate<K: KemCore>(sk: &[u8], ct: &[u8]) -> Result<Vec<u8>> {
74 let c = Ciphertext::<K>::try_from(ct).map_err(|_| QuantCryptError::InvalidCiphertext)?;
75 let dk = get_decapsulation_key_obj::<K>(sk)?;
76 let session_key = dk
77 .decapsulate(&c)
78 .map_err(|_| QuantCryptError::DecapFailed)?;
79 Ok(session_key.as_slice().to_vec())
80}
81
82pub struct MlKemManager {
84 kem_info: KemInfo,
85}
86
87impl Kem for MlKemManager {
88 fn new(kem_type: KemType) -> Result<Self> {
98 let kem_info = KemInfo::new(kem_type);
99 Ok(Self { kem_info })
100 }
101
102 fn key_gen_with_rng(&mut self, rng: &mut impl CryptoRngCore) -> Result<(Vec<u8>, Vec<u8>)> {
112 match self.kem_info.kem_type {
113 KemType::MlKem512 => Ok(key_gen_ml!(rng, MlKem512)),
114 KemType::MlKem768 => Ok(key_gen_ml!(rng, MlKem768)),
115 KemType::MlKem1024 => Ok(key_gen_ml!(rng, MlKem1024)),
116 _ => {
117 panic!("Not implemented");
118 }
119 }
120 }
121
122 fn key_gen(&mut self) -> Result<(Vec<u8>, Vec<u8>)> {
128 let mut rng = ChaCha20Rng::from_entropy();
129 self.key_gen_with_rng(&mut rng)
130 }
131
132 fn encap(&mut self, pk: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
142 let mut rng = ChaCha20Rng::from_entropy();
143 match self.kem_info.kem_type {
144 KemType::MlKem512 => {
145 encapsulate_ml!(rng, MlKem512, pk)
146 }
147 KemType::MlKem768 => {
148 encapsulate_ml!(rng, MlKem768, pk)
149 }
150 KemType::MlKem1024 => {
151 encapsulate_ml!(rng, MlKem1024, pk)
152 }
153 _ => {
154 panic!("Not implemented");
155 }
156 }
157 }
158
159 fn decap(&self, sk: &[u8], ct: &[u8]) -> Result<Vec<u8>> {
170 match self.kem_info.kem_type {
171 KemType::MlKem512 => decapsulate::<MlKem512>(sk, ct),
172 KemType::MlKem768 => decapsulate::<MlKem768>(sk, ct),
173 KemType::MlKem1024 => decapsulate::<MlKem1024>(sk, ct),
174 _ => Err(QuantCryptError::NotImplemented),
175 }
176 }
177
178 fn get_kem_info(&self) -> KemInfo {
187 self.kem_info.clone()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::kem::common::kem_type::KemType;
195 use crate::kem::common::macros::test_kem;
196
197 #[test]
198 fn test_ml_kem_512() {
199 let kem = MlKemManager::new(KemType::MlKem512);
200 test_kem!(kem);
201 }
202
203 #[test]
204 fn test_ml_kem_768() {
205 let kem = MlKemManager::new(KemType::MlKem768);
206 test_kem!(kem);
207 }
208
209 #[test]
210 fn test_ml_kem_1024() {
211 let kem = MlKemManager::new(KemType::MlKem1024);
212 test_kem!(kem);
213 }
214}