1use std::time::Instant;
11
12use hkdf::Hkdf;
13use sha2::Sha256;
14use crate::core::{
15 CryptoError, MAX_EPOCH, OLD_KEY_RETENTION, REJECT_AFTER_MESSAGES, REJECT_AFTER_TIME,
16 REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME,
17};
18use zeroize::Zeroize;
19
20use super::{SessionKey, SESSION_KEY_SIZE};
21
22#[derive(Debug)]
24pub struct RekeyState {
25 epoch: u32,
27 epoch_start: Instant,
29 send_count: u64,
31 recv_count: u64,
33}
34
35impl RekeyState {
36 pub fn new() -> Self {
38 Self {
39 epoch: 0,
40 epoch_start: Instant::now(),
41 send_count: 0,
42 recv_count: 0,
43 }
44 }
45
46 pub fn epoch(&self) -> u32 {
48 self.epoch
49 }
50
51 pub fn send_count(&self) -> u64 {
53 self.send_count
54 }
55
56 pub fn recv_count(&self) -> u64 {
58 self.recv_count
59 }
60
61 pub fn increment_send(&mut self) -> Result<u64, CryptoError> {
68 if self.send_count == REJECT_AFTER_MESSAGES {
69 return Err(CryptoError::CounterExhaustion);
70 }
71 let counter = self.send_count;
72 self.send_count += 1;
73 Ok(counter)
74 }
75
76 pub fn record_recv(&mut self, counter: u64) {
80 if counter >= self.recv_count {
81 self.recv_count = counter + 1;
82 }
83 }
84
85 pub fn should_rekey(&self) -> bool {
87 let time_exceeded = self.epoch_start.elapsed() >= REKEY_AFTER_TIME;
88 let messages_exceeded = self.send_count >= REKEY_AFTER_MESSAGES;
89 time_exceeded || messages_exceeded
90 }
91
92 pub fn keys_expired(&self) -> bool {
94 self.epoch_start.elapsed() >= REJECT_AFTER_TIME
95 }
96
97 pub fn can_rekey(&self) -> bool {
99 self.epoch < MAX_EPOCH
100 }
101
102 pub fn advance_epoch(&mut self) -> Result<(), CryptoError> {
109 if self.epoch == MAX_EPOCH {
110 return Err(CryptoError::EpochExhaustion);
111 }
112 self.epoch += 1;
113 self.epoch_start = Instant::now();
114 self.send_count = 0;
115 self.recv_count = 0;
116 Ok(())
117 }
118}
119
120impl Default for RekeyState {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126pub struct OldKeyRetention {
128 initiator_key: Option<SessionKey>,
130 responder_key: Option<SessionKey>,
132 retained_at: Option<Instant>,
134}
135
136impl OldKeyRetention {
137 pub fn new() -> Self {
139 Self {
140 initiator_key: None,
141 responder_key: None,
142 retained_at: None,
143 }
144 }
145
146 pub fn retain(&mut self, initiator_key: SessionKey, responder_key: SessionKey) {
148 self.initiator_key = Some(initiator_key);
149 self.responder_key = Some(responder_key);
150 self.retained_at = Some(Instant::now());
151 }
152
153 pub fn old_initiator_key(&self) -> Option<&SessionKey> {
155 if self.within_retention_window() {
156 self.initiator_key.as_ref()
157 } else {
158 None
159 }
160 }
161
162 pub fn old_responder_key(&self) -> Option<&SessionKey> {
164 if self.within_retention_window() {
165 self.responder_key.as_ref()
166 } else {
167 None
168 }
169 }
170
171 pub fn within_retention_window(&self) -> bool {
173 self.retained_at
174 .is_some_and(|t| t.elapsed() < OLD_KEY_RETENTION)
175 }
176
177 pub fn clear(&mut self) {
179 self.initiator_key = None;
180 self.responder_key = None;
181 self.retained_at = None;
182 }
183
184 pub fn should_clear(&self) -> bool {
186 self.retained_at
187 .is_some_and(|t| t.elapsed() >= OLD_KEY_RETENTION)
188 }
189
190 pub fn clear_if_expired(&mut self) {
192 if self.should_clear() {
193 self.clear();
194 }
195 }
196}
197
198impl Default for OldKeyRetention {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204pub fn derive_rekey_keys(
226 ephemeral_dh: &[u8; 32],
227 rekey_auth_key: &[u8; 32],
228 epoch: u32,
229) -> Result<(SessionKey, SessionKey), CryptoError> {
230 let mut ikm = [0u8; 64];
233 ikm[..32].copy_from_slice(ephemeral_dh);
234 ikm[32..].copy_from_slice(rekey_auth_key);
235
236 let label = b"nomad v1 rekey";
238 let epoch_bytes = epoch.to_le_bytes();
239 let mut info = Vec::with_capacity(label.len() + 4);
240 info.extend_from_slice(label);
241 info.extend_from_slice(&epoch_bytes);
242
243 let hk = Hkdf::<Sha256>::from_prk(&ikm)
246 .map_err(|_| CryptoError::KeyDerivationFailed)?;
247 let mut key_material = [0u8; 64];
248 hk.expand(&info, &mut key_material)
249 .map_err(|_| CryptoError::KeyDerivationFailed)?;
250
251 let mut initiator_key = [0u8; SESSION_KEY_SIZE];
252 let mut responder_key = [0u8; SESSION_KEY_SIZE];
253 initiator_key.copy_from_slice(&key_material[..32]);
254 responder_key.copy_from_slice(&key_material[32..]);
255
256 ikm.zeroize();
258 key_material.zeroize();
259
260 Ok((
261 SessionKey::from_bytes(initiator_key),
262 SessionKey::from_bytes(responder_key),
263 ))
264}
265
266pub fn derive_rekey_auth_key(static_dh_secret: &[u8; 32]) -> [u8; 32] {
281 let info = b"nomad v1 rekey auth";
282
283 let hk = Hkdf::<Sha256>::from_prk(static_dh_secret)
286 .expect("32 bytes is valid PRK length for SHA-256 HKDF");
287 let mut rekey_auth_key = [0u8; 32];
288 hk.expand(info, &mut rekey_auth_key)
289 .expect("32 bytes is valid output length for SHA-256 HKDF");
290
291 rekey_auth_key
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_rekey_state_new() {
300 let state = RekeyState::new();
301 assert_eq!(state.epoch(), 0);
302 assert_eq!(state.send_count(), 0);
303 assert_eq!(state.recv_count(), 0);
304 assert!(!state.should_rekey());
305 assert!(!state.keys_expired());
306 assert!(state.can_rekey());
307 }
308
309 #[test]
310 fn test_increment_send() {
311 let mut state = RekeyState::new();
312
313 for i in 0..10 {
314 let counter = state.increment_send().unwrap();
315 assert_eq!(counter, i);
316 }
317 assert_eq!(state.send_count(), 10);
318 }
319
320 #[test]
321 fn test_record_recv() {
322 let mut state = RekeyState::new();
323
324 state.record_recv(5);
325 assert_eq!(state.recv_count(), 6); state.record_recv(3); assert_eq!(state.recv_count(), 6);
329
330 state.record_recv(10);
331 assert_eq!(state.recv_count(), 11);
332 }
333
334 #[test]
335 fn test_advance_epoch() {
336 let mut state = RekeyState::new();
337 state.increment_send().unwrap();
338 state.increment_send().unwrap();
339
340 state.advance_epoch().unwrap();
341
342 assert_eq!(state.epoch(), 1);
343 assert_eq!(state.send_count(), 0);
344 assert_eq!(state.recv_count(), 0);
345 }
346
347 #[test]
348 fn test_old_key_retention() {
349 let mut retention = OldKeyRetention::new();
350
351 assert!(retention.old_initiator_key().is_none());
352 assert!(!retention.within_retention_window());
353
354 let key1 = SessionKey::from_bytes([0x01; SESSION_KEY_SIZE]);
355 let key2 = SessionKey::from_bytes([0x02; SESSION_KEY_SIZE]);
356
357 retention.retain(key1, key2);
358
359 assert!(retention.within_retention_window());
360 assert!(retention.old_initiator_key().is_some());
361 assert!(retention.old_responder_key().is_some());
362 }
363
364 #[test]
365 fn test_derive_rekey_keys() {
366 let ephemeral_dh = [0x42u8; 32];
367 let rekey_auth_key = [0x33u8; 32];
368
369 let (key1_epoch0, key2_epoch0) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 0).unwrap();
370 let (key1_epoch1, key2_epoch1) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 1).unwrap();
371
372 assert_ne!(key1_epoch0.as_bytes(), key1_epoch1.as_bytes());
374 assert_ne!(key2_epoch0.as_bytes(), key2_epoch1.as_bytes());
375
376 let (key1_epoch0_again, key2_epoch0_again) = derive_rekey_keys(&ephemeral_dh, &rekey_auth_key, 0).unwrap();
378 assert_eq!(key1_epoch0.as_bytes(), key1_epoch0_again.as_bytes());
379 assert_eq!(key2_epoch0.as_bytes(), key2_epoch0_again.as_bytes());
380 }
381
382 #[test]
383 fn test_derive_rekey_keys_different_ephemeral_dh() {
384 let ephemeral_dh1 = [0x01u8; 32];
385 let ephemeral_dh2 = [0x02u8; 32];
386 let rekey_auth_key = [0x33u8; 32];
387
388 let (key1_dh1, _) = derive_rekey_keys(&ephemeral_dh1, &rekey_auth_key, 0).unwrap();
389 let (key1_dh2, _) = derive_rekey_keys(&ephemeral_dh2, &rekey_auth_key, 0).unwrap();
390
391 assert_ne!(key1_dh1.as_bytes(), key1_dh2.as_bytes());
393 }
394
395 #[test]
396 fn test_derive_rekey_keys_pcs() {
397 let ephemeral_dh = [0x42u8; 32];
400 let auth_key1 = [0x01u8; 32];
401 let auth_key2 = [0x02u8; 32];
402
403 let (key1_auth1, _) = derive_rekey_keys(&ephemeral_dh, &auth_key1, 0).unwrap();
404 let (key1_auth2, _) = derive_rekey_keys(&ephemeral_dh, &auth_key2, 0).unwrap();
405
406 assert_ne!(key1_auth1.as_bytes(), key1_auth2.as_bytes());
410 }
411
412 #[test]
413 fn test_derive_rekey_auth_key() {
414 let static_dh1 = [0x01u8; 32];
415 let static_dh2 = [0x02u8; 32];
416
417 let auth_key1 = derive_rekey_auth_key(&static_dh1);
418 let auth_key2 = derive_rekey_auth_key(&static_dh2);
419
420 assert_ne!(auth_key1, auth_key2);
422
423 let auth_key1_again = derive_rekey_auth_key(&static_dh1);
425 assert_eq!(auth_key1, auth_key1_again);
426 }
427
428 #[test]
429 fn test_pcs_property() {
430 let ephemeral_dh = [0x42u8; 32];
436 let real_static_dh = [0xABu8; 32];
437 let attacker_guess_dh = [0xCDu8; 32];
438
439 let real_auth_key = derive_rekey_auth_key(&real_static_dh);
440 let attacker_auth_key = derive_rekey_auth_key(&attacker_guess_dh);
441
442 let (real_key1, _) = derive_rekey_keys(&ephemeral_dh, &real_auth_key, 1).unwrap();
444
445 let (attacker_key1, _) = derive_rekey_keys(&ephemeral_dh, &attacker_auth_key, 1).unwrap();
447
448 assert_ne!(real_key1.as_bytes(), attacker_key1.as_bytes());
450 }
451
452 fn hex_to_bytes(hex: &str) -> Vec<u8> {
458 (0..hex.len())
459 .step_by(2)
460 .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
461 .collect()
462 }
463
464 #[test]
465 fn test_vector_rekey_auth_key() {
466 let static_dh = hex_to_bytes("57fbeea357c6ca4af3654988d78e020ccc6f4bc56db385bff4a46084b1187266");
468 let expected_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
469
470 let mut static_dh_arr = [0u8; 32];
471 static_dh_arr.copy_from_slice(&static_dh);
472
473 let auth_key = derive_rekey_auth_key(&static_dh_arr);
474
475 assert_eq!(
476 auth_key.as_slice(),
477 expected_auth_key.as_slice(),
478 "rekey_auth_key derivation doesn't match test vector"
479 );
480 }
481
482 #[test]
483 fn test_vector_epoch_1() {
484 let ephemeral_dh = hex_to_bytes("813c560b94aec760c9a8d12a09bb4c2be3bfc35eb6983ceb264a13046d3aaa75");
486 let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
487 let expected_initiator_key = hex_to_bytes("ba7ba9959a0338866994033dc46c15df92e6a08b4d5041d5e52070001187c312");
488 let expected_responder_key = hex_to_bytes("91f2e4123a04abe6343003d6ff5793af7aae75ede7fdc6737aaf24964d9285f8");
489
490 let mut ephemeral_dh_arr = [0u8; 32];
491 let mut rekey_auth_key_arr = [0u8; 32];
492 ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
493 rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
494
495 let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 1).unwrap();
496
497 assert_eq!(
498 initiator_key.as_bytes(),
499 expected_initiator_key.as_slice(),
500 "epoch 1 initiator key doesn't match test vector"
501 );
502 assert_eq!(
503 responder_key.as_bytes(),
504 expected_responder_key.as_slice(),
505 "epoch 1 responder key doesn't match test vector"
506 );
507 }
508
509 #[test]
510 fn test_vector_epoch_2() {
511 let ephemeral_dh = hex_to_bytes("7efd5673c47236ad6f9bf85e945074615c1943c528a87cc0dc9084ad278d266e");
513 let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
514 let expected_initiator_key = hex_to_bytes("206c3c4f0838aaf5b039bad2ecd1a387d6f784afbf1d283dc0a438ad45f4db3e");
515 let expected_responder_key = hex_to_bytes("786554075c38e73a735b26cbfd650c9fd0f8909227e498487007fc2adfec661d");
516
517 let mut ephemeral_dh_arr = [0u8; 32];
518 let mut rekey_auth_key_arr = [0u8; 32];
519 ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
520 rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
521
522 let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 2).unwrap();
523
524 assert_eq!(
525 initiator_key.as_bytes(),
526 expected_initiator_key.as_slice(),
527 "epoch 2 initiator key doesn't match test vector"
528 );
529 assert_eq!(
530 responder_key.as_bytes(),
531 expected_responder_key.as_slice(),
532 "epoch 2 responder key doesn't match test vector"
533 );
534 }
535
536 #[test]
537 fn test_vector_epoch_100() {
538 let ephemeral_dh = hex_to_bytes("0038038a95c66833de6cd4a4743226d03d952d35d1885876f63b95deea271e3f");
540 let rekey_auth_key = hex_to_bytes("48c391a58d3e6fe3e5c463cd874b4565b752da33d63b9d93f9a469549ebbbe09");
541 let expected_initiator_key = hex_to_bytes("dda7dd785c4c5f75096c0ea88023b1558e26bb84f4c4eb72ba7977c6947abc1a");
542 let expected_responder_key = hex_to_bytes("110c7c42998204153892f1ac84634c355ed1b279174befd2f27936073567e54f");
543
544 let mut ephemeral_dh_arr = [0u8; 32];
545 let mut rekey_auth_key_arr = [0u8; 32];
546 ephemeral_dh_arr.copy_from_slice(&ephemeral_dh);
547 rekey_auth_key_arr.copy_from_slice(&rekey_auth_key);
548
549 let (initiator_key, responder_key) = derive_rekey_keys(&ephemeral_dh_arr, &rekey_auth_key_arr, 100).unwrap();
550
551 assert_eq!(
552 initiator_key.as_bytes(),
553 expected_initiator_key.as_slice(),
554 "epoch 100 initiator key doesn't match test vector"
555 );
556 assert_eq!(
557 responder_key.as_bytes(),
558 expected_responder_key.as_slice(),
559 "epoch 100 responder key doesn't match test vector"
560 );
561 }
562}