1use std::collections::VecDeque;
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use tokio::sync::{mpsc, oneshot, Mutex, Notify};
9
10pub trait Poolable: Send + 'static {
16 type Error: std::error::Error + Send + Sync + 'static;
18
19 fn connect(
21 addr: &str,
22 user: &str,
23 password: &str,
24 database: &str,
25 ) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send
26 where
27 Self: Sized;
28
29 fn has_pending_data(&self) -> bool;
31
32 fn reset(&self) -> impl std::future::Future<Output = bool> + Send {
41 async { true } }
43}
44
45#[derive(Debug)]
51#[non_exhaustive]
52pub enum PoolError<E: std::error::Error> {
53 Connect(E),
55 Draining,
57 Timeout,
59 Closed,
61 AtCapacity,
63}
64
65impl<E: std::error::Error> std::fmt::Display for PoolError<E> {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::Connect(e) => write!(f, "connection error: {e}"),
69 Self::Draining => write!(f, "pool is draining"),
70 Self::Timeout => write!(f, "checkout timeout"),
71 Self::Closed => write!(f, "pool closed"),
72 Self::AtCapacity => write!(f, "pool at max capacity"),
73 }
74 }
75}
76
77impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
78 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79 match self {
80 Self::Connect(e) => Some(e),
81 _ => None,
82 }
83 }
84}
85
86#[derive(Clone)]
96#[non_exhaustive]
97pub struct ConnPoolConfig {
98 pub addr: String,
100 pub user: String,
102 pub password: String,
104 pub database: String,
106 pub min_idle: usize,
108 pub max_size: usize,
110 pub max_lifetime: Duration,
112 pub max_lifetime_jitter: Duration,
114 pub checkout_timeout: Duration,
116 pub maintenance_interval: Duration,
118 pub test_on_checkout: bool,
120}
121
122impl std::fmt::Debug for ConnPoolConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("ConnPoolConfig")
125 .field("addr", &self.addr)
126 .field("user", &self.user)
127 .field("password", &"<redacted>")
128 .field("database", &self.database)
129 .field("min_idle", &self.min_idle)
130 .field("max_size", &self.max_size)
131 .field("max_lifetime", &self.max_lifetime)
132 .field("max_lifetime_jitter", &self.max_lifetime_jitter)
133 .field("checkout_timeout", &self.checkout_timeout)
134 .field("maintenance_interval", &self.maintenance_interval)
135 .field("test_on_checkout", &self.test_on_checkout)
136 .finish()
137 }
138}
139
140impl Default for ConnPoolConfig {
141 fn default() -> Self {
142 Self {
143 addr: String::new(),
144 user: String::new(),
145 password: String::new(),
146 database: String::new(),
147 min_idle: 1,
148 max_size: 10,
149 max_lifetime: Duration::from_secs(30 * 60),
150 max_lifetime_jitter: Duration::from_secs(60),
151 checkout_timeout: Duration::from_secs(5),
152 maintenance_interval: Duration::from_secs(10),
153 test_on_checkout: true,
154 }
155 }
156}
157
158type ConnHook<C> = Option<Box<dyn Fn(&C) + Send + Sync>>;
164type Hook = Option<Box<dyn Fn() + Send + Sync>>;
166
167#[non_exhaustive]
177pub struct LifecycleHooks<C> {
178 pub on_create: ConnHook<C>,
180 pub before_acquire: Hook,
182 pub on_checkout: ConnHook<C>,
184 pub on_checkin: ConnHook<C>,
186 pub after_release: Hook,
188 pub on_destroy: Hook,
190}
191
192impl<C> std::fmt::Debug for LifecycleHooks<C> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("LifecycleHooks")
195 .field("on_create", &self.on_create.is_some())
196 .field("before_acquire", &self.before_acquire.is_some())
197 .field("on_checkout", &self.on_checkout.is_some())
198 .field("on_checkin", &self.on_checkin.is_some())
199 .field("after_release", &self.after_release.is_some())
200 .field("on_destroy", &self.on_destroy.is_some())
201 .finish()
202 }
203}
204
205impl<C> Default for LifecycleHooks<C> {
206 fn default() -> Self {
207 Self {
208 on_create: None,
209 before_acquire: None,
210 on_checkout: None,
211 on_checkin: None,
212 after_release: None,
213 on_destroy: None,
214 }
215 }
216}
217
218#[derive(Debug, Clone)]
227#[non_exhaustive]
228pub struct PoolMetrics {
229 pub total: usize,
231 pub idle: usize,
233 pub in_use: usize,
235 pub waiters: usize,
237 pub total_checkouts: u64,
239 pub total_created: u64,
241 pub total_destroyed: u64,
243 pub total_timeouts: u64,
245}
246
247struct IdleConn<C> {
252 conn: C,
253 expires_at: Instant,
254}
255
256struct WaiterCountGuard<'a> {
259 counter: &'a AtomicUsize,
260}
261
262impl Drop for WaiterCountGuard<'_> {
263 fn drop(&mut self) {
264 self.counter.fetch_sub(1, Ordering::Relaxed);
265 }
266}
267
268struct Waiter<C> {
269 tx: oneshot::Sender<C>,
270}
271
272pub struct ConnPool<C: Poolable> {
282 config: ConnPoolConfig,
283 hooks: LifecycleHooks<C>,
284 idle: Mutex<VecDeque<IdleConn<C>>>,
285 waiters: Mutex<VecDeque<Waiter<C>>>,
286 total_count: AtomicUsize,
287 in_use_count: AtomicUsize,
288 waiter_count: AtomicUsize,
289 total_checkouts: AtomicU64,
290 total_created: AtomicU64,
291 total_destroyed: AtomicU64,
292 total_timeouts: AtomicU64,
293 draining: AtomicBool,
294 drain_complete: Notify,
295 shutdown_tx: mpsc::Sender<()>,
296}
297
298impl<C: Poolable> std::fmt::Debug for ConnPool<C> {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("ConnPool")
301 .field("config", &self.config)
302 .field("metrics", &self.metrics())
303 .field("draining", &self.draining.load(Ordering::Relaxed))
304 .finish()
305 }
306}
307
308impl<C: Poolable> ConnPool<C> {
309 pub async fn new(
311 config: ConnPoolConfig,
312 hooks: LifecycleHooks<C>,
313 ) -> Result<Arc<Self>, PoolError<C::Error>> {
314 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
315
316 let pool = Arc::new(Self {
317 config: config.clone(),
318 hooks,
319 idle: Mutex::new(VecDeque::with_capacity(config.max_size)),
320 waiters: Mutex::new(VecDeque::new()),
321 total_count: AtomicUsize::new(0),
322 in_use_count: AtomicUsize::new(0),
323 waiter_count: AtomicUsize::new(0),
324 total_checkouts: AtomicU64::new(0),
325 total_created: AtomicU64::new(0),
326 total_destroyed: AtomicU64::new(0),
327 total_timeouts: AtomicU64::new(0),
328 draining: AtomicBool::new(false),
329 drain_complete: Notify::new(),
330 shutdown_tx,
331 });
332
333 for _ in 0..config.min_idle {
334 match pool.create_connection().await {
335 Ok(idle_conn) => {
336 pool.idle.lock().await.push_back(idle_conn);
337 pool.total_count.fetch_add(1, Ordering::Release);
338 }
339 Err(e) => {
340 tracing::warn!("Failed to pre-fill connection: {e}");
341 }
342 }
343 }
344
345 {
346 let pool_ref = Arc::clone(&pool);
347 tokio::spawn(maintenance_task(pool_ref, shutdown_rx));
348 }
349
350 Ok(pool)
351 }
352
353 pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<C>, PoolError<C::Error>> {
355 if self.draining.load(Ordering::Acquire) {
356 return Err(PoolError::Draining);
357 }
358
359 if let Some(ref hook) = self.hooks.before_acquire {
360 hook();
361 }
362
363 if let Some(conn) = self.try_get_idle().await {
364 self.in_use_count.fetch_add(1, Ordering::Release);
365 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
366 if let Some(ref hook) = self.hooks.on_checkout {
367 hook(&conn);
368 }
369 return Ok(PoolGuard {
370 conn: Some(conn),
371 pool: Arc::clone(self),
372 });
373 }
374
375 if self.total_count.load(Ordering::Acquire) < self.config.max_size {
376 match self.create_and_track().await {
377 Ok(conn) => {
378 self.in_use_count.fetch_add(1, Ordering::Release);
379 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
380 if let Some(ref hook) = self.hooks.on_checkout {
381 hook(&conn);
382 }
383 return Ok(PoolGuard {
384 conn: Some(conn),
385 pool: Arc::clone(self),
386 });
387 }
388 Err(e) => {
389 tracing::warn!("Failed to create new connection: {e}");
390 }
391 }
392 }
393
394 let (tx, rx) = oneshot::channel();
395 {
396 let mut waiters = self.waiters.lock().await;
397 waiters.push_back(Waiter { tx });
398 self.waiter_count.fetch_add(1, Ordering::Relaxed);
399 }
400
401 let _waiter_guard = WaiterCountGuard {
405 counter: &self.waiter_count,
406 };
407
408 match tokio::time::timeout(self.config.checkout_timeout, rx).await {
409 Ok(Ok(conn)) => {
410 self.in_use_count.fetch_add(1, Ordering::Release);
411 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
412 if let Some(ref hook) = self.hooks.on_checkout {
413 hook(&conn);
414 }
415 Ok(PoolGuard {
416 conn: Some(conn),
417 pool: Arc::clone(self),
418 })
419 }
420 Ok(Err(_)) => Err(PoolError::Closed),
421 Err(_) => {
422 self.total_timeouts.fetch_add(1, Ordering::Relaxed);
423 {
427 let mut waiters = self.waiters.lock().await;
428 waiters.retain(|w| !w.tx.is_closed());
430 }
431 Err(PoolError::Timeout)
432 }
433 }
434 }
435
436 async fn try_get_idle(&self) -> Option<C> {
437 let mut idle = self.idle.lock().await;
438 while let Some(entry) = idle.pop_front() {
439 if Instant::now() >= entry.expires_at {
440 self.destroy_conn_stats();
441 if let Some(ref hook) = self.hooks.on_destroy {
442 hook();
443 }
444 continue;
445 }
446 if self.config.test_on_checkout && entry.conn.has_pending_data() {
447 self.destroy_conn_stats();
448 if let Some(ref hook) = self.hooks.on_destroy {
449 hook();
450 }
451 continue;
452 }
453 return Some(entry.conn);
454 }
455 None
456 }
457
458 async fn create_connection(&self) -> Result<IdleConn<C>, C::Error> {
459 let conn = C::connect(
460 &self.config.addr,
461 &self.config.user,
462 &self.config.password,
463 &self.config.database,
464 )
465 .await?;
466
467 self.total_created.fetch_add(1, Ordering::Relaxed);
468 if let Some(ref hook) = self.hooks.on_create {
469 hook(&conn);
470 }
471
472 let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
473 Ok(IdleConn {
474 conn,
475 expires_at: Instant::now() + jitter,
476 })
477 }
478
479 async fn create_and_track(&self) -> Result<C, PoolError<C::Error>> {
480 let prev = self.total_count.fetch_add(1, Ordering::Release);
481 if prev >= self.config.max_size {
482 self.total_count.fetch_sub(1, Ordering::Release);
483 return Err(PoolError::AtCapacity);
484 }
485
486 match self.create_connection().await {
487 Ok(idle_conn) => Ok(idle_conn.conn),
488 Err(e) => {
489 self.total_count.fetch_sub(1, Ordering::Release);
490 Err(PoolError::Connect(e))
491 }
492 }
493 }
494
495 fn return_conn(pool: Arc<Self>, conn: C) {
496 tokio::spawn(async move {
497 pool.return_conn_async(conn).await;
498 });
499 }
500
501 async fn return_conn_async(&self, conn: C) {
502 self.in_use_count.fetch_sub(1, Ordering::Release);
506
507 if conn.has_pending_data() {
508 self.destroy_conn_stats();
509 if let Some(ref hook) = self.hooks.on_destroy {
510 hook();
511 }
512 if let Some(ref hook) = self.hooks.after_release {
513 hook();
514 }
515 self.maybe_notify_drain();
516 return;
517 }
518
519 if !conn.reset().await {
521 self.destroy_conn_stats();
522 if let Some(ref hook) = self.hooks.on_destroy {
523 hook();
524 }
525 if let Some(ref hook) = self.hooks.after_release {
526 hook();
527 }
528 self.maybe_notify_drain();
529 return;
530 }
531
532 if let Some(ref hook) = self.hooks.on_checkin {
533 hook(&conn);
534 }
535
536 if self.draining.load(Ordering::Acquire) {
537 self.destroy_conn_stats();
538 if let Some(ref hook) = self.hooks.on_destroy {
539 hook();
540 }
541 if let Some(ref hook) = self.hooks.after_release {
542 hook();
543 }
544 self.maybe_notify_drain();
545 return;
546 }
547
548 let mut conn = conn;
549 {
550 let mut waiters = self.waiters.lock().await;
551 while let Some(waiter) = waiters.pop_front() {
552 match waiter.tx.send(conn) {
553 Ok(()) => {
554 if let Some(ref hook) = self.hooks.after_release {
555 hook();
556 }
557 return;
558 }
559 Err(returned_conn) => {
560 conn = returned_conn;
561 continue;
562 }
563 }
564 }
565 }
566
567 let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
568 let mut idle = self.idle.lock().await;
569 idle.push_back(IdleConn {
570 conn,
571 expires_at: Instant::now() + jitter,
572 });
573 if let Some(ref hook) = self.hooks.after_release {
574 hook();
575 }
576 }
577
578 fn maybe_notify_drain(&self) {
579 if self.draining.load(Ordering::Acquire) && self.total_count.load(Ordering::Acquire) == 0 {
580 self.drain_complete.notify_one();
581 }
582 }
583
584 fn destroy_conn_stats(&self) {
585 self.total_count.fetch_sub(1, Ordering::Release);
586 self.total_destroyed.fetch_add(1, Ordering::Relaxed);
587 }
588
589 pub fn metrics(&self) -> PoolMetrics {
596 let total = self.total_count.load(Ordering::Acquire);
597 let in_use = self.in_use_count.load(Ordering::Acquire);
598 PoolMetrics {
599 total,
600 idle: total.saturating_sub(in_use),
601 in_use,
602 waiters: self.waiter_count.load(Ordering::Relaxed),
603 total_checkouts: self.total_checkouts.load(Ordering::Relaxed),
604 total_created: self.total_created.load(Ordering::Relaxed),
605 total_destroyed: self.total_destroyed.load(Ordering::Relaxed),
606 total_timeouts: self.total_timeouts.load(Ordering::Relaxed),
607 }
608 }
609
610 pub async fn warm_up(&self, target: usize) {
613 let current = self.metrics().total;
614 let to_create = target
615 .saturating_sub(current)
616 .min(self.config.max_size - current);
617 let mut created = 0;
618 for _ in 0..to_create {
619 match self.create_connection().await {
620 Ok(idle_conn) => {
621 self.idle.lock().await.push_back(idle_conn);
622 self.total_count.fetch_add(1, Ordering::Release);
623 created += 1;
624 }
625 Err(e) => {
626 tracing::warn!("warm_up: failed to create connection: {e}");
627 break;
628 }
629 }
630 }
631 if created > 0 {
632 tracing::info!(created, target, "pool warm-up complete");
633 }
634 }
635
636 pub async fn drain(&self) {
638 self.draining.store(true, Ordering::Release);
639
640 let destroyed_count = {
643 let mut idle = self.idle.lock().await;
644 let count = idle.len();
645 idle.clear();
646 self.total_count.fetch_sub(count, Ordering::Release);
647 self.total_destroyed
648 .fetch_add(count as u64, Ordering::Relaxed);
649 count
650 };
651 if destroyed_count > 0 {
653 if let Some(ref hook) = self.hooks.on_destroy {
654 for _ in 0..destroyed_count {
655 hook();
656 }
657 }
658 }
659
660 {
661 let mut waiters = self.waiters.lock().await;
662 let waiter_count = waiters.len();
663 waiters.clear();
664 self.waiter_count.fetch_sub(waiter_count, Ordering::Relaxed);
665 }
666
667 loop {
668 let notified = self.drain_complete.notified();
669 if self.total_count.load(Ordering::Acquire) == 0 {
670 break;
671 }
672 notified.await;
673 }
674
675 let _ = self.shutdown_tx.send(()).await;
676 tracing::info!("Connection pool drained");
677 }
678
679 pub fn status(&self) -> String {
681 let m = self.metrics();
682 format!(
683 "pool: total={} idle={} in_use={} created={} destroyed={} timeouts={}",
684 m.total, m.idle, m.in_use, m.total_created, m.total_destroyed, m.total_timeouts
685 )
686 }
687}
688
689pub struct PoolGuard<C: Poolable> {
695 conn: Option<C>,
696 pool: Arc<ConnPool<C>>,
697}
698
699impl<C: Poolable + std::fmt::Debug> std::fmt::Debug for PoolGuard<C> {
700 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701 f.debug_struct("PoolGuard")
702 .field("conn", &self.conn)
703 .finish_non_exhaustive()
704 }
705}
706
707impl<C: Poolable> PoolGuard<C> {
708 pub fn conn(&self) -> &C {
710 self.conn
711 .as_ref()
712 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
713 }
714
715 pub fn conn_mut(&mut self) -> &mut C {
717 self.conn
718 .as_mut()
719 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
720 }
721
722 pub fn take(mut self) -> C {
725 let conn = self
726 .conn
727 .take()
728 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)");
729 self.pool.in_use_count.fetch_sub(1, Ordering::Release);
730 self.pool.total_count.fetch_sub(1, Ordering::Release);
731 conn
732 }
733}
734
735impl<C: Poolable> Drop for PoolGuard<C> {
736 fn drop(&mut self) {
737 if let Some(conn) = self.conn.take() {
738 ConnPool::return_conn(Arc::clone(&self.pool), conn);
739 }
740 }
741}
742
743impl<C: Poolable> std::ops::Deref for PoolGuard<C> {
744 type Target = C;
745 fn deref(&self) -> &Self::Target {
746 self.conn
747 .as_ref()
748 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
749 }
750}
751
752impl<C: Poolable> std::ops::DerefMut for PoolGuard<C> {
753 fn deref_mut(&mut self) -> &mut Self::Target {
754 self.conn
755 .as_mut()
756 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
757 }
758}
759
760async fn maintenance_task<C: Poolable>(
765 pool: Arc<ConnPool<C>>,
766 mut shutdown_rx: mpsc::Receiver<()>,
767) {
768 let mut interval = tokio::time::interval(pool.config.maintenance_interval);
769 interval.tick().await;
770 loop {
771 tokio::select! {
772 _ = interval.tick() => {}
773 _ = shutdown_rx.recv() => {
774 tracing::debug!("Maintenance task shutting down");
775 return;
776 }
777 }
778
779 if pool.draining.load(Ordering::Acquire) {
780 return;
781 }
782
783 {
784 let mut idle = pool.idle.lock().await;
785 let now = Instant::now();
786 let before = idle.len();
787 idle.retain(|entry| now < entry.expires_at);
788 let evicted = before - idle.len();
789 if evicted > 0 {
790 pool.total_count.fetch_sub(evicted, Ordering::Release);
791 pool.total_destroyed
792 .fetch_add(evicted as u64, Ordering::Relaxed);
793 tracing::debug!("Evicted {evicted} expired connections");
794 }
795 }
796
797 let total = pool.total_count.load(Ordering::Acquire);
798 let in_use = pool.in_use_count.load(Ordering::Acquire);
799 let current_idle = total.saturating_sub(in_use);
800
801 if current_idle < pool.config.min_idle && total < pool.config.max_size {
802 let to_create = (pool.config.min_idle - current_idle).min(pool.config.max_size - total);
803 for _ in 0..to_create {
804 match pool.create_and_track().await {
805 Ok(conn) => {
806 let jitter = jittered_duration(
807 pool.config.max_lifetime,
808 pool.config.max_lifetime_jitter,
809 );
810 let mut idle = pool.idle.lock().await;
811 idle.push_back(IdleConn {
812 conn,
813 expires_at: Instant::now() + jitter,
814 });
815 }
816 Err(_) => break,
817 }
818 }
819 }
820 }
821}
822
823fn jittered_duration(base: Duration, jitter: Duration) -> Duration {
828 if jitter.is_zero() {
829 return base;
830 }
831 let jitter_ms = jitter.as_millis() as u64;
832 let offset = fastrand_u64() % (jitter_ms * 2 + 1);
833 let jittered = base.as_millis() as i128 + offset as i128 - jitter_ms as i128;
834 Duration::from_millis(jittered.max(1) as u64)
835}
836
837fn fastrand_u64() -> u64 {
838 use std::cell::Cell;
839 thread_local! {
840 static STATE: Cell<u64> = Cell::new(
841 std::time::SystemTime::now()
842 .duration_since(std::time::UNIX_EPOCH)
843 .unwrap_or_default()
844 .as_nanos() as u64
845 );
846 }
847 STATE.with(|s| {
848 let mut x = s.get();
849 x ^= x << 13;
850 x ^= x >> 7;
851 x ^= x << 17;
852 if x == 0 {
853 x = 1;
854 }
855 s.set(x);
856 x
857 })
858}