#![cfg_attr(feature = "docs", feature(doc_cfg))]
use std::fmt::{self, Debug};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use async_std::sync::{RwLock, RwLockWriteGuard};
use sqlx::pool::{Pool, PoolConnection};
use sqlx::{Database, Transaction};
use tide::utils::async_trait;
use tide::{http::Method, Middleware, Next, Request, Result};
#[cfg(all(test, not(feature = "postgres")))]
compile_error!("The tests must be run with --features=test");
#[cfg(feature = "postgres")]
#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
pub mod postgres;
#[doc(hidden)]
pub enum ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
Transacting(Transaction<'static, DB>),
Plain(PoolConnection<DB>),
}
impl<DB> Debug for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Transacting(_) => f.debug_struct("ConnectionWrapInner::Transacting").finish(),
Self::Plain(_) => f.debug_struct("ConnectionWrapInner::Plain").finish(),
}
}
}
impl<DB> Deref for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
type Target = DB::Connection;
fn deref(&self) -> &Self::Target {
match self {
ConnectionWrapInner::Plain(c) => c,
ConnectionWrapInner::Transacting(c) => c,
}
}
}
impl<DB> DerefMut for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
ConnectionWrapInner::Plain(c) => c,
ConnectionWrapInner::Transacting(c) => c,
}
}
}
#[doc(hidden)]
pub type ConnectionWrap<DB> = Arc<RwLock<ConnectionWrapInner<DB>>>;
#[derive(Debug, Clone)]
pub struct SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
pool: Pool<DB>,
}
impl<DB> SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
pub async fn new(pgurl: &'_ str) -> std::result::Result<Self, sqlx::Error> {
let pool: Pool<DB> = Pool::connect(pgurl).await?;
Ok(Self { pool })
}
}
impl<DB> From<Pool<DB>> for SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn from(pool: Pool<DB>) -> Self {
Self { pool }
}
}
#[async_trait]
impl<State, DB> Middleware<State> for SQLxMiddleware<DB>
where
State: Clone + Send + Sync + 'static,
DB: Database,
DB::Connection: Send + Sync + 'static,
{
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
if req.ext::<ConnectionWrap<DB>>().is_some() {
return Ok(next.run(req).await);
}
let is_safe = match req.method() {
Method::Get => true,
Method::Head => true,
_ => false,
};
let conn_wrap_inner = if is_safe {
ConnectionWrapInner::Plain(self.pool.acquire().await?)
} else {
ConnectionWrapInner::Transacting(self.pool.begin().await?)
};
let conn_wrap = Arc::new(RwLock::new(conn_wrap_inner));
req.set_ext(conn_wrap.clone());
let res = next.run(req).await;
if res.error().is_none() {
if let Ok(conn_wrap_inner) = Arc::try_unwrap(conn_wrap) {
if let ConnectionWrapInner::Transacting(connection) = conn_wrap_inner.into_inner() {
connection.commit().await?;
}
} else {
panic!("We have err'd egregiously! Could not unwrap refcounted SQLx connection for COMMIT; handler may be storing connection or request inappropiately?")
}
}
Ok(res)
}
}
#[async_trait]
pub trait SQLxRequestExt {
async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
where
DB: Database,
DB::Connection: Send + Sync + 'static;
}
#[async_trait]
impl<T: Send + Sync + 'static> SQLxRequestExt for Request<T> {
async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
let sqlx_conn: &ConnectionWrap<DB> = self
.ext()
.expect("You must install SQLx middleware providing ConnectionWrap");
sqlx_conn.write().await
}
}