use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct PeerState {
pub session_id: u64,
pub addr: SocketAddr,
pub connected_at: Instant,
pub last_active: Instant,
pub bytes_sent: u64,
pub bytes_recv: u64,
}
impl PeerState {
pub fn new(session_id: u64, addr: SocketAddr) -> Self {
let now = Instant::now();
Self {
session_id,
addr,
connected_at: now,
last_active: now,
bytes_sent: 0,
bytes_recv: 0,
}
}
pub fn touch(&mut self) {
self.last_active = Instant::now();
}
pub fn idle_time(&self) -> Duration {
Instant::now().duration_since(self.last_active)
}
pub fn uptime(&self) -> Duration {
Instant::now().duration_since(self.connected_at)
}
pub fn record_sent(&mut self, bytes: u64) {
self.bytes_sent += bytes;
self.touch();
}
pub fn record_recv(&mut self, bytes: u64) {
self.bytes_recv += bytes;
self.touch();
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PoolError {
Full,
DuplicateSession,
NotFound,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Full => write!(f, "connection pool is full"),
PoolError::DuplicateSession => write!(f, "duplicate session ID"),
PoolError::NotFound => write!(f, "peer not found"),
}
}
}
pub struct ConnectionPool {
peers: HashMap<u64, PeerState>,
max_peers: usize,
idle_timeout: Duration,
}
impl ConnectionPool {
pub fn new(max_peers: usize, idle_timeout: Duration) -> Self {
Self {
peers: HashMap::new(),
max_peers,
idle_timeout,
}
}
pub fn add_peer(&mut self, state: PeerState) -> std::result::Result<(), PoolError> {
if self.peers.len() >= self.max_peers {
return Err(PoolError::Full);
}
if self.peers.contains_key(&state.session_id) {
return Err(PoolError::DuplicateSession);
}
self.peers.insert(state.session_id, state);
Ok(())
}
pub fn remove_peer(&mut self, session_id: u64) -> Option<PeerState> {
self.peers.remove(&session_id)
}
pub fn get_peer(&self, session_id: u64) -> Option<&PeerState> {
self.peers.get(&session_id)
}
pub fn get_peer_mut(&mut self, session_id: u64) -> Option<&mut PeerState> {
self.peers.get_mut(&session_id)
}
pub fn touch(&mut self, session_id: u64) {
if let Some(peer) = self.peers.get_mut(&session_id) {
peer.touch();
}
}
pub fn cleanup_idle(&mut self) -> Vec<PeerState> {
let timeout = self.idle_timeout;
let mut removed = Vec::new();
self.peers.retain(|_, peer| {
if peer.idle_time() > timeout {
removed.push(PeerState {
session_id: peer.session_id,
addr: peer.addr,
connected_at: peer.connected_at,
last_active: peer.last_active,
bytes_sent: peer.bytes_sent,
bytes_recv: peer.bytes_recv,
});
false
} else {
true
}
});
removed
}
pub fn peer_count(&self) -> usize {
self.peers.len()
}
pub fn is_full(&self) -> bool {
self.peers.len() >= self.max_peers
}
pub fn is_empty(&self) -> bool {
self.peers.is_empty()
}
pub fn active_peers(&self) -> impl Iterator<Item = &PeerState> {
self.peers.values()
}
pub fn find_by_addr(&self, addr: SocketAddr) -> Option<&PeerState> {
self.peers.values().find(|p| p.addr == addr)
}
pub fn total_bytes_sent(&self) -> u64 {
self.peers.values().map(|p| p.bytes_sent).sum()
}
pub fn total_bytes_recv(&self) -> u64 {
self.peers.values().map(|p| p.bytes_recv).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
}
#[test]
fn add_and_get_peer() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
let peer = pool.get_peer(1).unwrap();
assert_eq!(peer.session_id, 1);
assert_eq!(peer.addr.port(), 8001);
}
#[test]
fn full_pool_rejects() {
let mut pool = ConnectionPool::new(2, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
pool.add_peer(PeerState::new(2, addr(8002))).unwrap();
assert_eq!(
pool.add_peer(PeerState::new(3, addr(8003))),
Err(PoolError::Full)
);
}
#[test]
fn duplicate_session_rejects() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
assert_eq!(
pool.add_peer(PeerState::new(1, addr(8002))),
Err(PoolError::DuplicateSession)
);
}
#[test]
fn remove_peer() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
let removed = pool.remove_peer(1).unwrap();
assert_eq!(removed.session_id, 1);
assert!(pool.is_empty());
}
#[test]
fn cleanup_idle() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(0)); pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
pool.add_peer(PeerState::new(2, addr(8002))).unwrap();
std::thread::sleep(Duration::from_millis(10));
let removed = pool.cleanup_idle();
assert_eq!(removed.len(), 2);
assert!(pool.is_empty());
}
#[test]
fn find_by_addr() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
pool.add_peer(PeerState::new(2, addr(8002))).unwrap();
let found = pool.find_by_addr(addr(8002)).unwrap();
assert_eq!(found.session_id, 2);
}
#[test]
fn bytes_tracking() {
let mut pool = ConnectionPool::new(10, Duration::from_secs(60));
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
pool.get_peer_mut(1).unwrap().record_sent(100);
pool.get_peer_mut(1).unwrap().record_recv(200);
assert_eq!(pool.total_bytes_sent(), 100);
assert_eq!(pool.total_bytes_recv(), 200);
}
#[test]
fn peer_count_and_is_full() {
let mut pool = ConnectionPool::new(2, Duration::from_secs(60));
assert_eq!(pool.peer_count(), 0);
assert!(!pool.is_full());
pool.add_peer(PeerState::new(1, addr(8001))).unwrap();
assert_eq!(pool.peer_count(), 1);
pool.add_peer(PeerState::new(2, addr(8002))).unwrap();
assert!(pool.is_full());
}
}