use crate::crate_paths::{get_reinhardt_crate, get_reinhardt_orm_crate};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, Field, ItemStruct, Result, Type};
fn extract_fk_target_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
&& (last_segment.ident == "ForeignKeyField" || last_segment.ident == "OneToOneField")
&& let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
&& let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
{
return Some(inner_ty);
}
None
}
pub(crate) fn model_attribute_impl(
args: TokenStream,
mut input: ItemStruct,
) -> Result<TokenStream> {
let reinhardt = get_reinhardt_crate();
let orm_crate = get_reinhardt_orm_crate();
let has_derive_model = input.attrs.iter().any(|attr| {
if attr.path().is_ident("derive")
&& let syn::Meta::List(meta_list) = &attr.meta
{
if let Ok(paths) = meta_list.parse_args_with(
syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
) {
return paths.iter().any(|path| {
path.segments.last().is_some_and(|seg| seg.ident == "Model")
});
}
return false;
}
false
});
if has_derive_model {
return Ok(quote! { #input });
}
fn has_derive_trait(attrs: &[Attribute], trait_name: &str) -> bool {
attrs.iter().any(|attr| {
if attr.path().is_ident("derive")
&& let syn::Meta::List(meta_list) = &attr.meta
{
if let Ok(paths) = meta_list.parse_args_with(
syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
) {
return paths.iter().any(|path| {
path.segments
.last()
.is_some_and(|seg| seg.ident == trait_name)
});
}
return false;
}
false
})
}
fn has_fk_or_one_to_one_rel(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
if attr.path().is_ident("rel")
&& let syn::Meta::List(meta_list) = &attr.meta
{
let tokens_str = meta_list.tokens.to_string();
return tokens_str.contains("foreign_key") || tokens_str.contains("one_to_one");
}
false
})
}
let existing_field_names: std::collections::HashSet<String> =
if let syn::Fields::Named(ref fields) = input.fields {
fields
.named
.iter()
.filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
.collect()
} else {
std::collections::HashSet::new()
};
if let syn::Fields::Named(ref fields) = input.fields {
for field in fields.named.iter() {
if has_fk_or_one_to_one_rel(&field.attrs)
&& let Some(field_name) = &field.ident
{
let id_field_name_str = format!("{}_id", field_name);
if existing_field_names.contains(&id_field_name_str) {
return Err(syn::Error::new_spanned(
field,
format!(
"Field '{}' must not be manually defined. It will be auto-generated from the '{}' relationship field.",
id_field_name_str, field_name
),
));
}
}
}
}
let mut fk_id_fields: Vec<Field> = Vec::new();
let mut generated_id_field_names: std::collections::HashSet<String> =
std::collections::HashSet::new();
if let syn::Fields::Named(ref fields) = input.fields {
for field in fields.named.iter() {
if has_fk_or_one_to_one_rel(&field.attrs)
&& let Some(field_name) = &field.ident
&& let Some(target_ty) = extract_fk_target_type(&field.ty)
{
let id_field_name_str = format!("{}_id", field_name);
if !existing_field_names.contains(&id_field_name_str)
&& !generated_id_field_names.contains(&id_field_name_str)
{
let id_field_name = syn::Ident::new(&id_field_name_str, field_name.span());
let new_field: Field = syn::parse_quote! {
#[serde(default)]
#id_field_name: <#target_ty as #orm_crate::Model>::PrimaryKey
};
fk_id_fields.push(new_field);
generated_id_field_names.insert(id_field_name_str);
}
}
}
}
if let syn::Fields::Named(ref mut fields) = input.fields {
for field in fields.named.iter_mut() {
let has_many_to_many = field.attrs.iter().any(|attr| {
if attr.path().is_ident("rel") {
if let syn::Meta::List(meta_list) = &attr.meta {
let tokens_str = meta_list.tokens.to_string();
return tokens_str.contains("many_to_many");
}
}
false
});
let is_fk_field = extract_fk_target_type(&field.ty).is_some();
if has_many_to_many || is_fk_field {
let has_serde_skip = field.attrs.iter().any(|attr| {
if attr.path().is_ident("serde")
&& let syn::Meta::List(meta_list) = &attr.meta
{
let tokens_str = meta_list.tokens.to_string();
return tokens_str.contains("skip");
}
false
});
if !has_serde_skip {
let serde_skip_attr: Attribute = syn::parse_quote! { #[serde(skip)] };
field.attrs.push(serde_skip_attr);
}
}
}
for fk_field in fk_id_fields {
fields.named.push(fk_field);
}
}
let config_attr: Attribute = if args.is_empty() {
syn::parse_quote! { #[model_config] }
} else {
syn::parse_quote! { #[model_config(#args)] }
};
let model_path = quote!(#reinhardt::macros::Model);
let required_traits = ["Debug", "Clone", "PartialEq"];
let mut additional_traits = Vec::new();
for &trait_name in &required_traits {
if !has_derive_trait(&input.attrs, trait_name) {
additional_traits.push(trait_name);
}
}
let existing_derive_idx = input.attrs.iter().position(|attr| {
attr.path().is_ident("derive") && matches!(&attr.meta, syn::Meta::List(_))
});
if let Some(idx) = existing_derive_idx {
if let syn::Meta::List(ref meta_list) = input.attrs[idx].meta {
let existing_tokens = &meta_list.tokens;
let new_derive_attr: Attribute = if additional_traits.is_empty() {
syn::parse_quote! { #[derive(#model_path, #existing_tokens)] }
} else {
let traits_str = additional_traits.join(", ");
let tokens: TokenStream = traits_str
.parse()
.expect("Failed to parse derive traits - this is a bug");
syn::parse_quote! { #[derive(#model_path, #tokens, #existing_tokens)] }
};
input.attrs[idx] = new_derive_attr;
}
} else {
let derive_attr: Attribute = if additional_traits.is_empty() {
syn::parse_quote! { #[derive(#model_path)] }
} else {
let traits_str = additional_traits.join(", ");
let tokens: TokenStream = traits_str
.parse()
.expect("Failed to parse derive traits - this is a bug");
syn::parse_quote! { #[derive(#model_path, #tokens)] }
};
input.attrs.insert(0, derive_attr);
}
let config_insert_pos = if let Some(idx) = existing_derive_idx {
idx + 1
} else {
1
};
input.attrs.insert(config_insert_pos, config_attr);
Ok(quote! { #input })
}