use async_trait::async_trait;
use sea_orm::{
ActiveModelBehavior, ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait,
PaginatorTrait, PrimaryKeyTrait, TryIntoModel,
};
use crate::database::{QueryBuilder, DB};
use crate::error::FrameworkError;
#[async_trait]
pub trait Model: EntityTrait + Sized
where
Self::Model: ModelTrait<Entity = Self> + Send + Sync,
{
async fn all() -> Result<Vec<Self::Model>, FrameworkError> {
let db = DB::connection()?;
Self::find()
.all(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
async fn find_by_pk<K>(id: K) -> Result<Option<Self::Model>, FrameworkError>
where
K: Into<<Self::PrimaryKey as PrimaryKeyTrait>::ValueType> + Send,
{
let db = DB::connection()?;
Self::find_by_id(id)
.one(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
async fn find_or_fail<K>(id: K) -> Result<Self::Model, FrameworkError>
where
K: Into<<Self::PrimaryKey as PrimaryKeyTrait>::ValueType> + Send + std::fmt::Debug + Copy,
{
Self::find_by_pk(id).await?.ok_or_else(|| {
FrameworkError::database(format!(
"{} with id {:?} not found",
std::any::type_name::<Self>(),
id
))
})
}
async fn count_all() -> Result<u64, FrameworkError> {
let db = DB::connection()?;
Self::find()
.count(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
async fn exists_any() -> Result<bool, FrameworkError> {
Ok(Self::count_all().await? > 0)
}
async fn first() -> Result<Option<Self::Model>, FrameworkError> {
let db = DB::connection()?;
Self::find()
.one(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
}
#[async_trait]
pub trait ModelMut: Model
where
Self::Model: ModelTrait<Entity = Self> + IntoActiveModel<Self::ActiveModel> + Send + Sync,
Self::ActiveModel: ActiveModelTrait<Entity = Self> + ActiveModelBehavior + Send,
{
async fn insert_one(model: Self::ActiveModel) -> Result<Self::Model, FrameworkError> {
let db = DB::connection()?;
model
.insert(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
async fn update_one(model: Self::ActiveModel) -> Result<Self::Model, FrameworkError> {
let db = DB::connection()?;
model
.update(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))
}
async fn delete_by_pk<K>(id: K) -> Result<u64, FrameworkError>
where
K: Into<<Self::PrimaryKey as PrimaryKeyTrait>::ValueType> + Send,
{
let db = DB::connection()?;
let result = Self::delete_by_id(id)
.exec(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))?;
Ok(result.rows_affected)
}
async fn save_one(model: Self::ActiveModel) -> Result<Self::Model, FrameworkError>
where
Self::ActiveModel: TryIntoModel<Self::Model>,
{
let db = DB::connection()?;
let saved = model
.save(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))?;
saved
.try_into_model()
.map_err(|e| FrameworkError::database(e.to_string()))
}
}
pub trait ScopedQuery: EntityTrait + Sized
where
Self::Model: Send + Sync,
{
type Scope: Scope<Self>;
fn scoped(scope: Self::Scope) -> ScopedQueryBuilder<Self> {
let builder = QueryBuilder::new();
ScopedQueryBuilder {
inner: scope.apply(builder),
}
}
fn for_owner<C, V>(owner_id: V, column: C) -> QueryBuilder<Self>
where
C: ColumnTrait,
V: Into<sea_orm::Value>,
{
QueryBuilder::new().filter(column.eq(owner_id))
}
}
pub trait Scope<E: EntityTrait>
where
E::Model: Send + Sync,
{
fn apply(self, query: QueryBuilder<E>) -> QueryBuilder<E>;
}
pub struct ScopedQueryBuilder<E>
where
E: EntityTrait,
E::Model: Send + Sync,
{
inner: QueryBuilder<E>,
}
impl<E> ScopedQueryBuilder<E>
where
E: EntityTrait,
E::Model: Send + Sync,
{
pub fn and<S: Scope<E>>(self, scope: S) -> Self {
Self {
inner: scope.apply(self.inner),
}
}
pub fn filter<F>(self, filter: F) -> Self
where
F: sea_orm::sea_query::IntoCondition,
{
Self {
inner: self.inner.filter(filter),
}
}
pub fn into_query(self) -> QueryBuilder<E> {
self.inner
}
pub async fn all(self) -> Result<Vec<E::Model>, FrameworkError> {
self.inner.all().await
}
pub async fn first(self) -> Result<Option<E::Model>, FrameworkError> {
self.inner.first().await
}
pub async fn first_or_fail(self) -> Result<E::Model, FrameworkError> {
self.inner.first_or_fail().await
}
pub async fn count(self) -> Result<u64, FrameworkError> {
self.inner.count().await
}
pub async fn exists(self) -> Result<bool, FrameworkError> {
self.inner.exists().await
}
}
#[macro_export]
macro_rules! define_scopes {
($entity:ty { $($scope_name:ident $(($($arg:ident : $arg_ty:ty),*))? => $filter:expr),* $(,)? }) => {
pub enum Scope {
$($scope_name $(($($arg_ty),*))?,)*
}
impl $crate::database::Scope<$entity> for Scope {
fn apply(self, query: $crate::database::QueryBuilder<$entity>) -> $crate::database::QueryBuilder<$entity> {
match self {
$(Self::$scope_name $(($($arg),*))? => query.filter($filter),)*
}
}
}
impl $crate::database::ScopedQuery for $entity {
type Scope = Scope;
}
};
}
pub use define_scopes;