use rocket::figment::{Figment, providers::Serialized};
use rocket_db_pools::{Database, Pool};
use rocket::request::{FromRequest, Request, Outcome};
use std::ops::{Deref, DerefMut};
use rocket::{Ignite, Rocket, Sentinel};
use rocket::http::Status;
use rocket::async_trait;
#[async_trait]
trait PoolRead: Pool{
async fn get_read(&self) -> Result<Self::Connection, Self::Error>;
}
#[derive(Debug, Clone)]
pub struct ReadPool<P>{
main: P,
read: Option<P>,
}
#[rocket::async_trait]
impl<P> Pool for ReadPool<P> where P: Pool
{
type Error = P::Error;
type Connection = P::Connection;
async fn init(figment: &Figment) -> Result<Self, Self::Error> {
let main_pool = P::init(figment).await?;
if figment.contains("read"){
let read_config = figment.focus("read")
.join(Serialized::default("read.connect_timeout", 5));
let read_pool = P::init(&read_config).await?;
Ok(ReadPool{main: main_pool, read: Some(read_pool)})
} else {
Ok(ReadPool{main: main_pool, read: None})
}
}
async fn get(&self) -> Result<Self::Connection, Self::Error> {
self.main.get().await
}
async fn close(&self) {
self.main.close().await;
if let Some(ref read) = self.read {read.close().await;}
}
}
#[async_trait]
impl<P> PoolRead for ReadPool<P> where P: Pool{
async fn get_read(&self) -> Result<<P>::Connection, P::Error> {
self.read.as_ref().unwrap_or(&self.main).get().await
}
}
pub struct ReadConnection<D: Database>(<D::Pool as Pool>::Connection);
impl<D: Database> ReadConnection<D> {
pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
self.0
}
}
#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for ReadConnection<D> where D::Pool: PoolRead {
type Error = Option<<D::Pool as Pool>::Error>;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match D::fetch(req.rocket()) {
Some(db) => match db.get_read().await {
Ok(conn) => Outcome::Success(ReadConnection(conn)),
Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
},
None => Outcome::Error((Status::InternalServerError, None)),
}
}
}
impl<D: Database> Sentinel for ReadConnection<D> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
D::fetch(rocket).is_none()
}
}
impl<D: Database> Deref for ReadConnection<D> {
type Target = <D::Pool as Pool>::Connection;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<D: Database> DerefMut for ReadConnection<D> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<D: Database> ReadConnection<D> {
pub fn into_read_connection(self) -> ReadConnection<D>{
self
}
pub fn as_read_connection(&self) -> &ReadConnection<D>{
self
}
pub fn as_read_connection_mut(&mut self) -> &mut ReadConnection<D>{
self
}
}
#[allow(private_bounds)]
impl<D: Database> ReadConnection<D> where D::Pool: PoolRead {
pub async fn from_pool(pool: &D) -> Result<Self, <D::Pool as Pool>::Error>{
Ok(ReadConnection(pool.get_read().await?))
}
}
pub struct RwConnection<D: Database>(ReadConnection<D>);
impl<D: Database> RwConnection<D> {
pub fn into_inner(self) -> <D::Pool as Pool>::Connection {
self.0.0
}
pub fn into_read_connection(self) -> ReadConnection<D>{
self.0
}
pub fn as_read_connection(&self) -> &ReadConnection<D>{
&self.0
}
pub fn as_read_connection_mut(&mut self) -> &mut ReadConnection<D>{
&mut self.0
}
}
impl<D: Database> RwConnection<D>{
pub async fn from_pool(pool: &D) -> Result<Self, <D::Pool as Pool>::Error>{
Ok(RwConnection(ReadConnection(pool.get().await?)))
}
}
#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for RwConnection<D> {
type Error = Option<<D::Pool as Pool>::Error>;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match D::fetch(req.rocket()) {
Some(db) => match db.get().await {
Ok(conn) => Outcome::Success(RwConnection(ReadConnection(conn))),
Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
},
None => Outcome::Error((Status::InternalServerError, None)),
}
}
}
impl<D: Database> Sentinel for RwConnection<D> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
D::fetch(rocket).is_none()
}
}
impl<D: Database> Deref for RwConnection<D> {
type Target = <D::Pool as Pool>::Connection;
fn deref(&self) -> &Self::Target {
&self.0.0
}
}
impl<D: Database> DerefMut for RwConnection<D> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0.0
}
}
#[cfg(feature="rocket_okapi")]
mod okapi{
use super::*;
use rocket_okapi::request::OpenApiFromRequest;
use rocket_okapi::gen::OpenApiGenerator;
use rocket_okapi::request::RequestHeaderInput;
use rocket_okapi::OpenApiError;
impl<'r, D: Database> OpenApiFromRequest<'r> for ReadConnection<D> where D::Pool: PoolRead {
fn from_request_input(_gen: &mut OpenApiGenerator, _name: String, _required: bool) -> Result<RequestHeaderInput, OpenApiError> {
Ok(RequestHeaderInput::None)
}
}
impl<'r, D: Database> OpenApiFromRequest<'r> for RwConnection<D> {
fn from_request_input(_gen: &mut OpenApiGenerator, _name: String, _required: bool) -> Result<RequestHeaderInput, OpenApiError> {
Ok(RequestHeaderInput::None)
}
}
}