use crate::{
mod_def,
traits::{Model, Repository},
types::Database,
};
use futures::future::try_join_all;
use sqlx::{Error, Transaction};
use std::future::Future;
use std::sync::Arc;
mod_def! {
!export
pub(crate) mod insert_tx;
pub(crate) mod update_tx;
pub(crate) mod delete_tx;
pub(crate) mod save_tx;
}
pub trait TransactionRepository<M>: Repository<M>
where
M: Model,
{
fn with_transaction<'a, 'b, F, Fut, R, E>(
&'a self,
callback: F,
) -> impl Future<Output = Result<R, E>> + Send + 'a
where
F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
R: Send + 'a,
E: From<Error> + Send,
{
async move {
let transaction = self.pool().begin().await.map_err(E::from)?;
let (ret, tx) = callback(transaction).await;
match ret {
Ok(val) => {
tx.commit().await.map_err(E::from)?;
Ok(val)
}
Err(err) => {
tx.rollback().await.map_err(E::from)?;
Err(err)
}
}
}
}
fn transaction_sequential<'a, 'b, I, F, Fut, R, E>(
&'a self,
actions: I,
) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
where
I: IntoIterator<Item = F> + Send + 'a,
I::IntoIter: Send + 'a,
F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
R: Send + 'a,
E: From<Error> + Send + 'a,
{
async move {
let mut tx = self.pool().begin().await.map_err(E::from)?;
let mut results = Vec::new();
for action in actions {
let (result, new_tx) = action(tx).await;
tx = new_tx;
match result {
Ok(value) => results.push(value),
Err(e) => {
let _ = tx.rollback().await;
return Err(e);
}
}
}
tx.commit().await.map_err(E::from)?;
Ok(results)
}
}
fn transaction_concurrent<'a, 'b, I, F, Fut, R, E>(
&'a self,
actions: I,
) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
where
I: IntoIterator<Item = F> + Send + 'a,
I::IntoIter: Send + 'a,
F: FnOnce(Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Fut + Send + 'a,
Fut: Future<Output = Result<R, E>> + Send + 'a,
R: Send + 'a,
E: From<Error> + Send + 'a,
{
async move {
let tx = self.pool().begin().await.map_err(E::from)?;
let tx = Arc::new(parking_lot::Mutex::new(tx));
let futures: Vec<_> = actions
.into_iter()
.map(|action_fn| action_fn(tx.clone()))
.collect();
let results = try_join_all(futures).await;
match results {
Ok(values) => {
let tx = match Arc::into_inner(tx) {
Some(mutex) => mutex.into_inner(),
None => return Err(E::from(Error::PoolClosed)),
};
tx.commit().await.map_err(E::from)?;
Ok(values)
}
Err(e) => {
let tx = match Arc::into_inner(tx) {
Some(mutex) => mutex.into_inner(),
None => return Err(E::from(Error::PoolClosed)),
};
tx.rollback().await.map_err(E::from)?;
Err(e)
}
}
}
}
fn try_transaction<'a, 'b, I, F, Fut, R, E>(
&'a self,
actions: I,
) -> impl Future<Output = Result<Vec<R>, Vec<E>>> + Send + 'a
where
I: IntoIterator<Item = F> + Send + 'a,
I::IntoIter: Send + 'a,
F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
R: Send + 'a,
E: From<Error> + Send + 'a,
{
async move {
let mut tx = self.pool().begin().await.map_err(|e| vec![E::from(e)])?;
let mut results = Vec::new();
let mut errors = Vec::new();
for action in actions {
let (result, new_tx) = action(tx).await;
tx = new_tx;
match result {
Ok(result) => results.push(result),
Err(e) => errors.push(e),
}
}
if errors.is_empty() {
tx.commit().await.map_err(|e| vec![E::from(e)])?;
Ok(results)
} else {
let _ = tx.rollback().await;
Err(errors)
}
}
}
}
impl<T, M> TransactionRepository<M> for T
where
T: Repository<M>,
M: Model,
{
}