rocket_db_pools_community/
database.rs

1use std::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use rocket::fairing::{self, Fairing, Info, Kind};
5use rocket::figment::providers::Serialized;
6use rocket::http::Status;
7use rocket::request::{FromRequest, Outcome, Request};
8use rocket::{error, Build, Ignite, Orbit, Phase, Rocket, Sentinel};
9
10use crate::Pool;
11
12/// Derivable trait which ties a database [`Pool`] with a configuration name.
13///
14/// This trait should rarely, if ever, be implemented manually. Instead, it
15/// should be derived:
16///
17/// ```rust
18/// # extern crate rocket_db_pools_community as rocket_db_pools;
19/// # #[cfg(feature = "deadpool_redis")] mod _inner {
20/// # use rocket::launch;
21/// use rocket_db_pools::{deadpool_redis, Database};
22///
23/// #[derive(Database)]
24/// #[database("memdb")]
25/// struct Db(deadpool_redis::Pool);
26///
27/// #[launch]
28/// fn rocket() -> _ {
29///     rocket::build().attach(Db::init())
30/// }
31/// # }
32/// ```
33///
34/// See the [`Database` derive](derive@crate::Database) for details.
35pub trait Database:
36    From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static
37{
38    /// The [`Pool`] type of connections to this database.
39    ///
40    /// When `Database` is derived, this takes the value of the `Inner` type in
41    /// `struct Db(Inner)`.
42    type Pool: Pool;
43
44    /// The configuration name for this database.
45    ///
46    /// When `Database` is derived, this takes the value `"name"` in the
47    /// `#[database("name")]` attribute.
48    const NAME: &'static str;
49
50    /// Returns a fairing that initializes the database and its connection pool.
51    ///
52    /// # Example
53    ///
54    /// ```rust
55    /// # extern crate rocket_db_pools_community as rocket_db_pools;
56    /// # #[cfg(feature = "deadpool_postgres")] mod _inner {
57    /// # use rocket::launch;
58    /// use rocket_db_pools::{deadpool_postgres, Database};
59    ///
60    /// #[derive(Database)]
61    /// #[database("pg_db")]
62    /// struct Db(deadpool_postgres::Pool);
63    ///
64    /// #[launch]
65    /// fn rocket() -> _ {
66    ///     rocket::build().attach(Db::init())
67    /// }
68    /// # }
69    /// ```
70    fn init() -> Initializer<Self> {
71        Initializer::new()
72    }
73
74    /// Returns a reference to the initialized database in `rocket`. The
75    /// initializer fairing returned by `init()` must have already executed for
76    /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
77    /// is attached and either:
78    ///
79    ///   * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
80    ///     application is running. This is always the case in request guards
81    ///     and liftoff fairings,
82    ///   * _or_ Rocket is in the [`Build`](rocket::Build) or
83    ///     [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
84    ///     already been run. This is the case in all fairing callbacks
85    ///     corresponding to fairings attached _after_ the `Initializer`
86    ///     fairing.
87    ///
88    /// # Example
89    ///
90    /// Run database migrations in an ignite fairing. It is imperative that the
91    /// migration fairing be registered _after_ the `init()` fairing.
92    ///
93    /// ```rust
94    /// # extern crate rocket_db_pools_community as rocket_db_pools;
95    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
96    /// # use rocket::launch;
97    /// use rocket::{Rocket, Build};
98    /// use rocket::fairing::{self, AdHoc};
99    ///
100    /// use rocket_db_pools::{sqlx, Database};
101    ///
102    /// #[derive(Database)]
103    /// #[database("sqlite_db")]
104    /// struct Db(sqlx::SqlitePool);
105    ///
106    /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
107    ///     if let Some(db) = Db::fetch(&rocket) {
108    ///         // run migrations using `db`. get the inner type with &db.0.
109    ///         Ok(rocket)
110    ///     } else {
111    ///         Err(rocket)
112    ///     }
113    /// }
114    ///
115    /// #[launch]
116    /// fn rocket() -> _ {
117    ///     rocket::build()
118    ///         .attach(Db::init())
119    ///         .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
120    /// }
121    /// # }
122    /// ```
123    fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
124        if let Some(db) = rocket.state() {
125            return Some(db);
126        }
127
128        let conn = std::any::type_name::<Self>();
129        error!(
130            "`{conn}::init()` is not attached\n\
131            the fairing must be attached to use `{conn}` in routes."
132        );
133
134        None
135    }
136}
137
138/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
139///
140/// A value of this type can be created for any type `D` that implements
141/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
142/// value of this type _never_ needs to be constructed directly. This
143/// documentation exists purely as a reference.
144///
145/// This fairing initializes a database pool. Specifically, it:
146///
147///   1. Reads the configuration at `database.db_name`, where `db_name` is
148///      [`Database::NAME`].
149///
150///   2. Sets [`Config`](crate::Config) defaults on the configuration figment.
151///
152///   3. Calls [`Pool::init()`].
153///
154///   4. Stores the database instance in managed storage, retrievable via
155///      [`Database::fetch()`].
156///
157/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
158/// the type name `D` unless a name is explicitly provided via
159/// [`Self::with_name()`].
160pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
161
162/// A request guard which retrieves a single connection to a [`Database`].
163///
164/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
165/// single connection to `Db`.
166///
167/// The request guard succeeds if the database was initialized by the
168/// [`Initializer`] fairing and a connection is available within
169/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
170///   * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
171///     status `InternalServerError`. A [`Sentinel`] guards this condition, and so
172///     this type of error is unlikely to occur. A `None` error is returned.
173///   * If a connection is not available within `connect_timeout` seconds or
174///     another error occurs, the guard _fails_ with status `ServiceUnavailable`
175///     and the error is returned in `Some`.
176///
177/// ## Deref
178///
179/// A type of `Connection<Db>` dereferences, mutably and immutably, to the
180/// native database connection type. The [driver table](crate#supported-drivers)
181/// lists the concrete native `Deref` types.
182///
183/// # Example
184///
185/// ```rust
186/// # extern crate rocket_db_pools_community as rocket_db_pools;
187/// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
188/// # use rocket::get;
189/// # type Pool = rocket_db_pools::sqlx::SqlitePool;
190/// use rocket_db_pools::{Database, Connection};
191///
192/// #[derive(Database)]
193/// #[database("db")]
194/// struct Db(Pool);
195///
196/// #[get("/")]
197/// async fn db_op(db: Connection<Db>) {
198///     // use `&*db` to get an immutable borrow to the native connection type
199///     // use `&mut *db` to get a mutable borrow to the native connection type
200/// }
201/// # }
202/// ```
203pub struct Connection<D: Database>(<D::Pool as Pool>::Connection);
204
205impl<D: Database> Initializer<D> {
206    /// Returns a database initializer fairing for `D`.
207    ///
208    /// This method should never need to be called manually. See the [crate
209    /// docs](crate) for usage information.
210    pub fn new() -> Self {
211        Self(None, std::marker::PhantomData)
212    }
213
214    /// Returns a database initializer fairing for `D` with name `name`.
215    ///
216    /// This method should never need to be called manually. See the [crate
217    /// docs](crate) for usage information.
218    pub fn with_name(name: &'static str) -> Self {
219        Self(Some(name), std::marker::PhantomData)
220    }
221}
222
223impl<D: Database> Default for Initializer<D> {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl<D: Database> Connection<D> {
230    /// Returns the internal connection value. See the [`Connection` Deref
231    /// column](crate#supported-drivers) for the expected type of this value.
232    ///
233    /// Note that `Connection<D>` derefs to the internal connection type, so
234    /// using this method is likely unnecessary. See [deref](Connection#deref)
235    /// for examples.
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// # extern crate rocket_db_pools_community as rocket_db_pools;
241    /// # #[cfg(feature = "sqlx_sqlite")] mod _inner {
242    /// # use rocket::get;
243    /// # type Pool = rocket_db_pools::sqlx::SqlitePool;
244    /// use rocket_db_pools::{Database, Connection};
245    ///
246    /// #[derive(Database)]
247    /// #[database("db")]
248    /// struct Db(Pool);
249    ///
250    /// #[get("/")]
251    /// async fn db_op(db: Connection<Db>) {
252    ///     let inner = db.into_inner();
253    /// }
254    /// # }
255    /// ```
256    pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
257        self.0
258    }
259}
260
261#[rocket::async_trait]
262impl<D: Database> Fairing for Initializer<D> {
263    fn info(&self) -> Info {
264        Info {
265            name: self.0.unwrap_or(std::any::type_name::<Self>()),
266            kind: Kind::Ignite | Kind::Shutdown,
267        }
268    }
269
270    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
271        let workers: usize = rocket
272            .figment()
273            .extract_inner(rocket::Config::WORKERS)
274            .unwrap_or_else(|_| rocket::Config::default().workers);
275
276        let figment = rocket
277            .figment()
278            .focus(&format!("databases.{}", D::NAME))
279            .join(Serialized::default("max_connections", workers * 4))
280            .join(Serialized::default("connect_timeout", 5));
281
282        match <D::Pool>::init(&figment).await {
283            Ok(pool) => Ok(rocket.manage(D::from(pool))),
284            Err(e) => {
285                error!("database initialization failed: {e}");
286                Err(rocket)
287            }
288        }
289    }
290
291    async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
292        if let Some(db) = D::fetch(rocket) {
293            db.close().await;
294        }
295    }
296}
297
298#[rocket::async_trait]
299impl<'r, D: Database> FromRequest<'r> for Connection<D> {
300    type Error = Option<<D::Pool as Pool>::Error>;
301
302    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
303        match D::fetch(req.rocket()) {
304            Some(db) => match db.get().await {
305                Ok(conn) => Outcome::Success(Connection(conn)),
306                Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
307            },
308            None => Outcome::Error((Status::InternalServerError, None)),
309        }
310    }
311}
312
313impl<D: Database> Sentinel for Connection<D> {
314    fn abort(rocket: &Rocket<Ignite>) -> bool {
315        D::fetch(rocket).is_none()
316    }
317}
318
319impl<D: Database> Deref for Connection<D> {
320    type Target = <D::Pool as Pool>::Connection;
321
322    fn deref(&self) -> &Self::Target {
323        &self.0
324    }
325}
326
327impl<D: Database> DerefMut for Connection<D> {
328    fn deref_mut(&mut self) -> &mut Self::Target {
329        &mut self.0
330    }
331}