use std::future::poll_fn;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use futures_core::Stream;
use sqlx::query::Query;
use sqlx::{Executor, Pool, Transaction};
use crate::futures_util::BoxStream;
#[macro_use]
#[path = "./cond_macros.rs"]
mod cond_macros;
#[cfg(feature = "mysql")]
use sqlx::{mysql, MySql};
#[cfg(feature = "postgres")]
use sqlx::{postgres, Postgres};
#[cfg(feature = "sqlite")]
use sqlx::{sqlite, Sqlite};
#[derive(Clone, Debug)]
pub enum AnyPool {
#[cfg(feature = "postgres")]
Postgres(Pool<Postgres>),
#[cfg(feature = "mysql")]
MySql(Pool<MySql>),
#[cfg(feature = "sqlite")]
Sqlite(Pool<Sqlite>),
}
impl AnyPool {
pub async fn begin(&self) -> sqlx::Result<AnyTransaction> {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(pool) => pool.begin().await.map(AnyTransaction::Postgres),
#[cfg(feature = "mysql")]
Self::MySql(pool) => pool.begin().await.map(AnyTransaction::MySql),
#[cfg(feature = "sqlite")]
Self::Sqlite(pool) => pool.begin().await.map(AnyTransaction::Sqlite),
}
}
pub async fn close(&self) {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(pool) => pool.close().await,
#[cfg(feature = "mysql")]
Self::MySql(pool) => pool.close().await,
#[cfg(feature = "sqlite")]
Self::Sqlite(pool) => pool.close().await,
}
}
pub fn is_closed(&self) -> bool {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(pool) => pool.is_closed(),
#[cfg(feature = "mysql")]
Self::MySql(pool) => pool.is_closed(),
#[cfg(feature = "sqlite")]
Self::Sqlite(pool) => pool.is_closed(),
}
}
}
pub enum AnyTransaction {
#[cfg(feature = "postgres")]
Postgres(Transaction<'static, Postgres>),
#[cfg(feature = "mysql")]
MySql(Transaction<'static, MySql>),
#[cfg(feature = "sqlite")]
Sqlite(Transaction<'static, Sqlite>),
}
impl AnyTransaction {
pub async fn commit(self) -> sqlx::Result<()> {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(tx) => tx.commit().await,
#[cfg(feature = "mysql")]
Self::MySql(tx) => tx.commit().await,
#[cfg(feature = "sqlite")]
Self::Sqlite(tx) => tx.commit().await,
}
}
pub async fn rollback(self) -> sqlx::Result<()> {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(tx) => tx.rollback().await,
#[cfg(feature = "mysql")]
Self::MySql(tx) => tx.rollback().await,
#[cfg(feature = "sqlite")]
Self::Sqlite(tx) => tx.rollback().await,
}
}
}
pub enum AnyQuery<'q> {
#[cfg(feature = "postgres")]
PostgresPool(AnyQueryInner<'q, &'q Pool<Postgres>, postgres::PgArguments>),
#[cfg(feature = "mysql")]
MySqlPool(AnyQueryInner<'q, &'q Pool<MySql>, mysql::MySqlArguments>),
#[cfg(feature = "sqlite")]
SqlitePool(AnyQueryInner<'q, &'q Pool<Sqlite>, sqlite::SqliteArguments<'q>>),
#[cfg(feature = "postgres")]
PostgresConn(AnyQueryInner<'q, &'q mut postgres::PgConnection, postgres::PgArguments>),
#[cfg(feature = "mysql")]
MySqlConn(AnyQueryInner<'q, &'q mut mysql::MySqlConnection, mysql::MySqlArguments>),
#[cfg(feature = "sqlite")]
SqliteConn(AnyQueryInner<'q, &'q mut sqlite::SqliteConnection, sqlite::SqliteArguments<'q>>),
}
#[doc(hidden)]
pub struct AnyQueryInner<'q, E: Executor<'q>, A> {
executor: E,
query: Option<Query<'q, E::Database, A>>,
}
impl<'q> AnyQuery<'q> {
pub fn bind<T>(&mut self, value: T)
where
T: 'q + Send + AnyEncode<'q> + AnyType,
{
match self {
#[cfg(feature = "postgres")]
Self::PostgresPool(AnyQueryInner { query, .. })
| Self::PostgresConn(AnyQueryInner { query, .. }) => {
*query = query.take().map(|query| query.bind(value))
}
#[cfg(feature = "mysql")]
Self::MySqlPool(AnyQueryInner { query, .. })
| Self::MySqlConn(AnyQueryInner { query, .. }) => {
*query = query.take().map(|query| query.bind(value))
}
#[cfg(feature = "sqlite")]
Self::SqlitePool(AnyQueryInner { query, .. })
| Self::SqliteConn(AnyQueryInner { query, .. }) => {
*query = query.take().map(|query| query.bind(value))
}
}
}
pub fn fetch_many(self) -> BoxStream<'q, sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>> {
struct MappedStream<'stream, LI, LM, RI, RM> {
stream: BoxStream<'stream, sqlx::Result<sqlx::Either<LI, RI>>>,
map_left: LM,
map_right: RM,
}
impl<LI, LM, RI, RM> Unpin for MappedStream<'_, LI, LM, RI, RM> {}
impl<LI, LM, RI, RM> Stream for MappedStream<'_, LI, LM, RI, RM>
where
LM: Fn(LI) -> AnyQueryResult,
RM: Fn(RI) -> AnyRow,
{
type Item = sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let opt = ready!(self.stream.as_mut().poll_next(cx));
Poll::Ready(opt.map(|res| {
res.map(|either| either.map_left(&self.map_left).map_right(&self.map_right))
}))
}
}
macro_rules! match_impl {
($($variant:ident, $db:ident),+) => {
match self {$(
Self::$variant(AnyQueryInner { executor, query }) => {
Box::pin(MappedStream {
stream: executor.fetch_many(query.unwrap()),
map_left: AnyQueryResult::$db,
map_right: AnyRow::$db,
})
}
)+}
}
}
expand_fetch_impl!(match_impl)
}
pub async fn fetch_all(self) -> sqlx::Result<Vec<AnyRow>> {
macro_rules! match_impl {
($($variant:ident, $db:ident),+) => {
match self {$(
Self::$variant(AnyQueryInner { executor, query }) => {
let mut stream = executor.fetch_many(query.unwrap());
let mut vec = Vec::new();
while let Some(either) = poll_fn(|ctx| stream.as_mut().poll_next(ctx)).await.transpose()? {
if let Some(row) = either.right() {
vec.push(AnyRow::$db(row));
}
}
Ok(vec)
}
)+}
}
}
expand_fetch_impl!(match_impl)
}
pub async fn fetch_optional(self) -> sqlx::Result<Option<AnyRow>> {
macro_rules! match_impl {
($($variant:ident, $db:ident),+) => {
match self {$(
Self::$variant(AnyQueryInner { executor, query }) => executor
.fetch_optional(query.unwrap())
.await
.map(|option| option.map(AnyRow::$db)),
)+}
}
}
expand_fetch_impl!(match_impl)
}
pub async fn fetch_affected_rows(self) -> sqlx::Result<u64> {
macro_rules! match_impl {
($($variant:ident, $db:ident),+) => {
match self {$(
Self::$variant(AnyQueryInner { executor, query }) => {
let mut stream = executor.fetch_many(query.unwrap());
let mut count = 0;
while let Some(either) = poll_fn(|ctx| stream.as_mut().poll_next(ctx)).await.transpose()? {
match either {
sqlx::Either::Left(result) => count += result.rows_affected(),
sqlx::Either::Right(_row) => {},
}
}
Ok(count)
}
)+}
}
}
expand_fetch_impl!(match_impl)
}
}
pub enum AnyRow {
#[cfg(feature = "postgres")]
Postgres(postgres::PgRow),
#[cfg(feature = "mysql")]
MySql(mysql::MySqlRow),
#[cfg(feature = "sqlite")]
Sqlite(sqlite::SqliteRow),
}
pub enum AnyQueryResult {
#[cfg(feature = "postgres")]
Postgres(postgres::PgQueryResult),
#[cfg(feature = "mysql")]
MySql(mysql::MySqlQueryResult),
#[cfg(feature = "sqlite")]
Sqlite(sqlite::SqliteQueryResult),
}
impl AnyQueryResult {
pub fn rows_affected(&self) -> u64 {
match self {
#[cfg(feature = "postgres")]
Self::Postgres(result) => result.rows_affected(),
#[cfg(feature = "mysql")]
Self::MySql(result) => result.rows_affected(),
#[cfg(feature = "sqlite")]
Self::Sqlite(result) => result.rows_affected(),
}
}
}
pub trait AnyExecutor<'e> {
fn query<'q>(self, query: &'q str) -> AnyQuery<'q>
where
'e: 'q;
}
impl<'e> AnyExecutor<'e> for &'e AnyPool {
fn query<'q>(self, query: &'q str) -> AnyQuery<'q>
where
'e: 'q,
{
match self {
#[cfg(feature = "postgres")]
AnyPool::Postgres(pool) => AnyQuery::PostgresPool(AnyQueryInner {
executor: pool,
query: Some(sqlx::query(query)),
}),
#[cfg(feature = "mysql")]
AnyPool::MySql(pool) => AnyQuery::MySqlPool(AnyQueryInner {
executor: pool,
query: Some(sqlx::query(query)),
}),
#[cfg(feature = "sqlite")]
AnyPool::Sqlite(pool) => AnyQuery::SqlitePool(AnyQueryInner {
executor: pool,
query: Some(sqlx::query(query)),
}),
}
}
}
impl<'e> AnyExecutor<'e> for &'e mut AnyTransaction {
fn query<'q>(self, query: &'q str) -> AnyQuery<'q>
where
'e: 'q,
{
match self {
#[cfg(feature = "postgres")]
AnyTransaction::Postgres(tx) => AnyQuery::PostgresConn(AnyQueryInner {
executor: tx.deref_mut(),
query: Some(sqlx::query(query)),
}),
#[cfg(feature = "mysql")]
AnyTransaction::MySql(tx) => AnyQuery::MySqlConn(AnyQueryInner {
executor: tx.deref_mut(),
query: Some(sqlx::query(query)),
}),
#[cfg(feature = "sqlite")]
AnyTransaction::Sqlite(tx) => AnyQuery::SqliteConn(AnyQueryInner {
executor: tx.deref_mut(),
query: Some(sqlx::query(query)),
}),
}
}
}
macro_rules! uncond_trait_alias {
($(#[doc = $doc:literal])* trait $trait:ident $(<$lifetime:lifetime>)?: $($bound:path,)+) => {
$(#[doc = $doc])*
pub trait $trait $(<$lifetime>)?
where
$(Self: $bound),+
{}
impl<$($lifetime,)? T> $trait $(<$lifetime>)? for T
where
$(Self: $bound),+
{}
};
}
trait_alias!(
trait AnyEncode<'q>: sqlx::Encode<'q, Postgres>, sqlx::Encode<'q, MySql>, sqlx::Encode<'q, Sqlite>,
);
trait_alias!(
trait AnyDecode<'r>: sqlx::Decode<'r, Postgres>, sqlx::Decode<'r, MySql>, sqlx::Decode<'r, Sqlite>,
);
trait_alias!(
trait AnyType: sqlx::Type<Postgres>, sqlx::Type<MySql>, sqlx::Type<Sqlite>,
);