use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::parse::Parser as _;
use syn::{DeriveInput, Field, LitStr};
fn parse_attr_args(attr: TokenStream) -> syn::Result<Option<String>> {
if attr.is_empty() {
return Ok(None);
}
let mut table = None;
syn::meta::parser(|meta| {
if meta.path.is_ident("table") {
let value: LitStr = meta.value()?.parse()?;
table = Some(value.value());
Ok(())
} else {
Err(meta.error("unsupported model attribute"))
}
})
.parse2(attr)?;
Ok(table)
}
fn has_attr(field: &Field, name: &str) -> bool {
field.attrs.iter().any(|a| a.path().is_ident(name))
}
fn validate_attrs(field: &Field) -> Vec<&syn::Attribute> {
field
.attrs
.iter()
.filter(|a| a.path().is_ident("validate"))
.collect()
}
fn user_attrs(field: &Field) -> Vec<&syn::Attribute> {
field
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("id")
&& !a.path().is_ident("indexed")
&& !a.path().is_ident("validate")
&& !a.path().is_ident("default")
})
.collect()
}
fn excluded_from_new(field: &Field) -> bool {
has_attr(field, "id") || has_attr(field, "default")
}
fn pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
chars.next().map_or_else(String::new, |c| {
c.to_uppercase().to_string() + &chars.collect::<String>()
})
})
.collect()
}
fn is_option_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(tp) = ty {
tp.path
.segments
.last()
.is_some_and(|seg| seg.ident == "Option")
} else {
false
}
}
#[allow(clippy::too_many_lines)]
pub fn model_macro(attr: TokenStream, item: TokenStream) -> TokenStream {
let input: DeriveInput = match syn::parse2(item) {
Ok(input) => input,
Err(err) => return err.to_compile_error(),
};
let syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Named(ref fields),
..
}) = input.data
else {
return syn::Error::new_spanned(
&input.ident,
"#[model] can only be applied to structs with named fields",
)
.to_compile_error();
};
let table_name = match parse_attr_args(attr) {
Ok(explicit) => explicit.unwrap_or_else(|| infer_table_name(&input.ident)),
Err(err) => return err.to_compile_error(),
};
let table_ident = syn::Ident::new(&table_name, input.ident.span());
let name = &input.ident;
let vis = &input.vis;
let outer_attrs = &input.attrs;
let new_name = format_ident!("New{name}");
let update_name = format_ident!("Update{name}");
let all_fields: Vec<&Field> = fields.named.iter().collect();
let id_fields: Vec<&&Field> = all_fields.iter().filter(|f| has_attr(f, "id")).collect();
let id_field_names: Vec<&syn::Ident> = if id_fields.is_empty() {
all_fields
.iter()
.filter(|f| {
if let syn::Type::Path(tp) = &f.ty {
tp.path.is_ident("i32") || tp.path.is_ident("i64")
} else {
false
}
})
.take(1)
.filter_map(|f| f.ident.as_ref())
.collect()
} else {
id_fields.iter().filter_map(|f| f.ident.as_ref()).collect()
};
let fields_for_new: Vec<&&Field> = all_fields
.iter()
.filter(|f| {
!excluded_from_new(f)
&& f.ident
.as_ref()
.is_some_and(|id| !id_field_names.contains(&id))
})
.collect();
let has_validation = all_fields.iter().any(|f| !validate_attrs(f).is_empty());
let query_fields: Vec<TokenStream> = all_fields
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
let attrs = user_attrs(f);
quote! { #(#attrs)* pub #ident: #ty }
})
.collect();
let new_fields: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
let val_attrs = validate_attrs(f);
quote! { #(#val_attrs)* pub #ident: #ty }
})
.collect();
let update_fields: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! {
#[serde(default)]
pub #ident: ::autumn_web::hooks::Patch<#ty>
}
})
.collect();
let field_enum_name = format_ident!("{name}Field");
let field_enum_variants: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let variant = format_ident!("{}", pascal_case(&ident.to_string()));
quote! { #variant }
})
.collect();
let validate_derive = if has_validation {
quote! { #[derive(::autumn_web::reexports::validator::Validate)] }
} else {
quote! {}
};
let merge_arms: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let is_option = is_option_type(&f.ty);
if is_option {
quote! {
match &patch.#ident {
::autumn_web::hooks::Patch::Set(v) => after.#ident = v.clone(),
::autumn_web::hooks::Patch::Clear => after.#ident = None,
::autumn_web::hooks::Patch::Unchanged => {}
}
}
} else {
quote! {
match &patch.#ident {
::autumn_web::hooks::Patch::Set(v) => after.#ident = v.clone(),
::autumn_web::hooks::Patch::Clear => {
return Err(::autumn_web::AutumnError::bad_request_msg(
format!("Cannot clear non-nullable field `{}`", stringify!(#ident))
));
}
::autumn_web::hooks::Patch::Unchanged => {}
}
}
}
})
.collect();
let draft_accessor_sigs: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let ty = &f.ty;
quote! {
fn #ident(&mut self) -> ::autumn_web::hooks::DraftField<'_, #ty>;
}
})
.collect();
let draft_accessors: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let ty = &f.ty;
quote! {
fn #ident(&mut self) -> ::autumn_web::hooks::DraftField<'_, #ty> {
::autumn_web::hooks::DraftField::new(&self.before.#ident, &mut self.after.#ident)
}
}
})
.collect();
let draft_ext_name = format_ident!("{name}DraftExt");
let changeset_name = format_ident!("__{}Changeset", name);
let changeset_fields: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! { pub #ident: Option<#ty> }
})
.collect();
let changeset_conversions: Vec<TokenStream> = fields_for_new
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let is_option = is_option_type(&f.ty);
if is_option {
quote! {
#ident: match &self.#ident {
::autumn_web::hooks::Patch::Set(v) => Some(v.clone()),
::autumn_web::hooks::Patch::Clear => Some(None),
::autumn_web::hooks::Patch::Unchanged => None,
}
}
} else {
quote! {
#ident: match &self.#ident {
::autumn_web::hooks::Patch::Set(v) => Some(v.clone()),
::autumn_web::hooks::Patch::Clear => {
panic!("Cannot clear non-nullable field `{}`", stringify!(#ident));
}
::autumn_web::hooks::Patch::Unchanged => None,
}
}
}
})
.collect();
quote! {
#[derive(Debug, Clone, ::diesel::Queryable, ::diesel::Selectable, ::diesel::AsChangeset)]
#[derive(::serde::Serialize, ::serde::Deserialize)]
#[diesel(table_name = #table_ident)]
#(#outer_attrs)*
#vis struct #name {
#(#query_fields,)*
}
#[derive(Debug, Clone, ::diesel::Insertable)]
#[derive(::serde::Serialize, ::serde::Deserialize)]
#validate_derive
#[diesel(table_name = #table_ident)]
#vis struct #new_name {
#(#new_fields,)*
}
#[derive(Debug, Clone, Default)]
#[derive(::serde::Serialize, ::serde::Deserialize)]
#vis struct #update_name {
#(#update_fields,)*
}
#[doc(hidden)]
#[derive(Debug, Clone, ::diesel::AsChangeset)]
#[diesel(table_name = #table_ident)]
pub struct #changeset_name {
#(#changeset_fields,)*
}
impl #update_name {
#[doc(hidden)]
#[must_use]
pub fn __to_changeset(&self) -> #changeset_name {
#changeset_name {
#(#changeset_conversions,)*
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#vis enum #field_enum_name {
#(#field_enum_variants,)*
}
#vis trait #draft_ext_name {
fn from_patch(current: &#name, patch: &#update_name) -> ::autumn_web::AutumnResult<::autumn_web::hooks::UpdateDraft<#name>>;
#(#draft_accessor_sigs)*
}
impl #draft_ext_name for ::autumn_web::hooks::UpdateDraft<#name> {
fn from_patch(current: &#name, patch: &#update_name) -> ::autumn_web::AutumnResult<Self> {
let mut after = current.clone();
#(#merge_arms)*
Ok(Self::new_with_changes(current.clone(), after))
}
#(#draft_accessors)*
}
}
}
pub fn infer_table_name(ident: &syn::Ident) -> String {
let name = ident.to_string();
let snake = pascal_to_snake(&name);
format!("{snake}s")
}
pub fn pascal_to_snake(s: &str) -> String {
let mut result = String::new();
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() && i > 0 {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pascal_to_snake_simple() {
assert_eq!(pascal_to_snake("User"), "user");
}
#[test]
fn pascal_to_snake_multi_word() {
assert_eq!(pascal_to_snake("BlogPost"), "blog_post");
}
#[test]
fn pascal_to_snake_three_words() {
assert_eq!(
pascal_to_snake("UserProfileSettings"),
"user_profile_settings"
);
}
#[test]
fn pascal_case_simple() {
assert_eq!(pascal_case("title"), "Title");
}
#[test]
fn pascal_case_multi_word() {
assert_eq!(pascal_case("approved_at"), "ApprovedAt");
}
#[test]
fn pascal_case_single_char() {
assert_eq!(pascal_case("x"), "X");
}
#[test]
fn infer_table_name_simple() {
let ident = syn::Ident::new("User", proc_macro2::Span::call_site());
assert_eq!(infer_table_name(&ident), "users");
}
#[test]
fn infer_table_name_multi_word() {
let ident = syn::Ident::new("BlogPost", proc_macro2::Span::call_site());
assert_eq!(infer_table_name(&ident), "blog_posts");
}
}