nomad_protocol/crypto/
rekey.rs

1//! Session rekeying for forward secrecy
2//!
3//! Per 1-SECURITY.md, sessions MUST rekey periodically:
4//! - `REKEY_AFTER_TIME` (120s): Initiate rekey after this time
5//! - `REKEY_AFTER_MESSAGES` (2^60): Initiate rekey after this many frames
6//! - `REJECT_AFTER_TIME` (180s): Hard limit, reject old keys
7//! - `REJECT_AFTER_MESSAGES` (2^64-1): MUST terminate session
8//! - `OLD_KEY_RETENTION` (5s): Keep old keys for late packets
9
10use std::time::Instant;
11
12use hkdf::Hkdf;
13use sha2::Sha256;
14use crate::core::{
15    CryptoError, MAX_EPOCH, OLD_KEY_RETENTION, REJECT_AFTER_MESSAGES, REJECT_AFTER_TIME,
16    REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME,
17};
18use zeroize::Zeroize;
19
20use super::{SessionKey, SESSION_KEY_SIZE};
21
22/// Tracks the current key epoch and when rekeying is needed.
23#[derive(Debug)]
24pub struct RekeyState {
25    /// Current epoch number (increments on each rekey)
26    epoch: u32,
27    /// Time when current epoch started
28    epoch_start: Instant,
29    /// Number of messages sent in current epoch
30    send_count: u64,
31    /// Number of messages received in current epoch
32    recv_count: u64,
33}
34
35impl RekeyState {
36    /// Create a new rekey state starting at epoch 0.
37    pub fn new() -> Self {
38        Self {
39            epoch: 0,
40            epoch_start: Instant::now(),
41            send_count: 0,
42            recv_count: 0,
43        }
44    }
45
46    /// Get the current epoch.
47    pub fn epoch(&self) -> u32 {
48        self.epoch
49    }
50
51    /// Get the send counter for the current epoch.
52    pub fn send_count(&self) -> u64 {
53        self.send_count
54    }
55
56    /// Get the receive counter for the current epoch.
57    pub fn recv_count(&self) -> u64 {
58        self.recv_count
59    }
60
61    /// Increment the send counter.
62    ///
63    /// Returns the counter value to use for this message.
64    ///
65    /// # Errors
66    /// Returns `CounterExhaustion` if the counter has reached the hard limit.
67    pub fn increment_send(&mut self) -> Result<u64, CryptoError> {
68        if self.send_count == REJECT_AFTER_MESSAGES {
69            return Err(CryptoError::CounterExhaustion);
70        }
71        let counter = self.send_count;
72        self.send_count += 1;
73        Ok(counter)
74    }
75
76    /// Record a received message counter.
77    ///
78    /// Note: Actual replay detection is handled by the replay window.
79    pub fn record_recv(&mut self, counter: u64) {
80        if counter >= self.recv_count {
81            self.recv_count = counter + 1;
82        }
83    }
84
85    /// Check if we should initiate a rekey (soft limit reached).
86    pub fn should_rekey(&self) -> bool {
87        let time_exceeded = self.epoch_start.elapsed() >= REKEY_AFTER_TIME;
88        let messages_exceeded = self.send_count >= REKEY_AFTER_MESSAGES;
89        time_exceeded || messages_exceeded
90    }
91
92    /// Check if the current keys are expired (hard limit reached).
93    pub fn keys_expired(&self) -> bool {
94        self.epoch_start.elapsed() >= REJECT_AFTER_TIME
95    }
96
97    /// Check if we can perform another rekey (epoch limit).
98    pub fn can_rekey(&self) -> bool {
99        self.epoch < MAX_EPOCH
100    }
101
102    /// Advance to the next epoch.
103    ///
104    /// Resets counters and updates epoch start time.
105    ///
106    /// # Errors
107    /// Returns `EpochExhaustion` if the epoch counter has reached the limit.
108    pub fn advance_epoch(&mut self) -> Result<(), CryptoError> {
109        if self.epoch == MAX_EPOCH {
110            return Err(CryptoError::EpochExhaustion);
111        }
112        self.epoch += 1;
113        self.epoch_start = Instant::now();
114        self.send_count = 0;
115        self.recv_count = 0;
116        Ok(())
117    }
118}
119
120impl Default for RekeyState {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126/// Manages old keys during the transition period after a rekey.
127pub struct OldKeyRetention {
128    /// The old initiator key
129    initiator_key: Option<SessionKey>,
130    /// The old responder key
131    responder_key: Option<SessionKey>,
132    /// When the old keys were retained
133    retained_at: Option<Instant>,
134}
135
136impl OldKeyRetention {
137    /// Create a new retention manager with no old keys.
138    pub fn new() -> Self {
139        Self {
140            initiator_key: None,
141            responder_key: None,
142            retained_at: None,
143        }
144    }
145
146    /// Retain the current keys as old keys.
147    pub fn retain(&mut self, initiator_key: SessionKey, responder_key: SessionKey) {
148        self.initiator_key = Some(initiator_key);
149        self.responder_key = Some(responder_key);
150        self.retained_at = Some(Instant::now());
151    }
152
153    /// Get the old initiator key if still within retention window.
154    pub fn old_initiator_key(&self) -> Option<&SessionKey> {
155        if self.within_retention_window() {
156            self.initiator_key.as_ref()
157        } else {
158            None
159        }
160    }
161
162    /// Get the old responder key if still within retention window.
163    pub fn old_responder_key(&self) -> Option<&SessionKey> {
164        if self.within_retention_window() {
165            self.responder_key.as_ref()
166        } else {
167            None
168        }
169    }
170
171    /// Check if we're within the retention window.
172    pub fn within_retention_window(&self) -> bool {
173        self.retained_at
174            .is_some_and(|t| t.elapsed() < OLD_KEY_RETENTION)
175    }
176
177    /// Clear old keys (call after retention window expires or explicitly).
178    pub fn clear(&mut self) {
179        self.initiator_key = None;
180        self.responder_key = None;
181        self.retained_at = None;
182    }
183
184    /// Check if old keys should be cleared due to expired retention.
185    pub fn should_clear(&self) -> bool {
186        self.retained_at
187            .is_some_and(|t| t.elapsed() >= OLD_KEY_RETENTION)
188    }
189
190    /// Clear old keys if retention has expired.
191    pub fn clear_if_expired(&mut self) {
192        if self.should_clear() {
193            self.clear();
194        }
195    }
196}
197
198impl Default for OldKeyRetention {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Derive new session keys after a rekey with PCS protection.
205///
206/// This function provides Post-Compromise Security (PCS) by mixing in the
207/// `rekey_auth_key` which is derived from the static DH during handshake.
208/// An attacker who compromises session keys cannot derive future rekey keys
209/// without knowing the static DH secret.
210///
211/// Per 1-SECURITY.md (PCS fix):
212/// ```text
213/// ikm = ephemeral_dh || rekey_auth_key
214/// (new_initiator_key, new_responder_key) = HKDF-Expand(
215///     ikm,
216///     "nomad v1 rekey" || LE32(epoch),
217///     64
218/// )
219/// ```
220///
221/// # Arguments
222/// * `ephemeral_dh` - The DH result from the rekey ephemeral exchange
223/// * `rekey_auth_key` - Key derived from static DH during initial handshake
224/// * `epoch` - The new epoch number
225pub fn derive_rekey_keys(
226    ephemeral_dh: &[u8; 32],
227    rekey_auth_key: &[u8; 32],
228    epoch: u32,
229) -> Result<(SessionKey, SessionKey), CryptoError> {
230    // Concatenate ephemeral_dh || rekey_auth_key as IKM
231    // This ensures PCS: attacker needs fresh ephemeral DH AND static DH secret
232    let mut ikm = [0u8; 64];
233    ikm[..32].copy_from_slice(ephemeral_dh);
234    ikm[32..].copy_from_slice(rekey_auth_key);
235
236    // Build info: "nomad v1 rekey" || LE32(epoch)
237    let label = b"nomad v1 rekey";
238    let epoch_bytes = epoch.to_le_bytes();
239    let mut info = Vec::with_capacity(label.len() + 4);
240    info.extend_from_slice(label);
241    info.extend_from_slice(&epoch_bytes);
242
243    // HKDF-Expand only (no Extract step) with SHA-256
244    // The ikm is treated as a PRK directly per the spec
245    let hk = Hkdf::<Sha256>::from_prk(&ikm)
246        .map_err(|_| CryptoError::KeyDerivationFailed)?;
247    let mut key_material = [0u8; 64];
248    hk.expand(&info, &mut key_material)
249        .map_err(|_| CryptoError::KeyDerivationFailed)?;
250
251    let mut initiator_key = [0u8; SESSION_KEY_SIZE];
252    let mut responder_key = [0u8; SESSION_KEY_SIZE];
253    initiator_key.copy_from_slice(&key_material[..32]);
254    responder_key.copy_from_slice(&key_material[32..]);
255
256    // Zeroize intermediate material
257    ikm.zeroize();
258    key_material.zeroize();
259
260    Ok((
261        SessionKey::from_bytes(initiator_key),
262        SessionKey::from_bytes(responder_key),
263    ))
264}
265
266/// Derive the rekey authentication key from static DH secret.
267///
268/// This key is derived during handshake completion and used for PCS.
269/// It ensures that even if session keys are compromised, an attacker
270/// cannot derive future rekey keys without the static DH secret.
271///
272/// Per 1-SECURITY.md (PCS fix):
273/// ```text
274/// rekey_auth_key = HKDF-Expand(
275///     static_dh_secret,   // DH(s_initiator, S_responder)
276///     "nomad v1 rekey auth",
277///     32
278/// )
279/// ```
280pub fn derive_rekey_auth_key(static_dh_secret: &[u8; 32]) -> [u8; 32] {
281    let info = b"nomad v1 rekey auth";
282
283    // HKDF-Expand only (no Extract step) with SHA-256
284    // The static_dh_secret is treated as a PRK directly per the spec
285    let hk = Hkdf::<Sha256>::from_prk(static_dh_secret)
286        .expect("32 bytes is valid PRK length for SHA-256 HKDF");
287    let mut rekey_auth_key = [0u8; 32];
288    hk.expand(info, &mut rekey_auth_key)
289        .expect("32 bytes is valid output length for SHA-256 HKDF");
290
291    rekey_auth_key
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_rekey_state_new() {
300        let state = RekeyState::new();
301        assert_eq!(state.epoch(), 0);
302        assert_eq!(state.send_count(), 0);
303        assert_eq!(state.recv_count(), 0);
304        assert!(!state.should_rekey());
305        assert!(!state.keys_expired());
306        assert!(state.can_rekey());
307    }
308
309    #[test]
310    fn test_increment_send() {
311        let mut state = RekeyState::new();
312
313        for i in 0..10 {
314            let counter = state.increment_send().unwrap();
315            assert_eq!(counter, i);
316        }
317        assert_eq!(state.send_count(), 10);
318    }
319
320    #[test]
321    fn test_record_recv() {
322        let mut state = RekeyState::new();
323
324        state.record_recv(5);
325        assert_eq!(state.recv_count(), 6); // max + 1
326
327        state.record_recv(3); // Out of order, should not decrease
328        assert_eq!(state.recv_count(), 6);
329
330        state.record_recv(10);
331        assert_eq!(state.recv_count(), 11);
332    }
333
334    #[test]
335    fn test_advance_epoch() {
336        let mut state = RekeyState::new();
337        state.increment_send().unwrap();
338        state.increment_send().unwrap();
339
340        state.advance_epoch().unwrap();
341
342        assert_eq!(state.epoch(), 1);
343        assert_eq!(state.send_count(), 0);
344        assert_eq!(state.recv_count(), 0);
345    }
346
347    #[test]
348    fn test_old_key_retention() {
349        let mut retention = OldKeyRetention::new();
350
351        assert!(retention.old_initiator_key().is_none());
352        assert!(!retention.within_retention_window());
353
354        let key1 = SessionKey::from_bytes([0x01; SESSION_KEY_SIZE]);
355        let key2 = SessionKey::from_bytes([0x02; SESSION_KEY_SIZE]);
356
357        retention.retain(key1, key2);
358
359        assert!(retention.within_retention_window());
360        assert!(retention.old_initiator_key().is_some());
361        assert!(retention.old_responder_key().is_some());
362    }
363
364    #[test]
365    fn test_derive_rekey_keys() {
366        let ephemeral_dh = [0x42u8; 32];
367        let rekey_auth_key = [0x33u8; 32];
368
369        let (key1_epoch0, key2_epoch0) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 0).unwrap();
370        let (key1_epoch1, key2_epoch1) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 1).unwrap();
371
372        // Different epochs should produce different keys
373        assert_ne!(key1_epoch0.as_bytes(), key1_epoch1.as_bytes());
374        assert_ne!(key2_epoch0.as_bytes(), key2_epoch1.as_bytes());
375
376        // Same epoch should produce same keys
377        let (key1_epoch0_again, key2_epoch0_again) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 0).unwrap();
378        assert_eq!(key1_epoch0.as_bytes(), key1_epoch0_again.as_bytes());
379        assert_eq!(key2_epoch0.as_bytes(), key2_epoch0_again.as_bytes());
380    }
381
382    #[test]
383    fn test_derive_rekey_keys_different_ephemeral_dh() {
384        let ephemeral_dh1 = [0x01u8; 32];
385        let ephemeral_dh2 = [0x02u8; 32];
386        let rekey_auth_key = [0x33u8; 32];
387
388        let (key1_dh1, _) = derive_rekey_keys(&ephemeral_dh1, &rekey_auth_key, 0).unwrap();
389        let (key1_dh2, _) = derive_rekey_keys(&ephemeral_dh2, &rekey_auth_key, 0).unwrap();
390
391        // Different ephemeral DH should produce different keys
392        assert_ne!(key1_dh1.as_bytes(), key1_dh2.as_bytes());
393    }
394
395    #[test]
396    fn test_derive_rekey_keys_pcs() {
397        // Test that different rekey_auth_keys produce different rekey keys
398        // This verifies the PCS property
399        let ephemeral_dh = [0x42u8; 32];
400        let auth_key1 = [0x01u8; 32];
401        let auth_key2 = [0x02u8; 32];
402
403        let (key1_auth1, _) = derive_rekey_keys(&ephemeral_dh, &auth_key1, 0).unwrap();
404        let (key1_auth2, _) = derive_rekey_keys(&ephemeral_dh, &auth_key2, 0).unwrap();
405
406        // Different rekey_auth_keys should produce different keys
407        // This is the core PCS property: knowing session keys but not rekey_auth_key
408        // means you cannot derive future keys
409        assert_ne!(key1_auth1.as_bytes(), key1_auth2.as_bytes());
410    }
411
412    #[test]
413    fn test_derive_rekey_auth_key() {
414        let static_dh1 = [0x01u8; 32];
415        let static_dh2 = [0x02u8; 32];
416
417        let auth_key1 = derive_rekey_auth_key(&static_dh1);
418        let auth_key2 = derive_rekey_auth_key(&static_dh2);
419
420        // Different static DH secrets should produce different auth keys
421        assert_ne!(auth_key1, auth_key2);
422
423        // Same static DH secret should produce same auth key (deterministic)
424        let auth_key1_again = derive_rekey_auth_key(&static_dh1);
425        assert_eq!(auth_key1, auth_key1_again);
426    }
427
428    #[test]
429    fn test_pcs_property() {
430        // Simulate the PCS attack scenario:
431        // Attacker knows: ephemeral_dh for the rekey
432        // Attacker doesn't know: rekey_auth_key (derived from static DH)
433        // Attacker cannot derive: new rekey keys
434
435        let ephemeral_dh = [0x42u8; 32];
436        let real_static_dh = [0xABu8; 32];
437        let attacker_guess_dh = [0xCDu8; 32];
438
439        let real_auth_key = derive_rekey_auth_key(&real_static_dh);
440        let attacker_auth_key = derive_rekey_auth_key(&attacker_guess_dh);
441
442        // Real keys for epoch 1
443        let (real_key1, _) = derive_rekey_keys(&ephemeral_dh, &real_auth_key, 1).unwrap();
444
445        // Attacker's attempt at epoch 1 keys (with wrong auth key)
446        let (attacker_key1, _) = derive_rekey_keys(&ephemeral_dh, &attacker_auth_key, 1).unwrap();
447
448        // Keys must be different - attacker cannot derive the correct keys
449        assert_ne!(real_key1.as_bytes(), attacker_key1.as_bytes());
450    }
451
452    // ===== Test Vector Validation =====
453    // These tests validate against the official NOMAD protocol test vectors
454    // from nomad-specs/tests/vectors/rekey_vectors.json5
455
456    /// Helper to decode hex string to bytes
457    fn hex_to_bytes(hex: &str) -> Vec<u8> {
458        (0..hex.len())
459            .step_by(2)
460            .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
461            .collect()
462    }
463
464    #[test]
465    fn test_vector_rekey_auth_key() {
466        // From intermediate_values in rekey_vectors.json5
467        let static_dh = hex_to_bytes("57fbeea357c6ca4af3654988d78e020ccc6f4bc56db385bff4a46084b1187266");
468        let expected_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
469
470        let mut static_dh_arr = [0u8; 32];
471        static_dh_arr.copy_from_slice(&static_dh);
472
473        let auth_key = derive_rekey_auth_key(&static_dh_arr);
474
475        assert_eq!(
476            auth_key.as_slice(),
477            expected_auth_key.as_slice(),
478            "rekey_auth_key derivation doesn't match test vector"
479        );
480    }
481
482    #[test]
483    fn test_vector_epoch_1() {
484        // epoch_0_to_1 vector from rekey_vectors.json5
485        let ephemeral_dh = hex_to_bytes("813c560b94aec760c9a8d12a09bb4c2be3bfc35eb6983ceb264a13046d3aaa75");
486        let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
487        let expected_initiator_key = hex_to_bytes("ba7ba9959a0338866994033dc46c15df92e6a08b4d5041d5e52070001187c312");
488        let expected_responder_key = hex_to_bytes("91f2e4123a04abe6343003d6ff5793af7aae75ede7fdc6737aaf24964d9285f8");
489
490        let mut ephemeral_dh_arr = [0u8; 32];
491        let mut rekey_auth_key_arr = [0u8; 32];
492        ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
493        rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
494
495        let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 1).unwrap();
496
497        assert_eq!(
498            initiator_key.as_bytes(),
499            expected_initiator_key.as_slice(),
500            "epoch 1 initiator key doesn't match test vector"
501        );
502        assert_eq!(
503            responder_key.as_bytes(),
504            expected_responder_key.as_slice(),
505            "epoch 1 responder key doesn't match test vector"
506        );
507    }
508
509    #[test]
510    fn test_vector_epoch_2() {
511        // epoch_1_to_2_pcs_case vector from rekey_vectors.json5
512        let ephemeral_dh = hex_to_bytes("7efd5673c47236ad6f9bf85e945074615c1943c528a87cc0dc9084ad278d266e");
513        let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
514        let expected_initiator_key = hex_to_bytes("206c3c4f0838aaf5b039bad2ecd1a387d6f784afbf1d283dc0a438ad45f4db3e");
515        let expected_responder_key = hex_to_bytes("786554075c38e73a735b26cbfd650c9fd0f8909227e498487007fc2adfec661d");
516
517        let mut ephemeral_dh_arr = [0u8; 32];
518        let mut rekey_auth_key_arr = [0u8; 32];
519        ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
520        rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
521
522        let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 2).unwrap();
523
524        assert_eq!(
525            initiator_key.as_bytes(),
526            expected_initiator_key.as_slice(),
527            "epoch 2 initiator key doesn't match test vector"
528        );
529        assert_eq!(
530            responder_key.as_bytes(),
531            expected_responder_key.as_slice(),
532            "epoch 2 responder key doesn't match test vector"
533        );
534    }
535
536    #[test]
537    fn test_vector_epoch_100() {
538        // epoch_high_number vector from rekey_vectors.json5
539        let ephemeral_dh = hex_to_bytes("0038038a95c66833de6cd4a4743226d03d952d35d1885876f63b95deea271e3f");
540        let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
541        let expected_initiator_key = hex_to_bytes("dda7dd785c4c5f75096c0ea88023b1558e26bb84f4c4eb72ba7977c6947abc1a");
542        let expected_responder_key = hex_to_bytes("110c7c42998204153892f1ac84634c355ed1b279174befd2f27936073567e54f");
543
544        let mut ephemeral_dh_arr = [0u8; 32];
545        let mut rekey_auth_key_arr = [0u8; 32];
546        ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
547        rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
548
549        let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 100).unwrap();
550
551        assert_eq!(
552            initiator_key.as_bytes(),
553            expected_initiator_key.as_slice(),
554            "epoch 100 initiator key doesn't match test vector"
555        );
556        assert_eq!(
557            responder_key.as_bytes(),
558            expected_responder_key.as_slice(),
559            "epoch 100 responder key doesn't match test vector"
560        );
561    }
562}