use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::context::Context;
use crate::entity::parse::FieldDef;
impl Context<'_> {
pub fn relation_methods(&self) -> TokenStream {
let belongs_to_methods: Vec<TokenStream> = self
.entity
.relation_fields()
.iter()
.filter_map(|field| self.belongs_to_method(field))
.collect();
let has_many_methods: Vec<TokenStream> = self
.entity
.has_many_relations()
.iter()
.map(|related| self.has_many_method(related))
.collect();
quote! {
#(#belongs_to_methods)*
#(#has_many_methods)*
}
}
fn belongs_to_method(&self, field: &FieldDef) -> Option<TokenStream> {
let related_entity = field.belongs_to()?;
let related_snake = related_entity.to_string().to_case(Case::Snake);
let method_name = format_ident!("find_{}", related_snake);
let related_row = format_ident!("{}Row", related_entity);
let schema = &self.entity.schema;
let related_table = format!("{}.{}s", schema, related_snake);
let fk_name = field.name();
let id_type = self.id_type;
let placeholder = self.dialect.placeholder(1);
let trait_name = &self.trait_name;
Some(quote! {
async fn #method_name(&self, id: #id_type) -> Result<Option<#related_entity>, Self::Error> {
let entity = <Self as #trait_name>::find_by_id(self, id).await?;
match entity {
Some(e) => {
let row: Option<#related_row> = sqlx::query_as(
&format!("SELECT * FROM {} WHERE id = {}", #related_table, #placeholder)
).bind(&e.#fk_name).fetch_optional(self).await?;
Ok(row.map(#related_entity::from))
}
None => Ok(None)
}
}
})
}
fn has_many_method(&self, related: &syn::Ident) -> TokenStream {
let related_snake = related.to_string().to_case(Case::Snake);
let method_name = format_ident!("find_{}s", related_snake);
let related_row = format_ident!("{}Row", related);
let schema = &self.entity.schema;
let related_table = format!("{}.{}s", schema, related_snake);
let entity_snake = self.entity.name_str().to_case(Case::Snake);
let fk_field = format_ident!("{}_id", entity_snake);
let id_type = self.id_type;
let placeholder = self.dialect.placeholder(1);
quote! {
async fn #method_name(&self, #fk_field: #id_type) -> Result<Vec<#related>, Self::Error> {
let rows: Vec<#related_row> = sqlx::query_as(
&format!("SELECT * FROM {} WHERE {}_id = {}", #related_table, #entity_snake, #placeholder)
).bind(&#fk_field).fetch_all(self).await?;
Ok(rows.into_iter().map(#related::from).collect())
}
}
}
}