use crate::connection::VCLConnection;
use crate::error::VCLError;
use crate::packet::VCLPacket;
use std::collections::HashMap;
use tracing::{debug, info, warn};
pub type ConnectionId = u64;
pub struct VCLPool {
connections: HashMap<ConnectionId, VCLConnection>,
next_id: ConnectionId,
max_connections: usize,
}
impl VCLPool {
pub fn new(max_connections: usize) -> Self {
info!(max_connections, "VCLPool created");
VCLPool {
connections: HashMap::new(),
next_id: 0,
max_connections,
}
}
pub async fn bind(&mut self, addr: &str) -> Result<ConnectionId, VCLError> {
if self.connections.len() >= self.max_connections {
warn!(
current = self.connections.len(),
max = self.max_connections,
"Pool is at maximum capacity"
);
return Err(VCLError::InvalidPacket(format!(
"Pool is full: max {} connections",
self.max_connections
)));
}
let conn = VCLConnection::bind(addr).await?;
let id = self.next_id;
self.next_id += 1;
self.connections.insert(id, conn);
info!(id, addr, "Connection added to pool");
Ok(id)
}
pub async fn connect(&mut self, id: ConnectionId, addr: &str) -> Result<(), VCLError> {
let conn = self.get_mut(id)?;
debug!(id, peer = %addr, "Pool: connecting");
conn.connect(addr).await
}
pub async fn accept_handshake(&mut self, id: ConnectionId) -> Result<(), VCLError> {
let conn = self.get_mut(id)?;
debug!(id, "Pool: accepting handshake");
conn.accept_handshake().await
}
pub async fn send(&mut self, id: ConnectionId, data: &[u8]) -> Result<(), VCLError> {
let conn = self.get_mut(id)?;
debug!(id, size = data.len(), "Pool: sending");
conn.send(data).await
}
pub async fn recv(&mut self, id: ConnectionId) -> Result<VCLPacket, VCLError> {
let conn = self.get_mut(id)?;
debug!(id, "Pool: waiting for packet");
conn.recv().await
}
pub async fn ping(&mut self, id: ConnectionId) -> Result<(), VCLError> {
let conn = self.get_mut(id)?;
debug!(id, "Pool: sending ping");
conn.ping().await
}
pub async fn rotate_keys(&mut self, id: ConnectionId) -> Result<(), VCLError> {
let conn = self.get_mut(id)?;
debug!(id, "Pool: rotating keys");
conn.rotate_keys().await
}
pub fn close(&mut self, id: ConnectionId) -> Result<(), VCLError> {
match self.connections.get_mut(&id) {
Some(conn) => {
conn.close()?;
self.connections.remove(&id);
info!(id, "Connection removed from pool");
Ok(())
}
None => {
warn!(id, "close() called with unknown connection ID");
Err(VCLError::InvalidPacket(format!(
"Connection ID {} not found in pool",
id
)))
}
}
}
pub fn close_all(&mut self) {
info!(count = self.connections.len(), "Closing all pool connections");
for (id, conn) in self.connections.iter_mut() {
if let Err(e) = conn.close() {
warn!(id, error = %e, "Error closing connection during close_all");
}
}
self.connections.clear();
}
pub fn len(&self) -> usize {
self.connections.len()
}
pub fn is_empty(&self) -> bool {
self.connections.is_empty()
}
pub fn is_full(&self) -> bool {
self.connections.len() >= self.max_connections
}
pub fn connection_ids(&self) -> Vec<ConnectionId> {
self.connections.keys().copied().collect()
}
pub fn contains(&self, id: ConnectionId) -> bool {
self.connections.contains_key(&id)
}
pub fn get(&self, id: ConnectionId) -> Result<&VCLConnection, VCLError> {
self.connections.get(&id).ok_or_else(|| {
VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
})
}
fn get_mut(&mut self, id: ConnectionId) -> Result<&mut VCLConnection, VCLError> {
self.connections.get_mut(&id).ok_or_else(|| {
VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
})
}
}
impl Drop for VCLPool {
fn drop(&mut self) {
if !self.connections.is_empty() {
self.close_all();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_new() {
let pool = VCLPool::new(5);
assert_eq!(pool.len(), 0);
assert!(pool.is_empty());
assert!(!pool.is_full());
}
#[tokio::test]
async fn test_pool_bind() {
let mut pool = VCLPool::new(5);
let id = pool.bind("127.0.0.1:0").await.unwrap();
assert_eq!(pool.len(), 1);
assert!(pool.contains(id));
assert!(!pool.is_empty());
}
#[tokio::test]
async fn test_pool_max_capacity() {
let mut pool = VCLPool::new(2);
pool.bind("127.0.0.1:0").await.unwrap();
pool.bind("127.0.0.1:0").await.unwrap();
assert!(pool.is_full());
let result = pool.bind("127.0.0.1:0").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_pool_close() {
let mut pool = VCLPool::new(5);
let id = pool.bind("127.0.0.1:0").await.unwrap();
assert_eq!(pool.len(), 1);
pool.close(id).unwrap();
assert_eq!(pool.len(), 0);
assert!(!pool.contains(id));
}
#[tokio::test]
async fn test_pool_close_all() {
let mut pool = VCLPool::new(5);
pool.bind("127.0.0.1:0").await.unwrap();
pool.bind("127.0.0.1:0").await.unwrap();
pool.bind("127.0.0.1:0").await.unwrap();
assert_eq!(pool.len(), 3);
pool.close_all();
assert!(pool.is_empty());
}
#[tokio::test]
async fn test_pool_unknown_id() {
let mut pool = VCLPool::new(5);
let result = pool.close(999);
assert!(result.is_err());
}
#[tokio::test]
async fn test_pool_connection_ids() {
let mut pool = VCLPool::new(5);
let id1 = pool.bind("127.0.0.1:0").await.unwrap();
let id2 = pool.bind("127.0.0.1:0").await.unwrap();
let mut ids = pool.connection_ids();
ids.sort();
assert_eq!(ids, vec![id1, id2]);
}
}