1use std::collections::BTreeMap;
10#[cfg(not(any(test, target_arch = "wasm32")))]
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use chacha20poly1305::{
14 XChaCha20Poly1305, XNonce,
15 aead::{Aead, KeyInit, Payload},
16};
17use tracing::{error, instrument};
18
19use super::ciphersuite::Ciphersuite;
20use super::packet::TransportPacket;
21use crate::packet::TransportPacketAad;
22use crate::{error::NoiseProtocolError, symmetric_key::SymmetricKey};
23
24const MAX_MESSAGE_AGE: u64 = 86400;
26const CLOCK_SKEW_TOLERANCE: u64 = 60;
28
29const MAX_REKEY_GAP: u64 = 1024;
31
32const REKEY_INTERVAL: u64 = 86400;
35
36#[derive(Clone, Debug)]
46pub struct MultiDeviceTransport {
47 ciphersuite: Ciphersuite,
49
50 send_key: SymmetricKey,
52 send_rekey_counter: u64,
54 last_rekeyed_time: u64,
56 rekey_interval: u64,
58
59 recv_key: SymmetricKey,
61 recv_rekey_counter: u64,
63
64 seen_nonces: BTreeMap<Vec<u8>, u64>,
66
67 timeprovider: Timeprovider,
68}
69
70#[derive(Clone, Debug)]
71struct Timeprovider {
72 #[cfg(test)]
73 now: u64,
74}
75
76impl Timeprovider {
77 fn new() -> Self {
78 Timeprovider {
79 #[cfg(test)]
80 now: 0,
81 }
82 }
83
84 #[cfg(not(any(test, target_arch = "wasm32")))]
85 fn now(&self) -> u64 {
86 SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .expect("System time before Unix epoch")
89 .as_secs()
90 }
91
92 #[cfg(all(not(test), target_arch = "wasm32"))]
93 fn now(&self) -> u64 {
94 (js_sys::Date::now() / 1000.0) as u64
95 }
96
97 #[cfg(test)]
98 fn now(&self) -> u64 {
99 self.now
100 }
101
102 #[cfg(test)]
103 fn set_now(&mut self, now: u64) {
104 self.now = now;
105 }
106}
107
108impl MultiDeviceTransport {
109 pub(crate) fn new(
111 ciphersuite: Ciphersuite,
112 send_key: SymmetricKey,
113 recv_key: SymmetricKey,
114 ) -> Self {
115 let timeprovider = Timeprovider::new();
116 Self {
117 ciphersuite,
118 send_key,
119 send_rekey_counter: 1,
120 last_rekeyed_time: timeprovider.now(),
121 rekey_interval: REKEY_INTERVAL,
122 recv_key,
123 recv_rekey_counter: 1,
124 seen_nonces: BTreeMap::new(),
125 timeprovider,
126 }
127 }
128
129 fn prune_old_nonces(&mut self) {
131 let now = self.timeprovider.now();
132 let cutoff = now.saturating_sub(MAX_MESSAGE_AGE);
133 self.seen_nonces
134 .retain(|_, &mut timestamp| timestamp >= cutoff);
135 }
136
137 fn check_and_record_nonce(
139 &mut self,
140 packet_aad: &TransportPacketAad,
141 packet: &TransportPacket,
142 ) -> Result<(), NoiseProtocolError> {
143 self.prune_old_nonces();
144 if self.seen_nonces.contains_key(&packet.nonce) {
145 return Err(NoiseProtocolError::ReplayDetected);
146 }
147 self.seen_nonces
148 .insert(packet.nonce.clone(), packet_aad.timestamp);
149 Ok(())
150 }
151
152 pub fn seen_nonces(&self) -> Vec<u8> {
154 let mut buf = Vec::new();
155 ciborium::ser::into_writer(&self.seen_nonces, &mut buf)
156 .expect("should serialize seen nonces");
157 buf
158 }
159
160 pub fn set_seen_nonces(&mut self, data: &[u8]) -> Result<(), NoiseProtocolError> {
162 let nonces: BTreeMap<Vec<u8>, u64> =
163 ciborium::de::from_reader(data).map_err(|_| NoiseProtocolError::CborDecodeFailed)?;
164 self.seen_nonces = nonces;
165 Ok(())
166 }
167
168 pub fn ciphersuite(&self) -> Ciphersuite {
170 self.ciphersuite
171 }
172
173 fn rekey_send_if_needed(&mut self) -> Result<(), NoiseProtocolError> {
174 let now = self.timeprovider.now();
176 while now.saturating_sub(self.last_rekeyed_time) >= self.rekey_interval {
177 self.send_key = XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.send_key)?;
178 self.send_rekey_counter = self.send_rekey_counter.wrapping_add(1);
179 self.last_rekeyed_time += self.rekey_interval;
180 }
181
182 Ok(())
183 }
184
185 #[instrument(level = "debug", fields(ciphersuite = ?self.ciphersuite))]
189 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<TransportPacket, NoiseProtocolError> {
190 self.rekey_send_if_needed()?;
192
193 let aad = TransportPacketAad {
195 timestamp: self.timeprovider.now(),
196 chain_counter: self.send_rekey_counter,
197 ciphersuite: self.ciphersuite,
198 };
199
200 let (nonce, ciphertext) =
202 XChaCha20Poly1305RandomNonceCipher::encrypt(&self.send_key, plaintext, &aad.encode())?;
203
204 let packet = TransportPacket {
206 nonce: nonce.to_vec(),
207 ciphertext,
208 aad: aad.encode(),
209 };
210
211 Ok(packet)
212 }
213
214 #[instrument(level = "debug", fields(ciphersuite = ?self.ciphersuite))]
216 pub fn decrypt(&mut self, packet: &TransportPacket) -> Result<Vec<u8>, NoiseProtocolError> {
217 let packet_aad = TransportPacketAad::decode(&packet.aad)
218 .map_err(|_| NoiseProtocolError::DecryptionFailed)?;
219
220 if packet_aad.ciphersuite != self.ciphersuite {
222 error!("Ciphersuite mismatch detected");
223 return Err(NoiseProtocolError::CiphersuiteMismatch);
224 }
225 self.validate_message_timestamp(&packet_aad)?;
226 self.check_and_record_nonce(&packet_aad, packet)?;
227 self.rekey_receive(&packet_aad)?;
228
229 let packet_decryption_key = if packet_aad.chain_counter == self.recv_rekey_counter {
230 self.recv_key.clone()
231 } else {
232 XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.recv_key)?.clone()
233 };
234
235 XChaCha20Poly1305RandomNonceCipher::decrypt(
237 &packet_decryption_key,
238 &XChaCha20Poly1305Nonce::from_slice(&packet.nonce)?,
239 &packet.ciphertext,
240 &packet.aad,
241 )
242 }
243
244 fn rekey_receive(&mut self, packet_aad: &TransportPacketAad) -> Result<(), NoiseProtocolError> {
245 if packet_aad.chain_counter < self.recv_rekey_counter {
260 Err(NoiseProtocolError::Desynchronized)
262 } else if packet_aad.chain_counter > self.recv_rekey_counter + MAX_REKEY_GAP {
263 Err(NoiseProtocolError::Desynchronized)
265 } else {
266 while self.recv_rekey_counter < packet_aad.chain_counter - 1 {
269 self.recv_key = XChaCha20Poly1305RandomNonceCipher::rekey(&mut self.recv_key)?;
270 self.recv_rekey_counter = self.recv_rekey_counter.wrapping_add(1);
271 }
272 Ok(())
273 }
274 }
275
276 fn validate_message_timestamp(
278 &mut self,
279 packet_aad: &TransportPacketAad,
280 ) -> Result<(), NoiseProtocolError> {
281 let now = self.timeprovider.now();
282 let packet_timestamp = packet_aad.timestamp;
283
284 if packet_timestamp < now.saturating_sub(MAX_MESSAGE_AGE) {
286 error!(
287 "Message too old: timestamp={}, now={}, age={}",
288 packet_timestamp,
289 now,
290 now.saturating_sub(packet_timestamp)
291 );
292 return Err(NoiseProtocolError::MessageTooOld {
293 timestamp: packet_timestamp,
294 now,
295 });
296 }
297
298 if packet_timestamp > now + CLOCK_SKEW_TOLERANCE {
300 error!(
301 "Message from future: timestamp={}, now={}",
302 packet_timestamp, now
303 );
304 return Err(NoiseProtocolError::MessageFromFuture {
305 timestamp: packet_timestamp,
306 now,
307 });
308 }
309
310 Ok(())
311 }
312
313 pub(crate) fn restore_from_state(
315 ciphersuite: Ciphersuite,
316 send_key: SymmetricKey,
317 recv_key: SymmetricKey,
318 send_rekey_counter: u64,
319 recv_rekey_counter: u64,
320 last_rekeyed_time: u64,
321 rekey_interval: u64,
322 ) -> Self {
323 Self {
324 ciphersuite,
325 send_key,
326 send_rekey_counter,
327 last_rekeyed_time,
328 rekey_interval,
329 recv_key,
330 recv_rekey_counter,
331 seen_nonces: BTreeMap::new(),
332 timeprovider: Timeprovider::new(),
333 }
334 }
335
336 pub fn send_rekey_counter(&self) -> u64 {
338 self.send_rekey_counter
339 }
340
341 pub fn recv_rekey_counter(&self) -> u64 {
343 self.recv_rekey_counter
344 }
345
346 #[cfg(test)]
347 pub(crate) fn set_last_rekeyed_time(&mut self, timestamp: u64) {
348 self.last_rekeyed_time = timestamp;
349 }
350
351 pub fn last_rekeyed_time(&self) -> u64 {
353 self.last_rekeyed_time
354 }
355
356 pub fn rekey_interval(&self) -> u64 {
358 self.rekey_interval
359 }
360
361 pub fn keys(&self) -> (SymmetricKey, SymmetricKey) {
363 (self.send_key.clone(), self.recv_key.clone())
364 }
365
366 #[cfg(test)]
367 pub(crate) fn set_send_rekey_counter(&mut self, counter: u64) {
368 self.send_rekey_counter = counter;
369 }
370
371 #[cfg(test)]
372 #[allow(unused)]
373 pub(crate) fn set_recv_rekey_counter(&mut self, counter: u64) {
374 self.recv_rekey_counter = counter;
375 }
376
377 #[cfg(test)]
378 pub(crate) fn send_key(&self) -> SymmetricKey {
379 self.send_key.clone()
380 }
381
382 #[cfg(test)]
383 pub(crate) fn recv_key(&self) -> SymmetricKey {
384 self.recv_key.clone()
385 }
386}
387
388struct XChaCha20Poly1305Nonce([u8; 24]);
389
390impl XChaCha20Poly1305Nonce {
391 fn from_slice(slice: &[u8]) -> Result<Self, NoiseProtocolError> {
392 if slice.len() != 24 {
393 return Err(NoiseProtocolError::DecryptionFailed);
394 }
395 let mut nonce = [0u8; 24];
396 nonce.copy_from_slice(slice);
397 Ok(XChaCha20Poly1305Nonce(nonce))
398 }
399
400 fn rekey_max_value() -> Self {
402 XChaCha20Poly1305Nonce([0xFF; 24])
403 }
404
405 fn generate() -> Self {
406 let mut nonce = [0u8; 24];
407 let mut rng = rand::thread_rng();
408 rand::RngCore::fill_bytes(&mut rng, &mut nonce);
409 XChaCha20Poly1305Nonce(nonce)
410 }
411
412 fn to_vec(&self) -> Vec<u8> {
413 self.0.to_vec()
414 }
415}
416
417impl From<&XChaCha20Poly1305Nonce> for XNonce {
418 fn from(nonce: &XChaCha20Poly1305Nonce) -> Self {
419 *XNonce::from_slice(&nonce.0)
420 }
421}
422
423struct XChaCha20Poly1305RandomNonceCipher;
424
425impl XChaCha20Poly1305RandomNonceCipher {
426 fn rekey(key: &mut SymmetricKey) -> Result<SymmetricKey, NoiseProtocolError> {
427 let nonce = XChaCha20Poly1305Nonce::rekey_max_value();
428 let empty_key = [0u8; 32];
429 let cipher = XChaCha20Poly1305::new(key.as_slice().into());
430 let derived = cipher
431 .encrypt(
432 &(&nonce).into(),
433 Payload {
434 msg: &empty_key,
435 aad: &[],
436 },
437 )
438 .map_err(|_| NoiseProtocolError::RekeyFailed)?;
439 let mut new_key = [0u8; 32];
440 new_key.copy_from_slice(&derived[..32]);
441 Ok(SymmetricKey::from_bytes(new_key))
442 }
443
444 fn encrypt(
445 key: &SymmetricKey,
446 plaintext: &[u8],
447 aad: &[u8],
448 ) -> Result<(XChaCha20Poly1305Nonce, Vec<u8>), NoiseProtocolError> {
449 let nonce = XChaCha20Poly1305Nonce::generate();
450 let cipher = XChaCha20Poly1305::new(&key.to_bytes().into());
451 let ciphertext = cipher
452 .encrypt(
453 &(&nonce).into(),
454 Payload {
455 msg: plaintext,
456 aad,
457 },
458 )
459 .map_err(|_| NoiseProtocolError::TransportEncryptionFailed)?;
460
461 Ok((nonce, ciphertext))
462 }
463
464 fn decrypt(
465 key: &SymmetricKey,
466 nonce: &XChaCha20Poly1305Nonce,
467 ciphertext: &[u8],
468 aad: &[u8],
469 ) -> Result<Vec<u8>, NoiseProtocolError> {
470 let cipher = XChaCha20Poly1305::new(&key.to_bytes().into());
471 let payload = Payload {
472 msg: ciphertext,
473 aad,
474 };
475
476 let plaintext = cipher
477 .decrypt(&(nonce).into(), payload)
478 .map_err(|_| NoiseProtocolError::TransportDecryptionFailed)?;
479
480 Ok(plaintext)
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use crate::symmetric_key::{SYMMETRIC_KEY_TEST_VECTOR_1, SYMMETRIC_KEY_TEST_VECTOR_2};
487 const PLAINTEXT_TEST_VECTOR: &[u8] = b"Test message for multi-device transport";
488
489 use super::*;
490
491 fn setup_sender_receiver() -> (MultiDeviceTransport, MultiDeviceTransport) {
492 let send_key = SYMMETRIC_KEY_TEST_VECTOR_1;
493 let recv_key = SYMMETRIC_KEY_TEST_VECTOR_2;
494
495 let sender_send_key = send_key.clone();
497 let sender_recv_key = recv_key.clone();
498 let receiver_send_key = recv_key.clone();
499 let receiver_recv_key = send_key.clone();
500
501 let sender = MultiDeviceTransport::new(
502 Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
503 sender_send_key,
504 sender_recv_key,
505 );
506
507 let receiver = MultiDeviceTransport::new(
508 Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
509 receiver_send_key,
510 receiver_recv_key,
511 );
512
513 (sender, receiver)
514 }
515
516 #[test]
517 fn test_encrypt_decrypt() {
518 let (mut sender, mut receiver) = setup_sender_receiver();
519 let packet = sender
520 .encrypt(PLAINTEXT_TEST_VECTOR)
521 .expect("should encrypt");
522
523 let decrypted = receiver.decrypt(&packet).expect("should decrypt");
524 assert_eq!(PLAINTEXT_TEST_VECTOR, decrypted);
525 }
526
527 #[test]
528 fn test_replay_detection() {
529 let (mut sender, mut receiver) = setup_sender_receiver();
530 let packet = sender
531 .encrypt(PLAINTEXT_TEST_VECTOR)
532 .expect("should encrypt");
533
534 let _ = receiver.decrypt(&packet).expect("should decrypt");
535 let _ = receiver.decrypt(&packet).expect_err("should detect replay");
536 }
537
538 #[test]
539 fn test_message_too_old() {
540 let (mut sender, mut receiver) = setup_sender_receiver();
541
542 receiver.timeprovider.set_now(2000000000);
544 let packet = sender
545 .encrypt(PLAINTEXT_TEST_VECTOR)
546 .expect("should encrypt");
547
548 let _ = receiver
550 .decrypt(&packet)
551 .expect_err("should detect old message");
552 }
553
554 #[test]
555 fn test_message_from_future() {
556 let (mut sender, mut receiver) = setup_sender_receiver();
557
558 sender.timeprovider.set_now(2000000000);
560 let packet = sender
561 .encrypt(PLAINTEXT_TEST_VECTOR)
562 .expect("should encrypt");
563
564 let _ = receiver
566 .decrypt(&packet)
567 .expect_err("should detect future message");
568 }
569
570 #[test]
571 fn test_send_rekey() {
572 let (mut sender, mut receiver) = setup_sender_receiver();
573
574 sender.timeprovider.set_now(0);
576 let packet1 = sender
577 .encrypt(PLAINTEXT_TEST_VECTOR)
578 .expect("should encrypt");
579
580 sender.timeprovider.set_now(REKEY_INTERVAL);
582 let packet2 = sender
583 .encrypt(PLAINTEXT_TEST_VECTOR)
584 .expect("should encrypt after rekey");
585
586 receiver
588 .decrypt(&packet1)
589 .expect("should decrypt first message");
590 receiver.timeprovider.set_now(REKEY_INTERVAL);
592 receiver
593 .decrypt(&packet2)
594 .expect("should decrypt second message");
595
596 assert_eq!(sender.send_rekey_counter, 2);
598 }
599
600 #[test]
601 fn test_receive_rekey_catchup() {
602 let (mut sender, mut receiver) = setup_sender_receiver();
603
604 receiver.timeprovider.set_now(REKEY_INTERVAL);
606 receiver.encrypt(b"msg1").expect("should encrypt"); receiver.timeprovider.set_now(REKEY_INTERVAL * 2);
608 receiver.encrypt(b"msg2").expect("should encrypt"); receiver.timeprovider.set_now(REKEY_INTERVAL * 3);
610 let packet = receiver
611 .encrypt(PLAINTEXT_TEST_VECTOR)
612 .expect("should encrypt"); sender.timeprovider.set_now(REKEY_INTERVAL * 3);
616
617 let decrypted = sender
619 .decrypt(&packet)
620 .expect("should decrypt after catchup");
621 assert_eq!(decrypted, PLAINTEXT_TEST_VECTOR);
622 }
623
624 #[test]
625 fn test_desynchronization() {
626 let (mut sender, mut receiver) = setup_sender_receiver();
627 sender.set_send_rekey_counter(MAX_REKEY_GAP + 2);
628
629 let packet = sender
630 .encrypt(PLAINTEXT_TEST_VECTOR)
631 .expect("should encrypt");
632
633 let result = receiver.decrypt(&packet);
634 assert!(result.is_err());
635 assert!(matches!(
636 result.err(),
637 Some(NoiseProtocolError::Desynchronized)
638 ));
639 }
640
641 #[test]
642 fn test_device_group_out_of_order_answers() {
643 let (sender, receiver) = setup_sender_receiver();
645 let mut single_device = sender;
646 let mut device_group_device_1 = receiver.clone();
647 let mut device_group_device_2 = receiver;
648
649 let packet1 = single_device
651 .encrypt(PLAINTEXT_TEST_VECTOR)
652 .expect("should encrypt request1");
653 let packet2 = single_device
654 .encrypt(PLAINTEXT_TEST_VECTOR)
655 .expect("should encrypt request2");
656
657 let _ = device_group_device_2
659 .decrypt(&packet2)
660 .expect("should decrypt request2");
661 let response2 = device_group_device_2
662 .encrypt(PLAINTEXT_TEST_VECTOR)
663 .expect("should encrypt response2");
664 let decrypted_response2 = single_device
665 .decrypt(&response2)
666 .expect("should decrypt response2");
667 assert_eq!(decrypted_response2, PLAINTEXT_TEST_VECTOR);
668
669 let _ = device_group_device_1
670 .decrypt(&packet1)
671 .expect("should decrypt request1");
672 let response1 = device_group_device_1
673 .encrypt(PLAINTEXT_TEST_VECTOR)
674 .expect("should encrypt response1");
675 let decrypted_response1 = single_device
676 .decrypt(&response1)
677 .expect("should decrypt response1");
678 assert_eq!(decrypted_response1, PLAINTEXT_TEST_VECTOR);
679 }
680}