use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use log::debug;
#[derive(Debug)]
pub struct ConnectionPool {
connections: Arc<Mutex<HashMap<String, Vec<(TcpStream, Instant)>>>>,
max_idle_time: Duration,
}
impl ConnectionPool {
pub fn new() -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
max_idle_time: Duration::from_secs(30),
}
}
pub fn with_idle_timeout(timeout: Duration) -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
max_idle_time: timeout,
}
}
pub async fn get_or_create(&self, target_addr: &str) -> Result<TcpStream, std::io::Error> {
if let Some(stream) = self.get(target_addr).await {
debug!("Reusing connection from pool for {}", target_addr);
return Ok(stream);
}
debug!("Creating new connection to {}", target_addr);
TcpStream::connect(target_addr).await
}
pub async fn get(&self, target_addr: &str) -> Option<TcpStream> {
let mut pool = self.connections.lock().await;
if let Some(connections) = pool.get_mut(target_addr) {
let now = Instant::now();
connections.retain(|(_, instant)| now.duration_since(*instant) < self.max_idle_time);
if let Some((stream, _)) = connections.pop() {
return Some(stream);
}
}
None
}
pub async fn put(&self, target_addr: String, stream: TcpStream) {
let mut pool = self.connections.lock().await;
let connections = pool.entry(target_addr).or_default();
connections.push((stream, Instant::now()));
debug!("Returned connection to pool, total: {}", connections.len());
}
pub async fn clear(&self) {
let mut pool = self.connections.lock().await;
pool.clear();
}
pub async fn len(&self) -> usize {
let pool = self.connections.lock().await;
pool.values().map(|v| v.len()).sum()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}