use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_rustls::client::TlsStream;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_idle_per_host: usize,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
pub connection_timeout: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_idle_per_host: 10,
idle_timeout: Duration::from_secs(90),
max_lifetime: Duration::from_secs(600), connection_timeout: Duration::from_secs(30),
}
}
}
pub struct PooledConnection {
pub stream: TlsStream<TcpStream>,
created_at: Instant,
last_used: Instant,
}
impl PooledConnection {
pub fn new(stream: TlsStream<TcpStream>) -> Self {
let now = Instant::now();
Self {
stream,
created_at: now,
last_used: now,
}
}
pub fn is_expired(&self, max_lifetime: Duration) -> bool {
self.created_at.elapsed() > max_lifetime
}
pub fn is_idle_timeout(&self, idle_timeout: Duration) -> bool {
self.last_used.elapsed() > idle_timeout
}
pub fn mark_used(&mut self) {
self.last_used = Instant::now();
}
}
pub struct ConnectionPool {
pools: Arc<Mutex<HashMap<String, VecDeque<PooledConnection>>>>,
config: PoolConfig,
stats: Arc<Mutex<PoolStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct PoolStats {
pub hits: u64,
pub misses: u64,
pub total_connections: usize,
pub evictions: u64,
}
impl ConnectionPool {
pub fn new() -> Self {
Self::with_config(PoolConfig::default())
}
pub fn with_config(config: PoolConfig) -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
config,
stats: Arc::new(Mutex::new(PoolStats::default())),
}
}
pub async fn get(&self, host: &str) -> Option<TlsStream<TcpStream>> {
let mut pools = self.pools.lock().await;
let mut stats = self.stats.lock().await;
let pool = pools.get_mut(host)?;
while let Some(mut conn) = pool.pop_front() {
if conn.is_expired(self.config.max_lifetime) {
debug!(host = %host, "Connection expired, discarding");
stats.evictions += 1;
continue;
}
if conn.is_idle_timeout(self.config.idle_timeout) {
debug!(host = %host, "Connection idle timeout, discarding");
stats.evictions += 1;
continue;
}
conn.mark_used();
stats.hits += 1;
debug!(host = %host, "Reusing pooled connection (hit)");
return Some(conn.stream);
}
None
}
pub async fn put(&self, host: String, stream: TlsStream<TcpStream>) {
let mut pools = self.pools.lock().await;
let pool = pools.entry(host.clone()).or_insert_with(VecDeque::new);
if pool.len() >= self.config.max_idle_per_host {
debug!(host = %host, "Pool full, dropping connection");
return;
}
pool.push_back(PooledConnection::new(stream));
debug!(host = %host, pool_size = pool.len(), "Connection returned to pool");
}
pub async fn record_miss(&self) {
let mut stats = self.stats.lock().await;
stats.misses += 1;
}
pub async fn stats(&self) -> PoolStats {
let pools = self.pools.lock().await;
let mut stats = self.stats.lock().await;
stats.total_connections = pools.values().map(|p| p.len()).sum();
stats.clone()
}
pub async fn cleanup(&self) {
let mut pools = self.pools.lock().await;
let mut stats = self.stats.lock().await;
let mut total_removed = 0;
for (host, pool) in pools.iter_mut() {
let original_len = pool.len();
pool.retain(|conn| {
let valid = !conn.is_expired(self.config.max_lifetime)
&& !conn.is_idle_timeout(self.config.idle_timeout);
if !valid {
total_removed += 1;
stats.evictions += 1;
}
valid
});
let removed = original_len - pool.len();
if removed > 0 {
debug!(host = %host, removed, "Cleaned up expired connections");
}
}
if total_removed > 0 {
debug!(total_removed, "Pool cleanup complete");
}
}
pub async fn idle_count(&self, host: &str) -> usize {
let pools = self.pools.lock().await;
pools.get(host).map(|p| p.len()).unwrap_or(0)
}
pub async fn total_idle(&self) -> usize {
let pools = self.pools.lock().await;
pools.values().map(|p| p.len()).sum()
}
pub async fn clear(&self) {
let mut pools = self.pools.lock().await;
pools.clear();
debug!("Connection pool cleared");
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
pub fn start_cleanup_task(pool: Arc<ConnectionPool>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
pool.cleanup().await;
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_defaults() {
let config = PoolConfig::default();
assert_eq!(config.max_idle_per_host, 10);
assert_eq!(config.idle_timeout, Duration::from_secs(90));
assert_eq!(config.max_lifetime, Duration::from_secs(600));
assert_eq!(config.connection_timeout, Duration::from_secs(30));
}
#[test]
fn test_pool_config_custom() {
let config = PoolConfig {
max_idle_per_host: 5,
idle_timeout: Duration::from_secs(30),
max_lifetime: Duration::from_secs(300),
connection_timeout: Duration::from_secs(15),
};
assert_eq!(config.max_idle_per_host, 5);
assert_eq!(config.idle_timeout, Duration::from_secs(30));
assert_eq!(config.max_lifetime, Duration::from_secs(300));
assert_eq!(config.connection_timeout, Duration::from_secs(15));
}
#[test]
#[ignore] fn test_pooled_connection_expiration() {
let stream = create_mock_tls_stream();
let mut conn = PooledConnection::new(stream);
assert!(!conn.is_expired(Duration::from_secs(600)));
conn.created_at = Instant::now() - Duration::from_secs(700);
assert!(conn.is_expired(Duration::from_secs(600)));
}
#[test]
#[ignore] fn test_pooled_connection_idle_timeout() {
let stream = create_mock_tls_stream();
let mut conn = PooledConnection::new(stream);
assert!(!conn.is_idle_timeout(Duration::from_secs(90)));
conn.last_used = Instant::now() - Duration::from_secs(100);
assert!(conn.is_idle_timeout(Duration::from_secs(90)));
}
#[tokio::test]
async fn test_connection_pool_creation() {
let pool = ConnectionPool::new();
let stats = pool.stats().await;
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.evictions, 0);
}
#[tokio::test]
async fn test_pool_get_empty() {
let pool = ConnectionPool::new();
let conn = pool.get("example.com:443").await;
assert!(conn.is_none());
}
#[tokio::test]
async fn test_pool_idle_count() {
let pool = ConnectionPool::new();
assert_eq!(pool.idle_count("example.com:443").await, 0);
assert_eq!(pool.total_idle().await, 0);
}
fn create_mock_tls_stream() -> TlsStream<TcpStream> {
unimplemented!("Mock TLS stream creation for unit tests")
}
}