#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap, vec::Vec};
#[cfg(feature = "std")]
use std::collections::HashMap;
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
ChaCha20Poly1305, Nonce,
};
use rand_core::RngCore;
use super::peer_key::{KeyExchangeMessage, PeerIdentityKey, PeerSessionKey};
use super::EncryptionError;
use crate::NodeId;
pub const DEFAULT_SESSION_TIMEOUT_MS: u64 = 30 * 60 * 1000;
pub const DEFAULT_MAX_SESSIONS: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
AwaitingPeerKey,
Established,
Closed,
}
#[derive(Debug)]
pub struct PeerSession {
pub peer_node_id: NodeId,
pub state: SessionState,
session_key: Option<PeerSessionKey>,
peer_public_key: Option<[u8; 32]>,
pub created_at_ms: u64,
pub last_activity_ms: u64,
pub outbound_counter: u64,
pub inbound_counter: u64,
}
impl PeerSession {
pub fn new_initiator(peer_node_id: NodeId, now_ms: u64) -> Self {
Self {
peer_node_id,
state: SessionState::AwaitingPeerKey,
session_key: None,
peer_public_key: None,
created_at_ms: now_ms,
last_activity_ms: now_ms,
outbound_counter: 0,
inbound_counter: 0,
}
}
pub fn new_responder(
peer_node_id: NodeId,
session_key: PeerSessionKey,
peer_public_key: [u8; 32],
now_ms: u64,
) -> Self {
Self {
peer_node_id,
state: SessionState::Established,
session_key: Some(session_key),
peer_public_key: Some(peer_public_key),
created_at_ms: now_ms,
last_activity_ms: now_ms,
outbound_counter: 0,
inbound_counter: 0,
}
}
pub fn complete_handshake(
&mut self,
session_key: PeerSessionKey,
peer_public_key: [u8; 32],
now_ms: u64,
) {
self.state = SessionState::Established;
self.session_key = Some(session_key);
self.peer_public_key = Some(peer_public_key);
self.last_activity_ms = now_ms;
}
pub fn is_established(&self) -> bool {
self.state == SessionState::Established && self.session_key.is_some()
}
pub fn is_expired(&self, now_ms: u64, timeout_ms: u64) -> bool {
now_ms.saturating_sub(self.last_activity_ms) > timeout_ms
}
pub fn next_outbound_counter(&mut self) -> u64 {
let counter = self.outbound_counter;
self.outbound_counter = self.outbound_counter.wrapping_add(1);
counter
}
pub fn validate_inbound_counter(&mut self, counter: u64) -> bool {
if counter >= self.inbound_counter {
self.inbound_counter = counter.saturating_add(1);
true
} else {
false
}
}
pub fn session_key(&self) -> Option<&PeerSessionKey> {
self.session_key.as_ref()
}
pub fn touch(&mut self, now_ms: u64) {
self.last_activity_ms = now_ms;
}
pub fn close(&mut self) {
self.state = SessionState::Closed;
}
}
#[derive(Debug, Clone)]
pub struct PeerEncryptedMessage {
pub recipient_node_id: NodeId,
pub sender_node_id: NodeId,
pub counter: u64,
pub nonce: [u8; 12],
pub ciphertext: Vec<u8>,
}
impl PeerEncryptedMessage {
pub const OVERHEAD: usize = 4 + 4 + 8 + 12 + 16;
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(28 + self.ciphertext.len());
buf.extend_from_slice(&self.recipient_node_id.as_u32().to_le_bytes());
buf.extend_from_slice(&self.sender_node_id.as_u32().to_le_bytes());
buf.extend_from_slice(&self.counter.to_le_bytes());
buf.extend_from_slice(&self.nonce);
buf.extend_from_slice(&self.ciphertext);
buf
}
pub fn decode(data: &[u8]) -> Option<Self> {
if data.len() < 44 {
return None;
}
let recipient_node_id =
NodeId::new(u32::from_le_bytes([data[0], data[1], data[2], data[3]]));
let sender_node_id = NodeId::new(u32::from_le_bytes([data[4], data[5], data[6], data[7]]));
let counter = u64::from_le_bytes([
data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
]);
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&data[16..28]);
let ciphertext = data[28..].to_vec();
Some(Self {
recipient_node_id,
sender_node_id,
counter,
nonce,
ciphertext,
})
}
}
pub struct PeerSessionManager {
our_node_id: NodeId,
identity_key: PeerIdentityKey,
#[cfg(feature = "std")]
sessions: HashMap<NodeId, PeerSession>,
#[cfg(not(feature = "std"))]
sessions: BTreeMap<NodeId, PeerSession>,
max_sessions: usize,
session_timeout_ms: u64,
}
impl PeerSessionManager {
pub fn new(our_node_id: NodeId) -> Self {
Self {
our_node_id,
identity_key: PeerIdentityKey::generate(),
#[cfg(feature = "std")]
sessions: HashMap::new(),
#[cfg(not(feature = "std"))]
sessions: BTreeMap::new(),
max_sessions: DEFAULT_MAX_SESSIONS,
session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
}
}
pub fn with_identity_key(our_node_id: NodeId, identity_key: PeerIdentityKey) -> Self {
Self {
our_node_id,
identity_key,
#[cfg(feature = "std")]
sessions: HashMap::new(),
#[cfg(not(feature = "std"))]
sessions: BTreeMap::new(),
max_sessions: DEFAULT_MAX_SESSIONS,
session_timeout_ms: DEFAULT_SESSION_TIMEOUT_MS,
}
}
pub fn with_max_sessions(mut self, max: usize) -> Self {
self.max_sessions = max;
self
}
pub fn with_session_timeout(mut self, timeout_ms: u64) -> Self {
self.session_timeout_ms = timeout_ms;
self
}
pub fn our_public_key(&self) -> [u8; 32] {
self.identity_key.public_key_bytes()
}
pub fn our_node_id(&self) -> NodeId {
self.our_node_id
}
pub fn initiate_session(&mut self, peer_node_id: NodeId, now_ms: u64) -> KeyExchangeMessage {
let session = PeerSession::new_initiator(peer_node_id, now_ms);
self.sessions.insert(peer_node_id, session);
self.enforce_session_limit(now_ms);
KeyExchangeMessage::new(
self.our_node_id,
self.identity_key.public_key_bytes(),
false,
)
}
pub fn handle_key_exchange(
&mut self,
msg: &KeyExchangeMessage,
now_ms: u64,
) -> Option<(KeyExchangeMessage, bool)> {
let peer_node_id = msg.sender_node_id;
let peer_public = x25519_dalek::PublicKey::from(msg.public_key);
let shared_secret = self.identity_key.exchange(&peer_public);
let session_key = shared_secret.derive_session_key(self.our_node_id, peer_node_id);
if let Some(session) = self.sessions.get_mut(&peer_node_id) {
if session.state == SessionState::AwaitingPeerKey {
session.complete_handshake(session_key, msg.public_key, now_ms);
return Some((
KeyExchangeMessage::new(
self.our_node_id,
self.identity_key.public_key_bytes(),
false,
),
true, ));
}
return None;
}
if self.sessions.len() >= self.max_sessions {
self.cleanup_expired(now_ms);
if self.sessions.len() >= self.max_sessions {
log::warn!(
"Cannot accept E2EE session from {:?}: max sessions reached",
peer_node_id
);
return None;
}
}
let session = PeerSession::new_responder(peer_node_id, session_key, msg.public_key, now_ms);
self.sessions.insert(peer_node_id, session);
Some((
KeyExchangeMessage::new(
self.our_node_id,
self.identity_key.public_key_bytes(),
false,
),
true, ))
}
pub fn has_session(&self, peer_node_id: NodeId) -> bool {
self.sessions
.get(&peer_node_id)
.is_some_and(|s| s.is_established())
}
pub fn session_state(&self, peer_node_id: NodeId) -> Option<SessionState> {
self.sessions.get(&peer_node_id).map(|s| s.state)
}
pub fn encrypt_for_peer(
&mut self,
peer_node_id: NodeId,
plaintext: &[u8],
now_ms: u64,
) -> Result<PeerEncryptedMessage, EncryptionError> {
let session = self
.sessions
.get_mut(&peer_node_id)
.ok_or(EncryptionError::EncryptionFailed)?;
if !session.is_established() {
return Err(EncryptionError::EncryptionFailed);
}
let session_key_bytes = *session
.session_key()
.ok_or(EncryptionError::EncryptionFailed)?
.as_bytes();
let counter = session.next_outbound_counter();
session.touch(now_ms);
let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
.map_err(|_| EncryptionError::EncryptionFailed)?;
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|_| EncryptionError::EncryptionFailed)?;
Ok(PeerEncryptedMessage {
recipient_node_id: peer_node_id,
sender_node_id: self.our_node_id,
counter,
nonce: nonce_bytes,
ciphertext,
})
}
pub fn decrypt_from_peer(
&mut self,
msg: &PeerEncryptedMessage,
now_ms: u64,
) -> Result<Vec<u8>, EncryptionError> {
if msg.recipient_node_id != self.our_node_id {
return Err(EncryptionError::DecryptionFailed);
}
let session = self
.sessions
.get_mut(&msg.sender_node_id)
.ok_or(EncryptionError::DecryptionFailed)?;
if !session.is_established() {
return Err(EncryptionError::DecryptionFailed);
}
if !session.validate_inbound_counter(msg.counter) {
log::warn!(
"Replay attack detected from {:?}: counter {} < next expected {}",
msg.sender_node_id,
msg.counter,
session.inbound_counter
);
return Err(EncryptionError::DecryptionFailed);
}
let session_key_bytes = *session
.session_key()
.ok_or(EncryptionError::DecryptionFailed)?
.as_bytes();
session.touch(now_ms);
let cipher = ChaCha20Poly1305::new_from_slice(&session_key_bytes)
.map_err(|_| EncryptionError::DecryptionFailed)?;
let nonce = Nonce::from_slice(&msg.nonce);
cipher
.decrypt(nonce, msg.ciphertext.as_ref())
.map_err(|_| EncryptionError::DecryptionFailed)
}
pub fn close_session(&mut self, peer_node_id: NodeId) {
if let Some(session) = self.sessions.get_mut(&peer_node_id) {
session.close();
}
}
pub fn remove_session(&mut self, peer_node_id: NodeId) -> Option<PeerSession> {
self.sessions.remove(&peer_node_id)
}
pub fn cleanup_expired(&mut self, now_ms: u64) -> Vec<NodeId> {
let timeout = self.session_timeout_ms;
let expired: Vec<NodeId> = self
.sessions
.iter()
.filter(|(_, s)| s.is_expired(now_ms, timeout))
.map(|(id, _)| *id)
.collect();
for id in &expired {
self.sessions.remove(id);
}
expired
}
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn established_count(&self) -> usize {
self.sessions
.values()
.filter(|s| s.is_established())
.count()
}
fn enforce_session_limit(&mut self, now_ms: u64) {
self.cleanup_expired(now_ms);
while self.sessions.len() > self.max_sessions {
let oldest = self
.sessions
.iter()
.filter(|(_, s)| s.state == SessionState::Closed)
.min_by_key(|(_, s)| s.last_activity_ms)
.map(|(id, _)| *id);
if let Some(id) = oldest {
self.sessions.remove(&id);
} else {
let oldest = self
.sessions
.iter()
.filter(|(_, s)| !s.is_established())
.min_by_key(|(_, s)| s.last_activity_ms)
.map(|(id, _)| *id);
if let Some(id) = oldest {
self.sessions.remove(&id);
} else {
break; }
}
}
}
}
impl core::fmt::Debug for PeerSessionManager {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PeerSessionManager")
.field("our_node_id", &self.our_node_id)
.field("session_count", &self.sessions.len())
.field("max_sessions", &self.max_sessions)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_manager_creation() {
let manager = PeerSessionManager::new(NodeId::new(0x11111111));
assert_eq!(manager.our_node_id().as_u32(), 0x11111111);
assert_eq!(manager.session_count(), 0);
}
#[test]
fn test_initiate_session() {
let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
let msg = manager.initiate_session(NodeId::new(0x22222222), 1000);
assert_eq!(msg.sender_node_id.as_u32(), 0x11111111);
assert_eq!(manager.session_count(), 1);
assert_eq!(
manager.session_state(NodeId::new(0x22222222)),
Some(SessionState::AwaitingPeerKey)
);
}
#[test]
fn test_full_key_exchange() {
let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
let (bob_response, bob_established) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
assert!(bob_established);
assert!(bob.has_session(NodeId::new(0x11111111)));
let (_, alice_established) = alice.handle_key_exchange(&bob_response, 1000).unwrap();
assert!(alice_established);
assert!(alice.has_session(NodeId::new(0x22222222)));
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
alice.handle_key_exchange(&bob_response, 1000).unwrap();
let plaintext = b"Hello, Bob!";
let encrypted = alice
.encrypt_for_peer(NodeId::new(0x22222222), plaintext, 2000)
.unwrap();
let decrypted = bob.decrypt_from_peer(&encrypted, 2000).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_bidirectional_communication() {
let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
alice.handle_key_exchange(&bob_response, 1000).unwrap();
let msg1 = alice
.encrypt_for_peer(NodeId::new(0x22222222), b"From Alice", 2000)
.unwrap();
let dec1 = bob.decrypt_from_peer(&msg1, 2000).unwrap();
assert_eq!(dec1, b"From Alice");
let msg2 = bob
.encrypt_for_peer(NodeId::new(0x11111111), b"From Bob", 2000)
.unwrap();
let dec2 = alice.decrypt_from_peer(&msg2, 2000).unwrap();
assert_eq!(dec2, b"From Bob");
}
#[test]
fn test_replay_protection() {
let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
alice.handle_key_exchange(&bob_response, 1000).unwrap();
let encrypted = alice
.encrypt_for_peer(NodeId::new(0x22222222), b"Message", 2000)
.unwrap();
let result1 = bob.decrypt_from_peer(&encrypted, 2000);
assert!(result1.is_ok());
let result2 = bob.decrypt_from_peer(&encrypted, 2000);
assert!(result2.is_err());
}
#[test]
fn test_wrong_recipient_rejected() {
let mut alice = PeerSessionManager::new(NodeId::new(0x11111111));
let mut bob = PeerSessionManager::new(NodeId::new(0x22222222));
let mut charlie = PeerSessionManager::new(NodeId::new(0x33333333));
let alice_msg = alice.initiate_session(NodeId::new(0x22222222), 1000);
let (bob_response, _) = bob.handle_key_exchange(&alice_msg, 1000).unwrap();
alice.handle_key_exchange(&bob_response, 1000).unwrap();
let encrypted = alice
.encrypt_for_peer(NodeId::new(0x22222222), b"For Bob", 2000)
.unwrap();
let result = charlie.decrypt_from_peer(&encrypted, 2000);
assert!(result.is_err());
}
#[test]
fn test_session_expiry() {
let mut manager =
PeerSessionManager::new(NodeId::new(0x11111111)).with_session_timeout(10_000);
manager.initiate_session(NodeId::new(0x22222222), 1000);
let expired = manager.cleanup_expired(5000);
assert!(expired.is_empty());
assert_eq!(manager.session_count(), 1);
let expired = manager.cleanup_expired(20000);
assert_eq!(expired.len(), 1);
assert_eq!(manager.session_count(), 0);
}
#[test]
fn test_max_sessions_limit() {
let mut manager = PeerSessionManager::new(NodeId::new(0x11111111)).with_max_sessions(2);
manager.initiate_session(NodeId::new(0x22222222), 1000);
manager.initiate_session(NodeId::new(0x33333333), 2000);
manager.initiate_session(NodeId::new(0x44444444), 3000);
assert!(manager.session_count() <= 2);
}
#[test]
fn test_peer_encrypted_message_encode_decode() {
let ciphertext = vec![
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
0x99, 0x00,
];
let msg = PeerEncryptedMessage {
recipient_node_id: NodeId::new(0x22222222),
sender_node_id: NodeId::new(0x11111111),
counter: 42,
nonce: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
ciphertext: ciphertext.clone(),
};
let encoded = msg.encode();
let decoded = PeerEncryptedMessage::decode(&encoded).unwrap();
assert_eq!(decoded.recipient_node_id.as_u32(), 0x22222222);
assert_eq!(decoded.sender_node_id.as_u32(), 0x11111111);
assert_eq!(decoded.counter, 42);
assert_eq!(decoded.nonce, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
assert_eq!(decoded.ciphertext, ciphertext);
}
#[test]
fn test_close_session() {
let mut manager = PeerSessionManager::new(NodeId::new(0x11111111));
manager.initiate_session(NodeId::new(0x22222222), 1000);
manager.close_session(NodeId::new(0x22222222));
assert_eq!(
manager.session_state(NodeId::new(0x22222222)),
Some(SessionState::Closed)
);
}
}