use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::RwLock;
use tokio::sync::broadcast;
use tracing::{debug, info, warn};
use crate::crypto::NoiseSession;
use crate::types::{SessionId, TrafficStats};
use super::AuthorizedKey;
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub max_sessions_per_user: u32,
pub idle_timeout: Duration,
pub absolute_timeout: Duration,
pub cleanup_interval: Duration,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_sessions_per_user: 10,
idle_timeout: Duration::from_secs(300),
absolute_timeout: Duration::from_secs(86400),
cleanup_interval: Duration::from_secs(60),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
Handshaking,
Active,
Closing,
Closed,
}
pub struct ServerSession {
pub id: SessionId,
pub key: Option<AuthorizedKey>,
pub remote_addrs: RwLock<Vec<SocketAddr>>,
pub state: RwLock<SessionState>,
pub noise: RwLock<Option<NoiseSession>>,
pub stats: RwLock<TrafficStats>,
pub created_at: Instant,
pub last_activity: RwLock<Instant>,
pub uplinks: RwLock<Vec<String>>,
pub metadata: RwLock<HashMap<String, String>>,
}
impl ServerSession {
pub fn new(id: SessionId) -> Self {
let now = Instant::now();
Self {
id,
key: None,
remote_addrs: RwLock::new(Vec::new()),
state: RwLock::new(SessionState::Handshaking),
noise: RwLock::new(None),
stats: RwLock::new(TrafficStats::default()),
created_at: now,
last_activity: RwLock::new(now),
uplinks: RwLock::new(Vec::new()),
metadata: RwLock::new(HashMap::new()),
}
}
pub fn with_key(id: SessionId, key: AuthorizedKey) -> Self {
let mut session = Self::new(id);
session.key = Some(key);
session
}
pub fn touch(&self) {
*self.last_activity.write() = Instant::now();
}
pub fn is_expired(&self, idle_timeout: Duration, absolute_timeout: Duration) -> bool {
let now = Instant::now();
let last_activity = *self.last_activity.read();
if now.duration_since(last_activity) > idle_timeout {
return true;
}
if now.duration_since(self.created_at) > absolute_timeout {
return true;
}
false
}
pub fn add_remote_addr(&self, addr: SocketAddr) {
let mut addrs = self.remote_addrs.write();
if !addrs.contains(&addr) {
addrs.push(addr);
}
}
pub fn set_noise(&self, noise: NoiseSession) {
*self.noise.write() = Some(noise);
*self.state.write() = SessionState::Active;
}
pub fn is_active(&self) -> bool {
*self.state.read() == SessionState::Active
}
pub fn age(&self) -> Duration {
Instant::now().duration_since(self.created_at)
}
pub fn idle_time(&self) -> Duration {
Instant::now().duration_since(*self.last_activity.read())
}
pub fn record_sent(&self, bytes: u64) {
let mut stats = self.stats.write();
stats.bytes_sent += bytes;
stats.packets_sent += 1;
drop(stats);
self.touch();
}
pub fn record_received(&self, bytes: u64) {
let mut stats = self.stats.write();
stats.bytes_received += bytes;
stats.packets_received += 1;
drop(stats);
self.touch();
}
}
#[derive(Debug, Clone)]
pub enum SessionEvent {
Created(SessionId),
Authenticated {
session_id: SessionId,
user_id: String,
},
Closed(SessionId),
Expired(SessionId),
}
pub struct SessionManager {
config: SessionConfig,
sessions: DashMap<SessionId, Arc<ServerSession>>,
sessions_by_user: DashMap<String, Vec<SessionId>>,
sessions_by_addr: DashMap<SocketAddr, SessionId>,
event_tx: broadcast::Sender<SessionEvent>,
total_sessions: std::sync::atomic::AtomicU64,
}
impl SessionManager {
pub fn new(config: SessionConfig) -> Self {
let (event_tx, _) = broadcast::channel(256);
Self {
config,
sessions: DashMap::new(),
sessions_by_user: DashMap::new(),
sessions_by_addr: DashMap::new(),
event_tx,
total_sessions: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
self.event_tx.subscribe()
}
pub fn create_session(&self) -> Arc<ServerSession> {
let id = SessionId::generate();
let session = Arc::new(ServerSession::new(id));
self.sessions.insert(id, Arc::clone(&session));
self.total_sessions
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = self.event_tx.send(SessionEvent::Created(id));
debug!("Created session {}", id);
session
}
pub fn create_session_for_key(&self, key: &AuthorizedKey) -> Option<Arc<ServerSession>> {
let current_sessions = self
.sessions_by_user
.get(&key.public_key)
.map(|v| v.len())
.unwrap_or(0);
if current_sessions as u32 >= key.max_connections {
warn!(
"Key {} has reached session limit ({}/{})",
key.short_id(),
current_sessions,
key.max_connections
);
return None;
}
let id = SessionId::generate();
let session = Arc::new(ServerSession::with_key(id, key.clone()));
self.sessions.insert(id, Arc::clone(&session));
self.sessions_by_user
.entry(key.public_key.clone())
.or_default()
.push(id);
self.total_sessions
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = self.event_tx.send(SessionEvent::Created(id));
let _ = self.event_tx.send(SessionEvent::Authenticated {
session_id: id,
user_id: key.public_key.clone(),
});
debug!("Created session {} for key {}", id, key.short_id());
Some(session)
}
pub fn get_session(&self, id: SessionId) -> Option<Arc<ServerSession>> {
self.sessions.get(&id).map(|r| Arc::clone(&r))
}
pub fn get_session_by_addr(&self, addr: SocketAddr) -> Option<Arc<ServerSession>> {
let id = self.sessions_by_addr.get(&addr)?;
self.get_session(*id)
}
pub fn associate_addr(&self, session_id: SessionId, addr: SocketAddr) {
if let Some(session) = self.get_session(session_id) {
session.add_remote_addr(addr);
self.sessions_by_addr.insert(addr, session_id);
}
}
pub fn authenticate_session(&self, session_id: SessionId, key: AuthorizedKey) -> bool {
if let Some(_session) = self.sessions.get_mut(&session_id) {
self.sessions_by_user
.entry(key.public_key.clone())
.or_default()
.push(session_id);
let _ = self.event_tx.send(SessionEvent::Authenticated {
session_id,
user_id: key.public_key.clone(),
});
true
} else {
false
}
}
pub fn close_session(&self, id: SessionId) {
if let Some((_, session)) = self.sessions.remove(&id) {
for addr in session.remote_addrs.read().iter() {
self.sessions_by_addr.remove(addr);
}
if let Some(key) = &session.key {
if let Some(mut sessions) = self.sessions_by_user.get_mut(&key.public_key) {
sessions.retain(|s| *s != id);
}
}
let _ = self.event_tx.send(SessionEvent::Closed(id));
debug!("Closed session {}", id);
}
}
pub fn active_count(&self) -> usize {
self.sessions.len()
}
pub fn total_count(&self) -> u64 {
self.total_sessions
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn get_user_sessions(&self, user_id: &str) -> Vec<Arc<ServerSession>> {
self.sessions_by_user
.get(user_id)
.map(|ids| ids.iter().filter_map(|id| self.get_session(*id)).collect())
.unwrap_or_default()
}
pub fn cleanup_expired(&self) -> usize {
let mut expired = Vec::new();
for entry in self.sessions.iter() {
if entry
.value()
.is_expired(self.config.idle_timeout, self.config.absolute_timeout)
{
expired.push(*entry.key());
}
}
let count = expired.len();
for id in expired {
let _ = self.event_tx.send(SessionEvent::Expired(id));
self.close_session(id);
}
if count > 0 {
info!("Cleaned up {} expired sessions", count);
}
count
}
pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let interval = self.config.cleanup_interval;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
self.cleanup_expired();
}
})
}
pub fn all_sessions(&self) -> Vec<Arc<ServerSession>> {
self.sessions.iter().map(|r| Arc::clone(&r)).collect()
}
pub fn aggregate_stats(&self) -> TrafficStats {
let mut total = TrafficStats::default();
for session in self.sessions.iter() {
let stats = session.stats.read();
total.bytes_sent += stats.bytes_sent;
total.bytes_received += stats.bytes_received;
total.packets_sent += stats.packets_sent;
total.packets_received += stats.packets_received;
total.packets_dropped += stats.packets_dropped;
total.packets_retransmitted += stats.packets_retransmitted;
}
total
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create_session();
assert!(!session.is_active());
assert_eq!(manager.active_count(), 1);
}
#[test]
fn test_session_by_addr() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create_session();
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
manager.associate_addr(session.id, addr);
let found = manager.get_session_by_addr(addr);
assert!(found.is_some());
assert_eq!(found.unwrap().id, session.id);
}
#[test]
fn test_session_expiry() {
let config = SessionConfig {
idle_timeout: Duration::from_millis(10),
absolute_timeout: Duration::from_secs(3600),
..Default::default()
};
let manager = SessionManager::new(config);
let session = manager.create_session();
std::thread::sleep(Duration::from_millis(20));
assert!(session.is_expired(Duration::from_millis(10), Duration::from_secs(3600)));
}
#[test]
fn test_key_session_limit() {
let manager = SessionManager::new(SessionConfig::default());
let key = super::AuthorizedKey::new("test_key_base64").with_max_connections(2);
let s1 = manager.create_session_for_key(&key);
assert!(s1.is_some());
let s2 = manager.create_session_for_key(&key);
assert!(s2.is_some());
let s3 = manager.create_session_for_key(&key);
assert!(s3.is_none());
}
}