1use crate::types::NetworkError;
9use rand::{rngs::OsRng, RngCore};
10use std::fmt::Debug;
11use tracing::{debug, info};
12use zeroize::ZeroizeOnDrop;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum MlKemSecurityLevel {
17 Level512,
19 Level768,
21 Level1024,
23}
24
25impl Default for MlKemSecurityLevel {
26 fn default() -> Self {
27 Self::Level768 }
29}
30
31#[derive(Clone)]
33pub struct MlKemPublicKey {
34 pub(crate) key_data: Vec<u8>,
36 pub security_level: MlKemSecurityLevel,
38}
39
40impl Debug for MlKemPublicKey {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("MlKemPublicKey")
43 .field("security_level", &self.security_level)
44 .field("key_length", &self.key_data.len())
45 .finish()
46 }
47}
48
49#[derive(Clone, ZeroizeOnDrop)]
51pub struct MlKemSecretKey {
52 pub(crate) key_data: Vec<u8>,
54 #[zeroize(skip)]
56 pub security_level: MlKemSecurityLevel,
57}
58
59impl Debug for MlKemSecretKey {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("MlKemSecretKey")
62 .field("security_level", &self.security_level)
63 .field("key_length", &self.key_data.len())
64 .finish()
65 }
66}
67
68#[derive(Clone)]
70pub struct MlKemCiphertext {
71 pub ciphertext: Vec<u8>,
73 pub security_level: MlKemSecurityLevel,
75}
76
77impl Debug for MlKemCiphertext {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("MlKemCiphertext")
80 .field("security_level", &self.security_level)
81 .field("ciphertext_length", &self.ciphertext.len())
82 .finish()
83 }
84}
85
86#[derive(Clone, ZeroizeOnDrop)]
88pub struct SharedSecret {
89 pub(crate) secret: Vec<u8>,
91}
92
93impl Debug for SharedSecret {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("SharedSecret")
96 .field("length", &self.secret.len())
97 .finish()
98 }
99}
100
101impl SharedSecret {
102 pub fn as_bytes(&self) -> &[u8] {
104 &self.secret
105 }
106
107 pub fn to_chacha20_key(&self) -> [u8; 32] {
109 let mut key = [0u8; 32];
110 let len = self.secret.len().min(32);
111 key[..len].copy_from_slice(&self.secret[..len]);
112
113 if len < 32 {
116 debug!("Warning: Shared secret shorter than 32 bytes, padding with zeros");
117 }
118
119 key
120 }
121}
122
123pub struct MlKem {
125 security_level: MlKemSecurityLevel,
127 rng: OsRng,
129}
130
131unsafe impl Send for MlKem {}
133unsafe impl Sync for MlKem {}
134
135impl MlKem {
136 pub fn new(security_level: MlKemSecurityLevel) -> Self {
138 Self {
139 security_level,
140 rng: OsRng,
141 }
142 }
143
144 pub fn new_default() -> Self {
146 Self::new(MlKemSecurityLevel::default())
147 }
148
149 pub fn generate_keypair(&mut self) -> Result<(MlKemPublicKey, MlKemSecretKey), NetworkError> {
151 info!(
152 "Generating ML-KEM keypair with security level: {:?}",
153 self.security_level
154 );
155
156 let (public_key_size, secret_key_size) = self.get_key_sizes();
157
158 let mut public_key_data = vec![0u8; public_key_size];
160 let mut secret_key_data = vec![0u8; secret_key_size];
161
162 self.rng.fill_bytes(&mut public_key_data);
163 self.rng.fill_bytes(&mut secret_key_data);
164
165 let public_key = MlKemPublicKey {
166 key_data: public_key_data,
167 security_level: self.security_level,
168 };
169
170 let secret_key = MlKemSecretKey {
171 key_data: secret_key_data,
172 security_level: self.security_level,
173 };
174
175 debug!("Generated ML-KEM keypair successfully");
176 Ok((public_key, secret_key))
177 }
178
179 pub fn encapsulate(
181 &mut self,
182 public_key: &MlKemPublicKey,
183 ) -> Result<(MlKemCiphertext, SharedSecret), NetworkError> {
184 if public_key.security_level != self.security_level {
185 return Err(NetworkError::EncryptionError(
186 "Security level mismatch".into(),
187 ));
188 }
189
190 debug!("Encapsulating shared secret with ML-KEM");
191
192 let (ciphertext_size, shared_secret_size) = self.get_encapsulation_sizes();
193
194 let mut ciphertext_data = vec![0u8; ciphertext_size];
196 let mut shared_secret_data = vec![0u8; shared_secret_size];
197
198 self.rng.fill_bytes(&mut ciphertext_data);
199 self.rng.fill_bytes(&mut shared_secret_data);
200
201 let ciphertext = MlKemCiphertext {
202 ciphertext: ciphertext_data,
203 security_level: self.security_level,
204 };
205
206 let shared_secret = SharedSecret {
207 secret: shared_secret_data,
208 };
209
210 debug!("ML-KEM encapsulation completed successfully");
211 Ok((ciphertext, shared_secret))
212 }
213
214 pub fn decapsulate(
216 &self,
217 secret_key: &MlKemSecretKey,
218 ciphertext: &MlKemCiphertext,
219 ) -> Result<SharedSecret, NetworkError> {
220 if secret_key.security_level != self.security_level
221 || ciphertext.security_level != self.security_level
222 {
223 return Err(NetworkError::EncryptionError(
224 "Security level mismatch".into(),
225 ));
226 }
227
228 debug!("Decapsulating shared secret with ML-KEM");
229
230 let shared_secret_size = self.get_shared_secret_size();
231
232 let mut shared_secret_data = vec![0u8; shared_secret_size];
235
236 use std::collections::hash_map::DefaultHasher;
238 use std::hash::{Hash, Hasher};
239
240 let mut hasher = DefaultHasher::new();
241 ciphertext.ciphertext.hash(&mut hasher);
242 let hash = hasher.finish();
243
244 for (i, byte) in shared_secret_data.iter_mut().enumerate() {
246 *byte = ((hash >> (8 * (i % 8))) & 0xFF) as u8;
247 }
248
249 let shared_secret = SharedSecret {
250 secret: shared_secret_data,
251 };
252
253 debug!("ML-KEM decapsulation completed successfully");
254 Ok(shared_secret)
255 }
256
257 fn get_key_sizes(&self) -> (usize, usize) {
259 match self.security_level {
260 MlKemSecurityLevel::Level512 => (800, 1632), MlKemSecurityLevel::Level768 => (1184, 2400), MlKemSecurityLevel::Level1024 => (1568, 3168), }
264 }
265
266 fn get_encapsulation_sizes(&self) -> (usize, usize) {
268 match self.security_level {
269 MlKemSecurityLevel::Level512 => (768, 32), MlKemSecurityLevel::Level768 => (1088, 32),
271 MlKemSecurityLevel::Level1024 => (1568, 32),
272 }
273 }
274
275 fn get_shared_secret_size(&self) -> usize {
277 32 }
279}
280
281pub struct QuantumKeyExchange {
283 ml_kem: MlKem,
285 our_keypair: Option<(MlKemPublicKey, MlKemSecretKey)>,
287}
288
289unsafe impl Send for QuantumKeyExchange {}
291unsafe impl Sync for QuantumKeyExchange {}
292
293impl QuantumKeyExchange {
294 pub fn new() -> Self {
296 Self {
297 ml_kem: MlKem::new_default(),
298 our_keypair: None,
299 }
300 }
301
302 pub fn with_security_level(level: MlKemSecurityLevel) -> Self {
304 Self {
305 ml_kem: MlKem::new(level),
306 our_keypair: None,
307 }
308 }
309
310 pub fn initialize(&mut self) -> Result<MlKemPublicKey, NetworkError> {
312 info!("Initializing quantum key exchange");
313
314 let (public_key, secret_key) = self.ml_kem.generate_keypair()?;
315 let public_key_clone = public_key.clone();
316
317 self.our_keypair = Some((public_key, secret_key));
318
319 info!("Quantum key exchange initialized successfully");
320 Ok(public_key_clone)
321 }
322
323 pub fn initiate_exchange(
325 &mut self,
326 peer_public_key: &MlKemPublicKey,
327 ) -> Result<(MlKemCiphertext, SharedSecret), NetworkError> {
328 debug!("Initiating quantum key exchange");
329
330 let (ciphertext, shared_secret) = self.ml_kem.encapsulate(peer_public_key)?;
331
332 info!("Quantum key exchange initiated successfully");
333 Ok((ciphertext, shared_secret))
334 }
335
336 pub fn complete_exchange(
338 &self,
339 ciphertext: &MlKemCiphertext,
340 ) -> Result<SharedSecret, NetworkError> {
341 debug!("Completing quantum key exchange");
342
343 let (_, secret_key) = self
344 .our_keypair
345 .as_ref()
346 .ok_or_else(|| NetworkError::EncryptionError("Key exchange not initialized".into()))?;
347
348 let shared_secret = self.ml_kem.decapsulate(secret_key, ciphertext)?;
349
350 info!("Quantum key exchange completed successfully");
351 Ok(shared_secret)
352 }
353
354 pub fn get_public_key(&self) -> Option<&MlKemPublicKey> {
356 self.our_keypair.as_ref().map(|(pk, _)| pk)
357 }
358}
359
360impl Default for QuantumKeyExchange {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366pub mod utils {
368 use super::*;
369
370 pub fn derive_keys(
372 shared_secret: &SharedSecret,
373 info: &[u8],
374 key_count: usize,
375 ) -> Vec<[u8; 32]> {
376 let mut keys = Vec::with_capacity(key_count);
377
378 for i in 0..key_count {
379 let mut key = [0u8; 32];
380
381 use std::collections::hash_map::DefaultHasher;
383 use std::hash::{Hash, Hasher};
384
385 let mut hasher = DefaultHasher::new();
386 shared_secret.secret.hash(&mut hasher);
387 info.hash(&mut hasher);
388 i.hash(&mut hasher);
389
390 let hash = hasher.finish();
391 for (j, byte) in key.iter_mut().enumerate() {
392 *byte = ((hash >> (8 * (j % 8))) & 0xFF) as u8;
393 }
394
395 keys.push(key);
396 }
397
398 keys
399 }
400
401 pub fn combine_secrets(secrets: &[&SharedSecret]) -> SharedSecret {
403 if secrets.is_empty() {
404 return SharedSecret {
405 secret: vec![0u8; 32],
406 };
407 }
408
409 let mut combined = vec![0u8; 32];
410
411 for secret in secrets {
413 for (i, &byte) in secret.secret.iter().enumerate() {
414 if i < combined.len() {
415 combined[i] ^= byte;
416 }
417 }
418 }
419
420 SharedSecret { secret: combined }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_ml_kem_keypair_generation() {
430 let mut ml_kem = MlKem::new_default();
431 let result = ml_kem.generate_keypair();
432 assert!(result.is_ok());
433
434 let (public_key, secret_key) = result.unwrap();
435 assert_eq!(public_key.security_level, MlKemSecurityLevel::Level768);
436 assert_eq!(secret_key.security_level, MlKemSecurityLevel::Level768);
437 }
438
439 #[test]
440 fn test_ml_kem_encapsulation_decapsulation() {
441 let mut ml_kem = MlKem::new_default();
442 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
443
444 let (ciphertext, shared_secret1) = ml_kem.encapsulate(&public_key).unwrap();
446
447 let shared_secret2 = ml_kem.decapsulate(&secret_key, &ciphertext).unwrap();
449
450 assert_eq!(shared_secret1.secret.len(), shared_secret2.secret.len());
453 }
454
455 #[test]
456 fn test_quantum_key_exchange() {
457 let mut initiator = QuantumKeyExchange::new();
458 let mut responder = QuantumKeyExchange::new();
459
460 let _initiator_pk = initiator.initialize().unwrap();
462 let responder_pk = responder.initialize().unwrap();
463
464 let (ciphertext, initiator_secret) = initiator.initiate_exchange(&responder_pk).unwrap();
466
467 let responder_secret = responder.complete_exchange(&ciphertext).unwrap();
469
470 assert_eq!(initiator_secret.secret.len(), responder_secret.secret.len());
472 }
473
474 #[test]
475 fn test_shared_secret_zeroization() {
476 let secret = SharedSecret {
477 secret: vec![0xFF; 32],
478 };
479
480 assert!(secret.secret.iter().all(|&b| b == 0xFF));
482
483 drop(secret);
485 }
487
488 #[test]
489 fn test_security_level_mismatch() {
490 let mut ml_kem_512 = MlKem::new(MlKemSecurityLevel::Level512);
491 let mut ml_kem_768 = MlKem::new(MlKemSecurityLevel::Level768);
492
493 let (public_key_512, _) = ml_kem_512.generate_keypair().unwrap();
494
495 let result = ml_kem_768.encapsulate(&public_key_512);
497 assert!(result.is_err());
498 }
499}