1use std::collections::VecDeque;
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use tokio::sync::{mpsc, oneshot, Mutex, Notify};
9
10pub trait Poolable: Send + 'static {
16 type Error: std::error::Error + Send + Sync + 'static;
18
19 fn connect(
21 addr: &str,
22 user: &str,
23 password: &str,
24 database: &str,
25 ) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send
26 where
27 Self: Sized;
28
29 fn has_pending_data(&self) -> bool;
31
32 fn reset(&self) -> impl std::future::Future<Output = bool> + Send {
41 async { true } }
43}
44
45#[derive(Debug)]
51#[non_exhaustive]
52pub enum PoolError<E: std::error::Error> {
53 Connect(E),
55 Draining,
57 Timeout,
59 Closed,
61 AtCapacity,
63}
64
65impl<E: std::error::Error> std::fmt::Display for PoolError<E> {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::Connect(e) => write!(f, "connection error: {e}"),
69 Self::Draining => write!(f, "pool is draining"),
70 Self::Timeout => write!(f, "checkout timeout"),
71 Self::Closed => write!(f, "pool closed"),
72 Self::AtCapacity => write!(f, "pool at max capacity"),
73 }
74 }
75}
76
77impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
78 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79 match self {
80 Self::Connect(e) => Some(e),
81 _ => None,
82 }
83 }
84}
85
86#[derive(Clone)]
96#[non_exhaustive]
97pub struct ConnPoolConfig {
98 pub addr: String,
100 pub user: String,
102 pub password: String,
104 pub database: String,
106 pub min_idle: usize,
108 pub max_size: usize,
110 pub max_lifetime: Duration,
112 pub max_lifetime_jitter: Duration,
114 pub checkout_timeout: Duration,
116 pub maintenance_interval: Duration,
118 pub test_on_checkout: bool,
120}
121
122impl std::fmt::Debug for ConnPoolConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("ConnPoolConfig")
125 .field("addr", &self.addr)
126 .field("user", &self.user)
127 .field("password", &"<redacted>")
128 .field("database", &self.database)
129 .field("min_idle", &self.min_idle)
130 .field("max_size", &self.max_size)
131 .field("max_lifetime", &self.max_lifetime)
132 .field("max_lifetime_jitter", &self.max_lifetime_jitter)
133 .field("checkout_timeout", &self.checkout_timeout)
134 .field("maintenance_interval", &self.maintenance_interval)
135 .field("test_on_checkout", &self.test_on_checkout)
136 .finish()
137 }
138}
139
140impl Default for ConnPoolConfig {
141 fn default() -> Self {
142 Self {
143 addr: String::new(),
144 user: String::new(),
145 password: String::new(),
146 database: String::new(),
147 min_idle: 1,
148 max_size: 10,
149 max_lifetime: Duration::from_secs(30 * 60),
150 max_lifetime_jitter: Duration::from_secs(60),
151 checkout_timeout: Duration::from_secs(5),
152 maintenance_interval: Duration::from_secs(10),
153 test_on_checkout: true,
154 }
155 }
156}
157
158type ConnHook<C> = Option<Box<dyn Fn(&C) + Send + Sync>>;
164type Hook = Option<Box<dyn Fn() + Send + Sync>>;
166
167#[non_exhaustive]
177pub struct LifecycleHooks<C> {
178 pub on_create: ConnHook<C>,
180 pub before_acquire: Hook,
182 pub on_checkout: ConnHook<C>,
184 pub on_checkin: ConnHook<C>,
186 pub after_release: Hook,
188 pub on_destroy: Hook,
190}
191
192impl<C> std::fmt::Debug for LifecycleHooks<C> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("LifecycleHooks")
195 .field("on_create", &self.on_create.is_some())
196 .field("before_acquire", &self.before_acquire.is_some())
197 .field("on_checkout", &self.on_checkout.is_some())
198 .field("on_checkin", &self.on_checkin.is_some())
199 .field("after_release", &self.after_release.is_some())
200 .field("on_destroy", &self.on_destroy.is_some())
201 .finish()
202 }
203}
204
205impl<C> Default for LifecycleHooks<C> {
206 fn default() -> Self {
207 Self {
208 on_create: None,
209 before_acquire: None,
210 on_checkout: None,
211 on_checkin: None,
212 after_release: None,
213 on_destroy: None,
214 }
215 }
216}
217
218#[derive(Debug, Clone)]
227#[non_exhaustive]
228pub struct PoolMetrics {
229 pub total: usize,
231 pub idle: usize,
233 pub in_use: usize,
235 pub waiters: usize,
237 pub total_checkouts: u64,
239 pub total_created: u64,
241 pub total_destroyed: u64,
243 pub total_timeouts: u64,
245}
246
247struct IdleConn<C> {
252 conn: C,
253 expires_at: Instant,
254}
255
256struct WaiterCountGuard<'a> {
259 counter: &'a AtomicUsize,
260}
261
262impl Drop for WaiterCountGuard<'_> {
263 fn drop(&mut self) {
264 self.counter.fetch_sub(1, Ordering::Relaxed);
265 }
266}
267
268struct Waiter<C> {
269 tx: oneshot::Sender<C>,
270}
271
272pub struct ConnPool<C: Poolable> {
282 config: ConnPoolConfig,
283 hooks: LifecycleHooks<C>,
284 idle: Mutex<VecDeque<IdleConn<C>>>,
285 waiters: Mutex<VecDeque<Waiter<C>>>,
286 total_count: AtomicUsize,
287 in_use_count: AtomicUsize,
288 waiter_count: AtomicUsize,
289 total_checkouts: AtomicU64,
290 total_created: AtomicU64,
291 total_destroyed: AtomicU64,
292 total_timeouts: AtomicU64,
293 draining: AtomicBool,
294 drain_complete: Notify,
295 shutdown_tx: mpsc::Sender<()>,
296}
297
298impl<C: Poolable> std::fmt::Debug for ConnPool<C> {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("ConnPool")
301 .field("config", &self.config)
302 .field("metrics", &self.metrics())
303 .field("draining", &self.draining.load(Ordering::Relaxed))
304 .finish()
305 }
306}
307
308impl<C: Poolable> ConnPool<C> {
309 pub async fn new(
311 config: ConnPoolConfig,
312 hooks: LifecycleHooks<C>,
313 ) -> Result<Arc<Self>, PoolError<C::Error>> {
314 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
315
316 let pool = Arc::new(Self {
317 config: config.clone(),
318 hooks,
319 idle: Mutex::new(VecDeque::with_capacity(config.max_size)),
320 waiters: Mutex::new(VecDeque::new()),
321 total_count: AtomicUsize::new(0),
322 in_use_count: AtomicUsize::new(0),
323 waiter_count: AtomicUsize::new(0),
324 total_checkouts: AtomicU64::new(0),
325 total_created: AtomicU64::new(0),
326 total_destroyed: AtomicU64::new(0),
327 total_timeouts: AtomicU64::new(0),
328 draining: AtomicBool::new(false),
329 drain_complete: Notify::new(),
330 shutdown_tx,
331 });
332
333 for _ in 0..config.min_idle {
334 match pool.create_connection().await {
335 Ok(idle_conn) => {
336 pool.idle.lock().await.push_back(idle_conn);
337 pool.total_count.fetch_add(1, Ordering::Release);
338 }
339 Err(e) => {
340 tracing::warn!("Failed to pre-fill connection: {e}");
341 }
342 }
343 }
344
345 {
346 let pool_ref = Arc::clone(&pool);
347 tokio::spawn(maintenance_task(pool_ref, shutdown_rx));
348 }
349
350 Ok(pool)
351 }
352
353 pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<C>, PoolError<C::Error>> {
355 if self.draining.load(Ordering::Acquire) {
356 return Err(PoolError::Draining);
357 }
358
359 if let Some(ref hook) = self.hooks.before_acquire {
360 hook();
361 }
362
363 if let Some(conn) = self.try_get_idle().await {
364 self.in_use_count.fetch_add(1, Ordering::Release);
365 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
366 if let Some(ref hook) = self.hooks.on_checkout {
367 hook(&conn);
368 }
369 return Ok(PoolGuard {
370 conn: Some(conn),
371 pool: Arc::clone(self),
372 });
373 }
374
375 if self.total_count.load(Ordering::Acquire) < self.config.max_size {
376 match self.create_and_track().await {
377 Ok(conn) => {
378 self.in_use_count.fetch_add(1, Ordering::Release);
379 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
380 if let Some(ref hook) = self.hooks.on_checkout {
381 hook(&conn);
382 }
383 return Ok(PoolGuard {
384 conn: Some(conn),
385 pool: Arc::clone(self),
386 });
387 }
388 Err(e) => {
389 tracing::warn!("Failed to create new connection: {e}");
390 }
391 }
392 }
393
394 let (tx, rx) = oneshot::channel();
395 {
396 let mut waiters = self.waiters.lock().await;
397 waiters.push_back(Waiter { tx });
398 self.waiter_count.fetch_add(1, Ordering::Relaxed);
399 }
400
401 let _waiter_guard = WaiterCountGuard {
405 counter: &self.waiter_count,
406 };
407
408 match tokio::time::timeout(self.config.checkout_timeout, rx).await {
409 Ok(Ok(conn)) => {
410 self.in_use_count.fetch_add(1, Ordering::Release);
411 self.total_checkouts.fetch_add(1, Ordering::Relaxed);
412 if let Some(ref hook) = self.hooks.on_checkout {
413 hook(&conn);
414 }
415 Ok(PoolGuard {
416 conn: Some(conn),
417 pool: Arc::clone(self),
418 })
419 }
420 Ok(Err(_)) => Err(PoolError::Closed),
421 Err(_) => {
422 self.total_timeouts.fetch_add(1, Ordering::Relaxed);
423 {
427 let mut waiters = self.waiters.lock().await;
428 waiters.retain(|w| !w.tx.is_closed());
430 }
431 Err(PoolError::Timeout)
432 }
433 }
434 }
435
436 async fn try_get_idle(&self) -> Option<C> {
437 let mut idle = self.idle.lock().await;
438 while let Some(entry) = idle.pop_front() {
439 if Instant::now() >= entry.expires_at {
440 self.destroy_conn_stats();
441 if let Some(ref hook) = self.hooks.on_destroy {
442 hook();
443 }
444 continue;
445 }
446 if self.config.test_on_checkout && entry.conn.has_pending_data() {
447 self.destroy_conn_stats();
448 if let Some(ref hook) = self.hooks.on_destroy {
449 hook();
450 }
451 continue;
452 }
453 return Some(entry.conn);
454 }
455 None
456 }
457
458 async fn create_connection(&self) -> Result<IdleConn<C>, C::Error> {
459 let conn = C::connect(
460 &self.config.addr,
461 &self.config.user,
462 &self.config.password,
463 &self.config.database,
464 )
465 .await?;
466
467 self.total_created.fetch_add(1, Ordering::Relaxed);
468 if let Some(ref hook) = self.hooks.on_create {
469 hook(&conn);
470 }
471
472 let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
473 Ok(IdleConn {
474 conn,
475 expires_at: Instant::now() + jitter,
476 })
477 }
478
479 async fn create_and_track(&self) -> Result<C, PoolError<C::Error>> {
480 let prev = self.total_count.fetch_add(1, Ordering::Release);
481 if prev >= self.config.max_size {
482 self.total_count.fetch_sub(1, Ordering::Release);
483 return Err(PoolError::AtCapacity);
484 }
485
486 match self.create_connection().await {
487 Ok(idle_conn) => Ok(idle_conn.conn),
488 Err(e) => {
489 self.total_count.fetch_sub(1, Ordering::Release);
490 Err(PoolError::Connect(e))
491 }
492 }
493 }
494
495 fn return_conn(pool: Arc<Self>, conn: C) {
496 tokio::spawn(async move {
497 pool.return_conn_async(conn).await;
498 });
499 }
500
501 async fn return_conn_async(&self, conn: C) {
502 self.in_use_count.fetch_sub(1, Ordering::Release);
506
507 if conn.has_pending_data() {
508 self.destroy_conn_stats();
509 if let Some(ref hook) = self.hooks.on_destroy {
510 hook();
511 }
512 self.try_provision_for_waiter().await;
513 if let Some(ref hook) = self.hooks.after_release {
514 hook();
515 }
516 self.maybe_notify_drain();
517 return;
518 }
519
520 if !conn.reset().await {
522 self.destroy_conn_stats();
523 if let Some(ref hook) = self.hooks.on_destroy {
524 hook();
525 }
526 self.try_provision_for_waiter().await;
527 if let Some(ref hook) = self.hooks.after_release {
528 hook();
529 }
530 self.maybe_notify_drain();
531 return;
532 }
533
534 if let Some(ref hook) = self.hooks.on_checkin {
535 hook(&conn);
536 }
537
538 if self.draining.load(Ordering::Acquire) {
539 self.destroy_conn_stats();
540 if let Some(ref hook) = self.hooks.on_destroy {
541 hook();
542 }
543 if let Some(ref hook) = self.hooks.after_release {
544 hook();
545 }
546 self.maybe_notify_drain();
547 return;
548 }
549
550 let mut conn = conn;
551 {
552 let mut waiters = self.waiters.lock().await;
553 while let Some(waiter) = waiters.pop_front() {
554 match waiter.tx.send(conn) {
555 Ok(()) => {
556 if let Some(ref hook) = self.hooks.after_release {
557 hook();
558 }
559 return;
560 }
561 Err(returned_conn) => {
562 conn = returned_conn;
563 continue;
564 }
565 }
566 }
567 }
568
569 let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
570 let mut idle = self.idle.lock().await;
571 idle.push_back(IdleConn {
572 conn,
573 expires_at: Instant::now() + jitter,
574 });
575 if let Some(ref hook) = self.hooks.after_release {
576 hook();
577 }
578 }
579
580 fn maybe_notify_drain(&self) {
581 if self.draining.load(Ordering::Acquire) && self.total_count.load(Ordering::Acquire) == 0 {
582 self.drain_complete.notify_one();
583 }
584 }
585
586 async fn try_provision_for_waiter(&self) {
595 if self.draining.load(Ordering::Acquire) {
596 return;
597 }
598 let has_waiter = {
599 let waiters = self.waiters.lock().await;
600 !waiters.is_empty()
601 };
602 if !has_waiter {
603 return;
604 }
605
606 let mut conn = match self.create_and_track().await {
607 Ok(c) => c,
608 Err(e) => {
609 tracing::warn!(
610 "failed to provision replacement for waiter after conn destroyed: {e}"
611 );
612 return;
613 }
614 };
615
616 {
617 let mut waiters = self.waiters.lock().await;
618 while let Some(waiter) = waiters.pop_front() {
619 match waiter.tx.send(conn) {
620 Ok(()) => return,
621 Err(returned) => {
622 conn = returned;
623 continue;
624 }
625 }
626 }
627 }
628
629 let jitter = jittered_duration(self.config.max_lifetime, self.config.max_lifetime_jitter);
631 let mut idle = self.idle.lock().await;
632 idle.push_back(IdleConn {
633 conn,
634 expires_at: Instant::now() + jitter,
635 });
636 }
637
638 fn destroy_conn_stats(&self) {
639 self.total_count.fetch_sub(1, Ordering::Release);
640 self.total_destroyed.fetch_add(1, Ordering::Relaxed);
641 }
642
643 pub fn metrics(&self) -> PoolMetrics {
650 let total = self.total_count.load(Ordering::Acquire);
651 let in_use = self.in_use_count.load(Ordering::Acquire);
652 PoolMetrics {
653 total,
654 idle: total.saturating_sub(in_use),
655 in_use,
656 waiters: self.waiter_count.load(Ordering::Relaxed),
657 total_checkouts: self.total_checkouts.load(Ordering::Relaxed),
658 total_created: self.total_created.load(Ordering::Relaxed),
659 total_destroyed: self.total_destroyed.load(Ordering::Relaxed),
660 total_timeouts: self.total_timeouts.load(Ordering::Relaxed),
661 }
662 }
663
664 pub async fn warm_up(&self, target: usize) {
667 let current = self.metrics().total;
668 let to_create = target
669 .saturating_sub(current)
670 .min(self.config.max_size - current);
671 let mut created = 0;
672 for _ in 0..to_create {
673 match self.create_connection().await {
674 Ok(idle_conn) => {
675 self.idle.lock().await.push_back(idle_conn);
676 self.total_count.fetch_add(1, Ordering::Release);
677 created += 1;
678 }
679 Err(e) => {
680 tracing::warn!("warm_up: failed to create connection: {e}");
681 break;
682 }
683 }
684 }
685 if created > 0 {
686 tracing::info!(created, target, "pool warm-up complete");
687 }
688 }
689
690 pub async fn drain(&self) {
692 self.draining.store(true, Ordering::Release);
693
694 let destroyed_count = {
697 let mut idle = self.idle.lock().await;
698 let count = idle.len();
699 idle.clear();
700 self.total_count.fetch_sub(count, Ordering::Release);
701 self.total_destroyed
702 .fetch_add(count as u64, Ordering::Relaxed);
703 count
704 };
705 if destroyed_count > 0 {
707 if let Some(ref hook) = self.hooks.on_destroy {
708 for _ in 0..destroyed_count {
709 hook();
710 }
711 }
712 }
713
714 {
715 let mut waiters = self.waiters.lock().await;
716 let waiter_count = waiters.len();
717 waiters.clear();
718 self.waiter_count.fetch_sub(waiter_count, Ordering::Relaxed);
719 }
720
721 loop {
722 let notified = self.drain_complete.notified();
723 if self.total_count.load(Ordering::Acquire) == 0 {
724 break;
725 }
726 notified.await;
727 }
728
729 let _ = self.shutdown_tx.send(()).await;
730 tracing::info!("Connection pool drained");
731 }
732
733 pub fn status(&self) -> String {
735 let m = self.metrics();
736 format!(
737 "pool: total={} idle={} in_use={} created={} destroyed={} timeouts={}",
738 m.total, m.idle, m.in_use, m.total_created, m.total_destroyed, m.total_timeouts
739 )
740 }
741}
742
743pub struct PoolGuard<C: Poolable> {
749 conn: Option<C>,
750 pool: Arc<ConnPool<C>>,
751}
752
753impl<C: Poolable + std::fmt::Debug> std::fmt::Debug for PoolGuard<C> {
754 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
755 f.debug_struct("PoolGuard")
756 .field("conn", &self.conn)
757 .finish_non_exhaustive()
758 }
759}
760
761impl<C: Poolable> PoolGuard<C> {
762 pub fn conn(&self) -> &C {
764 self.conn
765 .as_ref()
766 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
767 }
768
769 pub fn conn_mut(&mut self) -> &mut C {
771 self.conn
772 .as_mut()
773 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
774 }
775
776 pub fn take(mut self) -> C {
779 let conn = self
780 .conn
781 .take()
782 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)");
783 self.pool.in_use_count.fetch_sub(1, Ordering::Release);
784 self.pool.total_count.fetch_sub(1, Ordering::Release);
785 conn
786 }
787}
788
789impl<C: Poolable> Drop for PoolGuard<C> {
790 fn drop(&mut self) {
791 if let Some(conn) = self.conn.take() {
792 ConnPool::return_conn(Arc::clone(&self.pool), conn);
793 }
794 }
795}
796
797impl<C: Poolable> std::ops::Deref for PoolGuard<C> {
798 type Target = C;
799 fn deref(&self) -> &Self::Target {
800 self.conn
801 .as_ref()
802 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
803 }
804}
805
806impl<C: Poolable> std::ops::DerefMut for PoolGuard<C> {
807 fn deref_mut(&mut self) -> &mut Self::Target {
808 self.conn
809 .as_mut()
810 .expect("PoolGuard: connection has been moved out via PoolGuard::take(); the guard is consumed by `take()` and must not be accessed afterwards (a logic bug in the caller)")
811 }
812}
813
814async fn maintenance_task<C: Poolable>(
819 pool: Arc<ConnPool<C>>,
820 mut shutdown_rx: mpsc::Receiver<()>,
821) {
822 let mut interval = tokio::time::interval(pool.config.maintenance_interval);
823 interval.tick().await;
824 loop {
825 tokio::select! {
826 _ = interval.tick() => {}
827 _ = shutdown_rx.recv() => {
828 tracing::debug!("Maintenance task shutting down");
829 return;
830 }
831 }
832
833 if pool.draining.load(Ordering::Acquire) {
834 return;
835 }
836
837 {
838 let mut idle = pool.idle.lock().await;
839 let now = Instant::now();
840 let before = idle.len();
841 idle.retain(|entry| now < entry.expires_at);
842 let evicted = before - idle.len();
843 if evicted > 0 {
844 pool.total_count.fetch_sub(evicted, Ordering::Release);
845 pool.total_destroyed
846 .fetch_add(evicted as u64, Ordering::Relaxed);
847 tracing::debug!("Evicted {evicted} expired connections");
848 }
849 }
850
851 let total = pool.total_count.load(Ordering::Acquire);
852 let in_use = pool.in_use_count.load(Ordering::Acquire);
853 let current_idle = total.saturating_sub(in_use);
854
855 if current_idle < pool.config.min_idle && total < pool.config.max_size {
856 let to_create = (pool.config.min_idle - current_idle).min(pool.config.max_size - total);
857 for _ in 0..to_create {
858 match pool.create_and_track().await {
859 Ok(conn) => {
860 let jitter = jittered_duration(
861 pool.config.max_lifetime,
862 pool.config.max_lifetime_jitter,
863 );
864 let mut idle = pool.idle.lock().await;
865 idle.push_back(IdleConn {
866 conn,
867 expires_at: Instant::now() + jitter,
868 });
869 }
870 Err(_) => break,
871 }
872 }
873 }
874 }
875}
876
877fn jittered_duration(base: Duration, jitter: Duration) -> Duration {
882 if jitter.is_zero() {
883 return base;
884 }
885 let jitter_ms = jitter.as_millis() as u64;
886 let offset = fastrand_u64() % (jitter_ms * 2 + 1);
887 let jittered = base.as_millis() as i128 + offset as i128 - jitter_ms as i128;
888 Duration::from_millis(jittered.max(1) as u64)
889}
890
891fn fastrand_u64() -> u64 {
892 use std::cell::Cell;
893 thread_local! {
894 static STATE: Cell<u64> = Cell::new(
895 std::time::SystemTime::now()
896 .duration_since(std::time::UNIX_EPOCH)
897 .unwrap_or_default()
898 .as_nanos() as u64
899 );
900 }
901 STATE.with(|s| {
902 let mut x = s.get();
903 x ^= x << 13;
904 x ^= x >> 7;
905 x ^= x << 17;
906 if x == 0 {
907 x = 1;
908 }
909 s.set(x);
910 x
911 })
912}