1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{
5 parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Fields, Lit, Meta,
6};
7
8fn field_is_optional(field: &syn::Field) -> bool {
10 if let syn::Type::Path(type_path) = &field.ty {
11 type_path
12 .path
13 .segments
14 .first()
15 .map(|seg| seg.ident == "Option")
16 .unwrap_or(false)
17 } else {
18 false
19 }
20}
21
22fn get_crudcrate_bool(field: &syn::Field, key: &str) -> Option<bool> {
26 for attr in &field.attrs {
27 if attr.path().is_ident("crudcrate") {
28 if let Meta::List(meta_list) = &attr.meta {
29 let metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
30 .parse2(meta_list.tokens.clone())
31 .ok()?;
32 for meta in metas.iter() {
33 if let Meta::NameValue(nv) = meta {
34 if nv.path.is_ident(key) {
35 if let syn::Expr::Lit(expr_lit) = &nv.value {
36 if let Lit::Bool(b) = &expr_lit.lit {
37 return Some(b.value);
38 }
39 }
40 }
41 }
42 }
43 }
44 }
45 }
46 None
47}
48
49fn get_crudcrate_expr(field: &syn::Field, key: &str) -> Option<syn::Expr> {
52 for attr in &field.attrs {
53 if attr.path().is_ident("crudcrate") {
54 if let Meta::List(meta_list) = &attr.meta {
55 let metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
56 .parse2(meta_list.tokens.clone())
57 .ok()?;
58 for meta in metas.iter() {
59 if let Meta::NameValue(nv) = meta {
60 if nv.path.is_ident(key) {
61 return Some(nv.value.clone());
62 }
63 }
64 }
65 }
66 }
67 }
68 None
69}
70
71fn get_string_from_attr(attr: &syn::Attribute) -> Option<String> {
74 if let Meta::NameValue(nv) = &attr.meta {
75 if let syn::Expr::Lit(expr_lit) = &nv.value {
76 if let Lit::Str(s) = &expr_lit.lit {
77 return Some(s.value());
78 }
79 }
80 }
81 None
82}
83
84#[proc_macro_derive(ToCreateModel, attributes(crudcrate))]
100pub fn to_create_model(input: TokenStream) -> TokenStream {
101 let input = parse_macro_input!(input as DeriveInput);
102 let name = input.ident;
103 let create_name = format_ident!("{}Create", name);
104
105 let mut active_model_override = None;
107 for attr in &input.attrs {
108 if attr.path().is_ident("active_model") {
109 if let Some(s) = get_string_from_attr(attr) {
110 active_model_override =
111 Some(syn::parse_str::<syn::Type>(&s).expect("Invalid active_model type"));
112 }
113 }
114 }
115 let active_model_type = if let Some(ty) = active_model_override {
116 quote! { #ty }
117 } else {
118 let ident = format_ident!("{}ActiveModel", name);
119 quote! { #ident }
120 };
121
122 let fields = if let Data::Struct(data) = input.data {
124 if let Fields::Named(named) = data.fields {
125 named.named
126 } else {
127 panic!("ToCreateModel only supports structs with named fields");
128 }
129 } else {
130 panic!("ToCreateModel can only be derived for structs");
131 };
132
133 let create_struct_fields = fields
137 .iter()
138 .filter(|field| get_crudcrate_bool(field, "create_model").unwrap_or(true))
139 .map(|field| {
140 let ident = &field.ident;
141 let ty = &field.ty;
142 if get_crudcrate_expr(field, "on_create").is_some() {
143 quote! {
144 #[serde(default)]
145 pub #ident: Option<#ty>
146 }
147 } else {
148 quote! {
149 pub #ident: #ty
150 }
151 }
152 });
153
154 let mut conv_lines = Vec::new();
156 for field in fields.iter() {
157 let ident = field.ident.as_ref().unwrap();
158 let include = get_crudcrate_bool(field, "create_model").unwrap_or(true);
159 let is_optional = field_is_optional(field);
160 if include {
161 if let Some(expr) = get_crudcrate_expr(field, "on_create") {
162 conv_lines.push(quote! {
166 #ident: ActiveValue::Set(match create.#ident {
167 Some(val) => val,
168 None => (#expr).into(),
169 })
170 });
171 } else {
172 conv_lines.push(quote! {
174 #ident: ActiveValue::Set(create.#ident)
175 });
176 }
177 } else if let Some(expr) = get_crudcrate_expr(field, "on_create") {
178 if is_optional {
180 conv_lines.push(quote! {
181 #ident: ActiveValue::Set(Some((#expr).into()))
182 });
183 } else {
184 conv_lines.push(quote! {
185 #ident: ActiveValue::Set((#expr).into())
186 });
187 }
188 }
189 }
190
191 let expanded = quote! {
192 #[derive(Serialize, Deserialize, ToSchema, Copy, Clone)]
193 pub struct #create_name {
194 #(#create_struct_fields),*
195 }
196
197 impl From<#create_name> for #active_model_type {
198 fn from(create: #create_name) -> Self {
199 #active_model_type {
200 #(#conv_lines),*
201 }
202 }
203 }
204 };
205
206 TokenStream::from(expanded)
207}
208
209#[proc_macro_derive(ToUpdateModel, attributes(crudcrate, active_model))]
224pub fn to_update_model(input: TokenStream) -> TokenStream {
225 let input = parse_macro_input!(input as DeriveInput);
226 let name = input.ident;
227 let update_name = format_ident!("{}Update", name);
228
229 let mut active_model_override = None;
231 for attr in &input.attrs {
232 if attr.path().is_ident("active_model") {
233 if let Some(s) = get_string_from_attr(attr) {
234 active_model_override =
235 Some(syn::parse_str::<syn::Type>(&s).expect("Invalid active_model type"));
236 }
237 }
238 }
239 let active_model_type = if let Some(ty) = active_model_override {
240 quote! { #ty }
241 } else {
242 let ident = format_ident!("{}ActiveModel", name);
243 quote! { #ident }
244 };
245
246 let fields = if let Data::Struct(data) = input.data {
248 if let Fields::Named(named) = data.fields {
249 named.named
250 } else {
251 panic!("ToUpdateModel only supports structs with named fields");
252 }
253 } else {
254 panic!("ToUpdateModel can only be derived for structs");
255 };
256
257 let included_fields: Vec<_> = fields
259 .iter()
260 .filter(|field| get_crudcrate_bool(field, "update_model").unwrap_or(true))
261 .collect();
262
263 let update_struct_fields = included_fields.iter().map(|field| {
264 let ident = &field.ident;
265 let ty = &field.ty;
266 let (_is_option, inner_ty) = if let syn::Type::Path(type_path) = ty {
268 if let Some(seg) = type_path.path.segments.first() {
269 if seg.ident == "Option" {
270 if let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments {
271 if let Some(syn::GenericArgument::Type(inner_ty)) = inner_args.args.first()
272 {
273 (true, inner_ty.clone())
274 } else {
275 (false, ty.clone())
276 }
277 } else {
278 (false, ty.clone())
279 }
280 } else {
281 (false, ty.clone())
282 }
283 } else {
284 (false, ty.clone())
285 }
286 } else {
287 (false, ty.clone())
288 };
289 quote! {
290 #[serde(
291 default,
292 skip_serializing_if = "Option::is_none",
293 with = "::serde_with::rust::double_option"
294 )]
295 pub #ident: Option<Option<#inner_ty>>
296 }
297 });
298
299 let included_merge: Vec<_> = included_fields
301 .iter()
302 .map(|field| {
303 let ident = &field.ident;
304 let is_optional = field_is_optional(field);
305 if is_optional {
306 quote! {
308 model.#ident = match self.#ident {
309 Some(Some(value)) => ActiveValue::Set(Some(value)),
310 Some(_) => ActiveValue::NotSet,
311 _ => ActiveValue::NotSet,
312 };
313 }
314 } else {
315 quote! {
316 model.#ident = match self.#ident {
317 Some(Some(value)) => ActiveValue::Set(value),
318 Some(_) => ActiveValue::NotSet,
319 _ => ActiveValue::NotSet,
320 };
321 }
322 }
323 })
324 .collect();
325
326 let excluded_merge: Vec<_> = fields
329 .iter()
330 .filter_map(|field| {
331 if get_crudcrate_bool(field, "update_model") == Some(false) {
332 if let Some(expr) = get_crudcrate_expr(field, "on_update") {
333 let ident = &field.ident;
334 if field_is_optional(field) {
335 Some(quote! {
336 model.#ident = ActiveValue::Set(Some((#expr).into()));
337 })
338 } else {
339 Some(quote! {
340 model.#ident = ActiveValue::Set((#expr).into());
341 })
342 }
343 } else {
344 None
345 }
346 } else {
347 None
348 }
349 })
350 .collect();
351
352 let expanded = quote! {
353 #[derive(Serialize, Deserialize, ToSchema, Copy, Clone)]
354 pub struct #update_name {
355 #(#update_struct_fields),*
356 }
357
358 impl #update_name {
359 pub fn merge_into_activemodel(self, mut model: #active_model_type) -> #active_model_type {
360 #(#included_merge)*
361 #(#excluded_merge)*
362 model
363 }
364 }
365 };
366
367 TokenStream::from(expanded)
368}