use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::parse::{EntityDef, FieldDef, SqlLevel};
use crate::utils::marker;
pub fn generate(entity: &EntityDef) -> TokenStream {
if entity.sql == SqlLevel::None {
return TokenStream::new();
}
let vis = &entity.vis;
let entity_name = entity.name();
let trait_name = format_ident!("{}Repository", entity_name);
let create_dto = entity.ident_with("Create", "Request");
let update_dto = entity.ident_with("Update", "Request");
let id_type = entity.id_field().ty();
let create_method = if entity.create_fields().is_empty() {
TokenStream::new()
} else {
quote! { async fn create(&self, dto: #create_dto) -> Result<#entity_name, Self::Error>; }
};
let update_method = if entity.update_fields().is_empty() {
TokenStream::new()
} else {
quote! { async fn update(&self, id: #id_type, dto: #update_dto) -> Result<#entity_name, Self::Error>; }
};
let relation_methods = generate_relation_methods(entity, id_type);
let projection_methods = generate_projection_methods(entity, id_type);
let soft_delete_methods = generate_soft_delete_methods(entity, id_type);
let query_method = generate_query_method(entity);
let stream_method = generate_stream_method(entity);
let marker = marker::generated();
quote! {
#marker
#[async_trait::async_trait]
#vis trait #trait_name: Send + Sync {
type Error: std::error::Error + Send + Sync;
type Pool;
fn pool(&self) -> &Self::Pool;
#create_method
async fn find_by_id(&self, id: #id_type) -> Result<Option<#entity_name>, Self::Error>;
#update_method
async fn delete(&self, id: #id_type) -> Result<bool, Self::Error>;
async fn list(&self, limit: i64, offset: i64) -> Result<Vec<#entity_name>, Self::Error>;
#query_method
#stream_method
#relation_methods
#projection_methods
#soft_delete_methods
}
}
}
fn generate_relation_methods(entity: &EntityDef, id_type: &syn::Type) -> TokenStream {
let belongs_to_methods: Vec<TokenStream> = entity
.relation_fields()
.iter()
.filter_map(|field| generate_belongs_to_method(field, id_type))
.collect();
let has_many_methods: Vec<TokenStream> = entity
.has_many_relations()
.iter()
.map(|related| generate_has_many_method(entity, related, id_type))
.collect();
quote! {
#(#belongs_to_methods)*
#(#has_many_methods)*
}
}
fn generate_belongs_to_method(field: &FieldDef, id_type: &syn::Type) -> Option<TokenStream> {
let related_entity = field.belongs_to()?;
let method_name = format_ident!("find_{}", related_entity.to_string().to_case(Case::Snake));
Some(quote! {
async fn #method_name(&self, id: #id_type) -> Result<Option<#related_entity>, Self::Error>;
})
}
fn generate_has_many_method(
entity: &EntityDef,
related: &syn::Ident,
id_type: &syn::Type
) -> TokenStream {
let related_snake = related.to_string().to_case(Case::Snake);
let method_name = format_ident!("find_{}s", related_snake);
let entity_snake = entity.name_str().to_case(Case::Snake);
let fk_field = format_ident!("{}_id", entity_snake);
quote! {
async fn #method_name(&self, #fk_field: #id_type) -> Result<Vec<#related>, Self::Error>;
}
}
fn generate_projection_methods(entity: &EntityDef, id_type: &syn::Type) -> TokenStream {
let entity_name = entity.name();
let methods: Vec<TokenStream> = entity
.projections
.iter()
.map(|proj| {
let proj_snake = proj.name.to_string().to_case(Case::Snake);
let method_name = format_ident!("find_by_id_{}", proj_snake);
let proj_type = format_ident!("{}{}", entity_name, proj.name);
quote! {
async fn #method_name(&self, id: #id_type) -> Result<Option<#proj_type>, Self::Error>;
}
})
.collect();
quote! { #(#methods)* }
}
fn generate_soft_delete_methods(entity: &EntityDef, id_type: &syn::Type) -> TokenStream {
if !entity.is_soft_delete() {
return TokenStream::new();
}
let entity_name = entity.name();
quote! {
async fn hard_delete(&self, id: #id_type) -> Result<bool, Self::Error>;
async fn restore(&self, id: #id_type) -> Result<bool, Self::Error>;
async fn find_by_id_with_deleted(&self, id: #id_type) -> Result<Option<#entity_name>, Self::Error>;
async fn list_with_deleted(&self, limit: i64, offset: i64) -> Result<Vec<#entity_name>, Self::Error>;
}
}
fn generate_query_method(entity: &EntityDef) -> TokenStream {
if !entity.has_filters() {
return TokenStream::new();
}
let entity_name = entity.name();
let query_type = entity.ident_with("", "Query");
quote! {
async fn query(&self, query: #query_type) -> Result<Vec<#entity_name>, Self::Error>;
}
}
pub fn generate_stream_method(entity: &EntityDef) -> TokenStream {
if !entity.has_streams() || !entity.has_filters() {
return TokenStream::new();
}
let entity_name = entity.name();
let filter_type = entity.ident_with("", "Filter");
quote! {
async fn stream_filtered(
&self,
filter: #filter_type,
) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Result<#entity_name, Self::Error>> + Send + '_>>, Self::Error>;
}
}