rocket_db_pools_community/pool.rs
1use rocket::figment::Figment;
2
3#[allow(unused_imports)]
4use {
5 crate::{Config, Error},
6 std::time::Duration,
7};
8
9/// Generic [`Database`](crate::Database) driver connection pool trait.
10///
11/// This trait provides a generic interface to various database pooling
12/// implementations in the Rust ecosystem. It can be implemented by anyone, but
13/// this crate provides implementations for common drivers.
14///
15/// **Implementations of this trait outside of this crate should be rare. You
16/// _do not_ need to implement this trait or understand its specifics to use
17/// this crate.**
18///
19/// ## Async Trait
20///
21/// [`Pool`] is an _async_ trait. Implementations of `Pool` must be decorated
22/// with an attribute of `#[async_trait]`:
23///
24/// ```rust
25/// # #[macro_use] extern crate rocket;
26/// # extern crate rocket_db_pools_community as rocket_db_pools;
27/// use rocket::figment::Figment;
28/// use rocket_db_pools::Pool;
29///
30/// # struct MyPool;
31/// # type Connection = ();
32/// # type Error = std::convert::Infallible;
33/// #[rocket::async_trait]
34/// impl Pool for MyPool {
35/// type Connection = Connection;
36///
37/// type Error = Error;
38///
39/// async fn init(figment: &Figment) -> Result<Self, Self::Error> {
40/// todo!("initialize and return an instance of the pool");
41/// }
42///
43/// async fn get(&self) -> Result<Self::Connection, Self::Error> {
44/// todo!("fetch one connection from the pool");
45/// }
46///
47/// async fn close(&self) {
48/// todo!("gracefully shutdown connection pool");
49/// }
50/// }
51/// ```
52///
53/// ## Implementing
54///
55/// Implementations of `Pool` typically trace the following outline:
56///
57/// 1. The `Error` associated type is set to [`Error`].
58///
59/// 2. A [`Config`] is [extracted](Figment::extract()) from the `figment`
60/// passed to init.
61///
62/// 3. The pool is initialized and returned in `init()`, wrapping
63/// initialization errors in [`Error::Init`].
64///
65/// 4. A connection is retrieved in `get()`, wrapping errors in
66/// [`Error::Get`].
67///
68/// Concretely, this looks like:
69///
70/// ```rust
71/// # extern crate rocket_db_pools_community as rocket_db_pools;
72/// use rocket::figment::Figment;
73/// use rocket_db_pools::{Pool, Config, Error};
74/// #
75/// # type InitError = std::convert::Infallible;
76/// # type GetError = std::convert::Infallible;
77/// # type Connection = ();
78/// #
79/// # struct MyPool(Config);
80/// # impl MyPool {
81/// # fn new(c: Config) -> Result<Self, InitError> {
82/// # Ok(Self(c))
83/// # }
84/// #
85/// # fn acquire(&self) -> Result<Connection, GetError> {
86/// # Ok(())
87/// # }
88/// #
89/// # async fn shutdown(&self) { }
90/// # }
91///
92/// #[rocket::async_trait]
93/// impl Pool for MyPool {
94/// type Connection = Connection;
95///
96/// type Error = Error<InitError, GetError>;
97///
98/// async fn init(figment: &Figment) -> Result<Self, Self::Error> {
99/// // Extract the config from `figment`.
100/// let config: Config = figment.extract()?;
101///
102/// // Read config values, initialize `MyPool`. Map errors of type
103/// // `InitError` to `Error<InitError, _>` with `Error::Init`.
104/// let pool = MyPool::new(config).map_err(Error::Init)?;
105///
106/// // Return the fully initialized pool.
107/// Ok(pool)
108/// }
109///
110/// async fn get(&self) -> Result<Self::Connection, Self::Error> {
111/// // Get one connection from the pool, here via an `acquire()` method.
112/// // Map errors of type `GetError` to `Error<_, GetError>`.
113/// self.acquire().map_err(Error::Get)
114/// }
115///
116/// async fn close(&self) {
117/// self.shutdown().await;
118/// }
119/// }
120/// ```
121#[rocket::async_trait]
122pub trait Pool: Sized + Send + 'static {
123 /// The connection type managed by this pool, returned by [`Self::get()`].
124 type Connection;
125
126 /// The error type returned by [`Self::init()`] and [`Self::get()`].
127 type Error: std::error::Error;
128
129 /// Constructs a pool from a [Value](rocket::figment::value::Value).
130 ///
131 /// It is up to each implementor of `Pool` to define its accepted
132 /// configuration value(s) via the `Config` associated type. Most
133 /// integrations provided in `rocket_db_pools` use [`Config`], which
134 /// accepts a (required) `url` and an (optional) `pool_size`.
135 ///
136 /// ## Errors
137 ///
138 /// This method returns an error if the configuration is not compatible, or
139 /// if creating a pool failed due to an unavailable database server,
140 /// insufficient resources, or another database-specific error.
141 async fn init(figment: &Figment) -> Result<Self, Self::Error>;
142
143 /// Asynchronously retrieves a connection from the factory or pool.
144 ///
145 /// ## Errors
146 ///
147 /// This method returns an error if a connection could not be retrieved,
148 /// such as a preconfigured timeout elapsing or when the database server is
149 /// unavailable.
150 async fn get(&self) -> Result<Self::Connection, Self::Error>;
151
152 /// Shutdown the connection pool, disallowing any new connections from being
153 /// retrieved and waking up any tasks with active connections.
154 ///
155 /// The returned future may either resolve when all connections are known to
156 /// have closed or at any point prior. Details are implementation specific.
157 async fn close(&self);
158}
159
160#[cfg(feature = "deadpool")]
161mod deadpool_postgres {
162 use super::{Config, Duration, Error, Figment};
163 use deadpool::{
164 managed::{Manager, Object, Pool, PoolError},
165 Runtime,
166 };
167
168 #[cfg(feature = "diesel")]
169 use diesel_async::pooled_connection::AsyncDieselConnectionManager;
170
171 pub trait DeadManager: Manager + Sized + Send + 'static {
172 fn new(config: &Config) -> Result<Self, Self::Error>;
173 }
174
175 #[cfg(feature = "deadpool_postgres")]
176 impl DeadManager for deadpool_postgres::Manager {
177 fn new(config: &Config) -> Result<Self, Self::Error> {
178 Ok(Self::new(
179 config.url.parse()?,
180 deadpool_postgres::tokio_postgres::NoTls,
181 ))
182 }
183 }
184
185 #[cfg(feature = "deadpool_redis")]
186 impl DeadManager for deadpool_redis::Manager {
187 fn new(config: &Config) -> Result<Self, Self::Error> {
188 Self::new(config.url.as_str())
189 }
190 }
191
192 #[cfg(feature = "diesel_postgres")]
193 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
194 fn new(config: &Config) -> Result<Self, Self::Error> {
195 Ok(Self::new(config.url.as_str()))
196 }
197 }
198
199 #[cfg(feature = "diesel_mysql")]
200 impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncMysqlConnection> {
201 fn new(config: &Config) -> Result<Self, Self::Error> {
202 Ok(Self::new(config.url.as_str()))
203 }
204 }
205
206 #[cfg(feature = "diesel_sqlite")]
207 impl DeadManager
208 for AsyncDieselConnectionManager<
209 diesel_async::sync_connection_wrapper::SyncConnectionWrapper<diesel::SqliteConnection>,
210 >
211 {
212 fn new(config: &Config) -> Result<Self, Self::Error> {
213 Ok(Self::new(config.url.as_str()))
214 }
215 }
216
217 #[rocket::async_trait]
218 impl<M: DeadManager, C: From<Object<M>>> crate::Pool for Pool<M, C>
219 where
220 M::Type: Send,
221 C: Send + 'static,
222 M::Error: std::error::Error,
223 {
224 type Error = Error<PoolError<M::Error>>;
225
226 type Connection = C;
227
228 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
229 let config: Config = figment.extract()?;
230 let manager = M::new(&config).map_err(|e| Error::Init(e.into()))?;
231
232 Pool::builder(manager)
233 .max_size(config.max_connections)
234 .wait_timeout(Some(Duration::from_secs(config.connect_timeout)))
235 .create_timeout(Some(Duration::from_secs(config.connect_timeout)))
236 .recycle_timeout(config.idle_timeout.map(Duration::from_secs))
237 .runtime(Runtime::Tokio1)
238 .build()
239 .map_err(|_| Error::Init(PoolError::NoRuntimeSpecified))
240 }
241
242 async fn get(&self) -> Result<Self::Connection, Self::Error> {
243 self.get().await.map_err(Error::Get)
244 }
245
246 async fn close(&self) {
247 <Pool<M, C>>::close(self)
248 }
249 }
250}
251
252#[cfg(feature = "sqlx")]
253mod sqlx {
254 use super::{Config, Duration, Error, Figment};
255 use rocket::tracing::level_filters::LevelFilter;
256 use sqlx::ConnectOptions;
257
258 type Options<D> = <<D as sqlx::Database>::Connection as sqlx::Connection>::Options;
259
260 // Provide specialized configuration for particular databases.
261 fn specialize(__options: &mut dyn std::any::Any, __config: &Config) {
262 #[cfg(feature = "sqlx_sqlite")]
263 if let Some(o) = __options.downcast_mut::<sqlx::sqlite::SqliteConnectOptions>() {
264 *o = std::mem::take(o)
265 .busy_timeout(Duration::from_secs(__config.connect_timeout))
266 .create_if_missing(true);
267
268 if let Some(ref exts) = __config.extensions {
269 for ext in exts {
270 *o = std::mem::take(o).extension(ext.clone());
271 }
272 }
273 }
274 }
275
276 #[rocket::async_trait]
277 impl<D: sqlx::Database> crate::Pool for sqlx::Pool<D> {
278 type Error = Error<sqlx::Error>;
279
280 type Connection = sqlx::pool::PoolConnection<D>;
281
282 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
283 let config = figment.extract::<Config>()?;
284 let mut opts = config.url.parse::<Options<D>>().map_err(Error::Init)?;
285 specialize(&mut opts, &config);
286
287 opts = opts.disable_statement_logging();
288 if let Ok(value) = figment.find_value(rocket::Config::LOG_LEVEL) {
289 if let Some(level) = value.as_str().and_then(|v| v.parse().ok()) {
290 let log_level = match level {
291 LevelFilter::OFF => log::LevelFilter::Off,
292 LevelFilter::ERROR => log::LevelFilter::Error,
293 LevelFilter::WARN => log::LevelFilter::Warn,
294 LevelFilter::INFO => log::LevelFilter::Info,
295 LevelFilter::DEBUG => log::LevelFilter::Debug,
296 LevelFilter::TRACE => log::LevelFilter::Trace,
297 };
298
299 opts = opts
300 .log_statements(log_level)
301 .log_slow_statements(log_level, Duration::default());
302 }
303 }
304
305 Ok(sqlx::pool::PoolOptions::new()
306 .max_connections(config.max_connections as u32)
307 .acquire_timeout(Duration::from_secs(config.connect_timeout))
308 .idle_timeout(config.idle_timeout.map(Duration::from_secs))
309 .min_connections(config.min_connections.unwrap_or_default())
310 .connect_lazy_with(opts))
311 }
312
313 async fn get(&self) -> Result<Self::Connection, Self::Error> {
314 self.acquire().await.map_err(Error::Get)
315 }
316
317 async fn close(&self) {
318 <sqlx::Pool<D>>::close(self).await;
319 }
320 }
321}
322
323#[cfg(feature = "mongodb")]
324mod mongodb {
325 use super::{Config, Duration, Error, Figment};
326 use mongodb::{options::ClientOptions, Client};
327
328 #[rocket::async_trait]
329 impl crate::Pool for Client {
330 type Error = Error<mongodb::error::Error, std::convert::Infallible>;
331
332 type Connection = Client;
333
334 async fn init(figment: &Figment) -> Result<Self, Self::Error> {
335 let config = figment.extract::<Config>()?;
336 let mut opts = ClientOptions::parse(&config.url)
337 .await
338 .map_err(Error::Init)?;
339 opts.min_pool_size = config.min_connections;
340 opts.max_pool_size = Some(config.max_connections as u32);
341 opts.max_idle_time = config.idle_timeout.map(Duration::from_secs);
342 opts.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
343 opts.server_selection_timeout = Some(Duration::from_secs(config.connect_timeout));
344 Client::with_options(opts).map_err(Error::Init)
345 }
346
347 async fn get(&self) -> Result<Self::Connection, Self::Error> {
348 Ok(self.clone())
349 }
350
351 async fn close(&self) {
352 // nothing to do for mongodb
353 }
354 }
355}