mobc/
lib.rs

1//! A generic connection pool with async/await support.
2//!
3//! Opening a new database connection every time one is needed is both
4//! inefficient and can lead to resource exhaustion under high traffic
5//! conditions. A connection pool maintains a set of open connections to a
6//! database, handing them out for repeated use.
7//!
8//! mobc is agnostic to the connection type it is managing. Implementors of the
9//! `Manager` trait provide the database-specific logic to create and
10//! check the health of connections.
11//!
12//! # Example
13//!
14//! Using an imaginary "foodb" database.
15//!
16//! ```rust
17//!use mobc::{Manager, Pool, async_trait};
18//!
19//!#[derive(Debug)]
20//!struct FooError;
21//!
22//!struct FooConnection;
23//!
24//!impl FooConnection {
25//!    async fn query(&self) -> String {
26//!        "nori".to_string()
27//!    }
28//!}
29//!
30//!struct FooManager;
31//!
32//!#[async_trait]
33//!impl Manager for FooManager {
34//!    type Connection = FooConnection;
35//!    type Error = FooError;
36//!
37//!    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
38//!        Ok(FooConnection)
39//!    }
40//!
41//!    async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
42//!        Ok(conn)
43//!    }
44//!}
45//!
46//!#[tokio::main]
47//!async fn main() {
48//!    let pool = Pool::builder().max_open(15).build(FooManager);
49//!    let num: usize = 10000;
50//!    let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(16);
51//!
52//!    for _ in 0..num {
53//!        let pool = pool.clone();
54//!        let mut tx = tx.clone();
55//!        tokio::spawn(async move {
56//!            let conn = pool.get().await.unwrap();
57//!            let name = conn.query().await;
58//!            assert_eq!(name, "nori".to_string());
59//!            tx.send(()).await.unwrap();
60//!        });
61//!    }
62//!
63//!    for _ in 0..num {
64//!        rx.recv().await.unwrap();
65//!    }
66//!}
67//! ```
68//!
69//! # Metrics
70//!
71//! Mobc uses the metrics crate to expose the following metrics
72//!
73//! 1. Active Connections - The number of connections in use.
74//! 1. Idle Connections - The number of connections that are not being used
75//! 1. Wait Count - the number of processes waiting for a connection
76//! 1. Wait Duration - A cumulative histogram of the wait time for a connection
77//!
78
79#![cfg_attr(feature = "docs", feature(doc_cfg))]
80#![warn(missing_docs)]
81#![recursion_limit = "256"]
82mod config;
83
84mod conn;
85mod error;
86mod metrics_utils;
87#[cfg(feature = "unstable")]
88#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
89pub mod runtime;
90mod spawn;
91mod time;
92
93pub use error::Error;
94
95pub use async_trait::async_trait;
96pub use config::Builder;
97use config::{Config, InternalConfig, ShareConfig};
98use conn::{ActiveConn, ConnState, IdleConn};
99use futures_channel::mpsc::{self, Receiver, Sender};
100use futures_util::lock::{Mutex, MutexGuard};
101use futures_util::select;
102use futures_util::FutureExt;
103use futures_util::SinkExt;
104use futures_util::StreamExt;
105use metrics::gauge;
106use metrics_utils::DurationHistogramGuard;
107pub use spawn::spawn;
108use std::fmt;
109use std::future::Future;
110use std::ops::{Deref, DerefMut};
111use std::sync::{
112    atomic::{AtomicU64, Ordering},
113    Arc, Weak,
114};
115use std::time::{Duration, Instant};
116#[doc(hidden)]
117pub use time::{delay_for, interval};
118use tokio::sync::{OwnedSemaphorePermit, Semaphore};
119
120use crate::metrics_utils::{GaugeGuard, IDLE_CONNECTIONS, WAIT_COUNT, WAIT_DURATION};
121
122const CONNECTION_REQUEST_QUEUE_SIZE: usize = 10000;
123
124#[async_trait]
125/// A trait which provides connection-specific functionality.
126pub trait Manager: Send + Sync + 'static {
127    /// The connection type this manager deals with.
128    type Connection: Send + 'static;
129    /// The error type returned by `Connection`s.
130    type Error: Send + Sync + 'static;
131
132    /// Spawns a new asynchronous task.
133    fn spawn_task<T>(&self, task: T)
134    where
135        T: Future + Send + 'static,
136        T::Output: Send + 'static,
137    {
138        spawn(task);
139    }
140
141    /// Attempts to create a new connection.
142    async fn connect(&self) -> Result<Self::Connection, Self::Error>;
143
144    /// Determines if the connection is still connected to the database when check-out.
145    ///
146    /// A standard implementation would check if a simple query like `SELECT 1`
147    /// succeeds.
148    async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error>;
149
150    /// *Quickly* determines a connection is still valid when check-in.
151    #[inline]
152    fn validate(&self, _conn: &mut Self::Connection) -> bool {
153        true
154    }
155}
156
157struct SharedPool<M: Manager> {
158    config: ShareConfig,
159    manager: M,
160    internals: Mutex<PoolInternals<M::Connection>>,
161    state: PoolState,
162    semaphore: Arc<Semaphore>,
163}
164
165struct PoolInternals<C> {
166    config: InternalConfig,
167    free_conns: Vec<IdleConn<C>>,
168    wait_duration: Duration,
169    cleaner_ch: Option<Sender<()>>,
170}
171
172struct PoolState {
173    num_open: Arc<AtomicU64>,
174    max_lifetime_closed: AtomicU64,
175    max_idle_closed: Arc<AtomicU64>,
176    wait_count: AtomicU64,
177}
178
179impl<C> Drop for PoolInternals<C> {
180    fn drop(&mut self) {
181        log::debug!("Pool internal drop");
182    }
183}
184
185/// A generic connection pool.
186pub struct Pool<M: Manager>(Arc<SharedPool<M>>);
187
188/// Returns a new `Pool` referencing the same state as `self`.
189impl<M: Manager> Clone for Pool<M> {
190    fn clone(&self) -> Self {
191        Pool(self.0.clone())
192    }
193}
194
195impl<M: Manager> fmt::Debug for Pool<M> {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        write!(f, "Pool")
198    }
199}
200
201/// Information about the state of a `Pool`.
202#[derive(Debug)]
203pub struct State {
204    /// Maximum number of open connections to the database
205    pub max_open: u64,
206
207    // Pool Status
208    /// The number of established connections both in use and idle.
209    pub connections: u64,
210    /// The number of connections currently in use.
211    pub in_use: u64,
212    /// The number of idle connections.
213    pub idle: u64,
214
215    // Counters
216    /// The total number of connections waited for.
217    pub wait_count: u64,
218    /// The total time blocked waiting for a new connection.
219    pub wait_duration: Duration,
220    /// The total number of connections closed due to `max_idle`.
221    pub max_idle_closed: u64,
222    /// The total number of connections closed due to `max_lifetime`.
223    pub max_lifetime_closed: u64,
224}
225
226impl<M: Manager> Drop for Pool<M> {
227    fn drop(&mut self) {}
228}
229
230impl<M: Manager> Pool<M> {
231    /// Creates a new connection pool with a default configuration.
232    pub fn new(manager: M) -> Pool<M> {
233        Pool::builder().build(manager)
234    }
235
236    /// Returns a builder type to configure a new pool.
237    pub fn builder() -> Builder<M> {
238        Builder::new()
239    }
240
241    /// Sets the maximum number of connections managed by the pool.
242    ///
243    /// 0 means unlimited, defaults to 10.
244    pub async fn set_max_open_conns(&self, n: u64) {
245        let mut internals = self.0.internals.lock().await;
246        internals.config.max_open = n;
247        if n > 0 && internals.config.max_idle > n {
248            drop(internals);
249            self.set_max_idle_conns(n).await;
250        }
251    }
252
253    /// Sets the maximum idle connection count maintained by the pool.
254    ///
255    /// The pool will maintain at most this many idle connections
256    /// at all times, while respecting the value of `max_open`.
257    ///
258    /// 0 means unlimited (limited only by `max_open`), defaults to 2.
259    pub async fn set_max_idle_conns(&self, n: u64) {
260        let mut internals = self.0.internals.lock().await;
261        internals.config.max_idle =
262            if internals.config.max_open > 0 && n > internals.config.max_open {
263                internals.config.max_open
264            } else {
265                n
266            };
267
268        let max_idle = internals.config.max_idle as usize;
269        // Treat max_idle == 0 as unlimited
270        if max_idle > 0 && internals.free_conns.len() > max_idle {
271            internals.free_conns.truncate(max_idle);
272        }
273    }
274
275    /// Sets the maximum lifetime of connections in the pool.
276    ///
277    /// Expired connections may be closed lazily before reuse.
278    ///
279    /// None meas reuse forever.
280    /// Defaults to None.
281    ///
282    /// # Panics
283    ///
284    /// Panics if `max_lifetime` is the zero `Duration`.
285    pub async fn set_conn_max_lifetime(&self, max_lifetime: Option<Duration>) {
286        assert_ne!(
287            max_lifetime,
288            Some(Duration::from_secs(0)),
289            "max_lifetime must be positive"
290        );
291        let mut internals = self.0.internals.lock().await;
292        internals.config.max_lifetime = max_lifetime;
293        if let Some(lifetime) = max_lifetime {
294            match internals.config.max_lifetime {
295                Some(prev) if lifetime < prev && internals.cleaner_ch.is_some() => {
296                    // FIXME
297                    let _ = internals.cleaner_ch.as_mut().unwrap().send(()).await;
298                }
299                _ => (),
300            }
301        }
302
303        if max_lifetime.is_some()
304            && self.0.state.num_open.load(Ordering::Relaxed) > 0
305            && internals.cleaner_ch.is_none()
306        {
307            log::debug!("run connection cleaner");
308            let shared1 = Arc::downgrade(&self.0);
309            let clean_rate = self.0.config.clean_rate;
310            let (cleaner_ch_sender, cleaner_ch) = mpsc::channel(1);
311            internals.cleaner_ch = Some(cleaner_ch_sender);
312            self.0.manager.spawn_task(async move {
313                connection_cleaner(shared1, cleaner_ch, clean_rate).await;
314            });
315        }
316    }
317
318    pub(crate) fn new_inner(manager: M, config: Config) -> Self {
319        let max_open = if config.max_open == 0 {
320            CONNECTION_REQUEST_QUEUE_SIZE
321        } else {
322            config.max_open as usize
323        };
324
325        gauge!(IDLE_CONNECTIONS).set(0.0);
326
327        let (share_config, internal_config) = config.split();
328        let internals = Mutex::new(PoolInternals {
329            config: internal_config,
330            free_conns: Vec::new(),
331            wait_duration: Duration::from_secs(0),
332            cleaner_ch: None,
333        });
334
335        let pool_state = PoolState {
336            num_open: Arc::new(AtomicU64::new(0)),
337            max_lifetime_closed: AtomicU64::new(0),
338            wait_count: AtomicU64::new(0),
339            max_idle_closed: Arc::new(AtomicU64::new(0)),
340        };
341
342        let shared = Arc::new(SharedPool {
343            config: share_config,
344            manager,
345            internals,
346            semaphore: Arc::new(Semaphore::new(max_open)),
347            state: pool_state,
348        });
349
350        Pool(shared)
351    }
352
353    /// Returns a single connection by either opening a new connection
354    /// or returning an existing connection from the connection pool. Conn will
355    /// block until either a connection is returned or timeout.
356    pub async fn get(&self) -> Result<Connection<M>, Error<M::Error>> {
357        match self.0.config.get_timeout {
358            Some(duration) => self.get_timeout(duration).await,
359            None => self.inner_get_with_retries().await,
360        }
361    }
362
363    /// Retrieves a connection from the pool, waiting for at most `timeout`
364    ///
365    /// The given timeout will be used instead of the configured connection
366    /// timeout.
367    pub async fn get_timeout(&self, duration: Duration) -> Result<Connection<M>, Error<M::Error>> {
368        time::timeout(duration, self.inner_get_with_retries()).await
369    }
370
371    async fn inner_get_with_retries(&self) -> Result<Connection<M>, Error<M::Error>> {
372        let mut try_times: u32 = 0;
373        let config = &self.0.config;
374        loop {
375            try_times += 1;
376            match self.get_connection().await {
377                Ok(conn) => return Ok(conn),
378                Err(Error::BadConn) => {
379                    if try_times == config.max_bad_conn_retries {
380                        return self.get_connection().await;
381                    }
382                    continue;
383                }
384                Err(err) => return Err(err),
385            }
386        }
387    }
388
389    async fn get_connection(&self) -> Result<Connection<M>, Error<M::Error>> {
390        let _guard = GaugeGuard::increment(WAIT_COUNT);
391        let c = self.get_or_create_conn().await?;
392
393        let conn = Connection {
394            pool: self.clone(),
395            conn: Some(c),
396        };
397
398        Ok(conn)
399    }
400
401    async fn validate_conn(
402        &self,
403        internal_config: InternalConfig,
404        conn: IdleConn<M::Connection>,
405    ) -> Option<IdleConn<M::Connection>> {
406        if conn.is_brand_new() {
407            return Some(conn);
408        }
409
410        if conn.expired(internal_config.max_lifetime) {
411            return None;
412        }
413
414        if conn.idle_expired(internal_config.max_idle_lifetime) {
415            return None;
416        }
417
418        let needs_health_check = self.0.config.health_check
419            && conn.needs_health_check(self.0.config.health_check_interval);
420
421        if needs_health_check {
422            let (raw, split) = conn.split_raw();
423            let checked_raw = self.0.manager.check(raw).await.ok()?;
424            let mut checked = split.restore(checked_raw);
425            checked.mark_checked();
426            return Some(checked);
427        }
428        Some(conn)
429    }
430
431    async fn get_or_create_conn(&self) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
432        self.0.state.wait_count.fetch_add(1, Ordering::Relaxed);
433        let wait_guard = DurationHistogramGuard::start(WAIT_DURATION);
434
435        let semaphore = Arc::clone(&self.0.semaphore);
436        let permit = semaphore
437            .acquire_owned()
438            .await
439            .map_err(|_| Error::PoolClosed)?;
440
441        self.0.state.wait_count.fetch_sub(1, Ordering::SeqCst);
442
443        let mut internals = self.0.internals.lock().await;
444
445        internals.wait_duration += wait_guard.into_elapsed();
446
447        let conn = internals.free_conns.pop();
448        let internal_config = internals.config.clone();
449        drop(internals);
450
451        if let Some(conn) = conn {
452            if let Some(valid_conn) = self.validate_conn(internal_config, conn).await {
453                return Ok(valid_conn.into_active(permit));
454            }
455        }
456
457        let create_r = self.open_new_connection(permit).await;
458
459        create_r
460    }
461
462    async fn open_new_connection(
463        &self,
464        permit: OwnedSemaphorePermit,
465    ) -> Result<ActiveConn<M::Connection>, Error<M::Error>> {
466        log::debug!("creating new connection from manager");
467        match self.0.manager.connect().await {
468            Ok(c) => {
469                self.0.state.num_open.fetch_add(1, Ordering::Relaxed);
470                let state = ConnState::new(
471                    Arc::clone(&self.0.state.num_open),
472                    Arc::clone(&self.0.state.max_idle_closed),
473                );
474                Ok(ActiveConn::new(c, permit, state))
475            }
476            Err(e) => Err(Error::Inner(e)),
477        }
478    }
479
480    /// Returns information about the current state of the pool.
481    /// It is better to use the metrics than this method, this method
482    /// requires a lock on the internals
483    pub async fn state(&self) -> State {
484        let internals = self.0.internals.lock().await;
485        let num_free_conns = internals.free_conns.len() as u64;
486        let wait_duration = internals.wait_duration;
487        let max_open = internals.config.max_open;
488        drop(internals);
489        State {
490            max_open,
491
492            connections: self.0.state.num_open.load(Ordering::Relaxed),
493            in_use: self.0.state.num_open.load(Ordering::Relaxed) - num_free_conns,
494            idle: num_free_conns,
495
496            wait_count: self.0.state.wait_count.load(Ordering::Relaxed),
497            wait_duration,
498            max_idle_closed: self.0.state.max_idle_closed.load(Ordering::Relaxed),
499            max_lifetime_closed: self.0.state.max_lifetime_closed.load(Ordering::Relaxed),
500        }
501    }
502}
503
504async fn recycle_conn<M: Manager>(
505    shared: &Arc<SharedPool<M>>,
506    mut conn: ActiveConn<M::Connection>,
507) {
508    if conn_still_valid(shared, &mut conn) {
509        conn.set_brand_new(false);
510        let internals = shared.internals.lock().await;
511        put_idle_conn::<M>(internals, conn);
512    }
513}
514
515fn conn_still_valid<M: Manager>(
516    shared: &Arc<SharedPool<M>>,
517    conn: &mut ActiveConn<M::Connection>,
518) -> bool {
519    if !shared.manager.validate(conn.as_raw_mut()) {
520        log::debug!("bad conn when check in");
521        return false;
522    }
523
524    true
525}
526
527fn put_idle_conn<M: Manager>(
528    mut internals: MutexGuard<'_, PoolInternals<M::Connection>>,
529    conn: ActiveConn<M::Connection>,
530) {
531    let idle_conn = conn.into_idle();
532    // Treat max_idle == 0 as unlimited idle connections.
533    if internals.config.max_idle == 0
534        || internals.config.max_idle > internals.free_conns.len() as u64
535    {
536        internals.free_conns.push(idle_conn);
537    }
538}
539
540async fn connection_cleaner<M: Manager>(
541    shared: Weak<SharedPool<M>>,
542    mut cleaner_ch: Receiver<()>,
543    clean_rate: Duration,
544) {
545    let mut interval = interval(clean_rate);
546    interval.tick().await;
547    loop {
548        select! {
549            _ = interval.tick().fuse() => (),
550            r = cleaner_ch.next().fuse() => match r{
551                Some(()) => (),
552                None=> return
553            },
554        }
555
556        if !clean_connection(&shared).await {
557            return;
558        }
559    }
560}
561
562async fn clean_connection<M: Manager>(shared: &Weak<SharedPool<M>>) -> bool {
563    let shared = match shared.upgrade() {
564        Some(shared) => shared,
565        None => {
566            log::debug!("Failed to clean connections");
567            return false;
568        }
569    };
570
571    log::debug!("Clean connections");
572
573    let mut internals = shared.internals.lock().await;
574    if shared.state.num_open.load(Ordering::Relaxed) == 0 || internals.config.max_lifetime.is_none()
575    {
576        internals.cleaner_ch.take();
577        return false;
578    }
579
580    let expired = Instant::now() - internals.config.max_lifetime.unwrap();
581    let mut closing = vec![];
582
583    let mut i = 0;
584    log::debug!(
585        "clean connections, idle conns {}",
586        internals.free_conns.len()
587    );
588
589    loop {
590        if i >= internals.free_conns.len() {
591            break;
592        }
593
594        if internals.free_conns[i].created_at() < expired {
595            let c = internals.free_conns.swap_remove(i);
596            closing.push(c);
597            continue;
598        }
599        i += 1;
600    }
601    drop(internals);
602
603    shared
604        .state
605        .max_lifetime_closed
606        .fetch_add(closing.len() as u64, Ordering::Relaxed);
607    true
608}
609
610/// A smart pointer wrapping a connection.
611pub struct Connection<M: Manager> {
612    pool: Pool<M>,
613    conn: Option<ActiveConn<M::Connection>>,
614}
615
616impl<M: Manager> Connection<M> {
617    /// Returns true is the connection is newly established.
618    pub fn is_brand_new(&self) -> bool {
619        self.conn.as_ref().unwrap().is_brand_new()
620    }
621
622    /// Unwraps the raw database connection.
623    pub fn into_inner(mut self) -> M::Connection {
624        self.conn.take().unwrap().into_raw()
625    }
626}
627
628impl<M: Manager> Drop for Connection<M> {
629    fn drop(&mut self) {
630        let Some(conn) = self.conn.take() else {
631            return;
632        };
633
634        let pool = Arc::clone(&self.pool.0);
635
636        self.pool.0.manager.spawn_task(async move {
637            recycle_conn(&pool, conn).await;
638        });
639    }
640}
641
642impl<M: Manager> Deref for Connection<M> {
643    type Target = M::Connection;
644    fn deref(&self) -> &Self::Target {
645        self.conn.as_ref().unwrap().as_raw_ref()
646    }
647}
648
649impl<M: Manager> DerefMut for Connection<M> {
650    fn deref_mut(&mut self) -> &mut M::Connection {
651        self.conn.as_mut().unwrap().as_raw_mut()
652    }
653}