1use pqcrypto_falcon::{
2 falconpadded1024::{self},
3 ffi::{
4 PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_BYTES,
5 PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
6 PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES,
7 },
8};
9use pqcrypto_mlkem::{
10 ffi::{
11 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES, PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES,
12 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
13 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
14 },
15 mlkem1024::{self, SharedSecret},
16};
17use pqcrypto_traits::kem::{Ciphertext, PublicKey};
18use pqcrypto_traits::sign::SignedMessage;
19use rand::RngCore;
20
21use crate::{
22 errors::CryptoError,
23 exchange::{
24 encryptor,
25 pair::{self, KEMPair, b2ss, ss2b},
26 },
27 signatures::keypair::{SignerPair, VerifierPair, ViewOperations},
28};
29
30const MAX_NONCE_COUNTER: u64 = u64::MAX - 1;
32
33pub struct MessageSession {
43 kem_pair: pair::KEMPair,
45 ds_pair: SignerPair,
47 shared_secret: SharedSecret,
49 target_verifier: VerifierPair,
51 current_nonce: [u8; 24],
53}
54
55impl MessageSession {
56 pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
65 let mut bytes = Vec::new();
66
67 bytes.extend_from_slice(self.kem_pair.to_bytes_uniform().as_slice());
69
70 bytes.extend_from_slice(self.ds_pair.to_bytes_uniform().as_slice());
72
73 bytes.extend_from_slice(&ss2b(&self.shared_secret));
75
76 bytes.extend_from_slice(&self.target_verifier.to_bytes());
78
79 bytes.extend_from_slice(&self.current_nonce[..]);
81 Ok(bytes)
82 }
83
84 pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
95 let expected_length = PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
97 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES
98 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
99 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES
100 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES
101 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
102 + 24;
103
104 if bytes.len() != expected_length {
105 return Err(CryptoError::IncongruentLength(expected_length, bytes.len()));
106 }
107
108 let mut idx = 0;
109
110 let kem_pair = pair::KEMPair::from_bytes_uniform(
112 &bytes[idx..idx
113 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
114 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES],
115 )?;
116
117 idx += PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
118 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES;
119
120 let ds_pair = SignerPair::from_bytes_uniform(
122 &bytes[idx..idx
123 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
124 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES],
125 )?;
126
127 idx += PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
128 + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES;
129
130 let ss_bytes = &bytes[idx..idx + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES];
132 let shared_secret = b2ss(parse_ss(ss_bytes)?);
133 idx += PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES;
134
135 let target_verifier = VerifierPair::from_bytes(
137 &bytes[idx..idx + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
138 )?;
139 idx += PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES;
140
141 let current_nonce = bytes[idx..idx + 24].try_into().unwrap();
143 idx += 24;
144
145 if idx != bytes.len() {
147 return Err(CryptoError::IncongruentLength(bytes.len(), idx));
148 }
149
150 Ok(Self {
151 kem_pair,
152 ds_pair,
153 shared_secret,
154 target_verifier,
155 current_nonce,
156 })
157 }
158
159 pub fn new_initiator(
172 my_keypair: KEMPair, my_signer: SignerPair, base_nonce: [u8; 24], target_pubkey: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES], target_verifier: &[u8; PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES], ) -> Result<(Self, [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES]), CryptoError> {
178 let pubkey = mlkem1024::PublicKey::from_bytes(target_pubkey)?;
179
180 let (shared_secret, ciphertext) = my_keypair.encapsulate(&pubkey);
182
183 let target_verifier = VerifierPair::from_bytes(target_verifier)?;
186
187 Ok((
189 Self {
190 kem_pair: my_keypair,
191 ds_pair: my_signer,
192 shared_secret,
193 target_verifier,
194 current_nonce: base_nonce,
195 },
196 ct2b(&ciphertext)?,
197 ))
198 }
199
200 pub fn new_responder(
212 my_keypair: KEMPair, my_signer: SignerPair, base_nonce: [u8; 24], ciphertext_bytes: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES], sender_verifier: &[u8; PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES], ) -> Result<Self, CryptoError> {
218 let ciphertext = Ciphertext::from_bytes(ciphertext_bytes)?;
222 let shared_secret = my_keypair.decapsulate(&ciphertext)?;
223
224 let target_verifier = VerifierPair::from_bytes(sender_verifier)?;
226
227 Ok(Self {
228 kem_pair: my_keypair,
229 ds_pair: my_signer,
230 shared_secret,
231 target_verifier,
232 current_nonce: base_nonce,
233 })
234 }
235
236 pub fn craft_message(&mut self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
248 let sig = self.ds_pair.sign(message);
250
251 self.increment_nonce();
253
254 encryptor::Encryptor::new(self.shared_secret).encrypt(&sig.as_bytes(), &self.current_nonce)
256 }
257
258 pub fn validate_message(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
270 self.increment_nonce();
272
273 let decrypted_message = encryptor::Encryptor::new(self.shared_secret)
275 .decrypt(ciphertext, &self.current_nonce)?;
276
277 if decrypted_message.len() < PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_BYTES {
279 return Err(CryptoError::FalconSignatureTooShort(
280 decrypted_message.len(),
281 ));
282 }
283
284 let sm = falconpadded1024::SignedMessage::from_bytes(&decrypted_message)?;
286 let msg = self.target_verifier.verify_message(&sm)?;
287
288 Ok(msg)
290 }
291
292 fn increment_nonce(&mut self) {
299 let mut counter = u64::from_le_bytes(self.current_nonce[16..24].try_into().unwrap());
300
301 if counter >= MAX_NONCE_COUNTER {
303 counter = 0;
306 } else {
307 counter += 1;
308 }
309
310 self.current_nonce[16..24].copy_from_slice(&counter.to_le_bytes());
311 }
312
313 pub fn get_counter(&self) -> u64 {
318 u64::from_le_bytes(self.current_nonce[16..24].try_into().unwrap())
319 }
320}
321
322fn ct2b(
331 ct: &mlkem1024::Ciphertext,
332) -> Result<[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES], CryptoError> {
333 let slice = ct.as_bytes();
334
335 if slice.len() == PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES {
336 let ptr = slice.as_ptr() as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES];
337 unsafe { Ok(*ptr) }
338 } else {
339 Err(CryptoError::IncongruentLength(
340 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES,
341 slice.len(),
342 ))
343 }
344}
345
346pub fn parse_ss<T>(slice: &[T]) -> Result<&[T; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES], CryptoError> {
355 if slice.len() == PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES {
356 let ptr = slice.as_ptr() as *const [T; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES];
357 unsafe { Ok(&*ptr) }
358 } else {
359 Err(CryptoError::IncongruentLength(
360 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES,
361 slice.len(),
362 ))
363 }
364}
365
366pub fn gen_session_id() -> [u8; 16] {
372 let mut session_id = [0u8; 16];
373 rand::rng().fill_bytes(&mut session_id);
374
375 session_id
376}
377
378pub fn create_nonce(session_id: &[u8; 16], counter: u64) -> [u8; 24] {
388 let mut nonce = [0u8; 24];
389 nonce[..16].copy_from_slice(session_id);
390 nonce[16..24].copy_from_slice(&counter.to_le_bytes());
391 nonce
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_message_session_serialization() {
400 let kem_pair = pair::KEMPair::create();
402 let ds_pair = SignerPair::create();
403 let target_kem_pair = pair::KEMPair::create();
404 let target_ds_pair = SignerPair::create();
405
406 let base_nonce = create_nonce(&gen_session_id(), 0);
408
409 let (session, _) = MessageSession::new_initiator(
411 kem_pair, ds_pair, base_nonce, &target_kem_pair.to_bytes().unwrap().0, &target_ds_pair.to_bytes().unwrap().0, )
417 .unwrap();
418
419 let serialized = session.to_bytes().unwrap();
421
422 let deserialized = MessageSession::from_bytes(&serialized).unwrap();
424
425 assert_eq!(session.current_nonce, deserialized.current_nonce);
427 }
428
429 #[test]
430 fn test_full_message_exchange() {
431 let alice_kem = pair::KEMPair::create();
433 let alice_ds = SignerPair::create();
434 let bob_kem = pair::KEMPair::create();
435 let bob_ds = SignerPair::create();
436
437 let base_nonce = create_nonce(&gen_session_id(), 0);
439
440 let (mut alice_session, ciphertext) = MessageSession::new_initiator(
442 alice_kem,
443 alice_ds.clone(),
444 base_nonce,
445 &bob_kem.to_bytes().unwrap().0,
446 &bob_ds.to_bytes().unwrap().0,
447 )
448 .unwrap();
449
450 let mut bob_session = MessageSession::new_responder(
452 bob_kem,
453 bob_ds.clone(),
454 base_nonce,
455 &ciphertext,
456 &alice_ds.to_bytes().unwrap().0,
457 )
458 .unwrap();
459
460 assert_eq!(
461 ss2b(&alice_session.shared_secret),
462 ss2b(&bob_session.shared_secret)
463 );
464
465 let message = b"Hello, Bob! This is a secret message.";
467 let encrypted_message = alice_session.craft_message(message).unwrap();
468
469 assert_eq!(
470 alice_session.current_nonce[16..24],
471 [1, 0, 0, 0, 0, 0, 0, 0]
472 );
473 assert_eq!(bob_session.current_nonce[16..24], [0, 0, 0, 0, 0, 0, 0, 0]);
474
475 let raw_message = bob_session.validate_message(&encrypted_message).unwrap();
477
478 assert_eq!(bob_session.current_nonce[16..24], [1, 0, 0, 0, 0, 0, 0, 0]);
479
480 assert_eq!(raw_message, message);
482
483 let reply = b"Hello, Alice! I received your message safely.";
485 let encrypted_reply = bob_session.craft_message(reply).unwrap();
486
487 let raw_reply = alice_session.validate_message(&encrypted_reply).unwrap();
489
490 assert_eq!(alice_session.current_nonce, bob_session.current_nonce);
492
493 assert_eq!(raw_reply, reply);
495 }
496
497 #[test]
498 fn test_nonce_desync() {
499 let alice_kem = pair::KEMPair::create();
501 let alice_ds = SignerPair::create();
502 let bob_kem = pair::KEMPair::create();
503 let bob_ds = SignerPair::create();
504
505 let base_nonce = create_nonce(&gen_session_id(), 0);
507
508 let (mut alice_session, ciphertext) = MessageSession::new_initiator(
510 alice_kem,
511 alice_ds.clone(),
512 base_nonce,
513 &bob_kem.to_bytes().unwrap().0,
514 &bob_ds.to_bytes().unwrap().0,
515 )
516 .unwrap();
517
518 let mut bob_session = MessageSession::new_responder(
520 bob_kem,
521 bob_ds.clone(),
522 base_nonce,
523 &ciphertext,
524 &alice_ds.to_bytes().unwrap().0,
525 )
526 .unwrap();
527
528 assert_eq!(
529 ss2b(&alice_session.shared_secret),
530 ss2b(&bob_session.shared_secret)
531 );
532
533 let message = b"Hello, Bob! This is a secret message.";
535 let encrypted_message = alice_session.craft_message(message).unwrap();
536
537 assert_eq!(
538 alice_session.current_nonce[16..24],
539 [1, 0, 0, 0, 0, 0, 0, 0]
540 );
541 assert_eq!(bob_session.current_nonce[16..24], [0, 0, 0, 0, 0, 0, 0, 0]);
542
543 bob_session.increment_nonce();
545 assert_eq!(bob_session.current_nonce[16..24], [1, 0, 0, 0, 0, 0, 0, 0]);
546
547 let result = bob_session.validate_message(&encrypted_message);
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_nonce_increment_and_counter() {
554 let kem_pair = pair::KEMPair::create();
556 let ds_pair = SignerPair::create();
557 let target_kem_pair = pair::KEMPair::create();
558 let target_ds_pair = SignerPair::create();
559
560 let initial_counter = 42;
562 let base_nonce = create_nonce(&gen_session_id(), initial_counter);
563
564 let (mut session, _) = MessageSession::new_initiator(
566 kem_pair,
567 ds_pair,
568 base_nonce,
569 &target_kem_pair.to_bytes().unwrap().0,
570 &target_ds_pair.to_bytes().unwrap().0,
571 )
572 .unwrap();
573
574 let counter = session.get_counter();
576 assert_eq!(counter, initial_counter);
577
578 session.increment_nonce();
580 let new_counter = session.get_counter();
581 assert_eq!(new_counter, initial_counter + 1);
582 }
583
584 #[test]
585 fn test_counter_wraparound() {
586 let kem_pair = pair::KEMPair::create();
588 let ds_pair = SignerPair::create();
589 let target_kem_pair = pair::KEMPair::create();
590 let target_ds_pair = SignerPair::create();
591
592 let base_nonce = create_nonce(&gen_session_id(), MAX_NONCE_COUNTER);
594
595 let (mut session, _) = MessageSession::new_initiator(
597 kem_pair,
598 ds_pair,
599 base_nonce,
600 &target_kem_pair.to_bytes().unwrap().0,
601 &target_ds_pair.to_bytes().unwrap().0,
602 )
603 .unwrap();
604
605 assert_eq!(session.get_counter(), MAX_NONCE_COUNTER);
607
608 session.increment_nonce();
610 assert_eq!(session.get_counter(), 0);
611 }
612
613 #[test]
614 fn test_shared_secret_consistency() {
615 let alice_kem = pair::KEMPair::create();
617 let bob_kem = pair::KEMPair::create();
618
619 let pubkey = mlkem1024::PublicKey::from_bytes(&bob_kem.to_bytes().unwrap().0).unwrap();
621 let (alice_ss, ciphertext) = alice_kem.encapsulate(&pubkey);
622 let ciphertext_bytes = ct2b(&ciphertext).unwrap();
623
624 let ciphertext_received = mlkem1024::Ciphertext::from_bytes(&ciphertext_bytes).unwrap();
626 let bob_ss = bob_kem.decapsulate(&ciphertext_received).unwrap();
627
628 let alice_ss_bytes = ss2b(&alice_ss);
630 let bob_ss_bytes = ss2b(&bob_ss);
631
632 assert_eq!(alice_ss_bytes, bob_ss_bytes);
634 }
635}