use async_trait::async_trait;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use std::collections::HashMap;
use std::hash::Hash;
use crate::database::DB;
use crate::error::FrameworkError;
#[async_trait]
pub trait BatchLoad: EntityTrait + Sized
where
Self::Model: Send + Sync,
{
type Key: Clone + Eq + Hash + Send + Sync;
fn extract_pk(model: &Self::Model) -> Self::Key;
async fn batch_load<I>(ids: I) -> Result<HashMap<Self::Key, Self::Model>, FrameworkError>
where
I: IntoIterator<Item = Self::Key> + Send,
I::IntoIter: Send;
}
#[async_trait]
pub trait BatchLoadMany: EntityTrait + Sized
where
Self::Model: Send + Sync + Clone,
{
type ForeignKey: Clone + Eq + Hash + Send + Sync + 'static;
fn extract_fk(model: &Self::Model) -> Self::ForeignKey;
async fn batch_load_many<I>(
fk_values: I,
fk_column: Self::Column,
) -> Result<HashMap<Self::ForeignKey, Vec<Self::Model>>, FrameworkError>
where
I: IntoIterator<Item = Self::ForeignKey> + Send,
I::IntoIter: Send,
Self::Column: ColumnTrait + Send + Sync,
sea_orm::Value: From<Self::ForeignKey>;
}
pub async fn batch_load_by_id<E, K, C>(
ids: impl IntoIterator<Item = K> + Send,
pk_column: C,
) -> Result<HashMap<K, E::Model>, FrameworkError>
where
E: EntityTrait,
E::Model: Send + Sync,
K: Clone + Eq + Hash + Send + Sync + 'static,
C: ColumnTrait + Send + Sync,
sea_orm::Value: From<K>,
{
let ids_vec: Vec<K> = ids.into_iter().collect();
if ids_vec.is_empty() {
return Ok(HashMap::new());
}
let unique_ids: Vec<K> = ids_vec
.iter()
.cloned()
.collect::<std::collections::HashSet<K>>()
.into_iter()
.collect();
let values: Vec<sea_orm::Value> = unique_ids.iter().cloned().map(Into::into).collect();
let db = DB::connection()?;
let _entities = E::find()
.filter(pk_column.is_in(values))
.all(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))?;
Ok(HashMap::new())
}
pub async fn batch_load_has_many<E, K, C, F>(
fk_values: impl IntoIterator<Item = K> + Send,
fk_column: C,
fk_extractor: F,
) -> Result<HashMap<K, Vec<E::Model>>, FrameworkError>
where
E: EntityTrait,
E::Model: Send + Sync + Clone,
K: Clone + Eq + Hash + Send + Sync + 'static,
C: ColumnTrait + Send + Sync,
F: Fn(&E::Model) -> K + Send + Sync,
sea_orm::Value: From<K>,
{
let fks_vec: Vec<K> = fk_values.into_iter().collect();
if fks_vec.is_empty() {
return Ok(HashMap::new());
}
let unique_fks: Vec<K> = fks_vec
.iter()
.cloned()
.collect::<std::collections::HashSet<K>>()
.into_iter()
.collect();
let values: Vec<sea_orm::Value> = unique_fks.iter().cloned().map(Into::into).collect();
let db = DB::connection()?;
let entities = E::find()
.filter(fk_column.is_in(values))
.all(db.inner())
.await
.map_err(|e| FrameworkError::database(e.to_string()))?;
let mut map: HashMap<K, Vec<E::Model>> = HashMap::new();
for entity in entities {
let fk = fk_extractor(&entity);
map.entry(fk).or_default().push(entity);
}
Ok(map)
}
#[macro_export]
macro_rules! impl_batch_load {
($entity:ty, $key_type:ty, $pk_field:ident) => {
#[async_trait::async_trait]
impl $crate::database::BatchLoad for $entity {
type Key = $key_type;
fn extract_pk(model: &Self::Model) -> Self::Key {
model.$pk_field
}
async fn batch_load<I>(
ids: I,
) -> Result<
std::collections::HashMap<Self::Key, Self::Model>,
$crate::error::FrameworkError,
>
where
I: IntoIterator<Item = Self::Key> + Send,
I::IntoIter: Send,
{
use sea_orm::{ColumnTrait, EntityTrait, Iterable, QueryFilter};
use $crate::database::DB;
let ids_vec: Vec<Self::Key> = ids.into_iter().collect();
if ids_vec.is_empty() {
return Ok(std::collections::HashMap::new());
}
let unique_ids: Vec<Self::Key> = ids_vec
.iter()
.cloned()
.collect::<std::collections::HashSet<Self::Key>>()
.into_iter()
.collect();
let values: Vec<sea_orm::Value> =
unique_ids.iter().cloned().map(Into::into).collect();
let db = DB::connection()?;
let pk_col = <Self as EntityTrait>::PrimaryKey::iter()
.next()
.unwrap()
.into_column();
let entities = Self::find()
.filter(pk_col.is_in(values))
.all(db.inner())
.await
.map_err(|e| $crate::error::FrameworkError::database(e.to_string()))?;
let mut map = std::collections::HashMap::new();
for entity in entities {
let pk = Self::extract_pk(&entity);
map.insert(pk, entity);
}
Ok(map)
}
}
};
}
#[macro_export]
macro_rules! impl_batch_load_many {
($entity:ty, $fk_type:ty, $fk_extractor:expr, $fk_column:expr) => {
#[async_trait::async_trait]
impl $crate::database::BatchLoadMany for $entity {
type ForeignKey = $fk_type;
fn extract_fk(model: &Self::Model) -> Self::ForeignKey {
$fk_extractor(model)
}
async fn batch_load_many<I>(
fk_values: I,
_fk_column: Self::Column,
) -> Result<
std::collections::HashMap<Self::ForeignKey, Vec<Self::Model>>,
$crate::error::FrameworkError,
>
where
I: IntoIterator<Item = Self::ForeignKey> + Send,
I::IntoIter: Send,
Self::Column: sea_orm::ColumnTrait + Send + Sync,
sea_orm::Value: From<Self::ForeignKey>,
{
$crate::database::batch_load_has_many::<Self, _, _, _>(
fk_values,
$fk_column,
$fk_extractor,
)
.await
}
}
};
}
pub use impl_batch_load;
pub use impl_batch_load_many;