use std::{future::Future, pin::Pin};
use sea_orm::{
AccessMode, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction, EntityTrait,
IsolationLevel, QueryFilter, TransactionTrait, sea_query::Expr,
};
use uuid::Uuid;
use crate::secure::tx_error::{InfraError, TxError};
use modkit_security::AccessScope;
use crate::secure::tx_config::TxConfig;
use crate::secure::{ScopableEntity, ScopeError, Scoped, SecureEntityExt, SecureSelect};
use crate::secure::db_ops::{SecureDeleteExt, SecureDeleteMany, SecureUpdateExt, SecureUpdateMany};
pub struct SecureTx<'a> {
pub(crate) tx: &'a DatabaseTransaction,
}
impl<'a> SecureTx<'a> {
#[must_use]
pub(crate) fn new(tx: &'a DatabaseTransaction) -> Self {
Self { tx }
}
}
pub struct SecureConn {
pub(crate) conn: DatabaseConnection,
}
impl SecureConn {
#[must_use]
pub(crate) fn conn_internal(&self) -> &DatabaseConnection {
&self.conn
}
#[must_use]
pub fn db_engine(&self) -> &'static str {
use sea_orm::DbBackend;
match self.conn.get_database_backend() {
DbBackend::Postgres => "postgres",
DbBackend::MySql => "mysql",
DbBackend::Sqlite => "sqlite",
}
}
#[allow(clippy::unused_self)] pub fn find<E>(&self, scope: &AccessScope) -> SecureSelect<E, Scoped>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
E::find().secure().scope_with(scope)
}
pub fn find_by_id<E>(
&self,
scope: &AccessScope,
id: Uuid,
) -> Result<SecureSelect<E, Scoped>, ScopeError>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
self.find::<E>(scope).and_id(id)
}
#[allow(clippy::unused_self)] #[must_use]
pub fn update_many<E>(&self, scope: &AccessScope) -> SecureUpdateMany<E, Scoped>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
E::update_many().secure().scope_with(scope)
}
#[allow(clippy::unused_self)] #[must_use]
pub fn delete_many<E>(&self, scope: &AccessScope) -> SecureDeleteMany<E, Scoped>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
E::delete_many().secure().scope_with(scope)
}
#[allow(clippy::needless_pass_by_value)] pub fn insert_one<E>(
&self,
scope: &AccessScope,
am: E::ActiveModel,
) -> Result<crate::secure::SecureInsertOne<E::ActiveModel, Scoped>, ScopeError>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
E::ActiveModel: sea_orm::ActiveModelTrait<Entity = E> + Send,
{
use crate::secure::SecureInsertExt;
E::insert(am.clone()).secure().scope_with_model(scope, &am)
}
pub async fn insert<E>(
&self,
scope: &AccessScope,
am: E::ActiveModel,
) -> Result<E::Model, ScopeError>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
E::ActiveModel: sea_orm::ActiveModelTrait<Entity = E> + Send,
E::Model: sea_orm::IntoActiveModel<E::ActiveModel>,
{
crate::secure::secure_insert::<E>(am, scope, self).await
}
pub async fn update_with_ctx<E>(
&self,
scope: &AccessScope,
id: Uuid,
am: E::ActiveModel,
) -> Result<E::Model, ScopeError>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
E::ActiveModel: sea_orm::ActiveModelTrait<Entity = E> + Send,
E::Model: sea_orm::IntoActiveModel<E::ActiveModel> + sea_orm::ModelTrait<Entity = E>,
{
crate::secure::secure_update_with_scope::<E>(am, scope, id, self).await
}
pub async fn delete_by_id<E>(&self, scope: &AccessScope, id: Uuid) -> Result<bool, ScopeError>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
let resource_col = E::resource_col().ok_or_else(|| {
ScopeError::Invalid("Entity must have a resource_col to use delete_by_id()")
})?;
let result = E::delete_many()
.filter(sea_orm::Condition::all().add(Expr::col(resource_col).eq(id)))
.secure()
.scope_with(scope)
.exec(self)
.await?;
Ok(result.rows_affected > 0)
}
pub async fn transaction<F>(self, f: F) -> (Self, anyhow::Result<()>)
where
F: for<'a> FnOnce(
&'a SecureTx<'a>,
)
-> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>>
+ Send,
{
let txn = match self.conn_internal().begin().await {
Ok(t) => t,
Err(e) => return (self, Err(e.into())),
};
let tx = SecureTx::new(&txn);
let res = f(&tx).await;
match res {
Ok(()) => match txn.commit().await {
Ok(()) => (self, Ok(())),
Err(e) => (self, Err(e.into())),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(e))
}
}
}
pub async fn transaction_with<T, F>(self, f: F) -> (Self, anyhow::Result<T>)
where
T: Send + 'static,
F: for<'a> FnOnce(
&'a SecureTx<'a>,
)
-> Pin<Box<dyn Future<Output = anyhow::Result<T>> + Send + 'a>>
+ Send,
{
let txn = match self.conn_internal().begin().await {
Ok(t) => t,
Err(e) => return (self, Err(e.into())),
};
let tx = SecureTx::new(&txn);
let res = f(&tx).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(e.into())),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(e))
}
}
}
pub async fn transaction_with_config<T, F>(
self,
cfg: TxConfig,
f: F,
) -> (Self, anyhow::Result<T>)
where
T: Send + 'static,
F: for<'a> FnOnce(
&'a SecureTx<'a>,
)
-> Pin<Box<dyn Future<Output = anyhow::Result<T>> + Send + 'a>>
+ Send,
{
let isolation: Option<IsolationLevel> = cfg.isolation.map(Into::into);
let access_mode: Option<AccessMode> = cfg.access_mode.map(Into::into);
let txn = match self
.conn_internal()
.begin_with_config(isolation, access_mode)
.await
{
Ok(t) => t,
Err(e) => return (self, Err(e.into())),
};
let tx = SecureTx::new(&txn);
let res = f(&tx).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(e.into())),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(e))
}
}
}
pub async fn in_transaction<T, E, F>(self, f: F) -> (Self, Result<T, TxError<E>>)
where
T: Send + 'static,
E: std::fmt::Debug + std::fmt::Display + Send + 'static,
F: for<'a> FnOnce(
&'a SecureTx<'a>,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>
+ Send,
{
let txn = match self.conn_internal().begin().await {
Ok(t) => t,
Err(e) => return (self, Err(TxError::Infra(InfraError::new(e.to_string())))),
};
let tx = SecureTx::new(&txn);
let res = f(&tx).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(TxError::Infra(InfraError::new(e.to_string())))),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(TxError::Domain(e)))
}
}
}
pub async fn in_transaction_mapped<T, E, F, M>(self, map_infra: M, f: F) -> (Self, Result<T, E>)
where
T: Send + 'static,
E: std::fmt::Debug + std::fmt::Display + Send + 'static,
M: FnOnce(InfraError) -> E + Send,
F: for<'a> FnOnce(
&'a SecureTx<'a>,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>
+ Send,
{
let (conn, result) = self.in_transaction(f).await;
(conn, result.map_err(|tx_err| tx_err.into_domain(map_infra)))
}
}