1#[cfg(not(feature = "std"))]
25use alloc::{collections::BTreeMap, vec::Vec};
26#[cfg(feature = "std")]
27use std::collections::HashMap;
28
29use chacha20poly1305::{
30 aead::{Aead, KeyInit, OsRng},
31 ChaCha20Poly1305, Nonce,
32};
33use rand_core::RngCore;
34
35use super::peer_key::{KeyExchangeMessage, PeerIdentityKey, PeerSessionKey};
36use super::EncryptionError;
37use crate::NodeId;
38
39pub const DEFAULT_SESSION_TIMEOUT_MS: u64 = 30 * 60 * 1000;
41
42pub const DEFAULT_MAX_SESSIONS: usize = 16;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum SessionState {
48 AwaitingPeerKey,
50 Established,
52 Closed,
54}
55
56#[derive(Debug)]
58pub struct PeerSession {
59 pub peer_node_id: NodeId,
61 pub state: SessionState,
63 session_key: Option<PeerSessionKey>,
65 peer_public_key: Option<[u8; 32]>,
67 pub created_at_ms: u64,
69 pub last_activity_ms: u64,
71 pub outbound_counter: u64,
73 pub inbound_counter: u64,
75}
76
77impl PeerSession {
78 pub fn new_initiator(peer_node_id: NodeId, now_ms: u64) -> Self {
80 Self {
81 peer_node_id,
82 state: SessionState::AwaitingPeerKey,
83 session_key: None,
84 peer_public_key: None,
85 created_at_ms: now_ms,
86 last_activity_ms: now_ms,
87 outbound_counter: 0,
88 inbound_counter: 0,
89 }
90 }
91
92 pub fn new_responder(
94 peer_node_id: NodeId,
95 session_key: PeerSessionKey,
96 peer_public_key: [u8; 32],
97 now_ms: u64,
98 ) -> Self {
99 Self {
100 peer_node_id,
101 state: SessionState::Established,
102 session_key: Some(session_key),
103 peer_public_key: Some(peer_public_key),
104 created_at_ms: now_ms,
105 last_activity_ms: now_ms,
106 outbound_counter: 0,
107 inbound_counter: 0,
108 }
109 }
110
111 pub fn complete_handshake(
113 &mut self,
114 session_key: PeerSessionKey,
115 peer_public_key: [u8; 32],
116 now_ms: u64,
117 ) {
118 self.state = SessionState::Established;
119 self.session_key = Some(session_key);
120 self.peer_public_key = Some(peer_public_key);
121 self.last_activity_ms = now_ms;
122 }
123
124 pub fn is_established(&self) -> bool {
126 self.state == SessionState::Established && self.session_key.is_some()
127 }
128
129 pub fn is_expired(&self, now_ms: u64, timeout_ms: u64) -> bool {
131 now_ms.saturating_sub(self.last_activity_ms) > timeout_ms
132 }
133
134 pub fn next_outbound_counter(&mut self) -> u64 {
136 let counter = self.outbound_counter;
137 self.outbound_counter = self.outbound_counter.wrapping_add(1);
138 counter
139 }
140
141 pub fn validate_inbound_counter(&mut self, counter: u64) -> bool {
146 if counter >= self.inbound_counter {
149 self.inbound_counter = counter.saturating_add(1);
150 true
151 } else {
152 false
153 }
154 }
155
156 pub fn session_key(&self) -> Option<&PeerSessionKey> {
158 self.session_key.as_ref()
159 }
160
161 pub fn touch(&mut self, now_ms: u64) {
163 self.last_activity_ms = now_ms;
164 }
165
166 pub fn close(&mut self) {
168 self.state = SessionState::Closed;
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct PeerEncryptedMessage {
175 pub recipient_node_id: NodeId,
177 pub sender_node_id: NodeId,
179 pub counter: u64,
181 pub nonce: [u8; 12],
183 pub ciphertext: Vec<u8>,
185}
186
187impl PeerEncryptedMessage {
188 pub const OVERHEAD: usize = 4 + 4 + 8 + 12 + 16;
190
191 pub fn encode(&self) -> Vec<u8> {
195 let mut buf = Vec::with_capacity(28 + self.ciphertext.len());
196 buf.extend_from_slice(&self.recipient_node_id.as_u32().to_le_bytes());
197 buf.extend_from_slice(&self.sender_node_id.as_u32().to_le_bytes());
198 buf.extend_from_slice(&self.counter.to_le_bytes());
199 buf.extend_from_slice(&self.nonce);
200 buf.extend_from_slice(&self.ciphertext);
201 buf
202 }
203
204 pub fn decode(data: &[u8]) -> Option<Self> {
206 if data.len() < 44 {
208 return None;
209 }
210
211 let recipient_node_id =
212 NodeId::new(u32::from_le_bytes([data[0], data[1], data[2], data[3]]));
213 let sender_node_id = NodeId::new(u32::from_le_bytes([data[4], data[5], data[6], data[7]]));
214 let counter = u64::from_le_bytes([
215 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
216 ]);
217
218 let mut nonce = [0u8; 12];
219 nonce.copy_from_slice(&data[16..28]);
220
221 let ciphertext = data[28..].to_vec();
222
223 Some(Self {
224 recipient_node_id,
225 sender_node_id,
226 counter,
227 nonce,
228 ciphertext,
229 })
230 }
231}
232
233pub struct PeerSessionManager {
235 our_node_id: NodeId,
237 identity_key: PeerIdentityKey,
239 #[cfg(feature = "std")]
241 sessions: HashMap<NodeId, PeerSession>,
242 #[cfg(not(feature = "std"))]
243 sessions: BTreeMap<NodeId, PeerSession>,
244 max_sessions: usize,
246 session_timeout_ms: u64,
248}
249
250impl PeerSessionManager {
251 pub fn new(our_node_id: NodeId) -> Self {
253 Self {
254 our_node_id,
255 identity_key: PeerIdentityKey::generate(),
256 #[cfg(feature = "std")]
257 sessions: HashMap::new(),
258 #[cfg(not(feature = "std"))]
259 sessions: BTreeMap::new(),
260 max_sessions: DEFAULT_MAX_SESSIONS,
261 session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
262 }
263 }
264
265 pub fn with_identity_key(our_node_id: NodeId, identity_key: PeerIdentityKey) -> Self {
267 Self {
268 our_node_id,
269 identity_key,
270 #[cfg(feature = "std")]
271 sessions: HashMap::new(),
272 #[cfg(not(feature = "std"))]
273 sessions: BTreeMap::new(),
274 max_sessions: DEFAULT_MAX_SESSIONS,
275 session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
276 }
277 }
278
279 pub fn with_max_sessions(mut self, max: usize) -> Self {
281 self.max_sessions = max;
282 self
283 }
284
285 pub fn with_session_timeout(mut self, timeout_ms: u64) -> Self {
287 self.session_timeout_ms = timeout_ms;
288 self
289 }
290
291 pub fn our_public_key(&self) -> [u8; 32] {
293 self.identity_key.public_key_bytes()
294 }
295
296 pub fn our_node_id(&self) -> NodeId {
298 self.our_node_id
299 }
300
301 pub fn initiate_session(&mut self, peer_node_id: NodeId, now_ms: u64) -> KeyExchangeMessage {
305 let session = PeerSession::new_initiator(peer_node_id, now_ms);
307 self.sessions.insert(peer_node_id, session);
308
309 self.enforce_session_limit(now_ms);
311
312 KeyExchangeMessage::new(
314 self.our_node_id,
315 self.identity_key.public_key_bytes(),
316 false,
317 )
318 }
319
320 pub fn handle_key_exchange(
326 &mut self,
327 msg: &KeyExchangeMessage,
328 now_ms: u64,
329 ) -> Option<(KeyExchangeMessage, bool)> {
330 let peer_node_id = msg.sender_node_id;
331 let peer_public = x25519_dalek::PublicKey::from(msg.public_key);
332
333 let shared_secret = self.identity_key.exchange(&peer_public);
335 let session_key = shared_secret.derive_session_key(self.our_node_id, peer_node_id);
336
337 if let Some(session) = self.sessions.get_mut(&peer_node_id) {
339 if session.state == SessionState::AwaitingPeerKey {
340 session.complete_handshake(session_key, msg.public_key, now_ms);
342 return Some((
343 KeyExchangeMessage::new(
344 self.our_node_id,
345 self.identity_key.public_key_bytes(),
346 false,
347 ),
348 true, ));
350 }
351 return None;
353 }
354
355 if self.sessions.len() >= self.max_sessions {
357 self.cleanup_expired(now_ms);
359 if self.sessions.len() >= self.max_sessions {
360 log::warn!(
361 "Cannot accept E2EE session from {:?}: max sessions reached",
362 peer_node_id
363 );
364 return None;
365 }
366 }
367
368 let session = PeerSession::new_responder(peer_node_id, session_key, msg.public_key, now_ms);
369 self.sessions.insert(peer_node_id, session);
370
371 Some((
373 KeyExchangeMessage::new(
374 self.our_node_id,
375 self.identity_key.public_key_bytes(),
376 false,
377 ),
378 true, ))
380 }
381
382 pub fn has_session(&self, peer_node_id: NodeId) -> bool {
384 self.sessions
385 .get(&peer_node_id)
386 .is_some_and(|s| s.is_established())
387 }
388
389 pub fn session_state(&self, peer_node_id: NodeId) -> Option<SessionState> {
391 self.sessions.get(&peer_node_id).map(|s| s.state)
392 }
393
394 pub fn encrypt_for_peer(
398 &mut self,
399 peer_node_id: NodeId,
400 plaintext: &[u8],
401 now_ms: u64,
402 ) -> Result<PeerEncryptedMessage, EncryptionError> {
403 let session = self
404 .sessions
405 .get_mut(&peer_node_id)
406 .ok_or(EncryptionError::EncryptionFailed)?;
407
408 if !session.is_established() {
409 return Err(EncryptionError::EncryptionFailed);
410 }
411
412 let session_key_bytes = *session
414 .session_key()
415 .ok_or(EncryptionError::EncryptionFailed)?
416 .as_bytes();
417 let counter = session.next_outbound_counter();
418 session.touch(now_ms);
419
420 let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
422 .map_err(|_| EncryptionError::EncryptionFailed)?;
423
424 let mut nonce_bytes = [0u8; 12];
426 OsRng.fill_bytes(&mut nonce_bytes);
427 let nonce = Nonce::from_slice(&nonce_bytes);
428
429 let ciphertext = cipher
431 .encrypt(nonce, plaintext)
432 .map_err(|_| EncryptionError::EncryptionFailed)?;
433
434 Ok(PeerEncryptedMessage {
435 recipient_node_id: peer_node_id,
436 sender_node_id: self.our_node_id,
437 counter,
438 nonce: nonce_bytes,
439 ciphertext,
440 })
441 }
442
443 pub fn decrypt_from_peer(
447 &mut self,
448 msg: &PeerEncryptedMessage,
449 now_ms: u64,
450 ) -> Result<Vec<u8>, EncryptionError> {
451 if msg.recipient_node_id != self.our_node_id {
453 return Err(EncryptionError::DecryptionFailed);
454 }
455
456 let session = self
457 .sessions
458 .get_mut(&msg.sender_node_id)
459 .ok_or(EncryptionError::DecryptionFailed)?;
460
461 if !session.is_established() {
462 return Err(EncryptionError::DecryptionFailed);
463 }
464
465 if !session.validate_inbound_counter(msg.counter) {
467 log::warn!(
468 "Replay attack detected from {:?}: counter {} < next expected {}",
469 msg.sender_node_id,
470 msg.counter,
471 session.inbound_counter
472 );
473 return Err(EncryptionError::DecryptionFailed);
474 }
475
476 let session_key_bytes = *session
478 .session_key()
479 .ok_or(EncryptionError::DecryptionFailed)?
480 .as_bytes();
481 session.touch(now_ms);
482
483 let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
485 .map_err(|_| EncryptionError::DecryptionFailed)?;
486
487 let nonce = Nonce::from_slice(&msg.nonce);
488
489 cipher
491 .decrypt(nonce, msg.ciphertext.as_ref())
492 .map_err(|_| EncryptionError::DecryptionFailed)
493 }
494
495 pub fn close_session(&mut self, peer_node_id: NodeId) {
497 if let Some(session) = self.sessions.get_mut(&peer_node_id) {
498 session.close();
499 }
500 }
501
502 pub fn remove_session(&mut self, peer_node_id: NodeId) -> Option<PeerSession> {
504 self.sessions.remove(&peer_node_id)
505 }
506
507 pub fn cleanup_expired(&mut self, now_ms: u64) -> Vec<NodeId> {
509 let timeout = self.session_timeout_ms;
510 let expired: Vec<NodeId> = self
511 .sessions
512 .iter()
513 .filter(|(_, s)| s.is_expired(now_ms, timeout))
514 .map(|(id, _)| *id)
515 .collect();
516
517 for id in &expired {
518 self.sessions.remove(id);
519 }
520
521 expired
522 }
523
524 pub fn session_count(&self) -> usize {
526 self.sessions.len()
527 }
528
529 pub fn established_count(&self) -> usize {
531 self.sessions
532 .values()
533 .filter(|s| s.is_established())
534 .count()
535 }
536
537 fn enforce_session_limit(&mut self, now_ms: u64) {
539 self.cleanup_expired(now_ms);
541
542 while self.sessions.len() > self.max_sessions {
544 let oldest = self
545 .sessions
546 .iter()
547 .filter(|(_, s)| s.state == SessionState::Closed)
548 .min_by_key(|(_, s)| s.last_activity_ms)
549 .map(|(id, _)| *id);
550
551 if let Some(id) = oldest {
552 self.sessions.remove(&id);
553 } else {
554 let oldest = self
556 .sessions
557 .iter()
558 .filter(|(_, s)| !s.is_established())
559 .min_by_key(|(_, s)| s.last_activity_ms)
560 .map(|(id, _)| *id);
561
562 if let Some(id) = oldest {
563 self.sessions.remove(&id);
564 } else {
565 break; }
567 }
568 }
569 }
570}
571
572impl core::fmt::Debug for PeerSessionManager {
573 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
574 f.debug_struct("PeerSessionManager")
575 .field("our_node_id", &self.our_node_id)
576 .field("session_count", &self.sessions.len())
577 .field("max_sessions", &self.max_sessions)
578 .finish()
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_session_manager_creation() {
588 let manager = PeerSessionManager::new(NodeId::new(0x11111111));
589 assert_eq!(manager.our_node_id().as_u32(), 0x11111111);
590 assert_eq!(manager.session_count(), 0);
591 }
592
593 #[test]
594 fn test_initiate_session() {
595 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
596 let msg = manager.initiate_session(NodeId::new(0x22222222), 1000);
597
598 assert_eq!(msg.sender_node_id.as_u32(), 0x11111111);
599 assert_eq!(manager.session_count(), 1);
600 assert_eq!(
601 manager.session_state(NodeId::new(0x22222222)),
602 Some(SessionState::AwaitingPeerKey)
603 );
604 }
605
606 #[test]
607 fn test_full_key_exchange() {
608 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
609 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
610
611 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
613
614 let (bob_response, bob_established) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
616 assert!(bob_established);
617 assert!(bob.has_session(NodeId::new(0x11111111)));
618
619 let (_, alice_established) = alice.handle_key_exchange(&bob_response, 1000).unwrap();
621 assert!(alice_established);
622 assert!(alice.has_session(NodeId::new(0x22222222)));
623 }
624
625 #[test]
626 fn test_encrypt_decrypt_roundtrip() {
627 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
628 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
629
630 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
632 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
633 alice.handle_key_exchange(&bob_response, 1000).unwrap();
634
635 let plaintext = b"Hello, Bob!";
637 let encrypted = alice
638 .encrypt_for_peer(NodeId::new(0x22222222), plaintext, 2000)
639 .unwrap();
640
641 let decrypted = bob.decrypt_from_peer(&encrypted, 2000).unwrap();
643 assert_eq!(decrypted, plaintext);
644 }
645
646 #[test]
647 fn test_bidirectional_communication() {
648 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
649 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
650
651 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
653 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
654 alice.handle_key_exchange(&bob_response, 1000).unwrap();
655
656 let msg1 = alice
658 .encrypt_for_peer(NodeId::new(0x22222222), b"From Alice", 2000)
659 .unwrap();
660 let dec1 = bob.decrypt_from_peer(&msg1, 2000).unwrap();
661 assert_eq!(dec1, b"From Alice");
662
663 let msg2 = bob
665 .encrypt_for_peer(NodeId::new(0x11111111), b"From Bob", 2000)
666 .unwrap();
667 let dec2 = alice.decrypt_from_peer(&msg2, 2000).unwrap();
668 assert_eq!(dec2, b"From Bob");
669 }
670
671 #[test]
672 fn test_replay_protection() {
673 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
674 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
675
676 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
678 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
679 alice.handle_key_exchange(&bob_response, 1000).unwrap();
680
681 let encrypted = alice
683 .encrypt_for_peer(NodeId::new(0x22222222), b"Message", 2000)
684 .unwrap();
685
686 let result1 = bob.decrypt_from_peer(&encrypted, 2000);
688 assert!(result1.is_ok());
689
690 let result2 = bob.decrypt_from_peer(&encrypted, 2000);
692 assert!(result2.is_err());
693 }
694
695 #[test]
696 fn test_wrong_recipient_rejected() {
697 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
698 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
699 let mut charlie = PeerSessionManager::new(NodeId::new(0x33333333));
700
701 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
703 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
704 alice.handle_key_exchange(&bob_response, 1000).unwrap();
705
706 let encrypted = alice
708 .encrypt_for_peer(NodeId::new(0x22222222), b"For Bob", 2000)
709 .unwrap();
710
711 let result = charlie.decrypt_from_peer(&encrypted, 2000);
713 assert!(result.is_err());
714 }
715
716 #[test]
717 fn test_session_expiry() {
718 let mut manager =
719 PeerSessionManager::new(NodeId::new(0x11111111)).with_session_timeout(10_000);
720
721 manager.initiate_session(NodeId::new(0x22222222), 1000);
723
724 let expired = manager.cleanup_expired(5000);
726 assert!(expired.is_empty());
727 assert_eq!(manager.session_count(), 1);
728
729 let expired = manager.cleanup_expired(20000);
731 assert_eq!(expired.len(), 1);
732 assert_eq!(manager.session_count(), 0);
733 }
734
735 #[test]
736 fn test_max_sessions_limit() {
737 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111)).with_max_sessions(2);
738
739 manager.initiate_session(NodeId::new(0x22222222), 1000);
740 manager.initiate_session(NodeId::new(0x33333333), 2000);
741 manager.initiate_session(NodeId::new(0x44444444), 3000);
742
743 assert!(manager.session_count() <= 2);
745 }
746
747 #[test]
748 fn test_peer_encrypted_message_encode_decode() {
749 let ciphertext = vec![
751 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
752 0x99, 0x00,
753 ];
754 let msg = PeerEncryptedMessage {
755 recipient_node_id: NodeId::new(0x22222222),
756 sender_node_id: NodeId::new(0x11111111),
757 counter: 42,
758 nonce: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
759 ciphertext: ciphertext.clone(),
760 };
761
762 let encoded = msg.encode();
763 let decoded = PeerEncryptedMessage::decode(&encoded).unwrap();
764
765 assert_eq!(decoded.recipient_node_id.as_u32(), 0x22222222);
766 assert_eq!(decoded.sender_node_id.as_u32(), 0x11111111);
767 assert_eq!(decoded.counter, 42);
768 assert_eq!(decoded.nonce, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
769 assert_eq!(decoded.ciphertext, ciphertext);
770 }
771
772 #[test]
773 fn test_close_session() {
774 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
775 manager.initiate_session(NodeId::new(0x22222222), 1000);
776
777 manager.close_session(NodeId::new(0x22222222));
778 assert_eq!(
779 manager.session_state(NodeId::new(0x22222222)),
780 Some(SessionState::Closed)
781 );
782 }
783}