use std::marker::PhantomData;
use std::ops::DerefMut;
use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome, Request};
use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel};
use rocket::figment::providers::Serialized;
use rocket::yansi::Paint;
#[cfg(feature = "rocket_okapi")]
use rocket_okapi::{
gen::OpenApiGenerator,
request::{OpenApiFromRequest, RequestHeaderInput},
};
use crate::Pool;
pub trait Database:
From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static
{
type Pool: Pool;
const NAME: &'static str;
fn init() -> Initializer<Self> {
Initializer::new()
}
fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
if let Some(db) = rocket.state() {
return Some(db);
}
let dbtype = std::any::type_name::<Self>();
let fairing = Paint::new(format!("{dbtype}::init()")).bold();
error!(
"Attempted to fetch unattached database `{}`.",
Paint::new(dbtype).bold()
);
info_!(
"`{}` fairing must be attached prior to using this database.",
fairing
);
None
}
}
pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
pub struct Connection<'a, D: Database>(&'a <D::Pool as Pool>::Connection);
impl<D: Database> Initializer<D> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self(None, std::marker::PhantomData)
}
pub fn with_name(name: &'static str) -> Self {
Self(Some(name), std::marker::PhantomData)
}
}
impl<'a, D: Database> Connection<'a, D> {
pub fn into_inner(self) -> &'a <D::Pool as Pool>::Connection {
self.0
}
}
#[cfg(feature = "rocket_okapi")]
impl<'r, D: Database> OpenApiFromRequest<'r> for Connection<'r, D> {
fn from_request_input(
_gen: &mut OpenApiGenerator,
_name: String,
_required: bool,
) -> rocket_okapi::Result<RequestHeaderInput> {
Ok(RequestHeaderInput::None)
}
}
#[rocket::async_trait]
impl<D: Database> Fairing for Initializer<D> {
fn info(&self) -> Info {
Info {
name: self.0.unwrap_or_else(std::any::type_name::<Self>),
kind: Kind::Ignite,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
let workers: usize = rocket
.figment()
.extract_inner(rocket::Config::WORKERS)
.unwrap_or_else(|_| rocket::Config::default().workers);
let figment = rocket
.figment()
.focus(&format!("databases.{}", D::NAME))
.merge(Serialized::default("max_connections", workers * 4))
.merge(Serialized::default("connect_timeout", 5))
.merge(Serialized::default("sqlx_logging", true));
match <D::Pool>::init(&figment).await {
Ok(pool) => Ok(rocket.manage(D::from(pool))),
Err(e) => {
error!("failed to initialize database: {}", e);
Err(rocket)
}
}
}
}
#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> {
type Error = Option<<D::Pool as Pool>::Error>;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match D::fetch(req.rocket()) {
Some(pool) => Outcome::Success(Connection(pool.borrow())),
None => Outcome::Error((Status::InternalServerError, None)),
}
}
}
impl<D: Database> Sentinel for Connection<'_, D> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
D::fetch(rocket).is_none()
}
}