use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::debug;
const MAX_CONNECTIONS_PER_HOST: usize = 4;
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug)]
struct PooledConnection {
stream: TcpStream,
last_used: Instant,
}
type HostKey = (String, u16);
type ConnectionVec = Vec<PooledConnection>;
type PoolMap = HashMap<HostKey, ConnectionVec>;
#[derive(Clone)]
pub struct ConnectionPool {
pools: Arc<Mutex<PoolMap>>,
}
impl ConnectionPool {
pub fn new() -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn get_connection(&self, host: &str, port: u16) -> Result<TcpStream, &'static str> {
let key = (host.to_string(), port);
{
let mut pools = self.pools.lock().await;
if let Some(pool) = pools.get_mut(&key) {
while let Some(mut conn) = pool.pop() {
if conn.last_used.elapsed() < IDLE_TIMEOUT {
if Self::is_connection_alive(&mut conn.stream).await {
debug!("Reusing connection to {}:{}", host, port);
return Ok(conn.stream);
}
}
debug!("Dropping stale connection to {}:{}", host, port);
}
}
}
debug!("Creating new connection to {}:{}", host, port);
timeout(CONNECTION_TIMEOUT, TcpStream::connect((host, port)))
.await
.map_err(|_| "Connection timeout")?
.map_err(|_| "Connection failed")
}
pub async fn return_connection(&self, host: String, port: u16, stream: TcpStream) {
let key = (host.clone(), port);
let mut pools = self.pools.lock().await;
let pool = pools.entry(key).or_insert_with(Vec::new);
if pool.len() < MAX_CONNECTIONS_PER_HOST {
debug!("Returning connection to pool for {}:{}", host, port);
pool.push(PooledConnection {
stream,
last_used: Instant::now(),
});
} else {
debug!("Pool full for {}:{}, dropping connection", host, port);
}
}
async fn is_connection_alive(stream: &mut TcpStream) -> bool {
stream.readable().await.is_ok()
}
pub async fn cleanup_stale_connections(&self) {
let mut pools = self.pools.lock().await;
let now = Instant::now();
for ((host, port), pool) in pools.iter_mut() {
pool.retain(|conn| {
let is_fresh = now.duration_since(conn.last_used) < IDLE_TIMEOUT;
if !is_fresh {
debug!("Removing stale connection to {}:{}", host, port);
}
is_fresh
});
}
pools.retain(|_, pool| !pool.is_empty());
}
pub async fn stats(&self) -> HashMap<HostKey, usize> {
let pools = self.pools.lock().await;
pools.iter().map(|(key, pool)| (key.clone(), pool.len())).collect()
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connection_pool_basic() {
let pool = ConnectionPool::new();
let stats = pool.stats().await;
assert!(stats.is_empty());
}
#[tokio::test]
async fn test_connection_pool_return() {
let pool = ConnectionPool::new();
if let Ok(stream) = TcpStream::connect("127.0.0.1:1").await {
pool.return_connection("test.com".to_string(), 80, stream).await;
let stats = pool.stats().await;
let key = ("test.com".to_string(), 80);
assert_eq!(stats.get(&key), Some(&1));
}
}
}