nomad_protocol/crypto/
session.rs

1//! Crypto session management with anti-replay protection
2//!
3//! This module combines all cryptographic primitives into a high-level
4//! CryptoSession that handles:
5//! - Sending and receiving encrypted frames
6//! - Nonce management
7//! - Anti-replay protection via sliding window
8//! - Epoch/counter tracking
9
10use crate::core::{CryptoError, HASH_SIZE, REPLAY_WINDOW_SIZE};
11
12use super::{
13    aead::{construct_aad, decrypt, encrypt, SessionKey},
14    nonce::{construct_nonce, Direction},
15    rekey::{OldKeyRetention, RekeyState},
16    Role, SessionId,
17};
18
19/// Anti-replay sliding window.
20///
21/// Per 1-SECURITY.md:
22/// - Window size: 2048 bits minimum
23/// - Below window: MUST reject
24/// - Seen nonce: MUST reject
25/// - Above highest: Update window
26pub struct ReplayWindow {
27    /// Bitmap for tracking seen nonces
28    bitmap: [u64; REPLAY_WINDOW_SIZE / 64],
29    /// Highest nonce seen so far
30    highest: u64,
31    /// Whether we've seen any packets yet
32    initialized: bool,
33}
34
35impl ReplayWindow {
36    /// Create a new replay window.
37    pub fn new() -> Self {
38        Self {
39            bitmap: [0; REPLAY_WINDOW_SIZE / 64],
40            highest: 0,
41            initialized: false,
42        }
43    }
44
45    /// Check if a nonce is a replay (without updating).
46    pub fn is_replay(&self, nonce: u64) -> bool {
47        if !self.initialized {
48            return false;
49        }
50
51        if nonce > self.highest {
52            return false;
53        }
54
55        let diff = self.highest - nonce;
56        if diff >= REPLAY_WINDOW_SIZE as u64 {
57            return true; // Below window
58        }
59
60        let bit_index = diff as usize;
61        let word_index = bit_index / 64;
62        let bit_offset = bit_index % 64;
63        (self.bitmap[word_index] & (1 << bit_offset)) != 0
64    }
65
66    /// Check if a nonce is a replay and update the window.
67    ///
68    /// Returns Ok(()) if the nonce is valid (not seen before).
69    /// Returns Err(ReplayDetected) if the nonce is a replay.
70    ///
71    /// Per 1-SECURITY.md, replay check MUST occur BEFORE AEAD verification.
72    pub fn check_and_update(&mut self, nonce: u64) -> Result<(), CryptoError> {
73        if !self.initialized {
74            // First packet - initialize
75            self.highest = nonce;
76            self.mark_seen(nonce);
77            self.initialized = true;
78            return Ok(());
79        }
80
81        if nonce > self.highest {
82            // Advance the window
83            let shift = nonce - self.highest;
84            self.shift_window(shift);
85            self.highest = nonce;
86            self.mark_seen(nonce);
87            Ok(())
88        } else {
89            let diff = self.highest - nonce;
90            if diff >= REPLAY_WINDOW_SIZE as u64 {
91                // Too old - below window
92                return Err(CryptoError::ReplayDetected);
93            }
94
95            // Check if already seen
96            if self.is_seen(nonce) {
97                return Err(CryptoError::ReplayDetected);
98            }
99
100            // Mark as seen
101            self.mark_seen(nonce);
102            Ok(())
103        }
104    }
105
106    /// Check if a nonce has been seen (internal helper).
107    fn is_seen(&self, nonce: u64) -> bool {
108        if nonce > self.highest {
109            return false;
110        }
111        let diff = self.highest - nonce;
112        if diff >= REPLAY_WINDOW_SIZE as u64 {
113            return true; // Treat below-window as "seen" (rejected)
114        }
115        let bit_index = diff as usize;
116        let word_index = bit_index / 64;
117        let bit_offset = bit_index % 64;
118        (self.bitmap[word_index] & (1 << bit_offset)) != 0
119    }
120
121    /// Mark a nonce as seen.
122    fn mark_seen(&mut self, nonce: u64) {
123        if nonce > self.highest {
124            return; // Will be marked after shift
125        }
126        let diff = self.highest - nonce;
127        if diff >= REPLAY_WINDOW_SIZE as u64 {
128            return; // Too old
129        }
130        let bit_index = diff as usize;
131        let word_index = bit_index / 64;
132        let bit_offset = bit_index % 64;
133        self.bitmap[word_index] |= 1 << bit_offset;
134    }
135
136    /// Shift the window forward.
137    ///
138    /// When we receive a new highest nonce, we need to shift all existing bits
139    /// to make room for the new highest at position 0.
140    /// Bit position represents (highest - nonce), so older nonces have higher bit positions.
141    fn shift_window(&mut self, shift: u64) {
142        if shift >= REPLAY_WINDOW_SIZE as u64 {
143            // Complete reset - all previous nonces fall outside the window
144            self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
145            return;
146        }
147
148        let shift_words = (shift / 64) as usize;
149        let shift_bits = (shift % 64) as u32;
150
151        // Shift whole words (towards higher indices = older nonces)
152        if shift_words > 0 {
153            // Shift from high to low to avoid overwriting
154            for i in (shift_words..self.bitmap.len()).rev() {
155                self.bitmap[i] = self.bitmap[i - shift_words];
156            }
157            // Clear the newly freed low words
158            for word in self.bitmap.iter_mut().take(shift_words) {
159                *word = 0;
160            }
161        }
162
163        // Shift remaining bits within words (towards higher bit positions)
164        if shift_bits > 0 {
165            let mut carry = 0u64;
166            // Process from highest index to lowest (shift bits up within each word)
167            for i in (0..self.bitmap.len()).rev() {
168                let new_carry = self.bitmap[i] >> (64 - shift_bits);
169                self.bitmap[i] = (self.bitmap[i] << shift_bits) | carry;
170                carry = new_carry;
171            }
172        }
173    }
174
175    /// Reset the window (e.g., after rekey).
176    pub fn reset(&mut self) {
177        self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
178        self.highest = 0;
179        self.initialized = false;
180    }
181}
182
183impl Default for ReplayWindow {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189/// A complete crypto session for secure communication.
190///
191/// Combines key management, nonce construction, AEAD, and anti-replay
192/// into a single interface.
193pub struct CryptoSession {
194    /// Session ID
195    session_id: SessionId,
196    /// Our role (initiator or responder)
197    role: Role,
198    /// Current send key
199    send_key: SessionKey,
200    /// Current receive key
201    recv_key: SessionKey,
202    /// Rekey state (epoch, counters)
203    rekey_state: RekeyState,
204    /// Replay window for incoming packets
205    replay_window: ReplayWindow,
206    /// Old key retention for late packets during rekey
207    old_keys: OldKeyRetention,
208    /// Handshake hash for key derivation (kept for session resumption)
209    #[allow(dead_code)]
210    handshake_hash: [u8; HASH_SIZE],
211    /// Rekey authentication key for PCS (derived from static DH).
212    /// This key is mixed into rekey KDF to ensure post-compromise security.
213    /// Without this key, an attacker who compromises session keys cannot
214    /// derive future rekey keys.
215    rekey_auth_key: [u8; HASH_SIZE],
216}
217
218impl CryptoSession {
219    /// Create a new crypto session after handshake completion.
220    ///
221    /// # Arguments
222    /// * `session_id` - Unique session identifier
223    /// * `role` - Our role (initiator or responder)
224    /// * `send_key` - Initial send key
225    /// * `recv_key` - Initial receive key
226    /// * `handshake_hash` - Hash of the handshake transcript
227    /// * `rekey_auth_key` - Key derived from static DH for PCS during rekey
228    pub fn new(
229        session_id: SessionId,
230        role: Role,
231        send_key: SessionKey,
232        recv_key: SessionKey,
233        handshake_hash: [u8; HASH_SIZE],
234        rekey_auth_key: [u8; HASH_SIZE],
235    ) -> Self {
236        Self {
237            session_id,
238            role,
239            send_key,
240            recv_key,
241            rekey_state: RekeyState::new(),
242            replay_window: ReplayWindow::new(),
243            old_keys: OldKeyRetention::new(),
244            handshake_hash,
245            rekey_auth_key,
246        }
247    }
248
249    /// Get the session ID.
250    pub fn session_id(&self) -> &SessionId {
251        &self.session_id
252    }
253
254    /// Get the current role.
255    pub fn role(&self) -> Role {
256        self.role
257    }
258
259    /// Get the current epoch.
260    pub fn epoch(&self) -> u32 {
261        self.rekey_state.epoch()
262    }
263
264    /// Check if we should initiate a rekey.
265    pub fn should_rekey(&self) -> bool {
266        self.rekey_state.should_rekey()
267    }
268
269    /// Check if keys are expired (session must terminate).
270    pub fn keys_expired(&self) -> bool {
271        self.rekey_state.keys_expired()
272    }
273
274    /// Get the direction for sending based on our role.
275    fn send_direction(&self) -> Direction {
276        match self.role {
277            Role::Initiator => Direction::InitiatorToResponder,
278            Role::Responder => Direction::ResponderToInitiator,
279        }
280    }
281
282    /// Get the direction for receiving based on our role.
283    fn recv_direction(&self) -> Direction {
284        self.send_direction().opposite()
285    }
286
287    /// Encrypt a frame for sending.
288    ///
289    /// Returns (nonce_counter, ciphertext).
290    pub fn encrypt_frame(
291        &mut self,
292        frame_type: u8,
293        flags: u8,
294        plaintext: &[u8],
295    ) -> Result<(u64, Vec<u8>), CryptoError> {
296        // Get counter and construct nonce
297        let counter = self.rekey_state.increment_send()?;
298        let nonce = construct_nonce(self.rekey_state.epoch(), self.send_direction(), counter);
299
300        // Construct AAD
301        let aad = construct_aad(frame_type, flags, self.session_id.as_bytes(), counter);
302
303        // Encrypt
304        let ciphertext = encrypt(&self.send_key, &nonce, &aad, plaintext)?;
305
306        Ok((counter, ciphertext))
307    }
308
309    /// Decrypt a received frame.
310    ///
311    /// Performs replay check BEFORE decryption per spec.
312    pub fn decrypt_frame(
313        &mut self,
314        frame_type: u8,
315        flags: u8,
316        nonce_counter: u64,
317        ciphertext: &[u8],
318    ) -> Result<Vec<u8>, CryptoError> {
319        // 1. Replay check FIRST (cheap, prevents DoS)
320        if self.replay_window.is_replay(nonce_counter) {
321            return Err(CryptoError::ReplayDetected);
322        }
323
324        // Construct nonce and AAD
325        let nonce = construct_nonce(self.rekey_state.epoch(), self.recv_direction(), nonce_counter);
326        let aad = construct_aad(frame_type, flags, self.session_id.as_bytes(), nonce_counter);
327
328        // 2. Try current keys first
329        if let Ok(plaintext) = decrypt(&self.recv_key, &nonce, &aad, ciphertext) {
330            // 3. Update replay window only after successful verification
331            let _ = self.replay_window.check_and_update(nonce_counter);
332            self.rekey_state.record_recv(nonce_counter);
333            return Ok(plaintext);
334        }
335
336        // 4. Try old keys if within retention window
337        self.old_keys.clear_if_expired();
338        if let Some(old_recv_key) = self.get_old_recv_key() {
339            // Try with previous epoch's nonce
340            let old_epoch = self.rekey_state.epoch().saturating_sub(1);
341            let old_nonce = construct_nonce(old_epoch, self.recv_direction(), nonce_counter);
342
343            if let Ok(plaintext) = decrypt(old_recv_key, &old_nonce, &aad, ciphertext) {
344                // Note: Don't update replay window for old epoch packets
345                // (they have their own counter space)
346                return Ok(plaintext);
347            }
348        }
349
350        Err(CryptoError::DecryptionFailed)
351    }
352
353    /// Get the old receive key based on role.
354    fn get_old_recv_key(&self) -> Option<&SessionKey> {
355        match self.role {
356            Role::Initiator => self.old_keys.old_responder_key(),
357            Role::Responder => self.old_keys.old_initiator_key(),
358        }
359    }
360
361    /// Perform a rekey operation with the given ephemeral DH result.
362    ///
363    /// Advances the epoch and derives new keys using PCS-secure derivation.
364    /// The caller is responsible for performing the ephemeral key exchange
365    /// and computing the DH shared secret.
366    ///
367    /// # Arguments
368    /// * `ephemeral_dh` - The result of DH(my_ephemeral, their_ephemeral_public)
369    ///
370    /// # Security
371    /// The rekey_auth_key (derived from static DH during handshake) is mixed
372    /// into the KDF along with ephemeral_dh. This ensures:
373    /// - Forward secrecy from the fresh ephemeral exchange
374    /// - Post-compromise security from the static DH-derived auth key
375    pub fn rekey(&mut self, ephemeral_dh: &[u8; 32]) -> Result<(), CryptoError> {
376        use super::rekey::derive_rekey_keys;
377
378        // Retain current keys
379        self.old_keys
380            .retain(self.send_key.clone(), self.recv_key.clone());
381
382        // Advance epoch
383        self.rekey_state.advance_epoch()?;
384
385        // Derive new keys with PCS protection
386        // IKM = ephemeral_dh || rekey_auth_key
387        // This ensures that an attacker needs BOTH:
388        // 1. To intercept the ephemeral exchange (forward secrecy)
389        // 2. To know the static DH secret (post-compromise security)
390        let (new_initiator_key, new_responder_key) =
391            derive_rekey_keys(ephemeral_dh, &self.rekey_auth_key, self.rekey_state.epoch())?;
392
393        // Update keys based on role
394        match self.role {
395            Role::Initiator => {
396                self.send_key = new_initiator_key;
397                self.recv_key = new_responder_key;
398            }
399            Role::Responder => {
400                self.send_key = new_responder_key;
401                self.recv_key = new_initiator_key;
402            }
403        }
404
405        // Reset replay window for new epoch
406        self.replay_window.reset();
407
408        Ok(())
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_replay_window_basic() {
418        let mut window = ReplayWindow::new();
419
420        // First packet should succeed
421        assert!(window.check_and_update(0).is_ok());
422
423        // Same packet should fail (replay)
424        assert!(window.check_and_update(0).is_err());
425
426        // New packet should succeed
427        assert!(window.check_and_update(1).is_ok());
428
429        // Out of order but in window should succeed
430        assert!(window.check_and_update(5).is_ok());
431        assert!(window.check_and_update(3).is_ok());
432        assert!(window.check_and_update(4).is_ok());
433        assert!(window.check_and_update(2).is_ok());
434
435        // All replays
436        assert!(window.check_and_update(0).is_err());
437        assert!(window.check_and_update(3).is_err());
438        assert!(window.check_and_update(5).is_err());
439    }
440
441    #[test]
442    fn test_replay_window_large_gap() {
443        let mut window = ReplayWindow::new();
444
445        assert!(window.check_and_update(0).is_ok());
446        assert!(window.check_and_update(1).is_ok());
447
448        // Large jump
449        assert!(window.check_and_update(1000).is_ok());
450
451        // Old packets now below window
452        assert!(window.check_and_update(0).is_err());
453        assert!(window.check_and_update(1).is_err());
454
455        // But recent packets in window should still work
456        assert!(window.check_and_update(999).is_ok());
457        assert!(window.check_and_update(998).is_ok());
458    }
459
460    #[test]
461    fn test_replay_window_full_reset() {
462        let mut window = ReplayWindow::new();
463
464        for i in 0..100 {
465            assert!(window.check_and_update(i).is_ok());
466        }
467
468        // Jump beyond window size
469        assert!(window.check_and_update(100 + REPLAY_WINDOW_SIZE as u64).is_ok());
470
471        // All previous should be below window
472        for i in 0..100 {
473            assert!(window.check_and_update(i).is_err());
474        }
475    }
476
477    #[test]
478    fn test_crypto_session_roundtrip() {
479        let session_id = SessionId::generate();
480        let send_key = SessionKey::from_bytes([0x01; 32]);
481        let recv_key = SessionKey::from_bytes([0x02; 32]);
482        let handshake_hash = [0x42; 32];
483        let rekey_auth_key = [0x33; 32]; // PCS rekey authentication key
484
485        let mut initiator = CryptoSession::new(
486            session_id,
487            Role::Initiator,
488            send_key.clone(),
489            recv_key.clone(),
490            handshake_hash,
491            rekey_auth_key,
492        );
493
494        let mut responder = CryptoSession::new(
495            session_id,
496            Role::Responder,
497            recv_key.clone(),
498            send_key.clone(),
499            handshake_hash,
500            rekey_auth_key,
501        );
502
503        // Initiator sends
504        let plaintext = b"Hello, NOMAD!";
505        let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
506
507        // Responder receives
508        let decrypted = responder
509            .decrypt_frame(0x03, 0x00, counter, &ciphertext)
510            .unwrap();
511        assert_eq!(decrypted, plaintext);
512
513        // Responder sends back
514        let reply = b"Hello back!";
515        let (reply_counter, reply_ciphertext) =
516            responder.encrypt_frame(0x03, 0x00, reply).unwrap();
517
518        // Initiator receives
519        let decrypted_reply = initiator
520            .decrypt_frame(0x03, 0x00, reply_counter, &reply_ciphertext)
521            .unwrap();
522        assert_eq!(decrypted_reply, reply);
523    }
524
525    #[test]
526    fn test_crypto_session_replay_detection() {
527        let session_id = SessionId::generate();
528        let send_key = SessionKey::from_bytes([0x01; 32]);
529        let recv_key = SessionKey::from_bytes([0x02; 32]);
530        let handshake_hash = [0x42; 32];
531        let rekey_auth_key = [0x33; 32];
532
533        let mut initiator = CryptoSession::new(
534            session_id,
535            Role::Initiator,
536            send_key.clone(),
537            recv_key.clone(),
538            handshake_hash,
539            rekey_auth_key,
540        );
541
542        let mut responder = CryptoSession::new(
543            session_id,
544            Role::Responder,
545            recv_key.clone(),
546            send_key.clone(),
547            handshake_hash,
548            rekey_auth_key,
549        );
550
551        let plaintext = b"test";
552        let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
553
554        // First receive succeeds
555        assert!(responder
556            .decrypt_frame(0x03, 0x00, counter, &ciphertext)
557            .is_ok());
558
559        // Replay should fail
560        assert!(responder
561            .decrypt_frame(0x03, 0x00, counter, &ciphertext)
562            .is_err());
563    }
564
565    #[test]
566    fn test_crypto_session_wrong_aad() {
567        let session_id = SessionId::generate();
568        let send_key = SessionKey::from_bytes([0x01; 32]);
569        let recv_key = SessionKey::from_bytes([0x02; 32]);
570        let handshake_hash = [0x42; 32];
571        let rekey_auth_key = [0x33; 32];
572
573        let mut initiator = CryptoSession::new(
574            session_id,
575            Role::Initiator,
576            send_key.clone(),
577            recv_key.clone(),
578            handshake_hash,
579            rekey_auth_key,
580        );
581
582        let mut responder = CryptoSession::new(
583            session_id,
584            Role::Responder,
585            recv_key.clone(),
586            send_key.clone(),
587            handshake_hash,
588            rekey_auth_key,
589        );
590
591        let plaintext = b"test";
592        let (counter, ciphertext) = initiator.encrypt_frame(0x03, 0x00, plaintext).unwrap();
593
594        // Wrong frame type should fail
595        assert!(responder
596            .decrypt_frame(0x04, 0x00, counter, &ciphertext)
597            .is_err());
598    }
599}