1use std::collections::BTreeMap;
10#[cfg(not(any(test, target_arch = "wasm32")))]
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use chacha20poly1305::{
14 XChaCha20Poly1305, XNonce,
15 aead::{Aead, KeyInit, Payload},
16};
17use tracing::{error, instrument};
18
19use super::ciphersuite::Ciphersuite;
20use super::packet::TransportPacket;
21use crate::packet::TransportPacketAad;
22use crate::{error::NoiseProtocolError, symmetric_key::SymmetricKey};
23
24const MAX_MESSAGE_AGE: u64 = 86400;
26const CLOCK_SKEW_TOLERANCE: u64 = 60;
28
29const MAX_REKEY_GAP: u64 = 1024;
31
32const REKEY_INTERVAL: u64 = 86400;
35
36#[derive(Clone, Debug)]
46pub struct MultiDeviceTransport {
47 ciphersuite: Ciphersuite,
49
50 send_key: SymmetricKey,
52 send_rekey_counter: u64,
54 last_rekeyed_time: u64,
56 rekey_interval: u64,
58
59 recv_key: SymmetricKey,
61 recv_rekey_counter: u64,
63
64 seen_nonces: BTreeMap<Vec<u8>, u64>,
66
67 timeprovider: Timeprovider,
68}
69
70#[derive(Clone, Debug)]
71struct Timeprovider {
72 #[cfg(test)]
73 now: u64,
74}
75
76impl Timeprovider {
77 fn new() -> Self {
78 Timeprovider {
79 #[cfg(test)]
80 now: 0,
81 }
82 }
83
84 #[cfg(not(any(test, target_arch = "wasm32")))]
85 fn now(&self) -> u64 {
86 SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .expect("System time before Unix epoch")
89 .as_secs()
90 }
91
92 #[cfg(all(not(test), target_arch = "wasm32"))]
93 fn now(&self) -> u64 {
94 (js_sys::Date::now() / 1000.0) as u64
95 }
96
97 #[cfg(test)]
98 fn now(&self) -> u64 {
99 self.now
100 }
101
102 #[cfg(test)]
103 fn set_now(&mut self, now: u64) {
104 self.now = now;
105 }
106}
107
108impl MultiDeviceTransport {
109 pub(crate) fn new(
111 ciphersuite: Ciphersuite,
112 send_key: SymmetricKey,
113 recv_key: SymmetricKey,
114 ) -> Self {
115 let timeprovider = Timeprovider::new();
116 Self {
117 ciphersuite,
118 send_key,
119 send_rekey_counter: 1,
120 last_rekeyed_time: timeprovider.now(),
121 rekey_interval: REKEY_INTERVAL,
122 recv_key,
123 recv_rekey_counter: 1,
124 seen_nonces: BTreeMap::new(),
125 timeprovider,
126 }
127 }
128
129 fn prune_old_nonces(&mut self) {
131 let now = self.timeprovider.now();
132 let cutoff = now.saturating_sub(MAX_MESSAGE_AGE);
133 self.seen_nonces
134 .retain(|_, &mut timestamp| timestamp >= cutoff);
135 }
136
137 fn check_and_record_nonce(
139 &mut self,
140 packet_aad: &TransportPacketAad,
141 packet: &TransportPacket,
142 ) -> Result<(), NoiseProtocolError> {
143 self.prune_old_nonces();
144 if self.seen_nonces.contains_key(&packet.nonce) {
145 return Err(NoiseProtocolError::ReplayDetected);
146 }
147 self.seen_nonces
148 .insert(packet.nonce.clone(), packet_aad.timestamp);
149 Ok(())
150 }
151
152 pub fn seen_nonces(&self) -> Vec<u8> {
154 let mut buf = Vec::new();
155 ciborium::ser::into_writer(&self.seen_nonces, &mut buf)
156 .expect("should serialize seen nonces");
157 buf
158 }
159
160 pub fn set_seen_nonces(&mut self, data: &[u8]) -> Result<(), NoiseProtocolError> {
162 let nonces: BTreeMap<Vec<u8>, u64> =
163 ciborium::de::from_reader(data).map_err(|_| NoiseProtocolError::CborDecodeFailed)?;
164 self.seen_nonces = nonces;
165 Ok(())
166 }
167
168 pub fn ciphersuite(&self) -> Ciphersuite {
170 self.ciphersuite
171 }
172
173 fn rekey_send_if_needed(&mut self) -> Result<(), NoiseProtocolError> {
174 let now = self.timeprovider.now();
176 while now.saturating_sub(self.last_rekeyed_time) >= self.rekey_interval {
177 self.send_key = XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.send_key)?;
178 self.send_rekey_counter = self.send_rekey_counter.wrapping_add(1);
179 self.last_rekeyed_time += self.rekey_interval;
180 }
181
182 Ok(())
183 }
184
185 #[instrument(level = "debug", fields(ciphersuite = ?self.ciphersuite))]
189 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<TransportPacket, NoiseProtocolError> {
190 self.rekey_send_if_needed()?;
192
193 let aad = TransportPacketAad {
195 timestamp: self.timeprovider.now(),
196 chain_counter: self.send_rekey_counter,
197 ciphersuite: self.ciphersuite,
198 };
199
200 let (nonce, ciphertext) =
202 XChaCha20Poly1305RandomNonceCipher::encrypt(&self.send_key, plaintext, &aad.encode())?;
203
204 let packet = TransportPacket {
206 nonce: nonce.to_vec(),
207 ciphertext,
208 aad: aad.encode(),
209 };
210
211 Ok(packet)
212 }
213
214 #[instrument(level = "debug", fields(ciphersuite = ?self.ciphersuite))]
216 pub fn decrypt(&mut self, packet: &TransportPacket) -> Result<Vec<u8>, NoiseProtocolError> {
217 let packet_aad = TransportPacketAad::decode(&packet.aad)
218 .map_err(|_| NoiseProtocolError::DecryptionFailed)?;
219
220 if packet_aad.ciphersuite != self.ciphersuite {
222 error!("Ciphersuite mismatch detected");
223 return Err(NoiseProtocolError::CiphersuiteMismatch);
224 }
225 self.validate_message_timestamp(&packet_aad)?;
226 self.check_and_record_nonce(&packet_aad, packet)?;
227 self.rekey_receive(&packet_aad)?;
228
229 let packet_decryption_key = if packet_aad.chain_counter == self.recv_rekey_counter {
230 self.recv_key.clone()
231 } else {
232 XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.recv_key)?.clone()
233 };
234
235 XChaCha20Poly1305RandomNonceCipher::decrypt(
237 &packet_decryption_key,
238 &XChaCha20Poly1305Nonce::from_slice(&packet.nonce),
239 &packet.ciphertext,
240 &packet.aad,
241 )
242 }
243
244 fn rekey_receive(&mut self, packet_aad: &TransportPacketAad) -> Result<(), NoiseProtocolError> {
245 if packet_aad.chain_counter < self.recv_rekey_counter {
260 Err(NoiseProtocolError::Desynchronized)
262 } else if packet_aad.chain_counter > self.recv_rekey_counter + MAX_REKEY_GAP {
263 Err(NoiseProtocolError::Desynchronized)
265 } else {
266 while self.recv_rekey_counter < packet_aad.chain_counter - 1 {
269 self.recv_key = XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.recv_key)?;
270 self.recv_rekey_counter = self.recv_rekey_counter.wrapping_add(1);
271 }
272 Ok(())
273 }
274 }
275
276 fn validate_message_timestamp(
278 &mut self,
279 packet_aad: &TransportPacketAad,
280 ) -> Result<(), NoiseProtocolError> {
281 let now = self.timeprovider.now();
282 let packet_timestamp = packet_aad.timestamp;
283
284 if packet_timestamp < now.saturating_sub(MAX_MESSAGE_AGE) {
286 error!(
287 "Message too old: timestamp={}, now={}, age={}",
288 packet_timestamp,
289 now,
290 now.saturating_sub(packet_timestamp)
291 );
292 return Err(NoiseProtocolError::MessageTooOld {
293 timestamp: packet_timestamp,
294 now,
295 });
296 }
297
298 if packet_timestamp > now + CLOCK_SKEW_TOLERANCE {
300 error!(
301 "Message from future: timestamp={}, now={}",
302 packet_timestamp, now
303 );
304 return Err(NoiseProtocolError::MessageFromFuture {
305 timestamp: packet_timestamp,
306 now,
307 });
308 }
309
310 Ok(())
311 }
312
313 pub(crate) fn restore_from_state(
315 ciphersuite: Ciphersuite,
316 send_key: SymmetricKey,
317 recv_key: SymmetricKey,
318 send_rekey_counter: u64,
319 recv_rekey_counter: u64,
320 last_rekeyed_time: u64,
321 rekey_interval: u64,
322 ) -> Self {
323 Self {
324 ciphersuite,
325 send_key,
326 send_rekey_counter,
327 last_rekeyed_time,
328 rekey_interval,
329 recv_key,
330 recv_rekey_counter,
331 seen_nonces: BTreeMap::new(),
332 timeprovider: Timeprovider::new(),
333 }
334 }
335
336 pub fn send_rekey_counter(&self) -> u64 {
338 self.send_rekey_counter
339 }
340
341 pub fn recv_rekey_counter(&self) -> u64 {
343 self.recv_rekey_counter
344 }
345
346 #[cfg(test)]
347 pub(crate) fn set_last_rekeyed_time(&mut self, timestamp: u64) {
348 self.last_rekeyed_time = timestamp;
349 }
350
351 pub fn last_rekeyed_time(&self) -> u64 {
353 self.last_rekeyed_time
354 }
355
356 pub fn rekey_interval(&self) -> u64 {
358 self.rekey_interval
359 }
360
361 pub fn keys(&self) -> (SymmetricKey, SymmetricKey) {
363 (self.send_key.clone(), self.recv_key.clone())
364 }
365
366 #[cfg(test)]
367 pub(crate) fn set_send_rekey_counter(&mut self, counter: u64) {
368 self.send_rekey_counter = counter;
369 }
370
371 #[cfg(test)]
372 #[allow(unused)]
373 pub(crate) fn set_recv_rekey_counter(&mut self, counter: u64) {
374 self.recv_rekey_counter = counter;
375 }
376
377 #[cfg(test)]
378 pub(crate) fn send_key(&self) -> SymmetricKey {
379 self.send_key.clone()
380 }
381
382 #[cfg(test)]
383 pub(crate) fn recv_key(&self) -> SymmetricKey {
384 self.recv_key.clone()
385 }
386}
387
388struct XChaCha20Poly1305Nonce([u8; 24]);
389
390impl XChaCha20Poly1305Nonce {
391 fn from_slice(slice: &[u8]) -> Self {
392 let mut nonce = [0u8; 24];
393 nonce.copy_from_slice(slice);
394 XChaCha20Poly1305Nonce(nonce)
395 }
396
397 fn rekey_max_value() -> Self {
399 XChaCha20Poly1305Nonce([0xFF; 24])
400 }
401
402 fn generate() -> Self {
403 let mut nonce = [0u8; 24];
404 let mut rng = rand::thread_rng();
405 rand::RngCore::fill_bytes(&mut rng, &mut nonce);
406 XChaCha20Poly1305Nonce(nonce)
407 }
408
409 fn to_vec(&self) -> Vec<u8> {
410 self.0.to_vec()
411 }
412}
413
414impl From<&XChaCha20Poly1305Nonce> for XNonce {
415 fn from(nonce: &XChaCha20Poly1305Nonce) -> Self {
416 *XNonce::from_slice(&nonce.0)
417 }
418}
419
420struct XChaCha20Poly1305RandomNonceCipher;
421
422impl XChaCha20Poly1305RandomNonceCipher {
423 fn rekey(key: &mut SymmetricKey) -> Result<SymmetricKey, NoiseProtocolError> {
424 let nonce = XChaCha20Poly1305Nonce::rekey_max_value();
425 let empty_key = [0u8; 32];
426 let cipher = XChaCha20Poly1305::new(key.as_slice().into());
427 let derived = cipher
428 .encrypt(
429 &(&nonce).into(),
430 Payload {
431 msg: &empty_key,
432 aad: &[],
433 },
434 )
435 .map_err(|_| NoiseProtocolError::RekeyFailed)?;
436 let mut new_key = [0u8; 32];
437 new_key.copy_from_slice(&derived[..32]);
438 Ok(SymmetricKey::from_bytes(new_key))
439 }
440
441 fn encrypt(
442 key: &SymmetricKey,
443 plaintext: &[u8],
444 aad: &[u8],
445 ) -> Result<(XChaCha20Poly1305Nonce, Vec<u8>), NoiseProtocolError> {
446 let nonce = XChaCha20Poly1305Nonce::generate();
447 let cipher = XChaCha20Poly1305::new(&key.to_bytes().into());
448 let ciphertext = cipher
449 .encrypt(
450 &(&nonce).into(),
451 Payload {
452 msg: plaintext,
453 aad,
454 },
455 )
456 .map_err(|_| NoiseProtocolError::TransportEncryptionFailed)?;
457
458 Ok((nonce, ciphertext))
459 }
460
461 fn decrypt(
462 key: &SymmetricKey,
463 nonce: &XChaCha20Poly1305Nonce,
464 ciphertext: &[u8],
465 aad: &[u8],
466 ) -> Result<Vec<u8>, NoiseProtocolError> {
467 let cipher = XChaCha20Poly1305::new(&key.to_bytes().into());
468 let payload = Payload {
469 msg: ciphertext,
470 aad,
471 };
472
473 let plaintext = cipher
474 .decrypt(&(nonce).into(), payload)
475 .map_err(|_| NoiseProtocolError::TransportDecryptionFailed)?;
476
477 Ok(plaintext)
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::symmetric_key::{SYMMETRIC_KEY_TEST_VECTOR_1, SYMMETRIC_KEY_TEST_VECTOR_2};
484 const PLAINTEXT_TEST_VECTOR: &[u8] = b"Test message for multi-device transport";
485
486 use super::*;
487
488 fn setup_sender_receiver() -> (MultiDeviceTransport, MultiDeviceTransport) {
489 let send_key = SYMMETRIC_KEY_TEST_VECTOR_1;
490 let recv_key = SYMMETRIC_KEY_TEST_VECTOR_2;
491
492 let sender_send_key = send_key.clone();
494 let sender_recv_key = recv_key.clone();
495 let receiver_send_key = recv_key.clone();
496 let receiver_recv_key = send_key.clone();
497
498 let sender = MultiDeviceTransport::new(
499 Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
500 sender_send_key,
501 sender_recv_key,
502 );
503
504 let receiver = MultiDeviceTransport::new(
505 Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
506 receiver_send_key,
507 receiver_recv_key,
508 );
509
510 (sender, receiver)
511 }
512
513 #[test]
514 fn test_encrypt_decrypt() {
515 let (mut sender, mut receiver) = setup_sender_receiver();
516 let packet = sender
517 .encrypt(PLAINTEXT_TEST_VECTOR)
518 .expect("should encrypt");
519
520 let decrypted = receiver.decrypt(&packet).expect("should decrypt");
521 assert_eq!(PLAINTEXT_TEST_VECTOR, decrypted);
522 }
523
524 #[test]
525 fn test_replay_detection() {
526 let (mut sender, mut receiver) = setup_sender_receiver();
527 let packet = sender
528 .encrypt(PLAINTEXT_TEST_VECTOR)
529 .expect("should encrypt");
530
531 let _ = receiver.decrypt(&packet).expect("should decrypt");
532 let _ = receiver.decrypt(&packet).expect_err("should detect replay");
533 }
534
535 #[test]
536 fn test_message_too_old() {
537 let (mut sender, mut receiver) = setup_sender_receiver();
538
539 receiver.timeprovider.set_now(2000000000);
541 let packet = sender
542 .encrypt(PLAINTEXT_TEST_VECTOR)
543 .expect("should encrypt");
544
545 let _ = receiver
547 .decrypt(&packet)
548 .expect_err("should detect old message");
549 }
550
551 #[test]
552 fn test_message_from_future() {
553 let (mut sender, mut receiver) = setup_sender_receiver();
554
555 sender.timeprovider.set_now(2000000000);
557 let packet = sender
558 .encrypt(PLAINTEXT_TEST_VECTOR)
559 .expect("should encrypt");
560
561 let _ = receiver
563 .decrypt(&packet)
564 .expect_err("should detect future message");
565 }
566
567 #[test]
568 fn test_send_rekey() {
569 let (mut sender, mut receiver) = setup_sender_receiver();
570
571 sender.timeprovider.set_now(0);
573 let packet1 = sender
574 .encrypt(PLAINTEXT_TEST_VECTOR)
575 .expect("should encrypt");
576
577 sender.timeprovider.set_now(REKEY_INTERVAL);
579 let packet2 = sender
580 .encrypt(PLAINTEXT_TEST_VECTOR)
581 .expect("should encrypt after rekey");
582
583 receiver
585 .decrypt(&packet1)
586 .expect("should decrypt first message");
587 receiver.timeprovider.set_now(REKEY_INTERVAL);
589 receiver
590 .decrypt(&packet2)
591 .expect("should decrypt second message");
592
593 assert_eq!(sender.send_rekey_counter, 2);
595 }
596
597 #[test]
598 fn test_receive_rekey_catchup() {
599 let (mut sender, mut receiver) = setup_sender_receiver();
600
601 receiver.timeprovider.set_now(REKEY_INTERVAL);
603 receiver.encrypt(b"msg1").expect("should encrypt"); receiver.timeprovider.set_now(REKEY_INTERVAL * 2);
605 receiver.encrypt(b"msg2").expect("should encrypt"); receiver.timeprovider.set_now(REKEY_INTERVAL * 3);
607 let packet = receiver
608 .encrypt(PLAINTEXT_TEST_VECTOR)
609 .expect("should encrypt"); sender.timeprovider.set_now(REKEY_INTERVAL * 3);
613
614 let decrypted = sender
616 .decrypt(&packet)
617 .expect("should decrypt after catchup");
618 assert_eq!(decrypted, PLAINTEXT_TEST_VECTOR);
619 }
620
621 #[test]
622 fn test_desynchronization() {
623 let (mut sender, mut receiver) = setup_sender_receiver();
624 sender.set_send_rekey_counter(MAX_REKEY_GAP + 2);
625
626 let packet = sender
627 .encrypt(PLAINTEXT_TEST_VECTOR)
628 .expect("should encrypt");
629
630 let result = receiver.decrypt(&packet);
631 assert!(result.is_err());
632 assert!(matches!(
633 result.err(),
634 Some(NoiseProtocolError::Desynchronized)
635 ));
636 }
637
638 #[test]
639 fn test_device_group_out_of_order_answers() {
640 let (sender, receiver) = setup_sender_receiver();
642 let mut single_device = sender;
643 let mut device_group_device_1 = receiver.clone();
644 let mut device_group_device_2 = receiver;
645
646 let packet1 = single_device
648 .encrypt(PLAINTEXT_TEST_VECTOR)
649 .expect("should encrypt request1");
650 let packet2 = single_device
651 .encrypt(PLAINTEXT_TEST_VECTOR)
652 .expect("should encrypt request2");
653
654 let _ = device_group_device_2
656 .decrypt(&packet2)
657 .expect("should decrypt request2");
658 let response2 = device_group_device_2
659 .encrypt(PLAINTEXT_TEST_VECTOR)
660 .expect("should encrypt response2");
661 let decrypted_response2 = single_device
662 .decrypt(&response2)
663 .expect("should decrypt response2");
664 assert_eq!(decrypted_response2, PLAINTEXT_TEST_VECTOR);
665
666 let _ = device_group_device_1
667 .decrypt(&packet1)
668 .expect("should decrypt request1");
669 let response1 = device_group_device_1
670 .encrypt(PLAINTEXT_TEST_VECTOR)
671 .expect("should encrypt response1");
672 let decrypted_response1 = single_device
673 .decrypt(&response1)
674 .expect("should decrypt response1");
675 assert_eq!(decrypted_response1, PLAINTEXT_TEST_VECTOR);
676 }
677}