use crate::error::{Error, Result};
use async_trait::async_trait;
use sqlx::{MySql, Transaction};
use std::future::Future;
#[async_trait]
pub trait Executor {
async fn execute(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult>;
async fn fetch_all<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
async fn fetch_one<'q, T>(&mut self, query: &'q str) -> Result<T>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
async fn fetch_optional<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
}
#[async_trait]
pub trait DbExecutor<'c>: Send + Sync {
async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult>;
async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
}
#[async_trait]
impl<'c> DbExecutor<'c> for sqlx::Pool<MySql> {
async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
sqlx::query(query)
.execute(&*self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_all(&*self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_one(&*self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_optional(&*self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
}
#[async_trait]
impl<'c> DbExecutor<'c> for Transaction<'c, MySql> {
async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
sqlx::query(query)
.execute(&mut **self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_all(&mut **self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_one(&mut **self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
where
T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
sqlx::query_as::<_, T>(query)
.fetch_optional(&mut **self)
.await
.map_err(|e| Error::Query(e.to_string()))
}
}
#[async_trait]
impl<'c, T> Executor for T
where
T: DbExecutor<'c> + Send + Sync,
{
async fn execute(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
self.execute_query(query).await
}
async fn fetch_all<'q, U>(&mut self, query: &'q str) -> Result<Vec<U>>
where
U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
self.fetch_all_query(query).await
}
async fn fetch_one<'q, U>(&mut self, query: &'q str) -> Result<U>
where
U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
self.fetch_one_query(query).await
}
async fn fetch_optional<'q, U>(&mut self, query: &'q str) -> Result<Option<U>>
where
U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
{
self.fetch_optional_query(query).await
}
}
pub struct TransactionManager<'c> {
tx: Option<Transaction<'c, MySql>>,
}
impl<'c> TransactionManager<'c> {
pub fn new(tx: Transaction<'c, MySql>) -> Self {
Self { tx: Some(tx) }
}
pub async fn execute<F, T, E>(&mut self, f: F) -> Result<T>
where
F: for<'a> FnOnce(
&'a mut Transaction<'c, MySql>,
) -> std::pin::Pin<
Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'a>,
>,
E: Into<Error>,
T: Send,
{
let tx = self
.tx
.take()
.ok_or_else(|| Error::Transaction("Transaction already used".to_string()))?;
let mut tx = tx;
match f(&mut tx).await {
Ok(value) => {
tx.commit()
.await
.map_err(|e| Error::Transaction(e.to_string()))?;
Ok(value)
}
Err(e) => {
if let Err(e) = tx.rollback().await {
return Err(Error::Transaction(e.to_string()));
}
Err(e.into())
}
}
}
}