rocket_sync_db_pools_community/
poolable.rs

1#[allow(unused)]
2use std::time::Duration;
3
4use r2d2::ManageConnection;
5use rocket::{Build, Rocket};
6
7#[allow(unused_imports)]
8use crate::{Config, Error};
9
10/// Trait implemented by `r2d2`-based database adapters.
11///
12/// # Provided Implementations
13///
14/// Implementations of `Poolable` are provided for the following types:
15///
16///   * [`diesel::MysqlConnection`](diesel::MysqlConnection)
17///   * [`diesel::PgConnection`](diesel::PgConnection)
18///   * [`diesel::SqliteConnection`](diesel::SqliteConnection)
19///   * [`postgres::Client`](postgres::Client)
20///   * [`rusqlite::Connection`](rusqlite::Connection)
21///   * [`memcache::Client`](memcache::Client)
22///
23/// # Implementation Guide
24///
25/// As an r2d2-compatible database (or other resource) adapter provider,
26/// implementing `Poolable` in your own library will enable Rocket users to
27/// consume your adapter with its built-in connection pooling support.
28///
29/// ## Example
30///
31/// Consider a library `foo` with the following types:
32///
33///   * `foo::ConnectionManager`, which implements [`r2d2::ManageConnection`]
34///   * `foo::Connection`, the `Connection` associated type of
35///     `foo::ConnectionManager`
36///   * `foo::Error`, errors resulting from manager instantiation
37///
38/// In order for Rocket to generate the required code to automatically provision
39/// a r2d2 connection pool into application state, the `Poolable` trait needs to
40/// be implemented for the connection type. The following example implements
41/// `Poolable` for `foo::Connection`:
42///
43/// ```rust
44/// # extern crate rocket_sync_db_pools_community as rocket_sync_db_pools;
45/// # mod foo {
46/// #     use std::fmt;
47/// #     use rocket_sync_db_pools::r2d2;
48/// #     #[derive(Debug)] pub struct Error;
49/// #     impl std::error::Error for Error {  }
50/// #     impl fmt::Display for Error {
51/// #         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
52/// #     }
53/// #
54/// #     pub struct Connection;
55/// #     pub struct ConnectionManager;
56/// #
57/// #     type Result<T> = std::result::Result<T, Error>;
58/// #
59/// #     impl ConnectionManager {
60/// #         pub fn new(url: &str) -> Result<Self> { Err(Error) }
61/// #     }
62/// #
63/// #     impl self::r2d2::ManageConnection for ConnectionManager {
64/// #          type Connection = Connection;
65/// #          type Error = Error;
66/// #          fn connect(&self) -> Result<Connection> { panic!() }
67/// #          fn is_valid(&self, _: &mut Connection) -> Result<()> { panic!() }
68/// #          fn has_broken(&self, _: &mut Connection) -> bool { panic!() }
69/// #     }
70/// # }
71/// use std::time::Duration;
72/// use rocket::{Rocket, Build};
73/// use rocket_sync_db_pools::{r2d2, Error, Config, Poolable, PoolResult};
74///
75/// impl Poolable for foo::Connection {
76///     type Manager = foo::ConnectionManager;
77///     type Error = foo::Error;
78///
79///     fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
80///         let config = Config::from(db_name, rocket)?;
81///         let manager = foo::ConnectionManager::new(&config.url).map_err(Error::Custom)?;
82///         Ok(r2d2::Pool::builder()
83///             .max_size(config.pool_size)
84///             .connection_timeout(Duration::from_secs(config.timeout as u64))
85///             .build(manager)?)
86///     }
87/// }
88/// ```
89///
90/// In this example, `ConnectionManager::new()` method returns a `foo::Error` on
91/// failure. The [`Error`] enum consolidates this type, the `r2d2::Error` type
92/// that can result from `r2d2::Pool::builder()`, and the
93/// [`figment::Error`](rocket::figment::Error) type from
94/// `database::Config::from()`.
95///
96/// In the event that a connection manager isn't fallible (as is the case with
97/// Diesel's r2d2 connection manager, for instance), the associated error type
98/// for the `Poolable` implementation should be `std::convert::Infallible`.
99///
100/// For more concrete example, consult Rocket's existing implementations of
101/// [`Poolable`].
102pub trait Poolable: Send + Sized + 'static {
103    /// The associated connection manager for the given connection type.
104    type Manager: ManageConnection<Connection = Self>;
105
106    /// The associated error type in the event that constructing the connection
107    /// manager and/or the connection pool fails.
108    type Error: std::fmt::Debug;
109
110    /// Creates an `r2d2` connection pool for `Manager::Connection`, returning
111    /// the pool on success.
112    #[allow(clippy::result_large_err)]
113    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
114}
115
116/// A type alias for the return type of [`Poolable::pool()`].
117#[allow(type_alias_bounds)]
118pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
119
120#[cfg(feature = "diesel_sqlite_pool")]
121impl Poolable for diesel::SqliteConnection {
122    type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
123    type Error = std::convert::Infallible;
124
125    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
126        use diesel::r2d2::{ConnectionManager, CustomizeConnection, Error, Pool};
127        use diesel::{connection::SimpleConnection, SqliteConnection};
128
129        #[derive(Debug)]
130        struct Customizer;
131
132        impl CustomizeConnection<SqliteConnection, Error> for Customizer {
133            fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
134                conn.batch_execute(
135                    "\
136                    PRAGMA journal_mode = WAL;\
137                    PRAGMA busy_timeout = 5000;\
138                    PRAGMA foreign_keys = ON;\
139                ",
140                )
141                .map_err(Error::QueryError)?;
142
143                Ok(())
144            }
145        }
146
147        let config = Config::from(db_name, rocket)?;
148        let manager = ConnectionManager::new(&config.url);
149        let pool = Pool::builder()
150            .connection_customizer(Box::new(Customizer))
151            .max_size(config.pool_size)
152            .connection_timeout(Duration::from_secs(config.timeout as u64))
153            .build(manager)?;
154
155        Ok(pool)
156    }
157}
158
159#[cfg(feature = "diesel_postgres_pool")]
160impl Poolable for diesel::PgConnection {
161    type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
162    type Error = std::convert::Infallible;
163
164    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
165        let config = Config::from(db_name, rocket)?;
166        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
167        let pool = r2d2::Pool::builder()
168            .max_size(config.pool_size)
169            .connection_timeout(Duration::from_secs(config.timeout as u64))
170            .build(manager)?;
171
172        Ok(pool)
173    }
174}
175
176#[cfg(feature = "diesel_mysql_pool")]
177impl Poolable for diesel::MysqlConnection {
178    type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
179    type Error = std::convert::Infallible;
180
181    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
182        let config = Config::from(db_name, rocket)?;
183        let manager = diesel::r2d2::ConnectionManager::new(&config.url);
184        let pool = r2d2::Pool::builder()
185            .max_size(config.pool_size)
186            .connection_timeout(Duration::from_secs(config.timeout as u64))
187            .build(manager)?;
188
189        Ok(pool)
190    }
191}
192
193// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
194#[cfg(feature = "postgres_pool")]
195impl Poolable for postgres::Client {
196    type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
197    type Error = postgres::Error;
198
199    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
200        let config = Config::from(db_name, rocket)?;
201        let url = config.url.parse().map_err(Error::Custom)?;
202        let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
203        let pool = r2d2::Pool::builder()
204            .max_size(config.pool_size)
205            .connection_timeout(Duration::from_secs(config.timeout as u64))
206            .build(manager)?;
207
208        Ok(pool)
209    }
210}
211
212#[cfg(feature = "sqlite_pool")]
213impl Poolable for rusqlite::Connection {
214    type Manager = r2d2_sqlite::SqliteConnectionManager;
215    type Error = std::convert::Infallible;
216
217    fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
218        use rocket::figment::providers::Serialized;
219
220        #[derive(Debug, serde::Deserialize, serde::Serialize)]
221        #[serde(rename_all = "snake_case")]
222        enum OpenFlag {
223            ReadOnly,
224            ReadWrite,
225            Create,
226            Uri,
227            Memory,
228            NoMutex,
229            FullMutex,
230            SharedCache,
231            PrivateCache,
232            Nofollow,
233        }
234
235        let figment = Config::figment(db_name, rocket);
236        let config: Config = figment.extract()?;
237        let open_flags: Vec<OpenFlag> = figment
238            .join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
239            .extract_inner("open_flags")?;
240
241        let mut flags = rusqlite::OpenFlags::default();
242        for flag in open_flags {
243            let sql_flag = match flag {
244                OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
245                OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
246                OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
247                OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
248                OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
249                OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
250                OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
251                OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
252                OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
253                OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
254            };
255
256            flags.insert(sql_flag)
257        }
258
259        let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url).with_flags(flags);
260
261        let pool = r2d2::Pool::builder()
262            .max_size(config.pool_size)
263            .connection_timeout(Duration::from_secs(config.timeout as u64))
264            .build(manager)?;
265
266        Ok(pool)
267    }
268}
269
270#[cfg(feature = "memcache_pool")]
271mod memcache_pool {
272    use memcache::{Client, Connectable, MemcacheError};
273
274    use super::*;
275
276    #[derive(Debug)]
277    pub struct ConnectionManager {
278        urls: Vec<String>,
279    }
280
281    impl ConnectionManager {
282        pub fn new<C: Connectable>(target: C) -> Self {
283            Self {
284                urls: target.get_urls(),
285            }
286        }
287    }
288
289    impl r2d2::ManageConnection for ConnectionManager {
290        type Connection = Client;
291        type Error = MemcacheError;
292
293        fn connect(&self) -> Result<Client, MemcacheError> {
294            Client::connect(self.urls.clone())
295        }
296
297        fn is_valid(&self, connection: &mut Client) -> Result<(), MemcacheError> {
298            connection.version().map(|_| ())
299        }
300
301        fn has_broken(&self, _connection: &mut Client) -> bool {
302            false
303        }
304    }
305
306    impl super::Poolable for memcache::Client {
307        type Manager = ConnectionManager;
308        type Error = MemcacheError;
309
310        fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
311            let config = Config::from(db_name, rocket)?;
312            let manager = ConnectionManager::new(&*config.url);
313            let pool = r2d2::Pool::builder()
314                .max_size(config.pool_size)
315                .connection_timeout(Duration::from_secs(config.timeout as u64))
316                .build(manager)?;
317
318            Ok(pool)
319        }
320    }
321}