crudcrate_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{
5    parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Fields, Lit, Meta,
6};
7
8/// Returns true if the field’s type is of the form Option<…>
9fn field_is_optional(field: &syn::Field) -> bool {
10    if let syn::Type::Path(type_path) = &field.ty {
11        type_path
12            .path
13            .segments
14            .first()
15            .map(|seg| seg.ident == "Option")
16            .unwrap_or(false)
17    } else {
18        false
19    }
20}
21
22/// Given a field and a key (e.g. "create_model" or "update_model"),
23/// look for a `#[crudcrate(...)]` attribute on the field and return the boolean value
24/// associated with that key, if present.
25fn get_crudcrate_bool(field: &syn::Field, key: &str) -> Option<bool> {
26    for attr in &field.attrs {
27        if attr.path().is_ident("crudcrate") {
28            if let Meta::List(meta_list) = &attr.meta {
29                let metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
30                    .parse2(meta_list.tokens.clone())
31                    .ok()?;
32                for meta in metas.iter() {
33                    if let Meta::NameValue(nv) = meta {
34                        if nv.path.is_ident(key) {
35                            if let syn::Expr::Lit(expr_lit) = &nv.value {
36                                if let Lit::Bool(b) = &expr_lit.lit {
37                                    return Some(b.value);
38                                }
39                            }
40                        }
41                    }
42                }
43            }
44        }
45    }
46    None
47}
48
49/// Given a field and a key (e.g. "on_create" or "on_update"), returns the expression
50/// provided in the `#[crudcrate(...)]` attribute for that key.
51fn get_crudcrate_expr(field: &syn::Field, key: &str) -> Option<syn::Expr> {
52    for attr in &field.attrs {
53        if attr.path().is_ident("crudcrate") {
54            if let Meta::List(meta_list) = &attr.meta {
55                let metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
56                    .parse2(meta_list.tokens.clone())
57                    .ok()?;
58                for meta in metas.iter() {
59                    if let Meta::NameValue(nv) = meta {
60                        if nv.path.is_ident(key) {
61                            return Some(nv.value.clone());
62                        }
63                    }
64                }
65            }
66        }
67    }
68    None
69}
70
71/// Extracts a string literal from a struct-level attribute of the form:
72///   #[active_model = "some::path"]
73fn get_string_from_attr(attr: &syn::Attribute) -> Option<String> {
74    if let Meta::NameValue(nv) = &attr.meta {
75        if let syn::Expr::Lit(expr_lit) = &nv.value {
76            if let Lit::Str(s) = &expr_lit.lit {
77                return Some(s.value());
78            }
79        }
80    }
81    None
82}
83
84/// ===================
85/// ToCreateModel Macro
86/// ===================
87/// This macro:
88/// 1. Generates a struct named `<OriginalName>Create` that includes only the fields
89///    where `#[crudcrate(create_model = false)]` is NOT specified (default is true).
90///    If a field has an `on_create` expression, its type is made optional (with `#[serde(default)]`)
91///    so that the user may override the default.
92/// 2. Generates an impl of `From<<OriginalName>Create> for <ActiveModelType>` where:
93///    - For each field that is exposed (create_model = true) with an on_create expression,
94///      the value is taken from the create struct if provided; otherwise, the on_create
95///      expression is used (with `.into()` called if necessary).
96///    - For fields that are exposed without an on_create expression, the value is taken directly.
97///    - For fields that are not exposed but have an `on_create` expression, that expression
98///      is always used.
99#[proc_macro_derive(ToCreateModel, attributes(crudcrate))]
100pub fn to_create_model(input: TokenStream) -> TokenStream {
101    let input = parse_macro_input!(input as DeriveInput);
102    let name = input.ident;
103    let create_name = format_ident!("{}Create", name);
104
105    // Look for a struct-level active_model override.
106    let mut active_model_override = None;
107    for attr in &input.attrs {
108        if attr.path().is_ident("active_model") {
109            if let Some(s) = get_string_from_attr(attr) {
110                active_model_override =
111                    Some(syn::parse_str::<syn::Type>(&s).expect("Invalid active_model type"));
112            }
113        }
114    }
115    let active_model_type = if let Some(ty) = active_model_override {
116        quote! { #ty }
117    } else {
118        let ident = format_ident!("{}ActiveModel", name);
119        quote! { #ident }
120    };
121
122    // Support only structs with named fields.
123    let fields = if let Data::Struct(data) = input.data {
124        if let Fields::Named(named) = data.fields {
125            named.named
126        } else {
127            panic!("ToCreateModel only supports structs with named fields");
128        }
129    } else {
130        panic!("ToCreateModel can only be derived for structs");
131    };
132
133    // Build the Create struct.
134    // For each field where create_model is true, if an on_create expression is provided,
135    // make its type optional and add a #[serde(default)] so missing JSON will default to None.
136    let create_struct_fields = fields
137        .iter()
138        .filter(|field| get_crudcrate_bool(field, "create_model").unwrap_or(true))
139        .map(|field| {
140            let ident = &field.ident;
141            let ty = &field.ty;
142            if get_crudcrate_expr(field, "on_create").is_some() {
143                quote! {
144                    #[serde(default)]
145                    pub #ident: Option<#ty>
146                }
147            } else {
148                quote! {
149                    pub #ident: #ty
150                }
151            }
152        });
153
154    // Generate conversion lines.
155    let mut conv_lines = Vec::new();
156    for field in fields.iter() {
157        let ident = field.ident.as_ref().unwrap();
158        let include = get_crudcrate_bool(field, "create_model").unwrap_or(true);
159        let is_optional = field_is_optional(field);
160        if include {
161            if let Some(expr) = get_crudcrate_expr(field, "on_create") {
162                // Field is included and has a default on_create expression.
163                // The create struct field is an Option, so if the user provides Some(val),
164                // use that; otherwise, use the on_create expression.
165                conv_lines.push(quote! {
166                    #ident: ActiveValue::Set(match create.#ident {
167                        Some(val) => val,
168                        None => (#expr).into(),
169                    })
170                });
171            } else {
172                // Field is included and has no on_create.
173                conv_lines.push(quote! {
174                    #ident: ActiveValue::Set(create.#ident)
175                });
176            }
177        } else if let Some(expr) = get_crudcrate_expr(field, "on_create") {
178            // Field is not exposed in the create struct but has an on_create default.
179            if is_optional {
180                conv_lines.push(quote! {
181                    #ident: ActiveValue::Set(Some((#expr).into()))
182                });
183            } else {
184                conv_lines.push(quote! {
185                    #ident: ActiveValue::Set((#expr).into())
186                });
187            }
188        }
189    }
190
191    let expanded = quote! {
192        #[derive(Serialize, Deserialize, ToSchema, Copy, Clone)]
193        pub struct #create_name {
194            #(#create_struct_fields),*
195        }
196
197        impl From<#create_name> for #active_model_type {
198            fn from(create: #create_name) -> Self {
199                #active_model_type {
200                    #(#conv_lines),*
201                }
202            }
203        }
204    };
205
206    TokenStream::from(expanded)
207}
208
209/// ===================
210/// ToUpdateModel Macro
211/// ===================
212/// This macro:
213/// 1. Generates a struct named `<OriginalName>Update` that includes only the fields
214///    where `#[crudcrate(update_model = false)]` is NOT specified (default is true).
215/// 2. Generates an impl for a method
216///    `merge_into_activemodel(self, mut model: ActiveModelType) -> ActiveModelType`
217///    that, for each field:
218///    - For fields included in the update struct, if a value is provided, it is merged into the model.
219///      If the field is optional, the value (of type T) must be wrapped in Some to match the ActiveModel’s field of type Option<T>.
220///    - For fields excluded (update_model = false) but with an `on_update` expression, that expression is used
221///      (wrapped with Some(...) if the field is optional).
222///    - Other fields are left unchanged.
223#[proc_macro_derive(ToUpdateModel, attributes(crudcrate, active_model))]
224pub fn to_update_model(input: TokenStream) -> TokenStream {
225    let input = parse_macro_input!(input as DeriveInput);
226    let name = input.ident;
227    let update_name = format_ident!("{}Update", name);
228
229    // Look for a struct-level active_model override.
230    let mut active_model_override = None;
231    for attr in &input.attrs {
232        if attr.path().is_ident("active_model") {
233            if let Some(s) = get_string_from_attr(attr) {
234                active_model_override =
235                    Some(syn::parse_str::<syn::Type>(&s).expect("Invalid active_model type"));
236            }
237        }
238    }
239    let active_model_type = if let Some(ty) = active_model_override {
240        quote! { #ty }
241    } else {
242        let ident = format_ident!("{}ActiveModel", name);
243        quote! { #ident }
244    };
245
246    // Support only structs with named fields.
247    let fields = if let Data::Struct(data) = input.data {
248        if let Fields::Named(named) = data.fields {
249            named.named
250        } else {
251            panic!("ToUpdateModel only supports structs with named fields");
252        }
253    } else {
254        panic!("ToUpdateModel can only be derived for structs");
255    };
256
257    // Build the Update struct with only fields where update_model is true.
258    let included_fields: Vec<_> = fields
259        .iter()
260        .filter(|field| get_crudcrate_bool(field, "update_model").unwrap_or(true))
261        .collect();
262
263    let update_struct_fields = included_fields.iter().map(|field| {
264        let ident = &field.ident;
265        let ty = &field.ty;
266        // For update, if the field is Option<T>, we want the update struct field to be Option<Option<T>>.
267        let (_is_option, inner_ty) = if let syn::Type::Path(type_path) = ty {
268            if let Some(seg) = type_path.path.segments.first() {
269                if seg.ident == "Option" {
270                    if let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments {
271                        if let Some(syn::GenericArgument::Type(inner_ty)) = inner_args.args.first()
272                        {
273                            (true, inner_ty.clone())
274                        } else {
275                            (false, ty.clone())
276                        }
277                    } else {
278                        (false, ty.clone())
279                    }
280                } else {
281                    (false, ty.clone())
282                }
283            } else {
284                (false, ty.clone())
285            }
286        } else {
287            (false, ty.clone())
288        };
289        quote! {
290            #[serde(
291                default,
292                skip_serializing_if = "Option::is_none",
293                with = "::serde_with::rust::double_option"
294            )]
295            pub #ident: Option<Option<#inner_ty>>
296        }
297    });
298
299    // Generate merge code for fields included in the update struct.
300    let included_merge: Vec<_> = included_fields
301        .iter()
302        .map(|field| {
303            let ident = &field.ident;
304            let is_optional = field_is_optional(field);
305            if is_optional {
306                // For optional fields, wrap the inner value in Some.
307                quote! {
308                    model.#ident = match self.#ident {
309                        Some(Some(value)) => ActiveValue::Set(Some(value)),
310                        Some(_) => ActiveValue::NotSet,
311                        _ => ActiveValue::NotSet,
312                    };
313                }
314            } else {
315                quote! {
316                    model.#ident = match self.#ident {
317                        Some(Some(value)) => ActiveValue::Set(value),
318                        Some(_) => ActiveValue::NotSet,
319                        _ => ActiveValue::NotSet,
320                    };
321                }
322            }
323        })
324        .collect();
325
326    // For fields excluded (update_model = false) that have an on_update expression,
327    // generate merge code. Wrap the expression using `.into()` if needed.
328    let excluded_merge: Vec<_> = fields
329        .iter()
330        .filter_map(|field| {
331            if get_crudcrate_bool(field, "update_model") == Some(false) {
332                if let Some(expr) = get_crudcrate_expr(field, "on_update") {
333                    let ident = &field.ident;
334                    if field_is_optional(field) {
335                        Some(quote! {
336                            model.#ident = ActiveValue::Set(Some((#expr).into()));
337                        })
338                    } else {
339                        Some(quote! {
340                            model.#ident = ActiveValue::Set((#expr).into());
341                        })
342                    }
343                } else {
344                    None
345                }
346            } else {
347                None
348            }
349        })
350        .collect();
351
352    let expanded = quote! {
353        #[derive(Serialize, Deserialize, ToSchema, Copy, Clone)]
354        pub struct #update_name {
355            #(#update_struct_fields),*
356        }
357
358        impl #update_name {
359            pub fn merge_into_activemodel(self, mut model: #active_model_type) -> #active_model_type {
360                #(#included_merge)*
361                #(#excluded_merge)*
362                model
363            }
364        }
365    };
366
367    TokenStream::from(expanded)
368}