use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
pub fn pool_key(node: &str, user: &str, database: &str) -> String {
format!("{}\0{}\0{}", node, user, database)
}
pub struct BackendIdlePool {
idle: DashMap<String, Vec<(TcpStream, Instant)>>,
max_idle_per_key: usize,
max_total_idle: usize,
total_idle: AtomicUsize,
reuses: AtomicU64,
parked: AtomicU64,
over_capacity: AtomicU64,
stale_evicted: AtomicU64,
reaped: AtomicU64,
}
impl BackendIdlePool {
pub fn new(max_idle_per_key: usize, max_total_idle: usize) -> Self {
Self {
idle: DashMap::new(),
max_idle_per_key: max_idle_per_key.max(1),
max_total_idle: max_total_idle.max(1),
total_idle: AtomicUsize::new(0),
reuses: AtomicU64::new(0),
parked: AtomicU64::new(0),
over_capacity: AtomicU64::new(0),
stale_evicted: AtomicU64::new(0),
reaped: AtomicU64::new(0),
}
}
pub fn checkout(&self, key: &str) -> Option<TcpStream> {
let mut guard = self.idle.get_mut(key)?;
while let Some((stream, _parked_at)) = guard.pop() {
self.total_idle.fetch_sub(1, Ordering::Relaxed);
if Self::probe_alive(&stream) {
self.reuses.fetch_add(1, Ordering::Relaxed);
return Some(stream);
}
self.stale_evicted.fetch_add(1, Ordering::Relaxed);
}
None
}
pub fn checkin(&self, key: &str, stream: TcpStream) -> bool {
if self.total_idle.fetch_add(1, Ordering::Relaxed) >= self.max_total_idle {
self.total_idle.fetch_sub(1, Ordering::Relaxed);
self.over_capacity.fetch_add(1, Ordering::Relaxed);
return false; }
let mut entry = self.idle.entry(key.to_string()).or_default();
if entry.len() >= self.max_idle_per_key {
self.total_idle.fetch_sub(1, Ordering::Relaxed);
self.over_capacity.fetch_add(1, Ordering::Relaxed);
return false; }
entry.push((stream, Instant::now()));
self.parked.fetch_add(1, Ordering::Relaxed);
true
}
pub fn reap_idle(&self, max_age: Duration) -> usize {
let mut reaped = 0usize;
for mut entry in self.idle.iter_mut() {
let before = entry.value().len();
entry
.value_mut()
.retain(|(_, parked_at)| parked_at.elapsed() < max_age);
reaped += before - entry.value().len();
}
if reaped > 0 {
self.total_idle.fetch_sub(reaped, Ordering::Relaxed);
self.reaped.fetch_add(reaped as u64, Ordering::Relaxed);
}
reaped
}
fn probe_alive(stream: &TcpStream) -> bool {
let mut probe = [0u8; 1];
matches!(
stream.try_read(&mut probe),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock
)
}
pub fn idle_count(&self) -> usize {
self.total_idle.load(Ordering::Relaxed)
}
pub fn max_total_idle(&self) -> usize {
self.max_total_idle
}
pub fn reaped(&self) -> u64 {
self.reaped.load(Ordering::Relaxed)
}
pub fn reuses(&self) -> u64 {
self.reuses.load(Ordering::Relaxed)
}
pub fn parked(&self) -> u64 {
self.parked.load(Ordering::Relaxed)
}
pub fn over_capacity(&self) -> u64 {
self.over_capacity.load(Ordering::Relaxed)
}
pub fn stale_evicted(&self) -> u64 {
self.stale_evicted.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
async fn live_stream(listener: &TcpListener) -> TcpStream {
let addr = listener.local_addr().unwrap();
let connect = TcpStream::connect(addr);
let accept = listener.accept();
let (client, _server) = tokio::join!(connect, accept);
client.unwrap()
}
#[test]
fn pool_key_is_nul_delimited_and_distinct() {
assert_eq!(pool_key("n", "u", "d"), "n\0u\0d");
assert_ne!(pool_key("n", "ud", ""), pool_key("n", "u", "d"));
}
#[tokio::test]
async fn checkin_then_checkout_reuses_same_connection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(4, 1000);
let key = pool_key("127.0.0.1:5432", "bench", "benchdb");
let s = live_stream(&listener).await;
let parked_addr = s.local_addr().unwrap();
assert!(pool.checkin(&key, s));
assert_eq!(pool.idle_count(), 1);
let got = pool.checkout(&key).expect("a parked connection is reusable");
assert_eq!(got.local_addr().unwrap(), parked_addr, "same socket reused");
assert_eq!(pool.reuses(), 1);
assert_eq!(pool.idle_count(), 0);
assert!(pool.checkout(&key).is_none());
}
#[tokio::test]
async fn distinct_identities_do_not_share() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(4, 1000);
let alice = pool_key("n", "alice", "db");
let bob = pool_key("n", "bob", "db");
pool.checkin(&alice, live_stream(&listener).await);
assert!(pool.checkout(&bob).is_none());
assert!(pool.checkout(&alice).is_some());
}
#[tokio::test]
async fn per_key_cap_sheds_excess() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(2, 1000);
let key = pool_key("n", "u", "d");
assert!(pool.checkin(&key, live_stream(&listener).await));
assert!(pool.checkin(&key, live_stream(&listener).await));
assert!(!pool.checkin(&key, live_stream(&listener).await));
assert_eq!(pool.over_capacity(), 1);
assert_eq!(pool.idle_count(), 2);
}
#[tokio::test]
async fn checkout_evicts_a_closed_connection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(4, 1000);
let key = pool_key("n", "u", "d");
let addr = listener.local_addr().unwrap();
let client = TcpStream::connect(addr).await.unwrap();
let (server, _) = listener.accept().await.unwrap();
pool.checkin(&key, client);
drop(server); tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(pool.checkout(&key).is_none());
assert_eq!(pool.stale_evicted(), 1);
}
#[tokio::test]
async fn global_cap_sheds_across_distinct_identities() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(10, 2);
assert!(pool.checkin(&pool_key("n", "a", "d"), live_stream(&listener).await));
assert!(pool.checkin(&pool_key("n", "b", "d"), live_stream(&listener).await));
assert!(!pool.checkin(&pool_key("n", "c", "d"), live_stream(&listener).await));
assert_eq!(pool.idle_count(), 2);
assert_eq!(pool.over_capacity(), 1);
}
#[tokio::test]
async fn reaper_drops_aged_idle_connections() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let pool = BackendIdlePool::new(4, 100);
let key = pool_key("n", "u", "d");
pool.checkin(&key, live_stream(&listener).await);
assert_eq!(pool.idle_count(), 1);
assert_eq!(pool.reap_idle(std::time::Duration::from_secs(60)), 0);
assert_eq!(pool.idle_count(), 1);
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
assert_eq!(pool.reap_idle(std::time::Duration::from_millis(5)), 1);
assert_eq!(pool.idle_count(), 0);
assert_eq!(pool.reaped(), 1);
assert!(pool.checkout(&key).is_none());
}
}