use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot, Mutex, Notify};
pub trait Poolable: Send + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn connect(
addr: &str,
user: &str,
password: &str,
database: &str,
) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send
where
Self: Sized;
fn has_pending_data(&self) -> bool;
fn reset(&self) -> impl std::future::Future<Output = bool> + Send {
async { true } }
}
#[derive(Debug)]
#[non_exhaustive]
pub enum PoolError<E: std::error::Error> {
Connect(E),
Draining,
Timeout,
Closed,
AtCapacity,
}
impl<E: std::error::Error> std::fmt::Display for PoolError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connect(e) => write!(f, "connection error: {e}"),
Self::Draining => write!(f, "pool is draining"),
Self::Timeout => write!(f, "checkout timeout"),
Self::Closed => write!(f, "pool closed"),
Self::AtCapacity => write!(f, "pool at max capacity"),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Connect(e) => Some(e),
_ => None,
}
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct ConnPoolConfig {
pub addr: String,
pub user: String,
pub password: String,
pub database: String,
pub min_idle: usize,
pub max_size: usize,
pub max_lifetime: Duration,
pub max_lifetime_jitter: Duration,
pub checkout_timeout: Duration,
pub maintenance_interval: Duration,
pub test_on_checkout: bool,
}
impl std::fmt::Debug for ConnPoolConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnPoolConfig")
.field("addr", &self.addr)
.field("user", &self.user)
.field("password", &"<redacted>")
.field("database", &self.database)
.field("min_idle", &self.min_idle)
.field("max_size", &self.max_size)
.field("max_lifetime", &self.max_lifetime)
.field("max_lifetime_jitter", &self.max_lifetime_jitter)
.field("checkout_timeout", &self.checkout_timeout)
.field("maintenance_interval", &self.maintenance_interval)
.field("test_on_checkout", &self.test_on_checkout)
.finish()
}
}
impl Default for ConnPoolConfig {
fn default() -> Self {
Self {
addr: String::new(),
user: String::new(),
password: String::new(),
database: String::new(),
min_idle: 1,
max_size: 10,
max_lifetime: Duration::from_secs(30 * 60),
max_lifetime_jitter: Duration::from_secs(60),
checkout_timeout: Duration::from_secs(5),
maintenance_interval: Duration::from_secs(10),
test_on_checkout: true,
}
}
}
type ConnHook<C> = Option<Box<dyn Fn(&C) + Send + Sync>>;
type Hook = Option<Box<dyn Fn() + Send + Sync>>;
#[non_exhaustive]
pub struct LifecycleHooks<C> {
pub on_create: ConnHook<C>,
pub before_acquire: Hook,
pub on_checkout: ConnHook<C>,
pub on_checkin: ConnHook<C>,
pub after_release: Hook,
pub on_destroy: Hook,
}
impl<C> std::fmt::Debug for LifecycleHooks<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LifecycleHooks")
.field("on_create", &self.on_create.is_some())
.field("before_acquire", &self.before_acquire.is_some())
.field("on_checkout", &self.on_checkout.is_some())
.field("on_checkin", &self.on_checkin.is_some())
.field("after_release", &self.after_release.is_some())
.field("on_destroy", &self.on_destroy.is_some())
.finish()
}
}
impl<C> Default for LifecycleHooks<C> {
fn default() -> Self {
Self {
on_create: None,
before_acquire: None,
on_checkout: None,
on_checkin: None,
after_release: None,
on_destroy: None,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct PoolMetrics {
pub total: usize,
pub idle: usize,
pub in_use: usize,
pub waiters: usize,
pub total_checkouts: u64,
pub total_created: u64,
pub total_destroyed: u64,
pub total_timeouts: u64,
}
struct IdleConn<C> {
conn: C,
expires_at: Instant,
}
struct WaiterCountGuard<'a> {
counter: &'a AtomicUsize,
}
impl Drop for WaiterCountGuard<'_> {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
}
}
struct Waiter<C> {
tx: oneshot::Sender<C>,
}
pub struct ConnPool<C: Poolable> {
config: ConnPoolConfig,
hooks: LifecycleHooks<C>,
idle: Mutex<VecDeque<IdleConn<C>>>,
waiters: Mutex<VecDeque<Waiter<C>>>,
total_count: AtomicUsize,
in_use_count: AtomicUsize,
waiter_count: AtomicUsize,
total_checkouts: AtomicU64,
total_created: AtomicU64,
total_destroyed: AtomicU64,
total_timeouts: AtomicU64,
draining: AtomicBool,
drain_complete: Notify,
shutdown_tx: mpsc::Sender<()>,
}
impl<C: Poolable> std::fmt::Debug for ConnPool<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnPool")
.field("config", &self.config)
.field("metrics", &self.metrics())
.field("draining", &self.draining.load(Ordering::Relaxed))
.finish()
}
}
impl<C: Poolable> ConnPool<C> {
pub async fn new(
config: ConnPoolConfig,
hooks: LifecycleHooks<C>,
) -> Result<Arc<Self>, PoolError<C::Error>> {
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let pool = Arc::new(Self {
config: config.clone(),
hooks,
idle: Mutex::new(VecDeque::with_capacity(config.max_size)),
waiters: Mutex::new(VecDeque::new()),
total_count: AtomicUsize::new(0),
in_use_count: AtomicUsize::new(0),
waiter_count: AtomicUsize::new(0),
total_checkouts: AtomicU64::new(0),
total_created: AtomicU64::new(0),
total_destroyed: AtomicU64::new(0),
total_timeouts: AtomicU64::new(0),
draining: AtomicBool::new(false),
drain_complete: Notify::new(),
shutdown_tx,
});
for _ in 0..config.min_idle {
match pool.create_connection().await {
Ok(idle_conn) => {
pool.idle.lock().await.push_back(idle_conn);
pool.total_count.fetch_add(1, Ordering::Release);
}
Err(e) => {
tracing::warn!("Failed to pre-fill connection: {e}");
}
}
}
{
let pool_ref = Arc::clone(&pool);
tokio::spawn(maintenance_task(pool_ref, shutdown_rx));
}
Ok(pool)
}
pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<C>, PoolError<C::Error>> {
if self.draining.load(Ordering::Acquire) {
return Err(PoolError::Draining);
}
if let Some(ref hook) = self.hooks.before_acquire {
hook();
}
if let Some(conn) = self.try_get_idle().await {
self.in_use_count.fetch_add(1, Ordering::Release);
self.total_checkouts.fetch_add(1, Ordering::Relaxed);
if let Some(ref hook) = self.hooks.on_checkout {
hook(&conn);
}
return Ok(PoolGuard {
conn: Some(conn),
pool: Arc::clone(self),
});
}
if self.total_count.load(Ordering::Acquire) < self.config.max_size {
match self.create_and_track().await {
Ok(conn) => {
self.in_use_count.fetch_add(1, Ordering::Release);
self.total_checkouts.fetch_add(1, Ordering::Relaxed);
if let Some(ref hook) = self.hooks.on_checkout {
hook(&conn);
}
return Ok(PoolGuard {
conn: Some(conn),
pool: Arc::clone(self),
});
}
Err(e) => {
tracing::warn!("Failed to create new connection: {e}");
}
}
}
let (tx, rx) = oneshot::channel();
{
let mut waiters = self.waiters.lock().await;
waiters.push_back(Waiter { tx });
self.waiter_count.fetch_add(1, Ordering::Relaxed);
}
let _waiter_guard = WaiterCountGuard {
counter: &self.waiter_count,
};
match tokio::time::timeout(self.config.checkout_timeout, rx).await {
Ok(Ok(conn)) => {
self.in_use_count.fetch_add(1, Ordering::Release);
self.total_checkouts.fetch_add(1, Ordering::Relaxed);
if let Some(ref hook) = self.hooks.on_checkout {
hook(&conn);
}
Ok(PoolGuard {
conn: Some(conn),
pool: Arc::clone(self),
})
}
Ok(Err(_)) => Err(PoolError::Closed),
Err(_) => {
self.total_timeouts.fetch_add(1, Ordering::Relaxed);
{
let mut waiters = self.waiters.lock().await;
waiters.retain(|w| !w.tx.is_closed());
}
Err(PoolError::Timeout)
}
}
}
async fn try_get_idle(&self) -> Option<C> {
let mut idle = self.idle.lock().await;
while let Some(entry) = idle.pop_front() {
if Instant::now() >= entry.expires_at {
self.destroy_conn_stats();
if let Some(ref hook) = self.hooks.on_destroy {
hook();
}
continue;
}
if self.config.test_on_checkout && entry.conn.has_pending_data() {
self.destroy_conn_stats();
if let Some(ref hook) = self.hooks.on_destroy {
hook();
}
continue;
}
return Some(entry.conn);
}
None
}
async fn create_connection(&self) -> Result<IdleConn<C>, C::Error> {
let conn = C::connect(
&self.config.addr,
&self.config.user,
&self.config.password,
&self.config.database,
)
.await?;
self.total_created.fetch_add(1, Ordering::Relaxed);
if let Some(ref hook) = self.hooks.on_create {
hook(&conn);
}
let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
Ok(IdleConn {
conn,
expires_at: Instant::now() + jitter,
})
}
async fn create_and_track(&self) -> Result<C, PoolError<C::Error>> {
let prev = self.total_count.fetch_add(1, Ordering::Release);
if prev >= self.config.max_size {
self.total_count.fetch_sub(1, Ordering::Release);
return Err(PoolError::AtCapacity);
}
match self.create_connection().await {
Ok(idle_conn) => Ok(idle_conn.conn),
Err(e) => {
self.total_count.fetch_sub(1, Ordering::Release);
Err(PoolError::Connect(e))
}
}
}
fn return_conn(pool: Arc<Self>, conn: C) {
tokio::spawn(async move {
pool.return_conn_async(conn).await;
});
}
async fn return_conn_async(&self, conn: C) {
self.in_use_count.fetch_sub(1, Ordering::Release);
if conn.has_pending_data() {
self.destroy_conn_stats();
if let Some(ref hook) = self.hooks.on_destroy {
hook();
}
self.try_provision_for_waiter().await;
if let Some(ref hook) = self.hooks.after_release {
hook();
}
self.maybe_notify_drain();
return;
}
if !conn.reset().await {
self.destroy_conn_stats();
if let Some(ref hook) = self.hooks.on_destroy {
hook();
}
self.try_provision_for_waiter().await;
if let Some(ref hook) = self.hooks.after_release {
hook();
}
self.maybe_notify_drain();
return;
}
if let Some(ref hook) = self.hooks.on_checkin {
hook(&conn);
}
if self.draining.load(Ordering::Acquire) {
self.destroy_conn_stats();
if let Some(ref hook) = self.hooks.on_destroy {
hook();
}
if let Some(ref hook) = self.hooks.after_release {
hook();
}
self.maybe_notify_drain();
return;
}
let mut conn = conn;
{
let mut waiters = self.waiters.lock().await;
while let Some(waiter) = waiters.pop_front() {
match waiter.tx.send(conn) {
Ok(()) => {
if let Some(ref hook) = self.hooks.after_release {
hook();
}
return;
}
Err(returned_conn) => {
conn = returned_conn;
continue;
}
}
}
}
let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
let mut idle = self.idle.lock().await;
idle.push_back(IdleConn {
conn,
expires_at: Instant::now() + jitter,
});
if let Some(ref hook) = self.hooks.after_release {
hook();
}
}
fn maybe_notify_drain(&self) {
if self.draining.load(Ordering::Acquire) && self.total_count.load(Ordering::Acquire) == 0 {
self.drain_complete.notify_one();
}
}
async fn try_provision_for_waiter(&self) {
if self.draining.load(Ordering::Acquire) {
return;
}
let has_waiter = {
let waiters = self.waiters.lock().await;
!waiters.is_empty()
};
if !has_waiter {
return;
}
let mut conn = match self.create_and_track().await {
Ok(c) => c,
Err(e) => {
tracing::warn!(
"failed to provision replacement for waiter after conn destroyed: {e}"
);
return;
}
};
{
let mut waiters = self.waiters.lock().await;
while let Some(waiter) = waiters.pop_front() {
match waiter.tx.send(conn) {
Ok(()) => return,
Err(returned) => {
conn = returned;
continue;
}
}
}
}
let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
let mut idle = self.idle.lock().await;
idle.push_back(IdleConn {
conn,
expires_at: Instant::now() + jitter,
});
}
fn destroy_conn_stats(&self) {
self.total_count.fetch_sub(1, Ordering::Release);
self.total_destroyed.fetch_add(1, Ordering::Relaxed);
}
pub fn metrics(&self) -> PoolMetrics {
let total = self.total_count.load(Ordering::Acquire);
let in_use = self.in_use_count.load(Ordering::Acquire);
PoolMetrics {
total,
idle: total.saturating_sub(in_use),
in_use,
waiters: self.waiter_count.load(Ordering::Relaxed),
total_checkouts: self.total_checkouts.load(Ordering::Relaxed),
total_created: self.total_created.load(Ordering::Relaxed),
total_destroyed: self.total_destroyed.load(Ordering::Relaxed),
total_timeouts: self.total_timeouts.load(Ordering::Relaxed),
}
}
pub async fn warm_up(&self, target: usize) {
let current = self.metrics().total;
let to_create = target
.saturating_sub(current)
.min(self.config.max_size - current);
let mut created = 0;
for _ in 0..to_create {
match self.create_connection().await {
Ok(idle_conn) => {
self.idle.lock().await.push_back(idle_conn);
self.total_count.fetch_add(1, Ordering::Release);
created += 1;
}
Err(e) => {
tracing::warn!("warm_up: failed to create connection: {e}");
break;
}
}
}
if created > 0 {
tracing::info!(created, target, "pool warm-up complete");
}
}
pub async fn drain(&self) {
self.draining.store(true, Ordering::Release);
let destroyed_count = {
let mut idle = self.idle.lock().await;
let count = idle.len();
idle.clear();
self.total_count.fetch_sub(count, Ordering::Release);
self.total_destroyed
.fetch_add(count as u64, Ordering::Relaxed);
count
};
if destroyed_count > 0 {
if let Some(ref hook) = self.hooks.on_destroy {
for _ in 0..destroyed_count {
hook();
}
}
}
{
let mut waiters = self.waiters.lock().await;
let waiter_count = waiters.len();
waiters.clear();
self.waiter_count.fetch_sub(waiter_count, Ordering::Relaxed);
}
loop {
let notified = self.drain_complete.notified();
if self.total_count.load(Ordering::Acquire) == 0 {
break;
}
notified.await;
}
let _ = self.shutdown_tx.send(()).await;
tracing::info!("Connection pool drained");
}
pub fn status(&self) -> String {
let m = self.metrics();
format!(
"pool: total={} idle={} in_use={} created={} destroyed={} timeouts={}",
m.total, m.idle, m.in_use, m.total_created, m.total_destroyed, m.total_timeouts
)
}
}
pub struct PoolGuard<C: Poolable> {
conn: Option<C>,
pool: Arc<ConnPool<C>>,
}
impl<C: Poolable + std::fmt::Debug> std::fmt::Debug for PoolGuard<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolGuard")
.field("conn", &self.conn)
.finish_non_exhaustive()
}
}
impl<C: Poolable> PoolGuard<C> {
pub fn conn(&self) -> &C {
self.conn
.as_ref()
.expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
}
pub fn conn_mut(&mut self) -> &mut C {
self.conn
.as_mut()
.expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
}
pub fn take(mut self) -> C {
let conn = self
.conn
.take()
.expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)");
self.pool.in_use_count.fetch_sub(1, Ordering::Release);
self.pool.total_count.fetch_sub(1, Ordering::Release);
conn
}
}
impl<C: Poolable> Drop for PoolGuard<C> {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
ConnPool::return_conn(Arc::clone(&self.pool), conn);
}
}
}
impl<C: Poolable> std::ops::Deref for PoolGuard<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
}
}
impl<C: Poolable> std::ops::DerefMut for PoolGuard<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
}
}
async fn maintenance_task<C: Poolable>(
pool: Arc<ConnPool<C>>,
mut shutdown_rx: mpsc::Receiver<()>,
) {
let mut interval = tokio::time::interval(pool.config.maintenance_interval);
interval.tick().await;
loop {
tokio::select! {
_ = interval.tick() => {}
_ = shutdown_rx.recv() => {
tracing::debug!("Maintenance task shutting down");
return;
}
}
if pool.draining.load(Ordering::Acquire) {
return;
}
{
let mut idle = pool.idle.lock().await;
let now = Instant::now();
let before = idle.len();
idle.retain(|entry| now < entry.expires_at);
let evicted = before - idle.len();
if evicted > 0 {
pool.total_count.fetch_sub(evicted, Ordering::Release);
pool.total_destroyed
.fetch_add(evicted as u64, Ordering::Relaxed);
tracing::debug!("Evicted {evicted} expired connections");
}
}
let total = pool.total_count.load(Ordering::Acquire);
let in_use = pool.in_use_count.load(Ordering::Acquire);
let current_idle = total.saturating_sub(in_use);
if current_idle < pool.config.min_idle && total < pool.config.max_size {
let to_create = (pool.config.min_idle - current_idle).min(pool.config.max_size - total);
for _ in 0..to_create {
match pool.create_and_track().await {
Ok(conn) => {
let jitter = jittered_duration(
pool.config.max_lifetime,
pool.config.max_lifetime_jitter,
);
let mut idle = pool.idle.lock().await;
idle.push_back(IdleConn {
conn,
expires_at: Instant::now() + jitter,
});
}
Err(_) => break,
}
}
}
}
}
fn jittered_duration(base: Duration, jitter: Duration) -> Duration {
if jitter.is_zero() {
return base;
}
let jitter_ms = jitter.as_millis() as u64;
let offset = fastrand_u64() % (jitter_ms * 2 + 1);
let jittered = base.as_millis() as i128 + offset as i128 - jitter_ms as i128;
Duration::from_millis(jittered.max(1) as u64)
}
fn fastrand_u64() -> u64 {
use std::cell::Cell;
thread_local! {
static STATE: Cell<u64> = Cell::new(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
);
}
STATE.with(|s| {
let mut x = s.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
if x == 0 {
x = 1;
}
s.set(x);
x
})
}