use crate::auth::types::PeerSession;
use crate::{Error, Result};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Default)]
pub struct SessionManager {
session_nonce_to_session: HashMap<String, PeerSession>,
identity_key_to_nonces: HashMap<String, HashSet<String>>,
}
impl SessionManager {
pub fn new() -> Self {
Self::default()
}
pub fn add_session(&mut self, session: PeerSession) -> Result<()> {
let nonce = session
.session_nonce
.as_ref()
.ok_or_else(|| Error::AuthError("Session must have a nonce".into()))?;
if self.session_nonce_to_session.contains_key(nonce) {
return Err(Error::AuthError(format!(
"Session nonce already exists: {}",
nonce
)));
}
if let Some(ref identity_key) = session.peer_identity_key {
let key_hex = identity_key.to_hex();
self.identity_key_to_nonces
.entry(key_hex)
.or_default()
.insert(nonce.clone());
}
self.session_nonce_to_session.insert(nonce.clone(), session);
Ok(())
}
pub fn update_session(&mut self, session: PeerSession) {
if let Some(ref nonce) = session.session_nonce {
if let Some(old_session) = self.session_nonce_to_session.get(nonce) {
let old_key_hex = old_session.peer_identity_key.as_ref().map(|k| k.to_hex());
let new_key_hex = session.peer_identity_key.as_ref().map(|k| k.to_hex());
if old_key_hex != new_key_hex {
if let Some(old_hex) = old_key_hex {
if let Some(nonces) = self.identity_key_to_nonces.get_mut(&old_hex) {
nonces.remove(nonce);
if nonces.is_empty() {
self.identity_key_to_nonces.remove(&old_hex);
}
}
}
if let Some(new_hex) = new_key_hex {
self.identity_key_to_nonces
.entry(new_hex)
.or_default()
.insert(nonce.clone());
}
}
self.session_nonce_to_session.insert(nonce.clone(), session);
}
}
}
pub fn get_session(&self, identifier: &str) -> Option<&PeerSession> {
if let Some(session) = self.session_nonce_to_session.get(identifier) {
return Some(session);
}
if let Some(nonces) = self.identity_key_to_nonces.get(identifier) {
return self.select_best_session(nonces);
}
None
}
pub fn get_session_mut(&mut self, session_nonce: &str) -> Option<&mut PeerSession> {
self.session_nonce_to_session.get_mut(session_nonce)
}
pub fn remove_session(&mut self, session: &PeerSession) {
if let Some(ref nonce) = session.session_nonce {
self.session_nonce_to_session.remove(nonce);
if let Some(ref identity_key) = session.peer_identity_key {
let key_hex = identity_key.to_hex();
if let Some(nonces) = self.identity_key_to_nonces.get_mut(&key_hex) {
nonces.remove(nonce);
if nonces.is_empty() {
self.identity_key_to_nonces.remove(&key_hex);
}
}
}
}
}
pub fn remove_by_nonce(&mut self, session_nonce: &str) {
if let Some(session) = self.session_nonce_to_session.remove(session_nonce) {
if let Some(ref identity_key) = session.peer_identity_key {
let key_hex = identity_key.to_hex();
if let Some(nonces) = self.identity_key_to_nonces.get_mut(&key_hex) {
nonces.remove(session_nonce);
if nonces.is_empty() {
self.identity_key_to_nonces.remove(&key_hex);
}
}
}
}
}
pub fn has_session(&self, identifier: &str) -> bool {
self.get_session(identifier).is_some()
}
pub fn len(&self) -> usize {
self.session_nonce_to_session.len()
}
pub fn is_empty(&self) -> bool {
self.session_nonce_to_session.is_empty()
}
pub fn get_sessions_for_identity(&self, identity_key_hex: &str) -> Vec<&PeerSession> {
self.identity_key_to_nonces
.get(identity_key_hex)
.map(|nonces| {
nonces
.iter()
.filter_map(|n| self.session_nonce_to_session.get(n))
.collect()
})
.unwrap_or_default()
}
pub fn iter(&self) -> impl Iterator<Item = &PeerSession> {
self.session_nonce_to_session.values()
}
pub fn clear(&mut self) {
self.session_nonce_to_session.clear();
self.identity_key_to_nonces.clear();
}
fn select_best_session(&self, nonces: &HashSet<String>) -> Option<&PeerSession> {
let mut best: Option<&PeerSession> = None;
for nonce in nonces {
if let Some(session) = self.session_nonce_to_session.get(nonce) {
best = match best {
None => Some(session),
Some(current) => {
if session.is_authenticated && !current.is_authenticated {
Some(session)
} else if !session.is_authenticated && current.is_authenticated {
Some(current)
} else if session.last_update > current.last_update {
Some(session)
} else {
Some(current)
}
}
};
}
}
best
}
pub fn prune_stale_sessions(&mut self, max_age_ms: u64) -> usize {
let now = crate::auth::types::current_time_ms();
let cutoff = now.saturating_sub(max_age_ms);
let stale_nonces: Vec<String> = self
.session_nonce_to_session
.iter()
.filter(|(_, session)| session.last_update < cutoff)
.map(|(nonce, _)| nonce.clone())
.collect();
let count = stale_nonces.len();
for nonce in stale_nonces {
self.remove_by_nonce(&nonce);
}
count
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::types::current_time_ms;
use crate::primitives::PrivateKey;
fn make_session(
nonce: &str,
identity_key: Option<&crate::primitives::PublicKey>,
) -> PeerSession {
PeerSession {
session_nonce: Some(nonce.to_string()),
peer_identity_key: identity_key.cloned(),
last_update: current_time_ms(),
..Default::default()
}
}
#[test]
fn test_add_and_get_session() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let session = make_session("nonce123", Some(&key));
mgr.add_session(session).unwrap();
assert!(mgr.get_session("nonce123").is_some());
assert!(mgr.get_session(&key.to_hex()).is_some());
assert!(mgr.get_session("nonexistent").is_none());
}
#[test]
fn test_duplicate_nonce_rejected() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let session1 = make_session("nonce123", Some(&key));
let session2 = make_session("nonce123", Some(&key));
mgr.add_session(session1).unwrap();
assert!(mgr.add_session(session2).is_err());
}
#[test]
fn test_session_without_nonce_rejected() {
let mut mgr = SessionManager::new();
let session = PeerSession::new();
assert!(mgr.add_session(session).is_err());
}
#[test]
fn test_prefers_authenticated() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let mut s1 = make_session("nonce1", Some(&key));
s1.is_authenticated = false;
s1.last_update = current_time_ms() + 1000;
mgr.add_session(s1).unwrap();
let mut s2 = make_session("nonce2", Some(&key));
s2.is_authenticated = true;
s2.last_update = current_time_ms();
mgr.add_session(s2).unwrap();
let session = mgr.get_session(&key.to_hex()).unwrap();
assert!(session.is_authenticated);
assert_eq!(session.session_nonce.as_deref(), Some("nonce2"));
}
#[test]
fn test_prefers_newer_when_same_auth_status() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let mut s1 = make_session("nonce1", Some(&key));
s1.is_authenticated = true;
s1.last_update = current_time_ms();
mgr.add_session(s1).unwrap();
let mut s2 = make_session("nonce2", Some(&key));
s2.is_authenticated = true;
s2.last_update = current_time_ms() + 1000;
mgr.add_session(s2).unwrap();
let session = mgr.get_session(&key.to_hex()).unwrap();
assert_eq!(session.session_nonce.as_deref(), Some("nonce2"));
}
#[test]
fn test_update_session() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let session = make_session("nonce123", Some(&key));
mgr.add_session(session).unwrap();
let mut updated = mgr.get_session("nonce123").unwrap().clone();
updated.is_authenticated = true;
mgr.update_session(updated);
let session = mgr.get_session("nonce123").unwrap();
assert!(session.is_authenticated);
}
#[test]
fn test_remove_session() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let session = make_session("nonce123", Some(&key));
mgr.add_session(session.clone()).unwrap();
assert!(mgr.has_session("nonce123"));
assert!(mgr.has_session(&key.to_hex()));
mgr.remove_session(&session);
assert!(!mgr.has_session("nonce123"));
assert!(!mgr.has_session(&key.to_hex()));
}
#[test]
fn test_remove_by_nonce() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let session = make_session("nonce123", Some(&key));
mgr.add_session(session).unwrap();
mgr.remove_by_nonce("nonce123");
assert!(!mgr.has_session("nonce123"));
}
#[test]
fn test_multiple_sessions_per_identity() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
let s1 = make_session("nonce1", Some(&key));
let s2 = make_session("nonce2", Some(&key));
let s3 = make_session("nonce3", Some(&key));
mgr.add_session(s1).unwrap();
mgr.add_session(s2).unwrap();
mgr.add_session(s3).unwrap();
let sessions = mgr.get_sessions_for_identity(&key.to_hex());
assert_eq!(sessions.len(), 3);
}
#[test]
fn test_len_and_is_empty() {
let mut mgr = SessionManager::new();
assert!(mgr.is_empty());
assert_eq!(mgr.len(), 0);
let session = make_session("nonce123", None);
mgr.add_session(session).unwrap();
assert!(!mgr.is_empty());
assert_eq!(mgr.len(), 1);
}
#[test]
fn test_clear() {
let mut mgr = SessionManager::new();
let key = PrivateKey::random().public_key();
mgr.add_session(make_session("nonce1", Some(&key))).unwrap();
mgr.add_session(make_session("nonce2", Some(&key))).unwrap();
assert_eq!(mgr.len(), 2);
mgr.clear();
assert!(mgr.is_empty());
}
}