use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_rustls::client::TlsStream;
#[derive(Debug)]
pub enum StreamWrapper {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
#[cfg(test)]
Dummy,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct PoolKey {
pub host: String,
pub port: u16,
}
impl PoolKey {
pub fn new(host: String, port: u16) -> Self {
PoolKey { host, port }
}
}
#[derive(Debug)]
pub struct PooledConnection {
pub stream: StreamWrapper,
pub created_at: Instant,
pub last_used: Instant,
pub is_active: bool,
}
impl PooledConnection {
pub fn new(stream: StreamWrapper) -> Self {
let now = Instant::now();
PooledConnection {
stream,
created_at: now,
last_used: now,
is_active: true,
}
}
#[cfg(test)]
pub fn mock(created_at: Instant, last_used: Instant) -> Self {
PooledConnection {
stream: StreamWrapper::Dummy,
created_at,
last_used,
is_active: true,
}
}
pub fn mark_used(&mut self) {
self.last_used = Instant::now();
}
pub fn is_expired(&self, timeout: Duration) -> bool {
self.last_used.elapsed() > timeout
}
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_connections_per_host: usize,
pub idle_timeout: Duration,
pub max_idle_connections: usize,
}
impl Default for PoolConfig {
fn default() -> Self {
PoolConfig {
max_connections_per_host: 100,
idle_timeout: Duration::from_secs(90),
max_idle_connections: 1000,
}
}
}
#[derive(Debug)]
pub struct ConnectionPool {
config: PoolConfig,
connections: Arc<Mutex<HashMap<PoolKey, Vec<PooledConnection>>>>,
stats: Arc<Mutex<PoolStats>>,
}
#[derive(Debug, Default)]
pub struct PoolStats {
pub total_acquired: u64,
pub total_released: u64,
pub total_hits: u64,
pub total_misses: u64,
pub current_connections: usize,
}
impl ConnectionPool {
pub fn new(config: PoolConfig) -> Self {
ConnectionPool {
config,
connections: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(PoolStats::default())),
}
}
pub async fn acquire(&self, key: PoolKey) -> Option<PooledConnection> {
let mut connections = self.connections.lock().await;
let mut stats = self.stats.lock().await;
self.cleanup_expired(&mut connections, &key);
if let Some(pool) = connections.get_mut(&key) {
if let Some(mut conn) = pool.pop() {
conn.mark_used();
stats.total_acquired += 1;
stats.total_hits += 1;
stats.current_connections -= 1;
return Some(conn);
}
}
stats.total_misses += 1;
None
}
pub async fn release(&self, key: PoolKey, connection: PooledConnection) {
let mut connections = self.connections.lock().await;
let mut stats = self.stats.lock().await;
let pool = connections.entry(key).or_insert_with(Vec::new);
if pool.len() < self.config.max_connections_per_host {
pool.push(connection);
stats.total_released += 1;
stats.current_connections += 1;
}
}
fn cleanup_expired(&self, connections: &mut HashMap<PoolKey, Vec<PooledConnection>>, key: &PoolKey) {
if let Some(pool) = connections.get_mut(key) {
pool.retain(|conn| !conn.is_expired(self.config.idle_timeout));
}
}
pub async fn cleanup_all(&self) {
let mut connections = self.connections.lock().await;
for (_, pool) in connections.iter_mut() {
pool.retain(|conn| !conn.is_expired(self.config.idle_timeout));
}
}
pub async fn stats(&self) -> PoolStats {
let stats = self.stats.lock().await;
PoolStats {
total_acquired: stats.total_acquired,
total_released: stats.total_released,
total_hits: stats.total_hits,
total_misses: stats.total_misses,
current_connections: stats.current_connections,
}
}
pub async fn hit_rate(&self) -> f64 {
let stats = self.stats.lock().await;
let total = stats.total_hits + stats.total_misses;
if total == 0 {
0.0
} else {
stats.total_hits as f64 / total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_key() {
let key1 = PoolKey::new("example.com".to_string(), 443);
let key2 = PoolKey::new("example.com".to_string(), 443);
let key3 = PoolKey::new("example.com".to_string(), 80);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_pooled_connection_expiry() {
let now = Instant::now();
let mut conn = PooledConnection::mock(now, now);
assert!(!conn.is_expired(Duration::from_secs(90)));
conn.last_used = Instant::now() - Duration::from_secs(100);
assert!(conn.is_expired(Duration::from_secs(90)));
}
#[tokio::test]
async fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_connections_per_host, 100);
assert_eq!(config.idle_timeout.as_secs(), 90);
assert_eq!(config.max_idle_connections, 1000);
}
#[tokio::test]
async fn test_pooled_connection_mark_used() {
let now = Instant::now();
let mut conn = PooledConnection::mock(now, now);
let original_last_used = conn.last_used;
tokio::time::sleep(Duration::from_millis(10)).await;
conn.mark_used();
assert!(conn.last_used > original_last_used);
}
#[tokio::test]
async fn test_pooled_connection_is_active() {
let now = Instant::now();
let mut conn = PooledConnection::mock(now, now);
assert!(conn.is_active);
conn.is_active = false;
assert!(!conn.is_active);
}
#[tokio::test]
async fn test_connection_pool_acquire_release() {
let pool = ConnectionPool::new(PoolConfig::default());
let key = PoolKey::new("example.com".to_string(), 443);
let now = Instant::now();
let conn = PooledConnection::mock(now, now);
pool.release(key.clone(), conn).await;
let acquired = pool.acquire(key.clone()).await;
assert!(acquired.is_some());
let stats = pool.stats().await;
assert_eq!(stats.total_released, 1);
assert_eq!(stats.total_acquired, 1);
}
#[tokio::test]
async fn test_connection_pool_acquire_empty() {
let pool = ConnectionPool::new(PoolConfig::default());
let key = PoolKey::new("example.com".to_string(), 443);
let acquired = pool.acquire(key).await;
assert!(acquired.is_none());
}
#[tokio::test]
async fn test_connection_pool_cleanup_expired() {
let mut config = PoolConfig::default();
config.idle_timeout = Duration::from_millis(50);
let pool = ConnectionPool::new(config);
let key = PoolKey::new("example.com".to_string(), 443);
let now = Instant::now();
let conn = PooledConnection::mock(
now - Duration::from_millis(100), now - Duration::from_millis(100) );
pool.release(key.clone(), conn).await;
pool.cleanup_all().await;
let acquired = pool.acquire(key).await;
assert!(acquired.is_none(), "Expired connection should be cleaned up");
}
#[tokio::test]
async fn test_connection_pool_stats() {
let pool = ConnectionPool::new(PoolConfig::default());
let key = PoolKey::new("example.com".to_string(), 443);
let stats1 = pool.stats().await;
assert_eq!(stats1.total_acquired, 0);
assert_eq!(stats1.total_released, 0);
let now = Instant::now();
let conn = PooledConnection::mock(now, now);
pool.release(key.clone(), conn).await;
let _ = pool.acquire(key).await;
let stats2 = pool.stats().await;
assert_eq!(stats2.total_released, 1);
assert_eq!(stats2.total_acquired, 1);
}
#[tokio::test]
async fn test_connection_pool_hit_rate() {
let pool = ConnectionPool::new(PoolConfig::default());
let key = PoolKey::new("example.com".to_string(), 443);
let now = Instant::now();
let conn = PooledConnection::mock(now, now);
pool.release(key.clone(), conn).await;
pool.acquire(key.clone()).await;
pool.acquire(key.clone()).await;
let hit_rate = pool.hit_rate().await;
assert!(hit_rate >= 0.0 && hit_rate <= 1.0);
}
#[tokio::test]
async fn test_pool_config_custom() {
let config = PoolConfig {
max_connections_per_host: 50,
idle_timeout: Duration::from_secs(60),
max_idle_connections: 500,
};
assert_eq!(config.max_connections_per_host, 50);
assert_eq!(config.idle_timeout.as_secs(), 60);
assert_eq!(config.max_idle_connections, 500);
}
#[tokio::test]
async fn test_pool_key_different_hosts() {
let key1 = PoolKey::new("example.com".to_string(), 443);
let key2 = PoolKey::new("other.com".to_string(), 443);
let key3 = PoolKey::new("example.com".to_string(), 80);
assert_ne!(key1, key2);
assert_ne!(key1, key3);
assert_ne!(key2, key3);
}
#[tokio::test]
async fn test_connection_pool_concurrent() {
let pool = Arc::new(ConnectionPool::new(PoolConfig::default()));
let key = PoolKey::new("example.com".to_string(), 443);
let mut handles = vec![];
for _i in 0..10 {
let pool_clone = pool.clone();
let key_clone = key.clone();
let handle = tokio::spawn(async move {
let _ = pool_clone.acquire(key_clone.clone()).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let now = Instant::now();
let conn = PooledConnection::mock(now, now);
pool_clone.release(key_clone, conn).await;
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let stats = pool.stats().await;
assert_eq!(stats.total_released, 10);
}
#[tokio::test]
async fn test_pooled_connection_creation_times() {
let now = Instant::now();
let conn = PooledConnection::mock(now, now);
assert_eq!(conn.created_at, conn.last_used);
assert!(conn.is_active);
}
}