use proc_macro2::TokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Ident, Visibility, punctuated::Punctuated, token::Comma};
#[derive(Default)]
pub struct EntityLoaderSchema {
pub fields: Vec<EntityLoaderField>,
}
pub struct EntityLoaderField {
pub is_one: bool,
pub is_self: bool,
pub is_reverse: bool,
pub field: Ident,
pub entity: String,
pub relation_enum: Option<syn::LitStr>,
pub via: Option<syn::LitStr>,
}
pub fn expand_entity_loader(vis: &Visibility, schema: EntityLoaderSchema) -> TokenStream {
let mut field_bools: Punctuated<_, Comma> = Punctuated::new();
let mut field_nests: Punctuated<_, Comma> = Punctuated::new();
let mut one_fields: Punctuated<_, Comma> = Punctuated::new();
let mut with_impl = TokenStream::new();
let mut with_nest_impl = TokenStream::new();
let mut select_impl = TokenStream::new();
let mut assemble_one = TokenStream::new();
let mut load_one = TokenStream::new();
let mut load_many = TokenStream::new();
let mut load_one_nest = TokenStream::new();
let mut load_many_nest = TokenStream::new();
let mut load_one_nest_nest = TokenStream::new();
let mut load_many_nest_nest = TokenStream::new();
let mut into_with_param_impl = TokenStream::new();
let mut arity = 1;
let (async_, await_) = if cfg!(feature = "async") {
(quote!(async), quote!(.await))
} else {
(quote!(), quote!())
};
let async_trait = if cfg!(feature = "async") {
quote!(#[async_trait::async_trait])
} else {
quote!()
};
one_fields.push(quote!(model));
let mut total_count = HashMap::new();
for entity_field in schema.fields.iter() {
*total_count.entry(&entity_field.entity).or_insert(0) += 1;
}
for entity_field in schema.fields.iter() {
let field = &entity_field.field;
let is_one = entity_field.is_one;
let is_self = entity_field.is_self;
let is_reverse = entity_field.is_reverse;
let entity: TokenStream = entity_field.entity.parse().unwrap();
let entity_module: TokenStream = entity_field
.entity
.trim_end_matches("::Entity")
.parse()
.unwrap();
if !is_self && *total_count.get(&entity_field.entity).unwrap() != 1 {
continue;
}
if !is_self {
field_bools.push(quote! {
#[doc = " Generated by sea-orm-macros"]
pub #field: bool
});
field_nests.push(quote! {
#[doc = " Generated by sea-orm-macros"]
pub #field: #entity_module::EntityLoaderWith
});
with_impl.extend(quote! {
if target == sea_orm::compound::LoadTarget::TableRef(#entity.table_ref()) {
self.#field = true;
}
});
with_nest_impl.extend(quote! {
if left == sea_orm::compound::LoadTarget::TableRef(#entity.table_ref()) {
self.with.#field = true;
self.nest.#field.set(right);
return self;
}
});
} else {
field_bools.push(quote! {
#[doc = " Generated by sea-orm-macros"]
pub #field: bool
});
field_nests.push(quote! {
#[doc = " Generated by sea-orm-macros"]
pub #field: EntityLoaderWith
});
if let Some(relation_enum) = &entity_field.relation_enum {
with_impl.extend(quote! {
if let sea_orm::compound::LoadTarget::Relation(relation_enum) = &target {
if relation_enum == #relation_enum {
self.#field = true;
}
}
});
}
if let Some(via_lit) = &entity_field.via {
let via = Ident::new(&via_lit.value(), via_lit.span());
let target_type = if !is_reverse {
Ident::new("TableRef", via_lit.span())
} else {
Ident::new("TableRefRev", via_lit.span())
};
let target_entity = if !is_reverse {
quote!(super::#via::Entity)
} else {
quote!(super::#via::EntityReverse)
};
with_impl.extend(quote! {
if target == sea_orm::compound::LoadTarget::#target_type(super::#via::Entity.table_ref()) {
self.#field = true;
}
});
with_nest_impl.extend(quote! {
if left == sea_orm::compound::LoadTarget::#target_type(super::#via::Entity.table_ref()) {
self.with.#field = true;
self.nest.#field.set(right);
return self;
}
});
into_with_param_impl.extend(quote! {
impl EntityLoaderWithParam for #target_entity {
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(sea_orm::compound::LoadTarget::#target_type(super::#via::Entity.table_ref()), None)
}
}
impl<S> EntityLoaderWithParam for (#target_entity, S)
where
S: EntityTrait,
Entity: Related<S>,
{
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(
sea_orm::compound::LoadTarget::#target_type(super::#via::Entity.table_ref()),
Some(sea_orm::compound::LoadTarget::TableRef(self.1.table_ref())),
)
}
}
});
}
}
if is_one && !is_self {
arity += 1;
if arity <= 3 {
one_fields.push(quote!(#field));
select_impl.extend(quote! {
let select = if self.with.#field && self.nest.#field.is_empty() {
self.with.#field = false;
loaded.#field = true;
select.find_also(Entity, #entity)
} else {
select.select_also_fake(#entity)
};
});
assemble_one.extend(quote! {
if loaded.#field {
model.#field = #field.map(Into::into).map(Box::new).into();
}
});
}
load_one.extend(quote! {
if with.#field {
let #field = models.as_slice().load_one_ex(#entity, db)#await_?;
let #field = #entity_module::EntityLoader::load_nest(#field, &nest.#field, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.map(Into::into).map(Box::new).into();
}
}
});
load_one_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_one_ex(#entity, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
if let Some(model) = model.as_mut() {
model.#field = #field.map(Into::into).map(Box::new).into();
}
}
}
});
load_one_nest_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_one_ex(#entity, db)#await_?;
for (models, #field) in models.iter_mut().zip(#field) {
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.map(Into::into).map(Box::new).into();
}
}
}
});
} else if !is_one && !is_self {
load_many.extend(quote! {
if with.#field {
let #field = models.as_slice().load_many_ex(#entity, db)#await_?;
let #field = #entity_module::EntityLoader::load_nest_nest(#field, &nest.#field, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.into();
}
}
});
load_many_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_many_ex(#entity, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
if let Some(model) = model.as_mut() {
model.#field = #field.into();
}
}
}
});
load_many_nest_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_many_ex(#entity, db)#await_?;
for (models, #field) in models.iter_mut().zip(#field) {
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.into();
}
}
}
});
} else if is_one && is_self {
if let Some(relation_enum) = &entity_field.relation_enum {
let relation_enum = Ident::new(&relation_enum.value(), relation_enum.span());
load_one.extend(quote! {
if with.#field {
let #field = models.as_slice().load_self_ex(#entity, Relation::#relation_enum, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.map(Into::into).map(Box::new).into();
}
}
});
}
} else if !is_one && is_self {
if let Some(relation_enum) = &entity_field.relation_enum {
let relation_enum = Ident::new(&relation_enum.value(), relation_enum.span());
load_many.extend(quote! {
if with.#field {
let #field = models.as_slice().load_self_many_ex(#entity, Relation::#relation_enum, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.into();
}
}
});
}
if let Some(via) = &entity_field.via {
let via = Ident::new(&via.value(), via.span());
load_many.extend(quote! {
if with.#field {
let #field = models.as_slice().load_self_via_ex(super::#via::Entity, #is_reverse, db)#await_?;
let #field = EntityLoader::load_nest_nest(#field, &nest.#field, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.into();
}
}
});
load_many_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_self_via_ex(super::#via::Entity, #is_reverse, db)#await_?;
for (model, #field) in models.iter_mut().zip(#field) {
if let Some(model) = model.as_mut() {
model.#field = #field.into();
}
}
}
});
load_many_nest_nest.extend(quote! {
if with.#field {
let #field = models.as_slice().load_self_via_ex(super::#via::Entity, #is_reverse, db)#await_?;
for (models, #field) in models.iter_mut().zip(#field) {
for (model, #field) in models.iter_mut().zip(#field) {
model.#field = #field.into();
}
}
}
});
}
}
}
quote! {
#[doc = " Generated by sea-orm-macros"]
#[derive(Clone)]
#vis struct EntityLoader {
select: sea_orm::Select<Entity>,
with: EntityLoaderWith,
nest: EntityLoaderNest,
}
#[doc = " Generated by sea-orm-macros"]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#vis struct EntityReverse;
impl sea_orm::compound::EntityReverse for EntityReverse {
type Entity = Entity;
}
#[doc = " Generated by sea-orm-macros"]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#vis struct EntityLoaderWith {
#field_bools
}
#[doc = " Generated by sea-orm-macros"]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#vis struct EntityLoaderNest {
#field_nests
}
impl Entity {
#[doc = " Generated by sea-orm-macros"]
pub const REVERSE: EntityReverse = EntityReverse;
}
impl EntityLoaderWith {
#[doc = " Generated by sea-orm-macros"]
pub fn is_empty(&self) -> bool {
self == &Self::default()
}
#[doc = " Generated by sea-orm-macros"]
pub fn set(&mut self, target: sea_orm::compound::LoadTarget) {
#with_impl
}
}
#[doc = " Parameters for EntityLoader"]
#vis trait EntityLoaderWithParam {
#[doc = " Generated by sea-orm-macros"]
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>);
}
#[automatically_derived]
impl<R> EntityLoaderWithParam for R
where
R: EntityTrait,
Entity: Related<R>,
{
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(sea_orm::compound::LoadTarget::TableRef(self.table_ref()), None)
}
}
#[automatically_derived]
impl<R, S> EntityLoaderWithParam for (R, S)
where
R: EntityTrait,
Entity: Related<R>,
S: EntityTrait,
R: Related<S>,
{
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(
sea_orm::compound::LoadTarget::TableRef(self.0.table_ref()),
Some(sea_orm::compound::LoadTarget::TableRef(self.1.table_ref())),
)
}
}
#[automatically_derived]
impl<R, S> EntityLoaderWithParam for sea_orm::compound::EntityLoaderWithSelf<R, S>
where
R: EntityTrait,
Entity: Related<R>,
S: EntityTrait,
R: RelatedSelfVia<S>,
{
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(
sea_orm::compound::LoadTarget::TableRef(self.0.table_ref()),
Some(sea_orm::compound::LoadTarget::TableRef(self.1.table_ref())),
)
}
}
#[automatically_derived]
impl<R, S, SR> EntityLoaderWithParam for sea_orm::compound::EntityLoaderWithSelfRev<R, SR>
where
R: EntityTrait,
Entity: Related<R>,
S: EntityTrait,
R: RelatedSelfVia<S>,
SR: sea_orm::compound::EntityReverse<Entity = S>,
{
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(
sea_orm::compound::LoadTarget::TableRef(self.0.table_ref()),
Some(sea_orm::compound::LoadTarget::TableRefRev(S::default().table_ref())),
)
}
}
#[automatically_derived]
impl EntityLoaderWithParam for Relation {
fn into_with_param(self) -> (sea_orm::compound::LoadTarget, Option<sea_orm::compound::LoadTarget>) {
(sea_orm::compound::LoadTarget::Relation(self.name()), None)
}
}
#into_with_param_impl
#[automatically_derived]
impl std::fmt::Debug for EntityLoader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EntityLoader")
.field("select", &match (Entity::default().schema_name(), Entity::default().table_name()) {
(Some(s), t) => format!("{s}.{t}"),
(None, t) => t.to_owned(),
})
.field("with", &self.with)
.field("nest", &self.nest)
.finish()
}
}
#[automatically_derived]
impl sea_orm::QueryFilter for EntityLoader {
type QueryStatement = <sea_orm::Select<Entity> as sea_orm::QueryFilter>::QueryStatement;
fn query(&mut self) -> &mut Self::QueryStatement {
sea_orm::QueryFilter::query(&mut self.select)
}
}
#[automatically_derived]
impl sea_orm::QueryOrder for EntityLoader {
type QueryStatement = <sea_orm::Select<Entity> as sea_orm::QueryOrder>::QueryStatement;
fn query(&mut self) -> &mut Self::QueryStatement {
sea_orm::QueryOrder::query(&mut self.select)
}
}
#[automatically_derived]
#async_trait
impl sea_orm::compound::EntityLoaderTrait<Entity> for EntityLoader {
type ModelEx = ModelEx;
#async_ fn fetch<C: sea_orm::ConnectionTrait>(self, db: &C, page: u64, page_size: u64) -> Result<Vec<Self::ModelEx>, sea_orm::DbErr> {
self.fetch(db, page, page_size)#await_
}
#async_ fn num_items<C: sea_orm::ConnectionTrait>(self, db: &C, page_size: u64) -> Result<u64, sea_orm::DbErr> {
self.select.paginate(db, page_size).num_items()#await_
}
}
impl Entity {
#[doc = " Generated by sea-orm-macros"]
pub fn load() -> EntityLoader {
EntityLoader {
select: Entity::find(),
with: Default::default(),
nest: Default::default(),
}
}
}
impl EntityLoader {
#[doc = " Generated by sea-orm-macros"]
pub #async_ fn one<C: sea_orm::ConnectionTrait>(mut self, db: &C) -> Result<Option<ModelEx>, sea_orm::DbErr> {
use sea_orm::QuerySelect;
self.select = self.select.limit(1);
Ok(self.all(db)#await_?.into_iter().next())
}
#[doc = " Generated by sea-orm-macros"]
pub #async_ fn all<C: sea_orm::ConnectionTrait>(self, db: &C) -> Result<Vec<ModelEx>, sea_orm::DbErr> {
self.fetch(db, 0, 0)#await_
}
#[doc = " Generated by sea-orm-macros"]
pub fn with<T: EntityLoaderWithParam>(mut self, param: T) -> Self {
match param.into_with_param() {
(left, None) => self.with_1(left),
(left, Some(right)) => self.with_2(left, right),
}
}
fn with_1(mut self, load_target: sea_orm::compound::LoadTarget) -> Self {
self.with.set(load_target);
self
}
fn with_2(mut self, left: sea_orm::compound::LoadTarget, right: sea_orm::compound::LoadTarget) -> Self {
#with_nest_impl
self
}
#[doc = " Generated by sea-orm-macros"]
#async_ fn fetch<C: sea_orm::ConnectionTrait>(mut self, db: &C, page: u64, page_size: u64) -> Result<Vec<ModelEx>, sea_orm::DbErr> {
let select = self.select;
let mut loaded = EntityLoaderWith::default();
#select_impl
let models = if page_size != 0 {
select.paginate(db, page_size).fetch_page(page)#await_?
} else {
select.all(db)#await_?
};
let models = models.into_iter().map(|(#one_fields)| {
let mut model = model.into_ex();
#assemble_one
model
}).collect::<Vec<_>>();
let models = Self::load(models, &self.with, &self.nest, db)#await_?;
Ok(models)
}
#[doc = " Generated by sea-orm-macros"]
pub #async_ fn load<C: sea_orm::ConnectionTrait>(mut models: Vec<ModelEx>, with: &EntityLoaderWith, nest: &EntityLoaderNest, db: &C) -> Result<Vec<ModelEx>, DbErr> {
use sea_orm::LoaderTraitEx;
#load_one
#load_many
Ok(models)
}
#[doc = " Generated by sea-orm-macros"]
pub #async_ fn load_nest<C: sea_orm::ConnectionTrait>(mut models: Vec<Option<ModelEx>>, with: &EntityLoaderWith, db: &C) -> Result<Vec<Option<ModelEx>>, DbErr> {
use sea_orm::LoaderTraitEx;
#load_one_nest
#load_many_nest
Ok(models)
}
#[doc = " Generated by sea-orm-macros"]
pub #async_ fn load_nest_nest<C: sea_orm::ConnectionTrait>(mut models: Vec<Vec<ModelEx>>, with: &EntityLoaderWith, db: &C) -> Result<Vec<Vec<ModelEx>>, DbErr> {
use sea_orm::NestedLoaderTrait;
#load_one_nest_nest
#load_many_nest_nest
Ok(models)
}
}
}
}