stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! Server-side connection pool for managing multiple peer sessions.
//!
//! `ConnectionPool` tracks per-peer state, enforces maximum connections,
//! and cleans up idle sessions.

use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};

/// Per-peer connection state tracked by the pool.
#[derive(Debug)]
pub struct PeerState {
    /// SRX session identifier.
    pub session_id: u64,
    /// Remote address.
    pub addr: SocketAddr,
    /// When the peer connected.
    pub connected_at: Instant,
    /// Last activity timestamp (updated on send/recv).
    pub last_active: Instant,
    /// Total bytes sent to this peer.
    pub bytes_sent: u64,
    /// Total bytes received from this peer.
    pub bytes_recv: u64,
}

impl PeerState {
    /// Create a new peer state with the given session ID and address.
    pub fn new(session_id: u64, addr: SocketAddr) -> Self {
        let now = Instant::now();
        Self {
            session_id,
            addr,
            connected_at: now,
            last_active: now,
            bytes_sent: 0,
            bytes_recv: 0,
        }
    }

    /// Update last_active to now.
    pub fn touch(&mut self) {
        self.last_active = Instant::now();
    }

    /// Duration since last activity.
    pub fn idle_time(&self) -> Duration {
        Instant::now().duration_since(self.last_active)
    }

    /// Total connection uptime.
    pub fn uptime(&self) -> Duration {
        Instant::now().duration_since(self.connected_at)
    }

    /// Record bytes sent.
    pub fn record_sent(&mut self, bytes: u64) {
        self.bytes_sent += bytes;
        self.touch();
    }

    /// Record bytes received.
    pub fn record_recv(&mut self, bytes: u64) {
        self.bytes_recv += bytes;
        self.touch();
    }
}

/// Error type for pool operations.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PoolError {
    /// Connection pool is at maximum capacity.
    Full,
    /// A peer with this session ID already exists.
    DuplicateSession,
    /// No peer found with the given session ID.
    NotFound,
}

impl std::fmt::Display for PoolError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            PoolError::Full => write!(f, "connection pool is full"),
            PoolError::DuplicateSession => write!(f, "duplicate session ID"),
            PoolError::NotFound => write!(f, "peer not found"),
        }
    }
}

/// Server-side connection pool managing multiple peer sessions.
pub struct ConnectionPool {
    peers: HashMap<u64, PeerState>,
    max_peers: usize,
    idle_timeout: Duration,
}

impl ConnectionPool {
    /// Create a new pool with the given limits.
    pub fn new(max_peers: usize, idle_timeout: Duration) -> Self {
        Self {
            peers: HashMap::new(),
            max_peers,
            idle_timeout,
        }
    }

    /// Add a peer to the pool. Fails if full or duplicate session ID.
    pub fn add_peer(&mut self, state: PeerState) -> std::result::Result<(), PoolError> {
        if self.peers.len() >= self.max_peers {
            return Err(PoolError::Full);
        }
        if self.peers.contains_key(&state.session_id) {
            return Err(PoolError::DuplicateSession);
        }
        self.peers.insert(state.session_id, state);
        Ok(())
    }

    /// Remove a peer by session ID.
    pub fn remove_peer(&mut self, session_id: u64) -> Option<PeerState> {
        self.peers.remove(&session_id)
    }

    /// Get a reference to a peer's state.
    pub fn get_peer(&self, session_id: u64) -> Option<&PeerState> {
        self.peers.get(&session_id)
    }

    /// Get a mutable reference to a peer's state.
    pub fn get_peer_mut(&mut self, session_id: u64) -> Option<&mut PeerState> {
        self.peers.get_mut(&session_id)
    }

    /// Update last_active for a peer.
    pub fn touch(&mut self, session_id: u64) {
        if let Some(peer) = self.peers.get_mut(&session_id) {
            peer.touch();
        }
    }

    /// Remove all peers that have been idle longer than `idle_timeout`.
    /// Returns the removed peers.
    pub fn cleanup_idle(&mut self) -> Vec<PeerState> {
        let timeout = self.idle_timeout;
        let mut removed = Vec::new();
        self.peers.retain(|_, peer| {
            if peer.idle_time() > timeout {
                removed.push(PeerState {
                    session_id: peer.session_id,
                    addr: peer.addr,
                    connected_at: peer.connected_at,
                    last_active: peer.last_active,
                    bytes_sent: peer.bytes_sent,
                    bytes_recv: peer.bytes_recv,
                });
                false
            } else {
                true
            }
        });
        removed
    }

    /// Current number of connected peers.
    pub fn peer_count(&self) -> usize {
        self.peers.len()
    }

    /// Whether the pool is at capacity.
    pub fn is_full(&self) -> bool {
        self.peers.len() >= self.max_peers
    }

    /// Whether the pool has no peers.
    pub fn is_empty(&self) -> bool {
        self.peers.is_empty()
    }

    /// Iterator over all active peer states.
    pub fn active_peers(&self) -> impl Iterator<Item = &PeerState> {
        self.peers.values()
    }

    /// Find a peer by remote address.
    pub fn find_by_addr(&self, addr: SocketAddr) -> Option<&PeerState> {
        self.peers.values().find(|p| p.addr == addr)
    }

    /// Total bytes sent across all peers.
    pub fn total_bytes_sent(&self) -> u64 {
        self.peers.values().map(|p| p.bytes_sent).sum()
    }

    /// Total bytes received across all peers.
    pub fn total_bytes_recv(&self) -> u64 {
        self.peers.values().map(|p| p.bytes_recv).sum()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{Ipv4Addr, SocketAddrV4};

    fn addr(port: u16) -> SocketAddr {
        SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
    }

    #[test]
    fn add_and_get_peer() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();

        let peer = pool.get_peer(1).unwrap();
        assert_eq!(peer.session_id, 1);
        assert_eq!(peer.addr.port(), 8001);
    }

    #[test]
    fn full_pool_rejects() {
        let mut pool = ConnectionPool::new(2, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
        pool.add_peer(PeerState::new(2, addr(8002))).unwrap();

        assert_eq!(
            pool.add_peer(PeerState::new(3, addr(8003))),
            Err(PoolError::Full)
        );
    }

    #[test]
    fn duplicate_session_rejects() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();

        assert_eq!(
            pool.add_peer(PeerState::new(1, addr(8002))),
            Err(PoolError::DuplicateSession)
        );
    }

    #[test]
    fn remove_peer() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();

        let removed = pool.remove_peer(1).unwrap();
        assert_eq!(removed.session_id, 1);
        assert!(pool.is_empty());
    }

    #[test]
    fn cleanup_idle() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(0)); // 0s timeout
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
        pool.add_peer(PeerState::new(2, addr(8002))).unwrap();

        // With 0s timeout, all peers are immediately idle
        std::thread::sleep(Duration::from_millis(10));
        let removed = pool.cleanup_idle();
        assert_eq!(removed.len(), 2);
        assert!(pool.is_empty());
    }

    #[test]
    fn find_by_addr() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
        pool.add_peer(PeerState::new(2, addr(8002))).unwrap();

        let found = pool.find_by_addr(addr(8002)).unwrap();
        assert_eq!(found.session_id, 2);
    }

    #[test]
    fn bytes_tracking() {
        let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();

        pool.get_peer_mut(1).unwrap().record_sent(100);
        pool.get_peer_mut(1).unwrap().record_recv(200);

        assert_eq!(pool.total_bytes_sent(), 100);
        assert_eq!(pool.total_bytes_recv(), 200);
    }

    #[test]
    fn peer_count_and_is_full() {
        let mut pool = ConnectionPool::new(2, Duration::from_secs(60));
        assert_eq!(pool.peer_count(), 0);
        assert!(!pool.is_full());

        pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
        assert_eq!(pool.peer_count(), 1);

        pool.add_peer(PeerState::new(2, addr(8002))).unwrap();
        assert!(pool.is_full());
    }
}