mobc/
lib.rs

1//! A generic connection pool with async/await support.
2//!
3//! Opening a new database connection every time one is needed is both
4//! inefficient and can lead to resource exhaustion under high traffic
5//! conditions. A connection pool maintains a set of open connections to a
6//! database, handing them out for repeated use.
7//!
8//! mobc is agnostic to the connection type it is managing. Implementors of the
9//! `Manager` trait provide the database-specific logic to create and
10//! check the health of connections.
11//!
12//! # Example
13//!
14//! Using an imaginary "foodb" database.
15//!
16//! ```rust
17//!use mobc::{Manager, Pool, async_trait};
18//!
19//!#[derive(Debug)]
20//!struct FooError;
21//!
22//!struct FooConnection;
23//!
24//!impl FooConnection {
25//!    async fn query(&self) -> String {
26//!        "nori".to_string()
27//!    }
28//!}
29//!
30//!struct FooManager;
31//!
32//!#[async_trait]
33//!impl Manager for FooManager {
34//!    type Connection = FooConnection;
35//!    type Error = FooError;
36//!
37//!    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
38//!        Ok(FooConnection)
39//!    }
40//!
41//!    async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
42//!        Ok(conn)
43//!    }
44//!}
45//!
46//!#[tokio::main]
47//!async fn main() {
48//!    let pool = Pool::builder().max_open(15).build(FooManager);
49//!    let num: usize = 10000;
50//!    let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(16);
51//!
52//!    for _ in 0..num {
53//!        let pool = pool.clone();
54//!        let mut tx = tx.clone();
55//!        tokio::spawn(async move {
56//!            let conn = pool.get().await.unwrap();
57//!            let name = conn.query().await;
58//!            assert_eq!(name, "nori".to_string());
59//!            tx.send(()).await.unwrap();
60//!        });
61//!    }
62//!
63//!    for _ in 0..num {
64//!        rx.recv().await.unwrap();
65//!    }
66//!}
67//! ```
68//!
69//! # Metrics
70//!
71//! Mobc uses the metrics crate to expose the following metrics
72//!
73//! 1. Active Connections - The number of connections in use.
74//! 1. Idle Connections - The number of connections that are not being used
75//! 1. Wait Count - the number of processes waiting for a connection
76//! 1. Wait Duration - A cumulative histogram of the wait time for a connection
77//!
78
79#![cfg_attr(feature = "docs", feature(doc_cfg))]
80#![warn(missing_docs)]
81#![recursion_limit = "256"]
82mod config;
83
84mod conn;
85mod error;
86mod metrics_utils;
87#[cfg(feature = "unstable")]
88#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
89pub mod runtime;
90mod spawn;
91mod time;
92
93pub use error::Error;
94
95pub use async_trait::async_trait;
96pub use config::Builder;
97use config::{Config, InternalConfig, ShareConfig};
98use conn::{ActiveConn, ConnState, IdleConn};
99use futures_channel::mpsc::{self, Receiver, Sender};
100use futures_util::lock::{Mutex, MutexGuard};
101use futures_util::select;
102use futures_util::FutureExt;
103use futures_util::SinkExt;
104use futures_util::StreamExt;
105use metrics::gauge;
106use metrics_utils::DurationHistogramGuard;
107pub use spawn::spawn;
108use std::fmt;
109use std::future::Future;
110use std::ops::{Deref, DerefMut};
111use std::sync::{
112    atomic::{AtomicU64, Ordering},
113    Arc, Weak,
114};
115use std::time::{Duration, Instant};
116#[doc(hidden)]
117pub use time::{delay_for, interval};
118use tokio::sync::{OwnedSemaphorePermit, Semaphore};
119
120use crate::metrics_utils::{GaugeGuard, IDLE_CONNECTIONS, WAIT_COUNT, WAIT_DURATION};
121
122const CONNECTION_REQUEST_QUEUE_SIZE: usize = 10000;
123
124#[async_trait]
125/// A trait which provides connection-specific functionality.
126pub trait Manager: Send + Sync + 'static {
127    /// The connection type this manager deals with.
128    type Connection: Send + 'static;
129    /// The error type returned by `Connection`s.
130    type Error: Send + Sync + 'static;
131
132    /// Spawns a new asynchronous task.
133    fn spawn_task<T>(&self, task: T)
134    where
135        T: Future + Send + 'static,
136        T::Output: Send + 'static,
137    {
138        spawn(task);
139    }
140
141    /// Attempts to create a new connection.
142    async fn connect(&self) -> Result<Self::Connection, Self::Error>;
143
144    /// Determines if the connection is still connected to the database when check-out.
145    ///
146    /// A standard implementation would check if a simple query like `SELECT 1`
147    /// succeeds.
148    async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error>;
149
150    /// *Quickly* determines a connection is still valid when check-in.
151    #[inline]
152    fn validate(&self, _conn: &mut Self::Connection) -> bool {
153        true
154    }
155}
156
157struct SharedPool<M: Manager> {
158    config: ShareConfig,
159    manager: M,
160    internals: Mutex<PoolInternals<M::Connection>>,
161    state: PoolState,
162    semaphore: Arc<Semaphore>,
163}
164
165struct PoolInternals<C> {
166    config: InternalConfig,
167    free_conns: Vec<IdleConn<C>>,
168    wait_duration: Duration,
169    cleaner_ch: Option<Sender<()>>,
170}
171
172struct PoolState {
173    num_open: Arc<AtomicU64>,
174    max_lifetime_closed: AtomicU64,
175    max_idle_closed: Arc<AtomicU64>,
176    wait_count: AtomicU64,
177}
178
179impl<C> Drop for PoolInternals<C> {
180    fn drop(&mut self) {
181        log::debug!("Pool internal drop");
182    }
183}
184
185/// A generic connection pool.
186pub struct Pool<M: Manager>(Arc<SharedPool<M>>);
187
188/// Returns a new `Pool` referencing the same state as `self`.
189impl<M: Manager> Clone for Pool<M> {
190    fn clone(&self) -> Self {
191        Pool(self.0.clone())
192    }
193}
194
195impl<M: Manager> fmt::Debug for Pool<M> {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        write!(f, "Pool")
198    }
199}
200
201/// Information about the state of a `Pool`.
202#[derive(Debug)]
203pub struct State {
204    /// Maximum number of open connections to the database
205    pub max_open: u64,
206
207    // Pool Status
208    /// The number of established connections both in use and idle.
209    pub connections: u64,
210    /// The number of connections currently in use.
211    pub in_use: u64,
212    /// The number of idle connections.
213    pub idle: u64,
214
215    // Counters
216    /// The total number of connections waited for.
217    pub wait_count: u64,
218    /// The total time blocked waiting for a new connection.
219    pub wait_duration: Duration,
220    /// The total number of connections closed due to `max_idle`.
221    pub max_idle_closed: u64,
222    /// The total number of connections closed due to `max_lifetime`.
223    pub max_lifetime_closed: u64,
224}
225
226impl<M: Manager> Drop for Pool<M> {
227    fn drop(&mut self) {}
228}
229
230impl<M: Manager> Pool<M> {
231    /// Creates a new connection pool with a default configuration.
232    pub fn new(manager: M) -> Pool<M> {
233        Pool::builder().build(manager)
234    }
235
236    /// Returns a builder type to configure a new pool.
237    pub fn builder() -> Builder<M> {
238        Builder::new()
239    }
240
241    /// Sets the maximum number of connections managed by the pool.
242    ///
243    /// 0 means unlimited, defaults to 10.
244    pub async fn set_max_open_conns(&self, n: u64) {
245        let mut internals = self.0.internals.lock().await;
246        internals.config.max_open = n;
247        if n > 0 && internals.config.max_idle > n {
248            drop(internals);
249            self.set_max_idle_conns(n).await;
250        }
251    }
252
253    /// Sets the maximum idle connection count maintained by the pool.
254    ///
255    /// The pool will maintain at most this many idle connections
256    /// at all times, while respecting the value of `max_open`.
257    ///
258    /// 0 means unlimited (limited only by `max_open`), defaults to 2.
259    pub async fn set_max_idle_conns(&self, n: u64) {
260        let mut internals = self.0.internals.lock().await;
261        internals.config.max_idle =
262            if internals.config.max_open > 0 && n > internals.config.max_open {
263                internals.config.max_open
264            } else {
265                n
266            };
267
268        let max_idle = internals.config.max_idle as usize;
269        // Treat max_idle == 0 as unlimited
270        if max_idle > 0 && internals.free_conns.len() > max_idle {
271            internals.free_conns.truncate(max_idle);
272        }
273    }
274
275    /// Sets the maximum lifetime of connections in the pool.
276    ///
277    /// Expired connections may be closed lazily before reuse.
278    ///
279    /// None means reuse forever.
280    /// Defaults to None.
281    ///
282    /// # Panics
283    ///
284    /// Panics if `max_lifetime` is the zero `Duration`.
285    pub async fn set_conn_max_lifetime(&self, max_lifetime: Option<Duration>) {
286        assert_ne!(
287            max_lifetime,
288            Some(Duration::from_secs(0)),
289            "max_lifetime must be positive"
290        );
291        let mut internals = self.0.internals.lock().await;
292        internals.config.max_lifetime = max_lifetime;
293        if let Some(lifetime) = max_lifetime {
294            match internals.config.max_lifetime {
295                Some(prev) if lifetime < prev && internals.cleaner_ch.is_some() => {
296                    // FIXME
297                    let _ = internals.cleaner_ch.as_mut().unwrap().send(()).await;
298                }
299                _ => (),
300            }
301        }
302
303        if max_lifetime.is_some()
304            && self.0.state.num_open.load(Ordering::Relaxed) > 0
305            && internals.cleaner_ch.is_none()
306        {
307            log::debug!("run connection cleaner");
308            let shared1 = Arc::downgrade(&self.0);
309            let clean_rate = self.0.config.clean_rate;
310            let (cleaner_ch_sender, cleaner_ch) = mpsc::channel(1);
311            internals.cleaner_ch = Some(cleaner_ch_sender);
312            self.0.manager.spawn_task(async move {
313                connection_cleaner(shared1, cleaner_ch, clean_rate).await;
314            });
315        }
316    }
317
318    pub(crate) fn new_inner(manager: M, config: Config) -> Self {
319        let max_open = if config.max_open == 0 {
320            CONNECTION_REQUEST_QUEUE_SIZE
321        } else {
322            config.max_open as usize
323        };
324
325        gauge!(IDLE_CONNECTIONS).set(0.0);
326
327        let (share_config, internal_config) = config.split();
328        let internals = Mutex::new(PoolInternals {
329            config: internal_config,
330            free_conns: Vec::new(),
331            wait_duration: Duration::from_secs(0),
332            cleaner_ch: None,
333        });
334
335        let pool_state = PoolState {
336            num_open: Arc::new(AtomicU64::new(0)),
337            max_lifetime_closed: AtomicU64::new(0),
338            wait_count: AtomicU64::new(0),
339            max_idle_closed: Arc::new(AtomicU64::new(0)),
340        };
341
342        let shared = Arc::new(SharedPool {
343            config: share_config,
344            manager,
345            internals,
346            semaphore: Arc::new(Semaphore::new(max_open)),
347            state: pool_state,
348        });
349
350        Pool(shared)
351    }
352
353    /// Returns a single connection by either opening a new connection
354    /// or returning an existing connection from the connection pool. Conn will
355    /// block until either a connection is returned or timeout.
356    pub async fn get(&self) -> Result<Connection<M>, Error<M::Error>> {
357        match self.0.config.get_timeout {
358            Some(duration) => self.get_timeout(duration).await,
359            None => self.inner_get_with_retries().await,
360        }
361    }
362
363    /// Retrieves a connection from the pool, waiting for at most `timeout`
364    ///
365    /// The given timeout will be used instead of the configured connection
366    /// timeout.
367    pub async fn get_timeout(&self, duration: Duration) -> Result<Connection<M>, Error<M::Error>> {
368        time::timeout(duration, self.inner_get_with_retries()).await
369    }
370
371    async fn inner_get_with_retries(&self) -> Result<Connection<M>, Error<M::Error>> {
372        let mut try_times: u32 = 0;
373        let config = &self.0.config;
374        loop {
375            try_times += 1;
376            match self.get_connection().await {
377                Ok(conn) => return Ok(conn),
378                Err(Error::BadConn) => {
379                    if try_times == config.max_bad_conn_retries {
380                        return self.get_connection().await;
381                    }
382                    continue;
383                }
384                Err(err) => return Err(err),
385            }
386        }
387    }
388
389    async fn get_connection(&self) -> Result<Connection<M>, Error<M::Error>> {
390        let _guard = GaugeGuard::increment(WAIT_COUNT);
391        let c = self.get_or_create_conn().await?;
392
393        let conn = Connection {
394            pool: self.clone(),
395            conn: Some(c),
396        };
397
398        Ok(conn)
399    }
400
401    async fn validate_conn(
402        &self,
403        internal_config: InternalConfig,
404        conn: IdleConn<M::Connection>,
405    ) -> Option<IdleConn<M::Connection>> {
406        if conn.is_brand_new() {
407            return Some(conn);
408        }
409
410        if conn.expired(internal_config.max_lifetime) {
411            return None;
412        }
413
414        if conn.idle_expired(internal_config.max_idle_lifetime) {
415            return None;
416        }
417
418        let needs_health_check = self.0.config.health_check
419            && conn.needs_health_check(self.0.config.health_check_interval);
420
421        if needs_health_check {
422            let (raw, split) = conn.split_raw();
423            let checked_raw = self.0.manager.check(raw).await.ok()?;
424            let mut checked = split.restore(checked_raw);
425            checked.mark_checked();
426            return Some(checked);
427        }
428        Some(conn)
429    }
430
431    async fn get_or_create_conn(&self) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
432        self.0.state.wait_count.fetch_add(1, Ordering::Relaxed);
433        let wait_guard = DurationHistogramGuard::start(WAIT_DURATION);
434
435        let semaphore = Arc::clone(&self.0.semaphore);
436        let permit = semaphore
437            .acquire_owned()
438            .await
439            .map_err(|_| Error::PoolClosed)?;
440
441        self.0.state.wait_count.fetch_sub(1, Ordering::SeqCst);
442
443        let mut internals = self.0.internals.lock().await;
444
445        internals.wait_duration += wait_guard.into_elapsed();
446
447        let conn = internals.free_conns.pop();
448        let internal_config = internals.config.clone();
449        drop(internals);
450
451        if let Some(conn) = conn {
452            if let Some(valid_conn) = self.validate_conn(internal_config, conn).await {
453                return Ok(valid_conn.into_active(permit));
454            }
455        }
456
457        self.open_new_connection(permit).await
458    }
459
460    async fn open_new_connection(
461        &self,
462        permit: OwnedSemaphorePermit,
463    ) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
464        log::debug!("creating new connection from manager");
465        match self.0.manager.connect().await {
466            Ok(c) => {
467                self.0.state.num_open.fetch_add(1, Ordering::Relaxed);
468                let state = ConnState::new(
469                    Arc::clone(&self.0.state.num_open),
470                    Arc::clone(&self.0.state.max_idle_closed),
471                );
472                Ok(ActiveConn::new(c, permit, state))
473            }
474            Err(e) => Err(Error::Inner(e)),
475        }
476    }
477
478    /// Returns information about the current state of the pool.
479    /// It is better to use the metrics than this method, this method
480    /// requires a lock on the internals
481    pub async fn state(&self) -> State {
482        let internals = self.0.internals.lock().await;
483        let num_free_conns = internals.free_conns.len() as u64;
484        let wait_duration = internals.wait_duration;
485        let max_open = internals.config.max_open;
486        drop(internals);
487        State {
488            max_open,
489
490            connections: self.0.state.num_open.load(Ordering::Relaxed),
491            in_use: self.0.state.num_open.load(Ordering::Relaxed) - num_free_conns,
492            idle: num_free_conns,
493
494            wait_count: self.0.state.wait_count.load(Ordering::Relaxed),
495            wait_duration,
496            max_idle_closed: self.0.state.max_idle_closed.load(Ordering::Relaxed),
497            max_lifetime_closed: self.0.state.max_lifetime_closed.load(Ordering::Relaxed),
498        }
499    }
500}
501
502async fn recycle_conn<M: Manager>(
503    shared: &Arc<SharedPool<M>>,
504    mut conn: ActiveConn<M::Connection>,
505) {
506    if conn_still_valid(shared, &mut conn) {
507        conn.set_brand_new(false);
508        let internals = shared.internals.lock().await;
509        put_idle_conn::<M>(internals, conn);
510    }
511}
512
513fn conn_still_valid<M: Manager>(
514    shared: &Arc<SharedPool<M>>,
515    conn: &mut ActiveConn<M::Connection>,
516) -> bool {
517    if !shared.manager.validate(conn.as_raw_mut()) {
518        log::debug!("bad conn when check in");
519        return false;
520    }
521
522    true
523}
524
525fn put_idle_conn<M: Manager>(
526    mut internals: MutexGuard<'_, PoolInternals<M::Connection>>,
527    conn: ActiveConn<M::Connection>,
528) {
529    let idle_conn = conn.into_idle();
530    // Treat max_idle == 0 as unlimited idle connections.
531    if internals.config.max_idle == 0
532        || internals.config.max_idle > internals.free_conns.len() as u64
533    {
534        internals.free_conns.push(idle_conn);
535    }
536}
537
538async fn connection_cleaner<M: Manager>(
539    shared: Weak<SharedPool<M>>,
540    mut cleaner_ch: Receiver<()>,
541    clean_rate: Duration,
542) {
543    let mut interval = interval(clean_rate);
544    interval.tick().await;
545    loop {
546        select! {
547            _ = interval.tick().fuse() => (),
548            r = cleaner_ch.next().fuse() => match r{
549                Some(()) => (),
550                None=> return
551            },
552        }
553
554        if !clean_connection(&shared).await {
555            return;
556        }
557    }
558}
559
560async fn clean_connection<M: Manager>(shared: &Weak<SharedPool<M>>) -> bool {
561    let shared = match shared.upgrade() {
562        Some(shared) => shared,
563        None => {
564            log::debug!("Failed to clean connections");
565            return false;
566        }
567    };
568
569    log::debug!("Clean connections");
570
571    let mut internals = shared.internals.lock().await;
572    if shared.state.num_open.load(Ordering::Relaxed) == 0 || internals.config.max_lifetime.is_none()
573    {
574        internals.cleaner_ch.take();
575        return false;
576    }
577
578    let expired = Instant::now() - internals.config.max_lifetime.unwrap();
579    let mut closing = vec![];
580
581    let mut i = 0;
582    log::debug!(
583        "clean connections, idle conns {}",
584        internals.free_conns.len()
585    );
586
587    loop {
588        if i >= internals.free_conns.len() {
589            break;
590        }
591
592        if internals.free_conns[i].created_at() < expired {
593            let c = internals.free_conns.swap_remove(i);
594            closing.push(c);
595            continue;
596        }
597        i += 1;
598    }
599    drop(internals);
600
601    shared
602        .state
603        .max_lifetime_closed
604        .fetch_add(closing.len() as u64, Ordering::Relaxed);
605    true
606}
607
608/// A smart pointer wrapping a connection.
609pub struct Connection<M: Manager> {
610    pool: Pool<M>,
611    conn: Option<ActiveConn<M::Connection>>,
612}
613
614impl<M: Manager> Connection<M> {
615    /// Returns true is the connection is newly established.
616    pub fn is_brand_new(&self) -> bool {
617        self.conn.as_ref().unwrap().is_brand_new()
618    }
619
620    /// Unwraps the raw database connection.
621    pub fn into_inner(mut self) -> M::Connection {
622        self.conn.take().unwrap().into_raw()
623    }
624}
625
626impl<M: Manager> Drop for Connection<M> {
627    fn drop(&mut self) {
628        let Some(conn) = self.conn.take() else {
629            return;
630        };
631
632        let pool = Arc::clone(&self.pool.0);
633
634        self.pool.0.manager.spawn_task(async move {
635            recycle_conn(&pool, conn).await;
636        });
637    }
638}
639
640impl<M: Manager> Deref for Connection<M> {
641    type Target = M::Connection;
642    fn deref(&self) -> &Self::Target {
643        self.conn.as_ref().unwrap().as_raw_ref()
644    }
645}
646
647impl<M: Manager> DerefMut for Connection<M> {
648    fn deref_mut(&mut self) -> &mut M::Connection {
649        self.conn.as_mut().unwrap().as_raw_mut()
650    }
651}