1#[cfg(not(feature = "std"))]
10use alloc::{collections::BTreeMap, vec::Vec};
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14use chacha20poly1305::{
15 aead::{Aead, KeyInit, OsRng},
16 ChaCha20Poly1305, Nonce,
17};
18use rand_core::RngCore;
19
20use super::peer_key::{KeyExchangeMessage, PeerIdentityKey, PeerSessionKey};
21use super::EncryptionError;
22use crate::NodeId;
23
24pub const DEFAULT_SESSION_TIMEOUT_MS: u64 = 30 * 60 * 1000;
26
27pub const DEFAULT_MAX_SESSIONS: usize = 16;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SessionState {
33 AwaitingPeerKey,
35 Established,
37 Closed,
39}
40
41#[derive(Debug)]
43pub struct PeerSession {
44 pub peer_node_id: NodeId,
46 pub state: SessionState,
48 session_key: Option<PeerSessionKey>,
50 peer_public_key: Option<[u8; 32]>,
52 pub created_at_ms: u64,
54 pub last_activity_ms: u64,
56 pub outbound_counter: u64,
58 pub inbound_counter: u64,
60}
61
62impl PeerSession {
63 pub fn new_initiator(peer_node_id: NodeId, now_ms: u64) -> Self {
65 Self {
66 peer_node_id,
67 state: SessionState::AwaitingPeerKey,
68 session_key: None,
69 peer_public_key: None,
70 created_at_ms: now_ms,
71 last_activity_ms: now_ms,
72 outbound_counter: 0,
73 inbound_counter: 0,
74 }
75 }
76
77 pub fn new_responder(
79 peer_node_id: NodeId,
80 session_key: PeerSessionKey,
81 peer_public_key: [u8; 32],
82 now_ms: u64,
83 ) -> Self {
84 Self {
85 peer_node_id,
86 state: SessionState::Established,
87 session_key: Some(session_key),
88 peer_public_key: Some(peer_public_key),
89 created_at_ms: now_ms,
90 last_activity_ms: now_ms,
91 outbound_counter: 0,
92 inbound_counter: 0,
93 }
94 }
95
96 pub fn complete_handshake(
98 &mut self,
99 session_key: PeerSessionKey,
100 peer_public_key: [u8; 32],
101 now_ms: u64,
102 ) {
103 self.state = SessionState::Established;
104 self.session_key = Some(session_key);
105 self.peer_public_key = Some(peer_public_key);
106 self.last_activity_ms = now_ms;
107 }
108
109 pub fn is_established(&self) -> bool {
111 self.state == SessionState::Established && self.session_key.is_some()
112 }
113
114 pub fn is_expired(&self, now_ms: u64, timeout_ms: u64) -> bool {
116 now_ms.saturating_sub(self.last_activity_ms) > timeout_ms
117 }
118
119 pub fn next_outbound_counter(&mut self) -> u64 {
121 let counter = self.outbound_counter;
122 self.outbound_counter = self.outbound_counter.wrapping_add(1);
123 counter
124 }
125
126 pub fn validate_inbound_counter(&mut self, counter: u64) -> bool {
131 if counter >= self.inbound_counter {
134 self.inbound_counter = counter.saturating_add(1);
135 true
136 } else {
137 false
138 }
139 }
140
141 pub fn session_key(&self) -> Option<&PeerSessionKey> {
143 self.session_key.as_ref()
144 }
145
146 pub fn touch(&mut self, now_ms: u64) {
148 self.last_activity_ms = now_ms;
149 }
150
151 pub fn close(&mut self) {
153 self.state = SessionState::Closed;
154 }
155}
156
157#[derive(Debug, Clone)]
159pub struct PeerEncryptedMessage {
160 pub recipient_node_id: NodeId,
162 pub sender_node_id: NodeId,
164 pub counter: u64,
166 pub nonce: [u8; 12],
168 pub ciphertext: Vec<u8>,
170}
171
172impl PeerEncryptedMessage {
173 pub const OVERHEAD: usize = 4 + 4 + 8 + 12 + 16;
175
176 pub fn encode(&self) -> Vec<u8> {
180 let mut buf = Vec::with_capacity(28 + self.ciphertext.len());
181 buf.extend_from_slice(&self.recipient_node_id.as_u32().to_le_bytes());
182 buf.extend_from_slice(&self.sender_node_id.as_u32().to_le_bytes());
183 buf.extend_from_slice(&self.counter.to_le_bytes());
184 buf.extend_from_slice(&self.nonce);
185 buf.extend_from_slice(&self.ciphertext);
186 buf
187 }
188
189 pub fn decode(data: &[u8]) -> Option<Self> {
191 if data.len() < 44 {
193 return None;
194 }
195
196 let recipient_node_id =
197 NodeId::new(u32::from_le_bytes([data[0], data[1], data[2], data[3]]));
198 let sender_node_id = NodeId::new(u32::from_le_bytes([data[4], data[5], data[6], data[7]]));
199 let counter = u64::from_le_bytes([
200 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
201 ]);
202
203 let mut nonce = [0u8; 12];
204 nonce.copy_from_slice(&data[16..28]);
205
206 let ciphertext = data[28..].to_vec();
207
208 Some(Self {
209 recipient_node_id,
210 sender_node_id,
211 counter,
212 nonce,
213 ciphertext,
214 })
215 }
216}
217
218pub struct PeerSessionManager {
220 our_node_id: NodeId,
222 identity_key: PeerIdentityKey,
224 #[cfg(feature = "std")]
226 sessions: HashMap<NodeId, PeerSession>,
227 #[cfg(not(feature = "std"))]
228 sessions: BTreeMap<NodeId, PeerSession>,
229 max_sessions: usize,
231 session_timeout_ms: u64,
233}
234
235impl PeerSessionManager {
236 pub fn new(our_node_id: NodeId) -> Self {
238 Self {
239 our_node_id,
240 identity_key: PeerIdentityKey::generate(),
241 #[cfg(feature = "std")]
242 sessions: HashMap::new(),
243 #[cfg(not(feature = "std"))]
244 sessions: BTreeMap::new(),
245 max_sessions: DEFAULT_MAX_SESSIONS,
246 session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
247 }
248 }
249
250 pub fn with_identity_key(our_node_id: NodeId, identity_key: PeerIdentityKey) -> Self {
252 Self {
253 our_node_id,
254 identity_key,
255 #[cfg(feature = "std")]
256 sessions: HashMap::new(),
257 #[cfg(not(feature = "std"))]
258 sessions: BTreeMap::new(),
259 max_sessions: DEFAULT_MAX_SESSIONS,
260 session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
261 }
262 }
263
264 pub fn with_max_sessions(mut self, max: usize) -> Self {
266 self.max_sessions = max;
267 self
268 }
269
270 pub fn with_session_timeout(mut self, timeout_ms: u64) -> Self {
272 self.session_timeout_ms = timeout_ms;
273 self
274 }
275
276 pub fn our_public_key(&self) -> [u8; 32] {
278 self.identity_key.public_key_bytes()
279 }
280
281 pub fn our_node_id(&self) -> NodeId {
283 self.our_node_id
284 }
285
286 pub fn initiate_session(&mut self, peer_node_id: NodeId, now_ms: u64) -> KeyExchangeMessage {
290 let session = PeerSession::new_initiator(peer_node_id, now_ms);
292 self.sessions.insert(peer_node_id, session);
293
294 self.enforce_session_limit(now_ms);
296
297 KeyExchangeMessage::new(
299 self.our_node_id,
300 self.identity_key.public_key_bytes(),
301 false,
302 )
303 }
304
305 pub fn handle_key_exchange(
311 &mut self,
312 msg: &KeyExchangeMessage,
313 now_ms: u64,
314 ) -> Option<(KeyExchangeMessage, bool)> {
315 let peer_node_id = msg.sender_node_id;
316 let peer_public = x25519_dalek::PublicKey::from(msg.public_key);
317
318 let shared_secret = self.identity_key.exchange(&peer_public);
320 let session_key = shared_secret.derive_session_key(self.our_node_id, peer_node_id);
321
322 if let Some(session) = self.sessions.get_mut(&peer_node_id) {
324 if session.state == SessionState::AwaitingPeerKey {
325 session.complete_handshake(session_key, msg.public_key, now_ms);
327 return Some((
328 KeyExchangeMessage::new(
329 self.our_node_id,
330 self.identity_key.public_key_bytes(),
331 false,
332 ),
333 true, ));
335 }
336 return None;
338 }
339
340 if self.sessions.len() >= self.max_sessions {
342 self.cleanup_expired(now_ms);
344 if self.sessions.len() >= self.max_sessions {
345 log::warn!(
346 "Cannot accept E2EE session from {:?}: max sessions reached",
347 peer_node_id
348 );
349 return None;
350 }
351 }
352
353 let session = PeerSession::new_responder(peer_node_id, session_key, msg.public_key, now_ms);
354 self.sessions.insert(peer_node_id, session);
355
356 Some((
358 KeyExchangeMessage::new(
359 self.our_node_id,
360 self.identity_key.public_key_bytes(),
361 false,
362 ),
363 true, ))
365 }
366
367 pub fn has_session(&self, peer_node_id: NodeId) -> bool {
369 self.sessions
370 .get(&peer_node_id)
371 .is_some_and(|s| s.is_established())
372 }
373
374 pub fn session_state(&self, peer_node_id: NodeId) -> Option<SessionState> {
376 self.sessions.get(&peer_node_id).map(|s| s.state)
377 }
378
379 pub fn encrypt_for_peer(
383 &mut self,
384 peer_node_id: NodeId,
385 plaintext: &[u8],
386 now_ms: u64,
387 ) -> Result<PeerEncryptedMessage, EncryptionError> {
388 let session = self
389 .sessions
390 .get_mut(&peer_node_id)
391 .ok_or(EncryptionError::EncryptionFailed)?;
392
393 if !session.is_established() {
394 return Err(EncryptionError::EncryptionFailed);
395 }
396
397 let session_key_bytes = *session
399 .session_key()
400 .ok_or(EncryptionError::EncryptionFailed)?
401 .as_bytes();
402 let counter = session.next_outbound_counter();
403 session.touch(now_ms);
404
405 let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
407 .map_err(|_| EncryptionError::EncryptionFailed)?;
408
409 let mut nonce_bytes = [0u8; 12];
411 OsRng.fill_bytes(&mut nonce_bytes);
412 let nonce = Nonce::from_slice(&nonce_bytes);
413
414 let ciphertext = cipher
416 .encrypt(nonce, plaintext)
417 .map_err(|_| EncryptionError::EncryptionFailed)?;
418
419 Ok(PeerEncryptedMessage {
420 recipient_node_id: peer_node_id,
421 sender_node_id: self.our_node_id,
422 counter,
423 nonce: nonce_bytes,
424 ciphertext,
425 })
426 }
427
428 pub fn decrypt_from_peer(
432 &mut self,
433 msg: &PeerEncryptedMessage,
434 now_ms: u64,
435 ) -> Result<Vec<u8>, EncryptionError> {
436 if msg.recipient_node_id != self.our_node_id {
438 return Err(EncryptionError::DecryptionFailed);
439 }
440
441 let session = self
442 .sessions
443 .get_mut(&msg.sender_node_id)
444 .ok_or(EncryptionError::DecryptionFailed)?;
445
446 if !session.is_established() {
447 return Err(EncryptionError::DecryptionFailed);
448 }
449
450 if !session.validate_inbound_counter(msg.counter) {
452 log::warn!(
453 "Replay attack detected from {:?}: counter {} < next expected {}",
454 msg.sender_node_id,
455 msg.counter,
456 session.inbound_counter
457 );
458 return Err(EncryptionError::DecryptionFailed);
459 }
460
461 let session_key_bytes = *session
463 .session_key()
464 .ok_or(EncryptionError::DecryptionFailed)?
465 .as_bytes();
466 session.touch(now_ms);
467
468 let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
470 .map_err(|_| EncryptionError::DecryptionFailed)?;
471
472 let nonce = Nonce::from_slice(&msg.nonce);
473
474 cipher
476 .decrypt(nonce, msg.ciphertext.as_ref())
477 .map_err(|_| EncryptionError::DecryptionFailed)
478 }
479
480 pub fn close_session(&mut self, peer_node_id: NodeId) {
482 if let Some(session) = self.sessions.get_mut(&peer_node_id) {
483 session.close();
484 }
485 }
486
487 pub fn remove_session(&mut self, peer_node_id: NodeId) -> Option<PeerSession> {
489 self.sessions.remove(&peer_node_id)
490 }
491
492 pub fn cleanup_expired(&mut self, now_ms: u64) -> Vec<NodeId> {
494 let timeout = self.session_timeout_ms;
495 let expired: Vec<NodeId> = self
496 .sessions
497 .iter()
498 .filter(|(_, s)| s.is_expired(now_ms, timeout))
499 .map(|(id, _)| *id)
500 .collect();
501
502 for id in &expired {
503 self.sessions.remove(id);
504 }
505
506 expired
507 }
508
509 pub fn session_count(&self) -> usize {
511 self.sessions.len()
512 }
513
514 pub fn established_count(&self) -> usize {
516 self.sessions
517 .values()
518 .filter(|s| s.is_established())
519 .count()
520 }
521
522 fn enforce_session_limit(&mut self, now_ms: u64) {
524 self.cleanup_expired(now_ms);
526
527 while self.sessions.len() > self.max_sessions {
529 let oldest = self
530 .sessions
531 .iter()
532 .filter(|(_, s)| s.state == SessionState::Closed)
533 .min_by_key(|(_, s)| s.last_activity_ms)
534 .map(|(id, _)| *id);
535
536 if let Some(id) = oldest {
537 self.sessions.remove(&id);
538 } else {
539 let oldest = self
541 .sessions
542 .iter()
543 .filter(|(_, s)| !s.is_established())
544 .min_by_key(|(_, s)| s.last_activity_ms)
545 .map(|(id, _)| *id);
546
547 if let Some(id) = oldest {
548 self.sessions.remove(&id);
549 } else {
550 break; }
552 }
553 }
554 }
555}
556
557impl core::fmt::Debug for PeerSessionManager {
558 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
559 f.debug_struct("PeerSessionManager")
560 .field("our_node_id", &self.our_node_id)
561 .field("session_count", &self.sessions.len())
562 .field("max_sessions", &self.max_sessions)
563 .finish()
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn test_session_manager_creation() {
573 let manager = PeerSessionManager::new(NodeId::new(0x11111111));
574 assert_eq!(manager.our_node_id().as_u32(), 0x11111111);
575 assert_eq!(manager.session_count(), 0);
576 }
577
578 #[test]
579 fn test_initiate_session() {
580 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
581 let msg = manager.initiate_session(NodeId::new(0x22222222), 1000);
582
583 assert_eq!(msg.sender_node_id.as_u32(), 0x11111111);
584 assert_eq!(manager.session_count(), 1);
585 assert_eq!(
586 manager.session_state(NodeId::new(0x22222222)),
587 Some(SessionState::AwaitingPeerKey)
588 );
589 }
590
591 #[test]
592 fn test_full_key_exchange() {
593 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
594 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
595
596 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
598
599 let (bob_response, bob_established) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
601 assert!(bob_established);
602 assert!(bob.has_session(NodeId::new(0x11111111)));
603
604 let (_, alice_established) = alice.handle_key_exchange(&bob_response, 1000).unwrap();
606 assert!(alice_established);
607 assert!(alice.has_session(NodeId::new(0x22222222)));
608 }
609
610 #[test]
611 fn test_encrypt_decrypt_roundtrip() {
612 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
613 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
614
615 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
617 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
618 alice.handle_key_exchange(&bob_response, 1000).unwrap();
619
620 let plaintext = b"Hello, Bob!";
622 let encrypted = alice
623 .encrypt_for_peer(NodeId::new(0x22222222), plaintext, 2000)
624 .unwrap();
625
626 let decrypted = bob.decrypt_from_peer(&encrypted, 2000).unwrap();
628 assert_eq!(decrypted, plaintext);
629 }
630
631 #[test]
632 fn test_bidirectional_communication() {
633 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
634 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
635
636 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
638 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
639 alice.handle_key_exchange(&bob_response, 1000).unwrap();
640
641 let msg1 = alice
643 .encrypt_for_peer(NodeId::new(0x22222222), b"From Alice", 2000)
644 .unwrap();
645 let dec1 = bob.decrypt_from_peer(&msg1, 2000).unwrap();
646 assert_eq!(dec1, b"From Alice");
647
648 let msg2 = bob
650 .encrypt_for_peer(NodeId::new(0x11111111), b"From Bob", 2000)
651 .unwrap();
652 let dec2 = alice.decrypt_from_peer(&msg2, 2000).unwrap();
653 assert_eq!(dec2, b"From Bob");
654 }
655
656 #[test]
657 fn test_replay_protection() {
658 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
659 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
660
661 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
663 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
664 alice.handle_key_exchange(&bob_response, 1000).unwrap();
665
666 let encrypted = alice
668 .encrypt_for_peer(NodeId::new(0x22222222), b"Message", 2000)
669 .unwrap();
670
671 let result1 = bob.decrypt_from_peer(&encrypted, 2000);
673 assert!(result1.is_ok());
674
675 let result2 = bob.decrypt_from_peer(&encrypted, 2000);
677 assert!(result2.is_err());
678 }
679
680 #[test]
681 fn test_wrong_recipient_rejected() {
682 let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
683 let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
684 let mut charlie = PeerSessionManager::new(NodeId::new(0x33333333));
685
686 let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
688 let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
689 alice.handle_key_exchange(&bob_response, 1000).unwrap();
690
691 let encrypted = alice
693 .encrypt_for_peer(NodeId::new(0x22222222), b"For Bob", 2000)
694 .unwrap();
695
696 let result = charlie.decrypt_from_peer(&encrypted, 2000);
698 assert!(result.is_err());
699 }
700
701 #[test]
702 fn test_session_expiry() {
703 let mut manager =
704 PeerSessionManager::new(NodeId::new(0x11111111)).with_session_timeout(10_000);
705
706 manager.initiate_session(NodeId::new(0x22222222), 1000);
708
709 let expired = manager.cleanup_expired(5000);
711 assert!(expired.is_empty());
712 assert_eq!(manager.session_count(), 1);
713
714 let expired = manager.cleanup_expired(20000);
716 assert_eq!(expired.len(), 1);
717 assert_eq!(manager.session_count(), 0);
718 }
719
720 #[test]
721 fn test_max_sessions_limit() {
722 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111)).with_max_sessions(2);
723
724 manager.initiate_session(NodeId::new(0x22222222), 1000);
725 manager.initiate_session(NodeId::new(0x33333333), 2000);
726 manager.initiate_session(NodeId::new(0x44444444), 3000);
727
728 assert!(manager.session_count() <= 2);
730 }
731
732 #[test]
733 fn test_peer_encrypted_message_encode_decode() {
734 let ciphertext = vec![
736 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
737 0x99, 0x00,
738 ];
739 let msg = PeerEncryptedMessage {
740 recipient_node_id: NodeId::new(0x22222222),
741 sender_node_id: NodeId::new(0x11111111),
742 counter: 42,
743 nonce: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
744 ciphertext: ciphertext.clone(),
745 };
746
747 let encoded = msg.encode();
748 let decoded = PeerEncryptedMessage::decode(&encoded).unwrap();
749
750 assert_eq!(decoded.recipient_node_id.as_u32(), 0x22222222);
751 assert_eq!(decoded.sender_node_id.as_u32(), 0x11111111);
752 assert_eq!(decoded.counter, 42);
753 assert_eq!(decoded.nonce, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
754 assert_eq!(decoded.ciphertext, ciphertext);
755 }
756
757 #[test]
758 fn test_close_session() {
759 let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
760 manager.initiate_session(NodeId::new(0x22222222), 1000);
761
762 manager.close_session(NodeId::new(0x22222222));
763 assert_eq!(
764 manager.session_state(NodeId::new(0x22222222)),
765 Some(SessionState::Closed)
766 );
767 }
768}