rocket_db_pools_community/
pool.rs

1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {
5    crate::{Config, Error},
6    std::time::Duration,
7};
8
9/// Generic [`Database`](crate::Database) driver connection pool trait.
10///
11/// This trait provides a generic interface to various database pooling
12/// implementations in the Rust ecosystem. It can be implemented by anyone, but
13/// this crate provides implementations for common drivers.
14///
15/// **Implementations of this trait outside of this crate should be rare. You
16/// _do not_ need to implement this trait or understand its specifics to use
17/// this crate.**
18///
19/// ## Async Trait
20///
21/// [`Pool`] is an _async_ trait. Implementations of `Pool` must be decorated
22/// with an attribute of `#[async_trait]`:
23///
24/// ```rust
25/// # #[macro_use] extern crate rocket;
26/// # extern crate rocket_db_pools_community as rocket_db_pools;
27/// use rocket::figment::Figment;
28/// use rocket_db_pools::Pool;
29///
30/// # struct MyPool;
31/// # type Connection = ();
32/// # type Error = std::convert::Infallible;
33/// #[rocket::async_trait]
34/// impl Pool for MyPool {
35///     type Connection = Connection;
36///
37///     type Error = Error;
38///
39///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
40///         todo!("initialize and return an instance of the pool");
41///     }
42///
43///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
44///         todo!("fetch one connection from the pool");
45///     }
46///
47///     async fn close(&self) {
48///         todo!("gracefully shutdown connection pool");
49///     }
50/// }
51/// ```
52///
53/// ## Implementing
54///
55/// Implementations of `Pool` typically trace the following outline:
56///
57///   1. The `Error` associated type is set to [`Error`].
58///
59///   2. A [`Config`] is [extracted](Figment::extract()) from the `figment`
60///      passed to init.
61///
62///   3. The pool is initialized and returned in `init()`, wrapping
63///      initialization errors in [`Error::Init`].
64///
65///   4. A connection is retrieved in `get()`, wrapping errors in
66///      [`Error::Get`].
67///
68/// Concretely, this looks like:
69///
70/// ```rust
71/// # extern crate rocket_db_pools_community as rocket_db_pools;
72/// use rocket::figment::Figment;
73/// use rocket_db_pools::{Pool, Config, Error};
74/// #
75/// # type InitError = std::convert::Infallible;
76/// # type GetError = std::convert::Infallible;
77/// # type Connection = ();
78/// #
79/// # struct MyPool(Config);
80/// # impl MyPool {
81/// #    fn new(c: Config) -> Result<Self, InitError> {
82/// #        Ok(Self(c))
83/// #    }
84/// #
85/// #    fn acquire(&self) -> Result<Connection, GetError> {
86/// #        Ok(())
87/// #    }
88/// #
89/// #   async fn shutdown(&self) { }
90/// # }
91///
92/// #[rocket::async_trait]
93/// impl Pool for MyPool {
94///     type Connection = Connection;
95///
96///     type Error = Error<InitError, GetError>;
97///
98///     async fn init(figment: &Figment) -> Result<Self, Self::Error> {
99///         // Extract the config from `figment`.
100///         let config: Config = figment.extract()?;
101///
102///         // Read config values, initialize `MyPool`. Map errors of type
103///         // `InitError` to `Error<InitError, _>` with `Error::Init`.
104///         let pool = MyPool::new(config).map_err(Error::Init)?;
105///
106///         // Return the fully initialized pool.
107///         Ok(pool)
108///     }
109///
110///     async fn get(&self) -> Result<Self::Connection, Self::Error> {
111///         // Get one connection from the pool, here via an `acquire()` method.
112///         // Map errors of type `GetError` to `Error<_, GetError>`.
113///         self.acquire().map_err(Error::Get)
114///     }
115///
116///     async fn close(&self) {
117///         self.shutdown().await;
118///     }
119/// }
120/// ```
121#[rocket::async_trait]
122pub trait Pool: Sized + Send + 'static {
123    /// The connection type managed by this pool, returned by [`Self::get()`].
124    type Connection;
125
126    /// The error type returned by [`Self::init()`] and [`Self::get()`].
127    type Error: std::error::Error;
128
129    /// Constructs a pool from a [Value](rocket::figment::value::Value).
130    ///
131    /// It is up to each implementor of `Pool` to define its accepted
132    /// configuration value(s) via the `Config` associated type.  Most
133    /// integrations provided in `rocket_db_pools` use [`Config`], which
134    /// accepts a (required) `url` and an (optional) `pool_size`.
135    ///
136    /// ## Errors
137    ///
138    /// This method returns an error if the configuration is not compatible, or
139    /// if creating a pool failed due to an unavailable database server,
140    /// insufficient resources, or another database-specific error.
141    async fn init(figment: &Figment) -> Result<Self, Self::Error>;
142
143    /// Asynchronously retrieves a connection from the factory or pool.
144    ///
145    /// ## Errors
146    ///
147    /// This method returns an error if a connection could not be retrieved,
148    /// such as a preconfigured timeout elapsing or when the database server is
149    /// unavailable.
150    async fn get(&self) -> Result<Self::Connection, Self::Error>;
151
152    /// Shutdown the connection pool, disallowing any new connections from being
153    /// retrieved and waking up any tasks with active connections.
154    ///
155    /// The returned future may either resolve when all connections are known to
156    /// have closed or at any point prior. Details are implementation specific.
157    async fn close(&self);
158}
159
160#[cfg(feature = "deadpool")]
161mod deadpool_postgres {
162    use super::{Config, Duration, Error, Figment};
163    use deadpool::{
164        managed::{Manager, Object, Pool, PoolError},
165        Runtime,
166    };
167
168    #[cfg(feature = "diesel")]
169    use diesel_async::pooled_connection::AsyncDieselConnectionManager;
170
171    pub trait DeadManager: Manager + Sized + Send + 'static {
172        fn new(config: &Config) -> Result<Self, Self::Error>;
173    }
174
175    #[cfg(feature = "deadpool_postgres")]
176    impl DeadManager for deadpool_postgres::Manager {
177        fn new(config: &Config) -> Result<Self, Self::Error> {
178            Ok(Self::new(
179                config.url.parse()?,
180                deadpool_postgres::tokio_postgres::NoTls,
181            ))
182        }
183    }
184
185    #[cfg(feature = "deadpool_redis")]
186    impl DeadManager for deadpool_redis::Manager {
187        fn new(config: &Config) -> Result<Self, Self::Error> {
188            Self::new(config.url.as_str())
189        }
190    }
191
192    #[cfg(feature = "diesel_postgres")]
193    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
194        fn new(config: &Config) -> Result<Self, Self::Error> {
195            Ok(Self::new(config.url.as_str()))
196        }
197    }
198
199    #[cfg(feature = "diesel_mysql")]
200    impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
201        fn new(config: &Config) -> Result<Self, Self::Error> {
202            Ok(Self::new(config.url.as_str()))
203        }
204    }
205
206    #[cfg(feature = "diesel_sqlite")]
207    impl DeadManager
208        for AsyncDieselConnectionManager<
209            diesel_async::sync_connection_wrapper::SyncConnectionWrapper<diesel::SqliteConnection>,
210        >
211    {
212        fn new(config: &Config) -> Result<Self, Self::Error> {
213            Ok(Self::new(config.url.as_str()))
214        }
215    }
216
217    #[rocket::async_trait]
218    impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
219    where
220        M::Type: Send,
221        C: Send + 'static,
222        M::Error: std::error::Error,
223    {
224        type Error = Error<PoolError<M::Error>>;
225
226        type Connection = C;
227
228        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
229            let config: Config = figment.extract()?;
230            let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
231
232            Pool::builder(manager)
233                .max_size(config.max_connections)
234                .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
235                .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
236                .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
237                .runtime(Runtime::Tokio1)
238                .build()
239                .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
240        }
241
242        async fn get(&self) -> Result<Self::Connection, Self::Error> {
243            self.get().await.map_err(Error::Get)
244        }
245
246        async fn close(&self) {
247            <Pool<M, C>>::close(self)
248        }
249    }
250}
251
252#[cfg(feature = "sqlx")]
253mod sqlx {
254    use super::{Config, Duration, Error, Figment};
255    use rocket::tracing::level_filters::LevelFilter;
256    use sqlx::ConnectOptions;
257
258    type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
259
260    // Provide specialized configuration for particular databases.
261    fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
262        #[cfg(feature = "sqlx_sqlite")]
263        if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
264            *o = std::mem::take(o)
265                .busy_timeout(Duration::from_secs(__config.connect_timeout))
266                .create_if_missing(true);
267
268            if let Some(ref exts) = __config.extensions {
269                for ext in exts {
270                    *o = std::mem::take(o).extension(ext.clone());
271                }
272            }
273        }
274    }
275
276    #[rocket::async_trait]
277    impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
278        type Error = Error<sqlx::Error>;
279
280        type Connection = sqlx::pool::PoolConnection<D>;
281
282        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
283            let config = figment.extract::<Config>()?;
284            let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
285            specialize(&mut opts, &config);
286
287            opts = opts.disable_statement_logging();
288            if let Ok(value) = figment.find_value(rocket::Config::LOG_LEVEL) {
289                if let Some(level) = value.as_str().and_then(|v| v.parse().ok()) {
290                    let log_level = match level {
291                        LevelFilter::OFF => log::LevelFilter::Off,
292                        LevelFilter::ERROR => log::LevelFilter::Error,
293                        LevelFilter::WARN => log::LevelFilter::Warn,
294                        LevelFilter::INFO => log::LevelFilter::Info,
295                        LevelFilter::DEBUG => log::LevelFilter::Debug,
296                        LevelFilter::TRACE => log::LevelFilter::Trace,
297                    };
298
299                    opts = opts
300                        .log_statements(log_level)
301                        .log_slow_statements(log_level, Duration::default());
302                }
303            }
304
305            Ok(sqlx::pool::PoolOptions::new()
306                .max_connections(config.max_connections as u32)
307                .acquire_timeout(Duration::from_secs(config.connect_timeout))
308                .idle_timeout(config.idle_timeout.map(Duration::from_secs))
309                .min_connections(config.min_connections.unwrap_or_default())
310                .connect_lazy_with(opts))
311        }
312
313        async fn get(&self) -> Result<Self::Connection, Self::Error> {
314            self.acquire().await.map_err(Error::Get)
315        }
316
317        async fn close(&self) {
318            <sqlx::Pool<D>>::close(self).await;
319        }
320    }
321}
322
323#[cfg(feature = "mongodb")]
324mod mongodb {
325    use super::{Config, Duration, Error, Figment};
326    use mongodb::{options::ClientOptions, Client};
327
328    #[rocket::async_trait]
329    impl crate::Pool for Client {
330        type Error = Error<mongodb::error::Error, std::convert::Infallible>;
331
332        type Connection = Client;
333
334        async fn init(figment: &Figment) -> Result<Self, Self::Error> {
335            let config = figment.extract::<Config>()?;
336            let mut opts = ClientOptions::parse(&config.url)
337                .await
338                .map_err(Error::Init)?;
339            opts.min_pool_size = config.min_connections;
340            opts.max_pool_size = Some(config.max_connections as u32);
341            opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
342            opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
343            opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
344            Client::with_options(opts).map_err(Error::Init)
345        }
346
347        async fn get(&self) -> Result<Self::Connection, Self::Error> {
348            Ok(self.clone())
349        }
350
351        async fn close(&self) {
352            // nothing to do for mongodb
353        }
354    }
355}