#![allow(clippy::upper_case_acronyms)] #![cfg_attr(feature = "docs", feature(doc_cfg))]
use std::fmt::{self, Debug};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use async_std::sync::{RwLock, RwLockWriteGuard};
use sqlx::pool::{Pool, PoolConnection};
use sqlx::{Database, Transaction};
use tide::utils::async_trait;
use tide::{http::Method, Middleware, Next, Request, Result};
#[cfg(all(feature = "tracing", debug_assertions))]
use tracing_crate::debug_span;
#[cfg(feature = "tracing")]
use tracing_crate::{info_span, Instrument};
#[cfg(all(test, not(feature = "postgres")))]
compile_error!("The tests must be run with --features=test");
#[cfg(feature = "postgres")]
#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
pub mod postgres;
#[doc(hidden)]
pub enum ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
Transacting(Transaction<'static, DB>),
Plain(PoolConnection<DB>),
}
impl<DB> Debug for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Transacting(_) => f.debug_struct("ConnectionWrapInner::Transacting").finish(),
Self::Plain(_) => f.debug_struct("ConnectionWrapInner::Plain").finish(),
}
}
}
impl<DB> Deref for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
type Target = DB::Connection;
fn deref(&self) -> &Self::Target {
match self {
ConnectionWrapInner::Plain(c) => c,
ConnectionWrapInner::Transacting(c) => c,
}
}
}
impl<DB> DerefMut for ConnectionWrapInner<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
ConnectionWrapInner::Plain(c) => c,
ConnectionWrapInner::Transacting(c) => c,
}
}
}
#[doc(hidden)]
pub type ConnectionWrap<DB> = Arc<RwLock<ConnectionWrapInner<DB>>>;
#[derive(Debug, Clone)]
pub struct SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
pool: Pool<DB>,
}
impl<DB> SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
pub async fn new(pgurl: &'_ str) -> std::result::Result<Self, sqlx::Error> {
let pool: Pool<DB> = Pool::connect(pgurl).await?;
Ok(Self { pool })
}
}
impl<DB> AsRef<Pool<DB>> for SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn as_ref(&self) -> &Pool<DB> {
&self.pool
}
}
impl<DB> From<Pool<DB>> for SQLxMiddleware<DB>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
fn from(pool: Pool<DB>) -> Self {
Self { pool }
}
}
#[async_trait]
impl<State, DB> Middleware<State> for SQLxMiddleware<DB>
where
State: Clone + Send + Sync + 'static,
DB: Database,
DB::Connection: Send + Sync + 'static,
{
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
if req.ext::<ConnectionWrap<DB>>().is_some() {
return Ok(next.run(req).await);
}
let is_safe = matches!(req.method(), Method::Get | Method::Head);
let conn_wrap_inner = if is_safe {
let conn_fut = self.pool.acquire();
#[cfg(feature = "tracing")]
let conn_fut = conn_fut.instrument(info_span!("Acquiring database connection"));
ConnectionWrapInner::Plain(conn_fut.await?)
} else {
let conn_fut = self.pool.begin();
#[cfg(feature = "tracing")]
let conn_fut =
conn_fut.instrument(info_span!("Acquiring database transaction", "COMMIT"));
ConnectionWrapInner::Transacting(conn_fut.await?)
};
let conn_wrap = Arc::new(RwLock::new(conn_wrap_inner));
req.set_ext(conn_wrap.clone());
let res = next.run(req).await;
if res.error().is_none() {
if let Ok(conn_wrap_inner) = Arc::try_unwrap(conn_wrap) {
if let ConnectionWrapInner::Transacting(connection) = conn_wrap_inner.into_inner() {
let commit_fut = connection.commit();
#[cfg(feature = "tracing")]
let commit_fut = commit_fut
.instrument(info_span!("Commiting database transaction", "COMMIT"));
commit_fut.await?;
}
} else {
panic!("We have err'd egregiously! Could not unwrap refcounted SQLx connection for COMMIT; handler may be storing connection or request inappropiately?")
}
}
Ok(res)
}
}
#[async_trait]
pub trait SQLxRequestExt {
async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
where
DB: Database,
DB::Connection: Send + Sync + 'static;
}
#[async_trait]
impl<T: Send + Sync + 'static> SQLxRequestExt for Request<T> {
async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
where
DB: Database,
DB::Connection: Send + Sync + 'static,
{
let sqlx_conn: &ConnectionWrap<DB> = self
.ext()
.expect("You must install SQLx middleware providing ConnectionWrap");
let rwlock_fut = sqlx_conn.write();
#[cfg(all(feature = "tracing", debug_assertions))]
let rwlock_fut =
rwlock_fut.instrument(debug_span!("Database connection RwLockWriteGuard acquire"));
rwlock_fut.await
}
}