1use crate::core::{CryptoError, HASH_SIZE, REPLAY_WINDOW_SIZE};
11
12use super::{
13 aead::{construct_aad, decrypt, encrypt, SessionKey},
14 nonce::{construct_nonce, Direction},
15 rekey::{OldKeyRetention, RekeyState},
16 Role, SessionId,
17};
18
19pub struct ReplayWindow {
27 bitmap: [u64; REPLAY_WINDOW_SIZE / 64],
29 highest: u64,
31 initialized: bool,
33}
34
35impl ReplayWindow {
36 pub fn new() -> Self {
38 Self {
39 bitmap: [0; REPLAY_WINDOW_SIZE / 64],
40 highest: 0,
41 initialized: false,
42 }
43 }
44
45 pub fn is_replay(&self, nonce: u64) -> bool {
47 if !self.initialized {
48 return false;
49 }
50
51 if nonce > self.highest {
52 return false;
53 }
54
55 let diff = self.highest - nonce;
56 if diff >= REPLAY_WINDOW_SIZE as u64 {
57 return true; }
59
60 let bit_index = diff as usize;
61 let word_index = bit_index / 64;
62 let bit_offset = bit_index % 64;
63 (self.bitmap[word_index] & (1 << bit_offset)) != 0
64 }
65
66 pub fn check_and_update(&mut self, nonce: u64) -> Result<(), CryptoError> {
73 if !self.initialized {
74 self.highest = nonce;
76 self.mark_seen(nonce);
77 self.initialized = true;
78 return Ok(());
79 }
80
81 if nonce > self.highest {
82 let shift = nonce - self.highest;
84 self.shift_window(shift);
85 self.highest = nonce;
86 self.mark_seen(nonce);
87 Ok(())
88 } else {
89 let diff = self.highest - nonce;
90 if diff >= REPLAY_WINDOW_SIZE as u64 {
91 return Err(CryptoError::ReplayDetected);
93 }
94
95 if self.is_seen(nonce) {
97 return Err(CryptoError::ReplayDetected);
98 }
99
100 self.mark_seen(nonce);
102 Ok(())
103 }
104 }
105
106 fn is_seen(&self, nonce: u64) -> bool {
108 if nonce > self.highest {
109 return false;
110 }
111 let diff = self.highest - nonce;
112 if diff >= REPLAY_WINDOW_SIZE as u64 {
113 return true; }
115 let bit_index = diff as usize;
116 let word_index = bit_index / 64;
117 let bit_offset = bit_index % 64;
118 (self.bitmap[word_index] & (1 << bit_offset)) != 0
119 }
120
121 fn mark_seen(&mut self, nonce: u64) {
123 if nonce > self.highest {
124 return; }
126 let diff = self.highest - nonce;
127 if diff >= REPLAY_WINDOW_SIZE as u64 {
128 return; }
130 let bit_index = diff as usize;
131 let word_index = bit_index / 64;
132 let bit_offset = bit_index % 64;
133 self.bitmap[word_index] |= 1 << bit_offset;
134 }
135
136 fn shift_window(&mut self, shift: u64) {
142 if shift >= REPLAY_WINDOW_SIZE as u64 {
143 self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
145 return;
146 }
147
148 let shift_words = (shift / 64) as usize;
149 let shift_bits = (shift % 64) as u32;
150
151 if shift_words > 0 {
153 for i in (shift_words..self.bitmap.len()).rev() {
155 self.bitmap[i] = self.bitmap[i - shift_words];
156 }
157 for word in self.bitmap.iter_mut().take(shift_words) {
159 *word = 0;
160 }
161 }
162
163 if shift_bits > 0 {
165 let mut carry = 0u64;
166 for i in (0..self.bitmap.len()).rev() {
168 let new_carry = self.bitmap[i] >> (64 - shift_bits);
169 self.bitmap[i] = (self.bitmap[i] << shift_bits) | carry;
170 carry = new_carry;
171 }
172 }
173 }
174
175 pub fn reset(&mut self) {
177 self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
178 self.highest = 0;
179 self.initialized = false;
180 }
181}
182
183impl Default for ReplayWindow {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189pub struct CryptoSession {
194 session_id: SessionId,
196 role: Role,
198 send_key: SessionKey,
200 recv_key: SessionKey,
202 rekey_state: RekeyState,
204 replay_window: ReplayWindow,
206 old_keys: OldKeyRetention,
208 #[allow(dead_code)]
210 handshake_hash: [u8; HASH_SIZE],
211 rekey_auth_key: [u8; HASH_SIZE],
216}
217
218impl CryptoSession {
219 pub fn new(
229 session_id: SessionId,
230 role: Role,
231 send_key: SessionKey,
232 recv_key: SessionKey,
233 handshake_hash: [u8; HASH_SIZE],
234 rekey_auth_key: [u8; HASH_SIZE],
235 ) -> Self {
236 Self {
237 session_id,
238 role,
239 send_key,
240 recv_key,
241 rekey_state: RekeyState::new(),
242 replay_window: ReplayWindow::new(),
243 old_keys: OldKeyRetention::new(),
244 handshake_hash,
245 rekey_auth_key,
246 }
247 }
248
249 pub fn session_id(&self) -> &SessionId {
251 &self.session_id
252 }
253
254 pub fn role(&self) -> Role {
256 self.role
257 }
258
259 pub fn epoch(&self) -> u32 {
261 self.rekey_state.epoch()
262 }
263
264 pub fn should_rekey(&self) -> bool {
266 self.rekey_state.should_rekey()
267 }
268
269 pub fn keys_expired(&self) -> bool {
271 self.rekey_state.keys_expired()
272 }
273
274 fn send_direction(&self) -> Direction {
276 match self.role {
277 Role::Initiator => Direction::InitiatorToResponder,
278 Role::Responder => Direction::ResponderToInitiator,
279 }
280 }
281
282 fn recv_direction(&self) -> Direction {
284 self.send_direction().opposite()
285 }
286
287 pub fn encrypt_frame(
291 &mut self,
292 frame_type: u8,
293 flags: u8,
294 plaintext: &[u8],
295 ) -> Result<(u64, Vec<u8>), CryptoError> {
296 let counter = self.rekey_state.increment_send()?;
298 let nonce = construct_nonce(self.rekey_state.epoch(), self.send_direction(), counter);
299
300 let aad = construct_aad(frame_type, flags, self.session_id.as_bytes(), counter);
302
303 let ciphertext = encrypt(&self.send_key, &nonce, &aad, plaintext)?;
305
306 Ok((counter, ciphertext))
307 }
308
309 pub fn decrypt_frame(
313 &mut self,
314 frame_type: u8,
315 flags: u8,
316 nonce_counter: u64,
317 ciphertext: &[u8],
318 ) -> Result<Vec<u8>, CryptoError> {
319 if self.replay_window.is_replay(nonce_counter) {
321 return Err(CryptoError::ReplayDetected);
322 }
323
324 let nonce = construct_nonce(self.rekey_state.epoch(), self.recv_direction(), nonce_counter);
326 let aad = construct_aad(frame_type, flags, self.session_id.as_bytes(), nonce_counter);
327
328 if let Ok(plaintext) = decrypt(&self.recv_key, &nonce, &aad, ciphertext) {
330 let _ = self.replay_window.check_and_update(nonce_counter);
332 self.rekey_state.record_recv(nonce_counter);
333 return Ok(plaintext);
334 }
335
336 self.old_keys.clear_if_expired();
338 if let Some(old_recv_key) = self.get_old_recv_key() {
339 let old_epoch = self.rekey_state.epoch().saturating_sub(1);
341 let old_nonce = construct_nonce(old_epoch, self.recv_direction(), nonce_counter);
342
343 if let Ok(plaintext) = decrypt(old_recv_key, &old_nonce, &aad, ciphertext) {
344 return Ok(plaintext);
347 }
348 }
349
350 Err(CryptoError::DecryptionFailed)
351 }
352
353 fn get_old_recv_key(&self) -> Option<&SessionKey> {
355 match self.role {
356 Role::Initiator => self.old_keys.old_responder_key(),
357 Role::Responder => self.old_keys.old_initiator_key(),
358 }
359 }
360
361 pub fn rekey(&mut self, ephemeral_dh: &[u8; 32]) -> Result<(), CryptoError> {
376 use super::rekey::derive_rekey_keys;
377
378 self.old_keys
380 .retain(self.send_key.clone(), self.recv_key.clone());
381
382 self.rekey_state.advance_epoch()?;
384
385 let (new_initiator_key, new_responder_key) =
391 derive_rekey_keys(ephemeral_dh, &self.rekey_auth_key, self.rekey_state.epoch())?;
392
393 match self.role {
395 Role::Initiator => {
396 self.send_key = new_initiator_key;
397 self.recv_key = new_responder_key;
398 }
399 Role::Responder => {
400 self.send_key = new_responder_key;
401 self.recv_key = new_initiator_key;
402 }
403 }
404
405 self.replay_window.reset();
407
408 Ok(())
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_replay_window_basic() {
418 let mut window = ReplayWindow::new();
419
420 assert!(window.check_and_update(0).is_ok());
422
423 assert!(window.check_and_update(0).is_err());
425
426 assert!(window.check_and_update(1).is_ok());
428
429 assert!(window.check_and_update(5).is_ok());
431 assert!(window.check_and_update(3).is_ok());
432 assert!(window.check_and_update(4).is_ok());
433 assert!(window.check_and_update(2).is_ok());
434
435 assert!(window.check_and_update(0).is_err());
437 assert!(window.check_and_update(3).is_err());
438 assert!(window.check_and_update(5).is_err());
439 }
440
441 #[test]
442 fn test_replay_window_large_gap() {
443 let mut window = ReplayWindow::new();
444
445 assert!(window.check_and_update(0).is_ok());
446 assert!(window.check_and_update(1).is_ok());
447
448 assert!(window.check_and_update(1000).is_ok());
450
451 assert!(window.check_and_update(0).is_err());
453 assert!(window.check_and_update(1).is_err());
454
455 assert!(window.check_and_update(999).is_ok());
457 assert!(window.check_and_update(998).is_ok());
458 }
459
460 #[test]
461 fn test_replay_window_full_reset() {
462 let mut window = ReplayWindow::new();
463
464 for i in 0..100 {
465 assert!(window.check_and_update(i).is_ok());
466 }
467
468 assert!(window.check_and_update(100 + REPLAY_WINDOW_SIZE as u64).is_ok());
470
471 for i in 0..100 {
473 assert!(window.check_and_update(i).is_err());
474 }
475 }
476
477 #[test]
478 fn test_crypto_session_roundtrip() {
479 let session_id = SessionId::generate();
480 let send_key = SessionKey::from_bytes([0x01; 32]);
481 let recv_key = SessionKey::from_bytes([0x02; 32]);
482 let handshake_hash = [0x42; 32];
483 let rekey_auth_key = [0x33; 32]; let mut initiator = CryptoSession::new(
486 session_id,
487 Role::Initiator,
488 send_key.clone(),
489 recv_key.clone(),
490 handshake_hash,
491 rekey_auth_key,
492 );
493
494 let mut responder = CryptoSession::new(
495 session_id,
496 Role::Responder,
497 recv_key.clone(),
498 send_key.clone(),
499 handshake_hash,
500 rekey_auth_key,
501 );
502
503 let plaintext = b"Hello, NOMAD!";
505 let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
506
507 let decrypted = responder
509 .decrypt_frame(0x03, 0x00, counter, &ciphertext)
510 .unwrap();
511 assert_eq!(decrypted, plaintext);
512
513 let reply = b"Hello back!";
515 let (reply_counter, reply_ciphertext) =
516 responder.encrypt_frame(0x03, 0x00, reply).unwrap();
517
518 let decrypted_reply = initiator
520 .decrypt_frame(0x03, 0x00, reply_counter, &reply_ciphertext)
521 .unwrap();
522 assert_eq!(decrypted_reply, reply);
523 }
524
525 #[test]
526 fn test_crypto_session_replay_detection() {
527 let session_id = SessionId::generate();
528 let send_key = SessionKey::from_bytes([0x01; 32]);
529 let recv_key = SessionKey::from_bytes([0x02; 32]);
530 let handshake_hash = [0x42; 32];
531 let rekey_auth_key = [0x33; 32];
532
533 let mut initiator = CryptoSession::new(
534 session_id,
535 Role::Initiator,
536 send_key.clone(),
537 recv_key.clone(),
538 handshake_hash,
539 rekey_auth_key,
540 );
541
542 let mut responder = CryptoSession::new(
543 session_id,
544 Role::Responder,
545 recv_key.clone(),
546 send_key.clone(),
547 handshake_hash,
548 rekey_auth_key,
549 );
550
551 let plaintext = b"test";
552 let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
553
554 assert!(responder
556 .decrypt_frame(0x03, 0x00, counter, &ciphertext)
557 .is_ok());
558
559 assert!(responder
561 .decrypt_frame(0x03, 0x00, counter, &ciphertext)
562 .is_err());
563 }
564
565 #[test]
566 fn test_crypto_session_wrong_aad() {
567 let session_id = SessionId::generate();
568 let send_key = SessionKey::from_bytes([0x01; 32]);
569 let recv_key = SessionKey::from_bytes([0x02; 32]);
570 let handshake_hash = [0x42; 32];
571 let rekey_auth_key = [0x33; 32];
572
573 let mut initiator = CryptoSession::new(
574 session_id,
575 Role::Initiator,
576 send_key.clone(),
577 recv_key.clone(),
578 handshake_hash,
579 rekey_auth_key,
580 );
581
582 let mut responder = CryptoSession::new(
583 session_id,
584 Role::Responder,
585 recv_key.clone(),
586 send_key.clone(),
587 handshake_hash,
588 rekey_auth_key,
589 );
590
591 let plaintext = b"test";
592 let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
593
594 assert!(responder
596 .decrypt_frame(0x04, 0x00, counter, &ciphertext)
597 .is_err());
598 }
599}