mpool/
lib.rs

1//! A generic connection pool.
2//!
3//! Implementors of the `ManageConnection` trait provide the specific
4//! logic to create and check the health of connections.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use std::io;
10//! use std::net::SocketAddr;
11//!
12//! use async_trait::async_trait;
13//! use mpool::{ManageConnection, Pool};
14//! use tokio::net::TcpStream;
15//!
16//! struct MyPool {
17//!     addr: SocketAddr,
18//! }
19//!
20//! #[async_trait]
21//! impl ManageConnection for MyPool {
22//!     type Connection = TcpStream;
23//!
24//!     async fn connect(&self) -> io::Result<Self::Connection> {
25//!         TcpStream::connect(self.addr).await
26//!     }
27//!
28//!     async fn check(&self, _conn: &mut Self::Connection) -> io::Result<()> {
29//!         Ok(())
30//!     }
31//! }
32//! ```
33
34use std::collections::LinkedList;
35use std::fmt;
36use std::io;
37use std::marker::PhantomData;
38use std::ops::{Add, Deref, DerefMut};
39use std::sync::{Arc, Mutex, MutexGuard};
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use tokio::time::{delay_for, timeout};
44
45#[cfg(test)]
46mod test;
47
48/// A trait which provides connection-specific functionality.
49#[async_trait]
50pub trait ManageConnection: Send + Sync + 'static {
51    /// The connection type this manager deals with.
52    type Connection: Send + 'static;
53
54    /// Attempts to create a new connection.
55    async fn connect(&self) -> io::Result<Self::Connection>;
56
57    /// Check if the connection is still valid, check background every `check_interval`.
58    ///
59    /// A standard implementation would check if a simple query like `PING` succee,
60    /// if the `Connection` is broken, error should return.
61    async fn check(&self, conn: &mut Self::Connection) -> io::Result<()>;
62}
63
64fn other(msg: &str) -> io::Error {
65    io::Error::new(io::ErrorKind::Other, msg)
66}
67
68/// A builder for a connection pool.
69pub struct Builder<M>
70where
71    M: ManageConnection,
72{
73    pub max_lifetime: Option<Duration>,
74    pub idle_timeout: Option<Duration>,
75    pub connection_timeout: Option<Duration>,
76    pub max_size: u32,
77    pub check_interval: Option<Duration>,
78    _pd: PhantomData<M>,
79}
80
81impl<M> fmt::Debug for Builder<M>
82where
83    M: ManageConnection,
84{
85    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
86        fmt.debug_struct("Builder")
87            .field("max_size", &self.max_size)
88            .field("max_lifetime", &self.max_lifetime)
89            .field("idle_timeout", &self.idle_timeout)
90            .field("connection_timeout", &self.connection_timeout)
91            .finish()
92    }
93}
94
95impl<M> Default for Builder<M>
96where
97    M: ManageConnection,
98{
99    fn default() -> Self {
100        Builder {
101            max_lifetime: Some(Duration::from_secs(60 * 30)),
102            idle_timeout: Some(Duration::from_secs(3 * 60)),
103            connection_timeout: Some(Duration::from_secs(3)),
104            check_interval: Some(Duration::from_secs(3)),
105            max_size: 0,
106            _pd: PhantomData,
107        }
108    }
109}
110
111impl<M> Builder<M>
112where
113    M: ManageConnection,
114{
115    // Constructs a new `Builder`.
116    ///
117    /// Parameters are initialized with their default values.
118    pub fn new() -> Self {
119        Builder::default()
120    }
121
122    /// Sets the maximum lifetime of connections in the pool.
123    ///
124    /// If a connection reaches its maximum lifetime while checked out it will
125    /// be closed when it is returned to the pool.
126    ///
127    /// Defaults to 30 minutes.
128    ///
129    /// use default if `max_lifetime` is the zero `Duration`.
130    pub fn max_lifetime(mut self, max_lifetime: Option<Duration>) -> Self {
131        if max_lifetime == Some(Duration::from_secs(0)) {
132            self
133        } else {
134            self.max_lifetime = max_lifetime;
135            self
136        }
137    }
138
139    /// Sets the idle timeout used by the pool.
140    ///
141    /// If set, connections will be closed after exceed idle time.
142    ///
143    /// Defaults to 3 minutes.
144    ///
145    /// use default if `idle_timeout` is the zero `Duration`.
146    pub fn idle_timeout(mut self, idle_timeout: Option<Duration>) -> Self {
147        if idle_timeout == Some(Duration::from_secs(0)) {
148            self
149        } else {
150            self.idle_timeout = idle_timeout;
151            self
152        }
153    }
154
155    /// Sets the connection timeout used by the pool.
156    ///
157    /// Calls to `Pool::get` will wait this long for a connection to become
158    /// available before returning an error.
159    ///
160    /// Defaults to 3 seconds.
161    /// don't timeout if `connection_timeout` is the zero duration
162    pub fn connection_timeout(mut self, connection_timeout: Option<Duration>) -> Self {
163        if connection_timeout == Some(Duration::from_secs(0)) {
164            self
165        } else {
166            self.connection_timeout = connection_timeout;
167            self
168        }
169    }
170
171    /// Sets the maximum number of connections managed by the pool.
172    ///
173    /// Defaults to 10.
174    ///
175    /// no limited if `max_size` is 0.
176    pub fn max_size(mut self, max_size: u32) -> Self {
177        self.max_size = max_size;
178        self
179    }
180
181    /// Sets the check interval of connections managed by the pool use the `ManageConnection::check`.
182    ///
183    /// Defaults to 3s.
184    pub fn check_interval(mut self, interval: Option<Duration>) -> Self {
185        self.check_interval = interval;
186        self
187    }
188
189    /// Consumes the builder, returning a new, initialized pool.
190    pub fn build(&self, manager: M) -> Pool<M>
191    where
192        M: ManageConnection,
193    {
194        let intervals = PoolInternals {
195            conns: LinkedList::new(),
196            active: 0,
197        };
198
199        let shared = SharedPool {
200            intervals: Mutex::new(intervals),
201            max_lifetime: self.max_lifetime,
202            idle_timeout: self.idle_timeout,
203            connection_timeout: self.connection_timeout,
204            max_size: self.max_size,
205            check_interval: self.check_interval,
206            manager,
207        };
208
209        let pool = Pool(Arc::new(shared));
210        tokio::spawn(pool.clone().check());
211        pool
212    }
213}
214
215/// A smart pointer wrapping a connection.
216pub struct Connection<M>
217where
218    M: ManageConnection,
219{
220    conn: Option<IdleConn<M::Connection>>,
221    pool: Pool<M>,
222}
223
224impl<M> Drop for Connection<M>
225where
226    M: ManageConnection,
227{
228    fn drop(&mut self) {
229        if self.conn.is_some() {
230            self.pool.put(self.conn.take().unwrap());
231        }
232    }
233}
234
235impl<M> Deref for Connection<M>
236where
237    M: ManageConnection,
238{
239    type Target = M::Connection;
240
241    fn deref(&self) -> &M::Connection {
242        &self.conn.as_ref().unwrap().conn
243    }
244}
245
246impl<M> DerefMut for Connection<M>
247where
248    M: ManageConnection,
249{
250    fn deref_mut(&mut self) -> &mut M::Connection {
251        &mut self.conn.as_mut().unwrap().conn
252    }
253}
254
255/// A generic connection pool.
256pub struct Pool<M>(Arc<SharedPool<M>>)
257where
258    M: ManageConnection;
259
260impl<M> Clone for Pool<M>
261where
262    M: ManageConnection,
263{
264    fn clone(&self) -> Pool<M> {
265        Pool(self.0.clone())
266    }
267}
268
269impl<M> Pool<M>
270where
271    M: ManageConnection,
272{
273    /// Creates a new connection pool with a default configuration.
274    pub fn new(manager: M) -> Pool<M> {
275        Pool::builder().build(manager)
276    }
277
278    /// Returns a builder type to configure a new pool.
279    pub fn builder() -> Builder<M> {
280        Builder::new()
281    }
282
283    pub(crate) fn interval<'a>(&'a self) -> MutexGuard<'a, PoolInternals<M::Connection>> {
284        self.0.intervals.lock().unwrap()
285    }
286
287    fn idle_count(&self) -> usize {
288        self.interval().conns.len()
289    }
290
291    fn incr_active(&self) {
292        self.interval().active += 1;
293    }
294
295    fn decr_active(&self) {
296        self.interval().active -= 1;
297    }
298
299    fn pop_front(&self) -> Option<IdleConn<M::Connection>> {
300        self.interval().conns.pop_front()
301    }
302
303    fn push_back(&mut self, conn: IdleConn<M::Connection>) {
304        self.interval().conns.push_back(conn);
305    }
306
307    fn exceed_idle_timeout(&self, conn: &IdleConn<M::Connection>) -> bool {
308        if let Some(idle_timeout) = self.0.idle_timeout {
309            if idle_timeout.as_micros() > 0 && conn.last_visited.add(idle_timeout) < Instant::now()
310            {
311                return true;
312            }
313        }
314
315        false
316    }
317
318    fn exceed_max_lifetime(&self, conn: &IdleConn<M::Connection>) -> bool {
319        if let Some(max_lifetime) = self.0.max_lifetime {
320            if max_lifetime.as_micros() > 0 && conn.created.add(max_lifetime) < Instant::now() {
321                return true;
322            }
323        }
324
325        false
326    }
327
328    async fn check(mut self) {
329        if let Some(interval) = self.0.check_interval {
330            loop {
331                delay_for(interval).await;
332                let n = self.idle_count();
333                for _ in 0..n {
334                    if let Some(mut conn) = self.pop_front() {
335                        if self.exceed_idle_timeout(&conn) || self.exceed_max_lifetime(&conn) {
336                            self.decr_active();
337                            continue;
338                        }
339                        match self.0.manager.check(&mut conn.conn).await {
340                            Ok(_) => {
341                                self.push_back(conn);
342                                continue;
343                            }
344                            Err(_) => {
345                                self.decr_active();
346                            }
347                        }
348                        continue;
349                    }
350                    break;
351                }
352            }
353        }
354    }
355
356    fn exceed_limit(&self) -> bool {
357        let max_size = self.0.max_size;
358        if max_size > 0 && self.interval().active > max_size {
359            true
360        } else {
361            false
362        }
363    }
364
365    /// Retrieves a connection from the pool.
366    ///
367    /// Waits for at most the connection timeout before returning an error.
368    pub async fn get_timeout(
369        &self,
370        connection_timeout: Option<Duration>,
371    ) -> io::Result<M::Connection> {
372        if let Some(connection_timeout) = connection_timeout {
373            let conn = match timeout(connection_timeout, self.0.manager.connect()).await {
374                Ok(s) => match s {
375                    Ok(s) => s,
376                    Err(e) => {
377                        return Err(other(&e.to_string()));
378                    }
379                },
380                Err(e) => {
381                    return Err(other(&e.to_string()));
382                }
383            };
384
385            Ok(conn)
386        } else {
387            let conn = self.0.manager.connect().await?;
388            Ok(conn)
389        }
390    }
391
392    /// Retrieves a connection from the pool.
393    ///
394    /// Waits for at most the configured connection timeout before returning an
395    /// error.
396    pub async fn get(&self) -> io::Result<Connection<M>> {
397        if let Some(conn) = self.pop_front() {
398            return Ok(Connection {
399                conn: Some(conn),
400                pool: self.clone(),
401            });
402        }
403
404        self.incr_active();
405        if self.exceed_limit() {
406            self.decr_active();
407            return Err(other("exceed limit"));
408        }
409
410        let conn = self
411            .get_timeout(self.0.connection_timeout)
412            .await
413            .map_err(|e| {
414                self.decr_active();
415                e
416            })?;
417
418        return Ok(Connection {
419            conn: Some(IdleConn {
420                conn,
421                last_visited: Instant::now(),
422                created: Instant::now(),
423            }),
424            pool: self.clone(),
425        });
426    }
427
428    fn put(&mut self, mut conn: IdleConn<M::Connection>) {
429        conn.last_visited = Instant::now();
430        self.push_back(conn);
431    }
432}
433
434struct SharedPool<M>
435where
436    M: ManageConnection,
437{
438    intervals: Mutex<PoolInternals<M::Connection>>,
439    max_lifetime: Option<Duration>,
440    idle_timeout: Option<Duration>,
441    connection_timeout: Option<Duration>,
442    max_size: u32,
443    check_interval: Option<Duration>,
444    manager: M,
445}
446
447struct IdleConn<C> {
448    conn: C,
449    last_visited: Instant,
450    created: Instant,
451}
452
453struct PoolInternals<C> {
454    conns: LinkedList<IdleConn<C>>,
455    active: u32,
456}