Skip to main content

pg_pool/
pool.rs

1//! Generic connection pool implementation.
2
3use std::collections::VecDeque;
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use tokio::sync::{mpsc, oneshot, Mutex, Notify};
9
10// ---------------------------------------------------------------------------
11// Poolable trait
12// ---------------------------------------------------------------------------
13
14/// Trait for connection types that can be managed by the pool.
15pub trait Poolable: Send + 'static {
16    /// The error type for connection operations.
17    type Error: std::error::Error + Send + Sync + 'static;
18
19    /// Create a new connection to the database.
20    fn connect(
21        addr: &str,
22        user: &str,
23        password: &str,
24        database: &str,
25    ) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send
26    where
27        Self: Sized;
28
29    /// Check if the connection has unconsumed data (is in a corrupted state).
30    fn has_pending_data(&self) -> bool;
31
32    /// Reset the connection to a clean state before returning to the pool.
33    /// Implementations should send `DISCARD ALL` or equivalent to clear
34    /// session state (transactions, SET variables, temp tables, prepared statements).
35    /// Returns false if the reset failed and the connection should be destroyed.
36    ///
37    /// **Must not panic.** A panic in `reset()` will cause the spawned return
38    /// task to abort, but `in_use_count` is decremented before `reset()` is
39    /// called so the pool remains consistent.
40    fn reset(&self) -> impl std::future::Future<Output = bool> + Send {
41        async { true } // default: no-op (backward compatible)
42    }
43}
44
45// ---------------------------------------------------------------------------
46// Pool error
47// ---------------------------------------------------------------------------
48
49/// Errors returned by pool operations.
50#[derive(Debug)]
51#[non_exhaustive]
52pub enum PoolError<E: std::error::Error> {
53    /// Connection creation failed.
54    Connect(E),
55    /// Pool is draining (shutting down).
56    Draining,
57    /// Checkout timed out waiting for an available connection.
58    Timeout,
59    /// Pool is closed.
60    Closed,
61    /// Pool is at maximum capacity.
62    AtCapacity,
63}
64
65impl<E: std::error::Error> std::fmt::Display for PoolError<E> {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Connect(e) => write!(f, "connection error: {e}"),
69            Self::Draining => write!(f, "pool is draining"),
70            Self::Timeout => write!(f, "checkout timeout"),
71            Self::Closed => write!(f, "pool closed"),
72            Self::AtCapacity => write!(f, "pool at max capacity"),
73        }
74    }
75}
76
77impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
78    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79        match self {
80            Self::Connect(e) => Some(e),
81            _ => None,
82        }
83    }
84}
85
86// ---------------------------------------------------------------------------
87// Configuration
88// ---------------------------------------------------------------------------
89
90/// Connection pool configuration.
91///
92/// Construct via [`ConnPoolConfig::default`] and update only the fields you care
93/// about. Marked `#[non_exhaustive]` so adding new tuning knobs in future minor
94/// releases is not a breaking change.
95#[derive(Clone)]
96#[non_exhaustive]
97pub struct ConnPoolConfig {
98    /// Address (host:port).
99    pub addr: String,
100    /// User.
101    pub user: String,
102    /// Password.
103    pub password: String,
104    /// Database.
105    pub database: String,
106    /// Minimum idle connections to maintain.
107    pub min_idle: usize,
108    /// Maximum total connections.
109    pub max_size: usize,
110    /// Maximum lifetime per connection (with jitter applied).
111    pub max_lifetime: Duration,
112    /// Jitter range for max_lifetime (± this value).
113    pub max_lifetime_jitter: Duration,
114    /// Timeout waiting for a connection from the pool.
115    pub checkout_timeout: Duration,
116    /// How often to run the maintenance task.
117    pub maintenance_interval: Duration,
118    /// Whether to verify connections on checkout.
119    pub test_on_checkout: bool,
120}
121
122impl std::fmt::Debug for ConnPoolConfig {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("ConnPoolConfig")
125            .field("addr", &self.addr)
126            .field("user", &self.user)
127            .field("password", &"<redacted>")
128            .field("database", &self.database)
129            .field("min_idle", &self.min_idle)
130            .field("max_size", &self.max_size)
131            .field("max_lifetime", &self.max_lifetime)
132            .field("max_lifetime_jitter", &self.max_lifetime_jitter)
133            .field("checkout_timeout", &self.checkout_timeout)
134            .field("maintenance_interval", &self.maintenance_interval)
135            .field("test_on_checkout", &self.test_on_checkout)
136            .finish()
137    }
138}
139
140impl Default for ConnPoolConfig {
141    fn default() -> Self {
142        Self {
143            addr: String::new(),
144            user: String::new(),
145            password: String::new(),
146            database: String::new(),
147            min_idle: 1,
148            max_size: 10,
149            max_lifetime: Duration::from_secs(30 * 60),
150            max_lifetime_jitter: Duration::from_secs(60),
151            checkout_timeout: Duration::from_secs(5),
152            maintenance_interval: Duration::from_secs(10),
153            test_on_checkout: true,
154        }
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Lifecycle hooks
160// ---------------------------------------------------------------------------
161
162/// A hook that receives a connection reference.
163type ConnHook<C> = Option<Box<dyn Fn(&C) + Send + Sync>>;
164/// A hook with no parameters.
165type Hook = Option<Box<dyn Fn() + Send + Sync>>;
166
167/// Lifecycle hook callbacks. All hooks are optional.
168///
169/// Connection-aware hooks (`on_create`, `on_checkout`, `on_checkin`) receive a `&C`
170/// reference to the connection. Non-connection hooks (`before_acquire`, `after_release`,
171/// `on_destroy`) take no parameters — `on_destroy` because the connection may be invalid,
172/// and `before_acquire`/`after_release` because no specific connection is involved yet.
173///
174/// Marked `#[non_exhaustive]` so additional hooks can be introduced in future
175/// minor releases without breaking downstream construction.
176#[non_exhaustive]
177pub struct LifecycleHooks<C> {
178    /// Called after a new connection is created.
179    pub on_create: ConnHook<C>,
180    /// Called before attempting to acquire a connection (checkout starts).
181    pub before_acquire: Hook,
182    /// Called when a connection is checked out and ready to use.
183    pub on_checkout: ConnHook<C>,
184    /// Called when a connection passes health checks and is about to return to the pool.
185    pub on_checkin: ConnHook<C>,
186    /// Called after a connection is fully released (all exit paths from return).
187    pub after_release: Hook,
188    /// Called when a connection is destroyed (expired, invalid, or during drain).
189    pub on_destroy: Hook,
190}
191
192impl<C> std::fmt::Debug for LifecycleHooks<C> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.debug_struct("LifecycleHooks")
195            .field("on_create", &self.on_create.is_some())
196            .field("before_acquire", &self.before_acquire.is_some())
197            .field("on_checkout", &self.on_checkout.is_some())
198            .field("on_checkin", &self.on_checkin.is_some())
199            .field("after_release", &self.after_release.is_some())
200            .field("on_destroy", &self.on_destroy.is_some())
201            .finish()
202    }
203}
204
205impl<C> Default for LifecycleHooks<C> {
206    fn default() -> Self {
207        Self {
208            on_create: None,
209            before_acquire: None,
210            on_checkout: None,
211            on_checkin: None,
212            after_release: None,
213            on_destroy: None,
214        }
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Pool metrics
220// ---------------------------------------------------------------------------
221
222/// Snapshot of pool metrics.
223///
224/// Marked `#[non_exhaustive]` so new counters can be added without breaking
225/// downstream pattern matches or struct construction.
226#[derive(Debug, Clone)]
227#[non_exhaustive]
228pub struct PoolMetrics {
229    /// Total number of connections currently held by the pool (idle + in-use).
230    pub total: usize,
231    /// Number of connections currently parked on the idle stack and ready to hand out.
232    pub idle: usize,
233    /// Number of connections currently checked out via [`PoolGuard`].
234    pub in_use: usize,
235    /// Number of tasks currently parked waiting for a connection.
236    pub waiters: usize,
237    /// Cumulative count of successful checkouts since the pool was created.
238    pub total_checkouts: u64,
239    /// Cumulative count of new connections opened by the pool.
240    pub total_created: u64,
241    /// Cumulative count of connections destroyed (closed, evicted, or failed validation).
242    pub total_destroyed: u64,
243    /// Cumulative count of `get()` calls that timed out before acquiring a connection.
244    pub total_timeouts: u64,
245}
246
247// ---------------------------------------------------------------------------
248// Internal types
249// ---------------------------------------------------------------------------
250
251struct IdleConn<C> {
252    conn: C,
253    expires_at: Instant,
254}
255
256/// Decrements `waiter_count` when dropped. Ensures the gauge stays accurate
257/// even if the `get()` future is cancelled while parked on the waiter queue.
258struct WaiterCountGuard<'a> {
259    counter: &'a AtomicUsize,
260}
261
262impl Drop for WaiterCountGuard<'_> {
263    fn drop(&mut self) {
264        self.counter.fetch_sub(1, Ordering::Relaxed);
265    }
266}
267
268struct Waiter<C> {
269    tx: oneshot::Sender<C>,
270}
271
272// ---------------------------------------------------------------------------
273// ConnPool
274// ---------------------------------------------------------------------------
275
276/// Production-grade connection pool, generic over connection type `C`.
277///
278/// **Hook safety:** Lifecycle hooks must not call back into the pool (e.g.,
279/// calling `get()` from inside a hook will deadlock). Hooks should be fast
280/// and non-blocking.
281pub struct ConnPool<C: Poolable> {
282    config: ConnPoolConfig,
283    hooks: LifecycleHooks<C>,
284    idle: Mutex<VecDeque<IdleConn<C>>>,
285    waiters: Mutex<VecDeque<Waiter<C>>>,
286    total_count: AtomicUsize,
287    in_use_count: AtomicUsize,
288    waiter_count: AtomicUsize,
289    total_checkouts: AtomicU64,
290    total_created: AtomicU64,
291    total_destroyed: AtomicU64,
292    total_timeouts: AtomicU64,
293    draining: AtomicBool,
294    drain_complete: Notify,
295    shutdown_tx: mpsc::Sender<()>,
296}
297
298impl<C: Poolable> std::fmt::Debug for ConnPool<C> {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("ConnPool")
301            .field("config", &self.config)
302            .field("metrics", &self.metrics())
303            .field("draining", &self.draining.load(Ordering::Relaxed))
304            .finish()
305    }
306}
307
308impl<C: Poolable> ConnPool<C> {
309    /// Create a new connection pool and spawn the maintenance task.
310    pub async fn new(
311        config: ConnPoolConfig,
312        hooks: LifecycleHooks<C>,
313    ) -> Result<Arc<Self>, PoolError<C::Error>> {
314        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
315
316        let pool = Arc::new(Self {
317            config: config.clone(),
318            hooks,
319            idle: Mutex::new(VecDeque::with_capacity(config.max_size)),
320            waiters: Mutex::new(VecDeque::new()),
321            total_count: AtomicUsize::new(0),
322            in_use_count: AtomicUsize::new(0),
323            waiter_count: AtomicUsize::new(0),
324            total_checkouts: AtomicU64::new(0),
325            total_created: AtomicU64::new(0),
326            total_destroyed: AtomicU64::new(0),
327            total_timeouts: AtomicU64::new(0),
328            draining: AtomicBool::new(false),
329            drain_complete: Notify::new(),
330            shutdown_tx,
331        });
332
333        for _ in 0..config.min_idle {
334            match pool.create_connection().await {
335                Ok(idle_conn) => {
336                    pool.idle.lock().await.push_back(idle_conn);
337                    pool.total_count.fetch_add(1, Ordering::Release);
338                }
339                Err(e) => {
340                    tracing::warn!("Failed to pre-fill connection: {e}");
341                }
342            }
343        }
344
345        {
346            let pool_ref = Arc::clone(&pool);
347            tokio::spawn(maintenance_task(pool_ref, shutdown_rx));
348        }
349
350        Ok(pool)
351    }
352
353    /// Check out a connection from the pool.
354    pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<C>, PoolError<C::Error>> {
355        if self.draining.load(Ordering::Acquire) {
356            return Err(PoolError::Draining);
357        }
358
359        if let Some(ref hook) = self.hooks.before_acquire {
360            hook();
361        }
362
363        if let Some(conn) = self.try_get_idle().await {
364            self.in_use_count.fetch_add(1, Ordering::Release);
365            self.total_checkouts.fetch_add(1, Ordering::Relaxed);
366            if let Some(ref hook) = self.hooks.on_checkout {
367                hook(&conn);
368            }
369            return Ok(PoolGuard {
370                conn: Some(conn),
371                pool: Arc::clone(self),
372            });
373        }
374
375        if self.total_count.load(Ordering::Acquire) < self.config.max_size {
376            match self.create_and_track().await {
377                Ok(conn) => {
378                    self.in_use_count.fetch_add(1, Ordering::Release);
379                    self.total_checkouts.fetch_add(1, Ordering::Relaxed);
380                    if let Some(ref hook) = self.hooks.on_checkout {
381                        hook(&conn);
382                    }
383                    return Ok(PoolGuard {
384                        conn: Some(conn),
385                        pool: Arc::clone(self),
386                    });
387                }
388                Err(e) => {
389                    tracing::warn!("Failed to create new connection: {e}");
390                }
391            }
392        }
393
394        let (tx, rx) = oneshot::channel();
395        {
396            let mut waiters = self.waiters.lock().await;
397            waiters.push_back(Waiter { tx });
398            self.waiter_count.fetch_add(1, Ordering::Relaxed);
399        }
400
401        // Decrement waiter_count on every exit path, including future cancellation.
402        // Without this, a caller that drops the get() future (e.g. via tokio::select!
403        // or an outer timeout) would leak the counter, eventually saturating it.
404        let _waiter_guard = WaiterCountGuard {
405            counter: &self.waiter_count,
406        };
407
408        match tokio::time::timeout(self.config.checkout_timeout, rx).await {
409            Ok(Ok(conn)) => {
410                self.in_use_count.fetch_add(1, Ordering::Release);
411                self.total_checkouts.fetch_add(1, Ordering::Relaxed);
412                if let Some(ref hook) = self.hooks.on_checkout {
413                    hook(&conn);
414                }
415                Ok(PoolGuard {
416                    conn: Some(conn),
417                    pool: Arc::clone(self),
418                })
419            }
420            Ok(Err(_)) => Err(PoolError::Closed),
421            Err(_) => {
422                self.total_timeouts.fetch_add(1, Ordering::Relaxed);
423                // Clean up our dead waiter from the queue to prevent unbounded growth.
424                // The sender (tx) is dropped by the timeout, so return_conn_async will
425                // skip it, but we should remove the entry to free memory.
426                {
427                    let mut waiters = self.waiters.lock().await;
428                    // Remove waiters whose receiver has been dropped (tx.is_closed()).
429                    waiters.retain(|w| !w.tx.is_closed());
430                }
431                Err(PoolError::Timeout)
432            }
433        }
434    }
435
436    async fn try_get_idle(&self) -> Option<C> {
437        let mut idle = self.idle.lock().await;
438        while let Some(entry) = idle.pop_front() {
439            if Instant::now() >= entry.expires_at {
440                self.destroy_conn_stats();
441                if let Some(ref hook) = self.hooks.on_destroy {
442                    hook();
443                }
444                continue;
445            }
446            if self.config.test_on_checkout && entry.conn.has_pending_data() {
447                self.destroy_conn_stats();
448                if let Some(ref hook) = self.hooks.on_destroy {
449                    hook();
450                }
451                continue;
452            }
453            return Some(entry.conn);
454        }
455        None
456    }
457
458    async fn create_connection(&self) -> Result<IdleConn<C>, C::Error> {
459        let conn = C::connect(
460            &self.config.addr,
461            &self.config.user,
462            &self.config.password,
463            &self.config.database,
464        )
465        .await?;
466
467        self.total_created.fetch_add(1, Ordering::Relaxed);
468        if let Some(ref hook) = self.hooks.on_create {
469            hook(&conn);
470        }
471
472        let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
473        Ok(IdleConn {
474            conn,
475            expires_at: Instant::now() + jitter,
476        })
477    }
478
479    async fn create_and_track(&self) -> Result<C, PoolError<C::Error>> {
480        let prev = self.total_count.fetch_add(1, Ordering::Release);
481        if prev >= self.config.max_size {
482            self.total_count.fetch_sub(1, Ordering::Release);
483            return Err(PoolError::AtCapacity);
484        }
485
486        match self.create_connection().await {
487            Ok(idle_conn) => Ok(idle_conn.conn),
488            Err(e) => {
489                self.total_count.fetch_sub(1, Ordering::Release);
490                Err(PoolError::Connect(e))
491            }
492        }
493    }
494
495    fn return_conn(pool: Arc<Self>, conn: C) {
496        tokio::spawn(async move {
497            pool.return_conn_async(conn).await;
498        });
499    }
500
501    async fn return_conn_async(&self, conn: C) {
502        // Decrement in_use immediately — the connection is no longer "in use"
503        // regardless of whether it goes back to idle, to a waiter, or is destroyed.
504        // This ensures the counter stays accurate even if reset() or a hook panics.
505        self.in_use_count.fetch_sub(1, Ordering::Release);
506
507        if conn.has_pending_data() {
508            self.destroy_conn_stats();
509            if let Some(ref hook) = self.hooks.on_destroy {
510                hook();
511            }
512            self.try_provision_for_waiter().await;
513            if let Some(ref hook) = self.hooks.after_release {
514                hook();
515            }
516            self.maybe_notify_drain();
517            return;
518        }
519
520        // Reset connection state (DISCARD ALL) to prevent dirty state leaking.
521        if !conn.reset().await {
522            self.destroy_conn_stats();
523            if let Some(ref hook) = self.hooks.on_destroy {
524                hook();
525            }
526            self.try_provision_for_waiter().await;
527            if let Some(ref hook) = self.hooks.after_release {
528                hook();
529            }
530            self.maybe_notify_drain();
531            return;
532        }
533
534        if let Some(ref hook) = self.hooks.on_checkin {
535            hook(&conn);
536        }
537
538        if self.draining.load(Ordering::Acquire) {
539            self.destroy_conn_stats();
540            if let Some(ref hook) = self.hooks.on_destroy {
541                hook();
542            }
543            if let Some(ref hook) = self.hooks.after_release {
544                hook();
545            }
546            self.maybe_notify_drain();
547            return;
548        }
549
550        let mut conn = conn;
551        {
552            let mut waiters = self.waiters.lock().await;
553            while let Some(waiter) = waiters.pop_front() {
554                match waiter.tx.send(conn) {
555                    Ok(()) => {
556                        if let Some(ref hook) = self.hooks.after_release {
557                            hook();
558                        }
559                        return;
560                    }
561                    Err(returned_conn) => {
562                        conn = returned_conn;
563                        continue;
564                    }
565                }
566            }
567        }
568
569        let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
570        let mut idle = self.idle.lock().await;
571        idle.push_back(IdleConn {
572            conn,
573            expires_at: Instant::now() + jitter,
574        });
575        if let Some(ref hook) = self.hooks.after_release {
576            hook();
577        }
578    }
579
580    fn maybe_notify_drain(&self) {
581        if self.draining.load(Ordering::Acquire) && self.total_count.load(Ordering::Acquire) == 0 {
582            self.drain_complete.notify_one();
583        }
584    }
585
586    /// After destroying a broken/unresettable connection on return, opportunistically
587    /// provision a fresh connection for a parked waiter. Without this, waiters parked
588    /// at `max_size` capacity would sleep until `checkout_timeout` even though the
589    /// destroyed connection just freed a slot.
590    ///
591    /// No-ops if there are no waiters, the pool is draining, or capacity is full.
592    /// Failures to create a replacement are logged and dropped: the waiter will
593    /// time out normally, which is the same outcome as before this hook existed.
594    async fn try_provision_for_waiter(&self) {
595        if self.draining.load(Ordering::Acquire) {
596            return;
597        }
598        let has_waiter = {
599            let waiters = self.waiters.lock().await;
600            !waiters.is_empty()
601        };
602        if !has_waiter {
603            return;
604        }
605
606        let mut conn = match self.create_and_track().await {
607            Ok(c) => c,
608            Err(e) => {
609                tracing::warn!(
610                    "failed to provision replacement for waiter after conn destroyed: {e}"
611                );
612                return;
613            }
614        };
615
616        {
617            let mut waiters = self.waiters.lock().await;
618            while let Some(waiter) = waiters.pop_front() {
619                match waiter.tx.send(conn) {
620                    Ok(()) => return,
621                    Err(returned) => {
622                        conn = returned;
623                        continue;
624                    }
625                }
626            }
627        }
628
629        // No live waiter received the conn; park it on the idle stack.
630        let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
631        let mut idle = self.idle.lock().await;
632        idle.push_back(IdleConn {
633            conn,
634            expires_at: Instant::now() + jitter,
635        });
636    }
637
638    fn destroy_conn_stats(&self) {
639        self.total_count.fetch_sub(1, Ordering::Release);
640        self.total_destroyed.fetch_add(1, Ordering::Relaxed);
641    }
642
643    /// Get a snapshot of pool metrics.
644    ///
645    /// Note: metrics are read from atomic counters without a global lock,
646    /// so values may be slightly inconsistent during high concurrency
647    /// (e.g., `in_use` could briefly exceed `total`). Use `saturating_sub`
648    /// for derived values.
649    pub fn metrics(&self) -> PoolMetrics {
650        let total = self.total_count.load(Ordering::Acquire);
651        let in_use = self.in_use_count.load(Ordering::Acquire);
652        PoolMetrics {
653            total,
654            idle: total.saturating_sub(in_use),
655            in_use,
656            waiters: self.waiter_count.load(Ordering::Relaxed),
657            total_checkouts: self.total_checkouts.load(Ordering::Relaxed),
658            total_created: self.total_created.load(Ordering::Relaxed),
659            total_destroyed: self.total_destroyed.load(Ordering::Relaxed),
660            total_timeouts: self.total_timeouts.load(Ordering::Relaxed),
661        }
662    }
663
664    /// Pre-populate the pool to a target number of idle connections.
665    /// Useful for warming up on startup to avoid first-request latency.
666    pub async fn warm_up(&self, target: usize) {
667        let current = self.metrics().total;
668        let to_create = target
669            .saturating_sub(current)
670            .min(self.config.max_size - current);
671        let mut created = 0;
672        for _ in 0..to_create {
673            match self.create_connection().await {
674                Ok(idle_conn) => {
675                    self.idle.lock().await.push_back(idle_conn);
676                    self.total_count.fetch_add(1, Ordering::Release);
677                    created += 1;
678                }
679                Err(e) => {
680                    tracing::warn!("warm_up: failed to create connection: {e}");
681                    break;
682                }
683            }
684        }
685        if created > 0 {
686            tracing::info!(created, target, "pool warm-up complete");
687        }
688    }
689
690    /// Initiate graceful drain.
691    pub async fn drain(&self) {
692        self.draining.store(true, Ordering::Release);
693
694        // Clear idle connections — release the lock BEFORE calling hooks
695        // to prevent deadlocks if a hook interacts with the pool.
696        let destroyed_count = {
697            let mut idle = self.idle.lock().await;
698            let count = idle.len();
699            idle.clear();
700            self.total_count.fetch_sub(count, Ordering::Release);
701            self.total_destroyed
702                .fetch_add(count as u64, Ordering::Relaxed);
703            count
704        };
705        // Hooks called outside the lock.
706        if destroyed_count > 0 {
707            if let Some(ref hook) = self.hooks.on_destroy {
708                for _ in 0..destroyed_count {
709                    hook();
710                }
711            }
712        }
713
714        {
715            let mut waiters = self.waiters.lock().await;
716            let waiter_count = waiters.len();
717            waiters.clear();
718            self.waiter_count.fetch_sub(waiter_count, Ordering::Relaxed);
719        }
720
721        loop {
722            let notified = self.drain_complete.notified();
723            if self.total_count.load(Ordering::Acquire) == 0 {
724                break;
725            }
726            notified.await;
727        }
728
729        let _ = self.shutdown_tx.send(()).await;
730        tracing::info!("Connection pool drained");
731    }
732
733    /// Current pool status string.
734    pub fn status(&self) -> String {
735        let m = self.metrics();
736        format!(
737            "pool: total={} idle={} in_use={} created={} destroyed={} timeouts={}",
738            m.total, m.idle, m.in_use, m.total_created, m.total_destroyed, m.total_timeouts
739        )
740    }
741}
742
743// ---------------------------------------------------------------------------
744// PoolGuard
745// ---------------------------------------------------------------------------
746
747/// A checked-out connection that returns itself to the pool on drop.
748pub struct PoolGuard<C: Poolable> {
749    conn: Option<C>,
750    pool: Arc<ConnPool<C>>,
751}
752
753impl<C: Poolable + std::fmt::Debug> std::fmt::Debug for PoolGuard<C> {
754    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
755        f.debug_struct("PoolGuard")
756            .field("conn", &self.conn)
757            .finish_non_exhaustive()
758    }
759}
760
761impl<C: Poolable> PoolGuard<C> {
762    /// Borrow the wrapped connection. Panics if [`PoolGuard::take`] was already called.
763    pub fn conn(&self) -> &C {
764        self.conn
765            .as_ref()
766            .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
767    }
768
769    /// Mutably borrow the wrapped connection. Panics if [`PoolGuard::take`] was already called.
770    pub fn conn_mut(&mut self) -> &mut C {
771        self.conn
772            .as_mut()
773            .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
774    }
775
776    /// Take ownership of the connection, removing it from the pool.
777    /// After calling this, the guard must not be used — it will panic.
778    pub fn take(mut self) -> C {
779        let conn = self
780            .conn
781            .take()
782            .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)");
783        self.pool.in_use_count.fetch_sub(1, Ordering::Release);
784        self.pool.total_count.fetch_sub(1, Ordering::Release);
785        conn
786    }
787}
788
789impl<C: Poolable> Drop for PoolGuard<C> {
790    fn drop(&mut self) {
791        if let Some(conn) = self.conn.take() {
792            ConnPool::return_conn(Arc::clone(&self.pool), conn);
793        }
794    }
795}
796
797impl<C: Poolable> std::ops::Deref for PoolGuard<C> {
798    type Target = C;
799    fn deref(&self) -> &Self::Target {
800        self.conn
801            .as_ref()
802            .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
803    }
804}
805
806impl<C: Poolable> std::ops::DerefMut for PoolGuard<C> {
807    fn deref_mut(&mut self) -> &mut Self::Target {
808        self.conn
809            .as_mut()
810            .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
811    }
812}
813
814// ---------------------------------------------------------------------------
815// Maintenance task
816// ---------------------------------------------------------------------------
817
818async fn maintenance_task<C: Poolable>(
819    pool: Arc<ConnPool<C>>,
820    mut shutdown_rx: mpsc::Receiver<()>,
821) {
822    let mut interval = tokio::time::interval(pool.config.maintenance_interval);
823    interval.tick().await;
824    loop {
825        tokio::select! {
826            _ = interval.tick() => {}
827            _ = shutdown_rx.recv() => {
828                tracing::debug!("Maintenance task shutting down");
829                return;
830            }
831        }
832
833        if pool.draining.load(Ordering::Acquire) {
834            return;
835        }
836
837        {
838            let mut idle = pool.idle.lock().await;
839            let now = Instant::now();
840            let before = idle.len();
841            idle.retain(|entry| now < entry.expires_at);
842            let evicted = before - idle.len();
843            if evicted > 0 {
844                pool.total_count.fetch_sub(evicted, Ordering::Release);
845                pool.total_destroyed
846                    .fetch_add(evicted as u64, Ordering::Relaxed);
847                tracing::debug!("Evicted {evicted} expired connections");
848            }
849        }
850
851        let total = pool.total_count.load(Ordering::Acquire);
852        let in_use = pool.in_use_count.load(Ordering::Acquire);
853        let current_idle = total.saturating_sub(in_use);
854
855        if current_idle < pool.config.min_idle && total < pool.config.max_size {
856            let to_create = (pool.config.min_idle - current_idle).min(pool.config.max_size - total);
857            for _ in 0..to_create {
858                match pool.create_and_track().await {
859                    Ok(conn) => {
860                        let jitter = jittered_duration(
861                            pool.config.max_lifetime,
862                            pool.config.max_lifetime_jitter,
863                        );
864                        let mut idle = pool.idle.lock().await;
865                        idle.push_back(IdleConn {
866                            conn,
867                            expires_at: Instant::now() + jitter,
868                        });
869                    }
870                    Err(_) => break,
871                }
872            }
873        }
874    }
875}
876
877// ---------------------------------------------------------------------------
878// Helpers
879// ---------------------------------------------------------------------------
880
881fn jittered_duration(base: Duration, jitter: Duration) -> Duration {
882    if jitter.is_zero() {
883        return base;
884    }
885    let jitter_ms = jitter.as_millis() as u64;
886    let offset = fastrand_u64() % (jitter_ms * 2 + 1);
887    let jittered = base.as_millis() as i128 + offset as i128 - jitter_ms as i128;
888    Duration::from_millis(jittered.max(1) as u64)
889}
890
891fn fastrand_u64() -> u64 {
892    use std::cell::Cell;
893    thread_local! {
894        static STATE: Cell<u64> = Cell::new(
895            std::time::SystemTime::now()
896                .duration_since(std::time::UNIX_EPOCH)
897                .unwrap_or_default()
898                .as_nanos() as u64
899        );
900    }
901    STATE.with(|s| {
902        let mut x = s.get();
903        x ^= x << 13;
904        x ^= x >> 7;
905        x ^= x << 17;
906        if x == 0 {
907            x = 1;
908        }
909        s.set(x);
910        x
911    })
912}