Skip to main content

dbkit_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{parse_macro_input, Attribute, Field, Fields, Ident, ItemStruct, Meta, Type};
5
6#[proc_macro_derive(Model, attributes(model, key, autoincrement, unique, index, has_many, belongs_to, many_to_many))]
7pub fn derive_model(_input: TokenStream) -> TokenStream {
8    TokenStream::from(quote! {
9        compile_error!("dbkit: use #[model] instead of #[derive(Model)]");
10    })
11}
12
13#[proc_macro_derive(DbEnum, attributes(dbkit))]
14pub fn derive_db_enum(input: TokenStream) -> TokenStream {
15    let input = parse_macro_input!(input as syn::ItemEnum);
16    match expand_db_enum(input) {
17        Ok(tokens) => tokens,
18        Err(err) => err.to_compile_error().into(),
19    }
20}
21
22#[proc_macro_attribute]
23pub fn model(attr: TokenStream, item: TokenStream) -> TokenStream {
24    let input = parse_macro_input!(item as ItemStruct);
25    let args = parse_macro_input!(attr with syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated);
26    let args = parse_model_args(args);
27    match expand_model(args, input) {
28        Ok(tokens) => tokens,
29        Err(err) => err.to_compile_error().into(),
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34enum RelationKind {
35    HasMany,
36    BelongsTo,
37    ManyToMany,
38}
39
40struct RelationInfo {
41    field: Field,
42    param_ident: Ident,
43    state_mod_ident: Ident,
44    child_type: Type,
45    kind: RelationKind,
46    belongs_to_key: Option<Ident>,
47    belongs_to_ref: Option<Ident>,
48    many_to_many_through: Option<Ident>,
49    many_to_many_left_key: Option<Ident>,
50    many_to_many_right_key: Option<Ident>,
51}
52
53struct ScalarFieldInfo {
54    field: Field,
55    ident: Ident,
56    ty: Type,
57    is_key: bool,
58    is_autoincrement: bool,
59}
60
61#[derive(Default)]
62struct ModelArgs {
63    table: Option<String>,
64    schema: Option<String>,
65}
66
67fn expand_model(args: ModelArgs, input: ItemStruct) -> syn::Result<TokenStream> {
68    if !input.generics.params.is_empty() {
69        return Err(syn::Error::new_spanned(
70            input.generics,
71            "dbkit: #[model] does not support generics yet",
72        ));
73    }
74
75    let struct_ident = input.ident;
76    let model_ident = format_ident!("{}Model", struct_ident);
77    let insert_ident = format_ident!("{}Insert", struct_ident);
78    let vis = input.vis;
79
80    let table_name = args.table.unwrap_or_else(|| to_snake_case(&struct_ident.to_string()));
81    let schema_name = args.schema;
82
83    let mut primary_keys: Vec<(Ident, Type)> = Vec::new();
84    let mut relation_fields = Vec::new();
85    let mut output_fields = Vec::new();
86    let mut insert_fields = Vec::new();
87    let mut scalar_fields = Vec::new();
88
89    let struct_attrs = filter_struct_attrs(&input.attrs);
90
91    let fields = match input.fields {
92        Fields::Named(named) => named.named,
93        _ => {
94            return Err(syn::Error::new_spanned(
95                struct_ident,
96                "dbkit: #[model] requires a struct with named fields",
97            ))
98        }
99    };
100
101    for field in fields {
102        let field_ident = field
103            .ident
104            .clone()
105            .ok_or_else(|| syn::Error::new_spanned(&field, "dbkit: unnamed field"))?;
106
107        let is_relation =
108            has_attr(&field.attrs, "has_many") || has_attr(&field.attrs, "belongs_to") || has_attr(&field.attrs, "many_to_many");
109
110        let is_key = has_attr(&field.attrs, "key");
111        let is_autoincrement = has_attr(&field.attrs, "autoincrement");
112
113        if is_key {
114            primary_keys.push((field_ident.clone(), field.ty.clone()));
115        }
116
117        if is_relation {
118            let (kind, child_type) = relation_type(&field)?;
119            let state_mod_ident = format_ident!("{}_{}_state", to_snake_case(&struct_ident.to_string()), field_ident);
120            let param_ident = format_ident!("{}Rel", to_camel_case(&field_ident.to_string()));
121            let (belongs_to_key, belongs_to_ref) = if kind == RelationKind::BelongsTo {
122                let (key, references) = parse_belongs_to_args(&field.attrs)?;
123                (Some(key), Some(references))
124            } else {
125                (None, None)
126            };
127            let (many_to_many_through, many_to_many_left_key, many_to_many_right_key) = if kind == RelationKind::ManyToMany {
128                let (through, left_key, right_key) = parse_many_to_many_args(&field.attrs)?;
129                (Some(through), Some(left_key), Some(right_key))
130            } else {
131                (None, None, None)
132            };
133
134            relation_fields.push(RelationInfo {
135                field: field.clone(),
136                param_ident: param_ident.clone(),
137                state_mod_ident,
138                child_type,
139                kind,
140                belongs_to_key,
141                belongs_to_ref,
142                many_to_many_through,
143                many_to_many_left_key,
144                many_to_many_right_key,
145            });
146
147            let cleaned_field = Field {
148                attrs: filter_field_attrs(&field.attrs),
149                ty: syn::parse_quote!(#param_ident),
150                ..field
151            };
152            output_fields.push(cleaned_field);
153            continue;
154        }
155
156        let cleaned_field = Field {
157            attrs: filter_field_attrs(&field.attrs),
158            ..field.clone()
159        };
160        output_fields.push(cleaned_field.clone());
161
162        if !(is_key && is_autoincrement) {
163            insert_fields.push(cleaned_field.clone());
164        }
165
166        scalar_fields.push(ScalarFieldInfo {
167            field: cleaned_field,
168            ident: field_ident,
169            ty: field.ty.clone(),
170            is_key,
171            is_autoincrement,
172        });
173    }
174
175    let table_expr = if let Some(schema) = schema_name {
176        quote!(::dbkit::Table::new(#table_name).with_schema(#schema))
177    } else {
178        quote!(::dbkit::Table::new(#table_name))
179    };
180
181    if relation_fields.iter().any(|rel| rel.kind == RelationKind::ManyToMany) && primary_keys.len() != 1 {
182        return Err(syn::Error::new_spanned(
183            struct_ident,
184            "dbkit: many-to-many requires exactly one #[key] on the parent model",
185        ));
186    }
187
188    let generics_with_defaults = relation_fields
189        .iter()
190        .map(|rel| {
191            let ident = &rel.param_ident;
192            let state_mod = &rel.state_mod_ident;
193            quote!(#ident: #state_mod::State = ::dbkit::NotLoaded)
194        })
195        .collect::<Vec<_>>();
196
197    let impl_generics_params = relation_fields
198        .iter()
199        .map(|rel| {
200            let ident = &rel.param_ident;
201            let state_mod = &rel.state_mod_ident;
202            quote!(#ident: #state_mod::State)
203        })
204        .collect::<Vec<_>>();
205
206    let generic_idents = relation_fields.iter().map(|rel| &rel.param_ident).collect::<Vec<_>>();
207
208    let struct_generics = if generics_with_defaults.is_empty() {
209        quote!()
210    } else {
211        quote!(<#(#generics_with_defaults),*>)
212    };
213
214    let impl_generics = if impl_generics_params.is_empty() {
215        quote!()
216    } else {
217        quote!(<#(#impl_generics_params),*>)
218    };
219
220    let struct_type_args = if generic_idents.is_empty() {
221        quote!()
222    } else {
223        quote!(<#(#generic_idents),*>)
224    };
225
226    let columns = output_fields
227        .iter()
228        .filter(|field| !is_relation_field(field, &relation_fields))
229        .map(|field| {
230            let ident = field.ident.as_ref().expect("field ident");
231            let name = ident.to_string();
232            let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
233            quote!(pub const #ident: ::dbkit::Column<#struct_ident, #ty> = ::dbkit::Column::new(Self::TABLE, #name);)
234        })
235        .collect::<Vec<_>>();
236
237    let column_refs = output_fields
238        .iter()
239        .filter(|field| !is_relation_field(field, &relation_fields))
240        .map(|field| {
241            let ident = field.ident.as_ref().expect("field ident");
242            quote!(Self::#ident.as_ref())
243        })
244        .collect::<Vec<_>>();
245
246    let columns_const = quote!(
247        pub const COLUMNS: &'static [::dbkit::ColumnRef] = &[#(#column_refs),*];
248    );
249
250    let primary_key_refs = primary_keys
251        .iter()
252        .map(|(ident, _)| quote!(Self::#ident.as_ref()))
253        .collect::<Vec<_>>();
254
255    let primary_keys_const = if primary_keys.is_empty() {
256        quote!(
257            pub const PRIMARY_KEYS: &'static [::dbkit::ColumnRef] = &[];
258        )
259    } else {
260        quote!(pub const PRIMARY_KEYS: &'static [::dbkit::ColumnRef] = &[#(#primary_key_refs),*];)
261    };
262
263    let insert_values = insert_fields.iter().map(|field| {
264        let ident = field.ident.as_ref().expect("field ident");
265        quote!(insert = insert.value(Self::#ident, values.#ident);)
266    });
267    let insert_field_idents = insert_fields
268        .iter()
269        .map(|field| field.ident.as_ref().expect("field ident"))
270        .collect::<Vec<_>>();
271
272    let active_ident = format_ident!("{}Active", struct_ident);
273
274    let active_fields = scalar_fields.iter().map(|field| {
275        let ident = &field.ident;
276        let vis = &field.field.vis;
277        let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
278        quote!(#vis #ident: ::dbkit::ActiveValue<#ty>)
279    });
280
281    let active_from_model = scalar_fields.iter().map(|field| {
282        let ident = &field.ident;
283        if option_inner_type(&field.ty).is_some() {
284            quote!(#ident: ::dbkit::ActiveValue::unchanged_option(#ident))
285        } else {
286            quote!(#ident: ::dbkit::ActiveValue::unchanged(#ident))
287        }
288    });
289
290    let active_destructure = scalar_fields.iter().map(|field| field.ident.clone()).collect::<Vec<_>>();
291
292    let active_insert_steps = scalar_fields.iter().map(|field| {
293        let ident = &field.ident;
294        let name = ident.to_string();
295        let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
296        let is_option = option_inner_type(&field.ty).is_some();
297        let required = !field.is_autoincrement && !is_option;
298        let required_check = if required {
299            quote!(return Err(::dbkit::Error::Decode(format!("missing required field: {}", #name)));)
300        } else {
301            quote!()
302        };
303        quote!(
304            match #ident {
305                ::dbkit::ActiveValue::Unset => {
306                    #required_check
307                }
308                ::dbkit::ActiveValue::Set(value) => {
309                    insert = insert.value(#struct_ident::#ident, value);
310                }
311                ::dbkit::ActiveValue::Unchanged(value) => {
312                    insert = insert.value(#struct_ident::#ident, value);
313                }
314                ::dbkit::ActiveValue::UnchangedNull => {
315                    insert = insert.value(#struct_ident::#ident, None::<#ty>);
316                }
317                ::dbkit::ActiveValue::Null => {
318                    insert = insert.value(#struct_ident::#ident, None::<#ty>);
319                }
320            }
321        )
322    });
323
324    let active_insert_fn = quote!(
325        pub async fn insert(
326            self,
327            ex: &(impl ::dbkit::Executor + Send + Sync),
328        ) -> Result<#struct_ident, ::dbkit::Error> {
329            let Self { #(#active_destructure,)* } = self;
330            let mut insert = ::dbkit::Insert::new(#struct_ident::TABLE);
331            #(#active_insert_steps)*
332            let insert = insert.returning_all();
333            let row = ::dbkit::InsertExt::one(insert, ex).await?;
334            row.ok_or(::dbkit::Error::NotFound)
335        }
336    );
337
338    let pk_idents = primary_keys.iter().map(|(ident, _)| ident.clone()).collect::<Vec<_>>();
339
340    let active_update_fn = if !primary_keys.is_empty() {
341        let pk_vars = primary_keys
342            .iter()
343            .enumerate()
344            .map(|(idx, _)| format_ident!("pk_value_{}", idx))
345            .collect::<Vec<_>>();
346        let pk_extracts = primary_keys.iter().zip(pk_vars.iter()).map(|((ident, _), var)| {
347            let pk_name = ident.to_string();
348            quote!(
349                let #var = match #ident {
350                    ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
351                    ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::Unset | ::dbkit::ActiveValue::UnchangedNull => {
352                        return Err(::dbkit::Error::Decode(format!(
353                            "missing required field: {}",
354                            #pk_name
355                        )));
356                    }
357                };
358            )
359        });
360        let pk_filters = primary_keys
361            .iter()
362            .zip(pk_vars.iter())
363            .map(|((ident, _), var)| quote!(update = update.filter(#struct_ident::#ident.eq(#var));));
364        let update_steps = scalar_fields.iter().filter(|field| !field.is_key).map(|field| {
365            let ident = &field.ident;
366            let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
367            quote!(
368                match #ident {
369                    ::dbkit::ActiveValue::Unset => {}
370                    ::dbkit::ActiveValue::Set(value) => {
371                        update = update.set(#struct_ident::#ident, value);
372                        any_set = true;
373                    }
374                    ::dbkit::ActiveValue::Unchanged(_) | ::dbkit::ActiveValue::UnchangedNull => {}
375                    ::dbkit::ActiveValue::Null => {
376                        update = update.set(#struct_ident::#ident, None::<#ty>);
377                        any_set = true;
378                    }
379                }
380            )
381        });
382        quote!(
383            pub async fn update(
384                self,
385                ex: &(impl ::dbkit::Executor + Send + Sync),
386            ) -> Result<#struct_ident, ::dbkit::Error> {
387                let Self { #(#active_destructure,)* } = self;
388                #(#pk_extracts)*
389                let mut update = ::dbkit::Update::new(#struct_ident::TABLE);
390                let mut any_set = false;
391                #(#update_steps)*
392                if !any_set {
393                    return Err(::dbkit::Error::Decode("no fields set for update".to_string()));
394                }
395                #(#pk_filters)*
396                let update = update.returning_all();
397                let mut rows = ::dbkit::UpdateExt::all(update, ex).await?;
398                rows.pop().ok_or(::dbkit::Error::NotFound)
399            }
400        )
401    } else {
402        quote!()
403    };
404
405    let active_delete_fn = if !primary_keys.is_empty() {
406        let pk_vars = primary_keys
407            .iter()
408            .enumerate()
409            .map(|(idx, _)| format_ident!("pk_value_{}", idx))
410            .collect::<Vec<_>>();
411        let pk_extracts = primary_keys.iter().zip(pk_vars.iter()).map(|((ident, _), var)| {
412            let pk_name = ident.to_string();
413            quote!(
414                let #var = match #ident {
415                    ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
416                    ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::Unset | ::dbkit::ActiveValue::UnchangedNull => {
417                        return Err(::dbkit::Error::Decode(format!(
418                            "missing required field: {}",
419                            #pk_name
420                        )));
421                    }
422                };
423            )
424        });
425        let pk_filters = primary_keys
426            .iter()
427            .zip(pk_vars.iter())
428            .map(|((ident, _), var)| quote!(delete = delete.filter(#struct_ident::#ident.eq(#var));));
429        quote!(
430            pub async fn delete(
431                self,
432                ex: &(impl ::dbkit::Executor + Send + Sync),
433            ) -> Result<u64, ::dbkit::Error> {
434                let Self { #(#pk_idents,)* .. } = self;
435                #(#pk_extracts)*
436                let mut delete = ::dbkit::Delete::new(#struct_ident::TABLE);
437                #(#pk_filters)*
438                ::dbkit::DeleteExt::execute(delete, ex).await
439            }
440        )
441    } else {
442        quote!()
443    };
444
445    let active_save_flag_checks = scalar_fields.iter().map(|field| {
446        let ident = &field.ident;
447        quote!(
448            match &#ident {
449                ::dbkit::ActiveValue::Unchanged(_) | ::dbkit::ActiveValue::UnchangedNull => {
450                    any_loaded = true;
451                }
452                ::dbkit::ActiveValue::Set(_) | ::dbkit::ActiveValue::Null => {
453                    any_changed = true;
454                }
455                ::dbkit::ActiveValue::Unset => {}
456            }
457        )
458    });
459
460    let active_save_model_fields = scalar_fields.iter().map(|field| {
461        let ident = &field.ident;
462        let name = ident.to_string();
463        if option_inner_type(&field.ty).is_some() {
464            quote!(
465                #ident: match #ident {
466                    ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => Some(value),
467                    ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::UnchangedNull => None,
468                    ::dbkit::ActiveValue::Unset => {
469                        return Err(::dbkit::Error::Decode(format!(
470                            "missing required field: {}",
471                            #name
472                        )));
473                    }
474                },
475            )
476        } else {
477            quote!(
478                #ident: match #ident {
479                    ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
480                    ::dbkit::ActiveValue::Null
481                    | ::dbkit::ActiveValue::Unset
482                    | ::dbkit::ActiveValue::UnchangedNull => {
483                        return Err(::dbkit::Error::Decode(format!(
484                            "missing required field: {}",
485                            #name
486                        )));
487                    }
488                },
489            )
490        }
491    });
492
493    let active_save_relation_defaults = relation_fields.iter().map(|rel| {
494        let ident = rel.field.ident.as_ref().expect("field ident");
495        quote!(#ident: Default::default(),)
496    });
497
498    let active_save_update_branch = if !primary_keys.is_empty() {
499        quote!(return Self { #(#active_destructure,)* }.update(ex).await;)
500    } else {
501        quote!(
502            return Err(::dbkit::Error::Decode(
503                "update requires primary key".to_string(),
504            ));
505        )
506    };
507
508    let active_save_fn = quote!(
509        pub async fn save(
510            self,
511            ex: &(impl ::dbkit::Executor + Send + Sync),
512        ) -> Result<#struct_ident, ::dbkit::Error> {
513            let Self { #(#active_destructure,)* } = self;
514            let mut any_loaded = false;
515            let mut any_changed = false;
516            #(#active_save_flag_checks)*
517
518            if any_loaded {
519                if any_changed {
520                    #active_save_update_branch
521                }
522                let model = #struct_ident {
523                    #(#active_save_model_fields)*
524                    #(#active_save_relation_defaults)*
525                };
526                return Ok(model);
527            }
528
529            Self { #(#active_destructure,)* }.insert(ex).await
530        }
531    );
532
533    let model_delete_impl = if !primary_keys.is_empty() {
534        let pk_filters = primary_keys
535            .iter()
536            .map(|(ident, _)| quote!(delete = delete.filter(Self::#ident.eq(#ident));));
537        quote!(
538            impl #impl_generics ::dbkit::ModelDelete for #model_ident #struct_type_args {
539                fn delete<'e, E>(self, ex: &'e E) -> ::dbkit::executor::BoxFuture<'e, Result<u64, ::dbkit::Error>>
540                where
541                    E: ::dbkit::Executor + Send + Sync + 'e,
542                {
543                    let Self { #(#pk_idents,)* .. } = self;
544                    let mut delete = ::dbkit::Delete::new(Self::TABLE);
545                    #(#pk_filters)*
546                    ::dbkit::DeleteExt::execute(delete, ex)
547                }
548            }
549        )
550    } else {
551        quote!()
552    };
553
554    let into_active_fn = quote!(
555        pub fn into_active(self) -> #active_ident {
556            let Self { #(#active_destructure,)* .. } = self;
557            #active_ident {
558                #(#active_from_model,)*
559            }
560        }
561    );
562
563    let primary_key_const = if primary_keys.len() == 1 {
564        let (ident, ty) = primary_keys.first().expect("primary key length checked");
565        let name = ident.to_string();
566        Some(quote!(pub const PRIMARY_KEY: ::dbkit::Column<#struct_ident, #ty> = ::dbkit::Column::new(Self::TABLE, #name);))
567    } else {
568        None
569    };
570
571    let by_id_fn = if primary_keys.len() == 1 {
572        let (ident, ty) = primary_keys.first().expect("primary key length checked");
573        Some(quote!(
574            pub fn by_id(id: #ty) -> ::dbkit::Select<#struct_ident> {
575                Self::query().filter(Self::#ident.eq(id)).limit(1)
576            }
577        ))
578    } else {
579        None
580    };
581
582    let any_state_ident = format_ident!("{}AnyState", struct_ident);
583
584    let relation_state_modules = relation_fields.iter().map(|rel| {
585        let state_mod = &rel.state_mod_ident;
586        let (sealed_impl, state_impl) = match rel.kind {
587            RelationKind::HasMany | RelationKind::ManyToMany => (
588                quote!(
589                    impl<T> Sealed for Vec<T> {}
590                ),
591                quote!(
592                    impl<T> State for Vec<T> {}
593                ),
594            ),
595            RelationKind::BelongsTo => (
596                quote!(
597                    impl<T> Sealed for Option<T> {}
598                ),
599                quote!(
600                    impl<T> State for Option<T> {}
601                ),
602            ),
603        };
604        quote!(
605            pub mod #state_mod {
606                mod sealed {
607                    pub trait Sealed {}
608                    impl Sealed for ::dbkit::NotLoaded {}
609                    #sealed_impl
610                }
611                pub trait State: sealed::Sealed {}
612                impl State for ::dbkit::NotLoaded {}
613                #state_impl
614            }
615        )
616    });
617
618    let relation_methods = relation_fields.iter().map(|rel| {
619        let field_ident = rel.field.ident.as_ref().expect("field ident");
620        let method_ident = format_ident!("{}_loaded", field_ident);
621        let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
622        let loaded_type: Type = match rel.kind {
623            RelationKind::HasMany | RelationKind::ManyToMany => syn::parse_quote!(Vec<#item_ident>),
624            RelationKind::BelongsTo => syn::parse_quote!(Option<#item_ident>),
625        };
626
627        let mut other_params = Vec::new();
628        let mut type_params = Vec::new();
629        for other in &relation_fields {
630            if other.field.ident == rel.field.ident {
631                type_params.push(quote!(#loaded_type));
632            } else {
633                let ident = &other.param_ident;
634                let state_mod = &other.state_mod_ident;
635                other_params.push(quote!(#ident: #state_mod::State));
636                type_params.push(quote!(#ident));
637            }
638        }
639
640        let mut impl_params = Vec::new();
641        impl_params.push(quote!(#item_ident));
642        impl_params.extend(other_params);
643
644        let impl_generics = if impl_params.is_empty() {
645            quote!()
646        } else {
647            quote!(<#(#impl_params),*>)
648        };
649        let type_args = if type_params.is_empty() {
650            quote!()
651        } else {
652            quote!(<#(#type_params),*>)
653        };
654
655        let (return_ty, body) = match rel.kind {
656            RelationKind::HasMany | RelationKind::ManyToMany => (quote!(&[#item_ident]), quote!(&self.#field_ident)),
657            RelationKind::BelongsTo => (quote!(Option<&#item_ident>), quote!(self.#field_ident.as_ref())),
658        };
659
660        quote!(
661            impl #impl_generics #model_ident #type_args {
662                pub fn #method_ident(&self) -> #return_ty {
663                    #body
664                }
665            }
666        )
667    });
668
669    let model_value_arms = output_fields
670        .iter()
671        .filter(|field| !is_relation_field(field, &relation_fields))
672        .map(|field| {
673            let ident = field.ident.as_ref().expect("field ident");
674            let name = ident.to_string();
675            quote!(#name => Some(self.#ident.clone().into()),)
676        });
677
678    let model_value_impl = quote!(
679        impl #impl_generics ::dbkit::ModelValue for #model_ident #struct_type_args {
680            fn column_value(&self, column: ::dbkit::ColumnRef) -> Option<::dbkit::Value> {
681                if column.table.name != Self::TABLE.name {
682                    return None;
683                }
684                match column.name {
685                    #(#model_value_arms)*
686                    _ => None,
687                }
688            }
689        }
690    );
691
692    let from_row_generics = relation_fields.iter().map(|rel| {
693        let ident = &rel.param_ident;
694        let state_mod = &rel.state_mod_ident;
695        quote!(#ident: #state_mod::State + Default)
696    });
697
698    let from_row_impl_generics = if relation_fields.is_empty() {
699        quote!(<'r>)
700    } else {
701        quote!(<'r, #(#from_row_generics),*>)
702    };
703
704    let from_row_fields = output_fields.iter().map(|field| {
705        let ident = field.ident.as_ref().expect("field ident");
706        if is_relation_field(field, &relation_fields) {
707            quote!(#ident: Default::default())
708        } else {
709            let name = ident.to_string();
710            quote!(#ident: ::dbkit::sqlx::Row::try_get(row, #name)?)
711        }
712    });
713
714    let from_row_impl = quote!(
715        impl #from_row_impl_generics ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow>
716            for #model_ident #struct_type_args
717        {
718            fn from_row(row: &'r ::dbkit::sqlx::postgres::PgRow) -> Result<Self, ::dbkit::sqlx::Error> {
719                Ok(Self {
720                    #(#from_row_fields,)*
721                })
722            }
723        }
724    );
725
726    let joined_from_row_fields = output_fields.iter().map(|field| {
727        let ident = field.ident.as_ref().expect("field ident");
728        if is_relation_field(field, &relation_fields) {
729            quote!(#ident: Default::default())
730        } else {
731            let name = ident.to_string();
732            quote!(
733                #ident: {
734                    let column = format!("{}{}", prefix, #name);
735                    ::dbkit::sqlx::Row::try_get(row, column.as_str())?
736                }
737            )
738        }
739    });
740
741    let joined_pk_checks = if primary_keys.is_empty() {
742        if let Some(first_field) = scalar_fields.first() {
743            let name = first_field.ident.to_string();
744            let ty = option_inner_type(&first_field.ty).unwrap_or_else(|| first_field.ty.clone());
745            quote!(
746                let value: Option<#ty> = {
747                    let column = format!("{}{}", prefix, #name);
748                    ::dbkit::sqlx::Row::try_get(row, column.as_str())?
749                };
750                Ok(value.is_some())
751            )
752        } else {
753            quote!(Ok(false))
754        }
755    } else {
756        let checks = primary_keys.iter().map(|(ident, ty)| {
757            let name = ident.to_string();
758            let ty = option_inner_type(ty).unwrap_or_else(|| ty.clone());
759            quote!(
760                let value: Option<#ty> = {
761                    let column = format!("{}{}", prefix, #name);
762                    ::dbkit::sqlx::Row::try_get(row, column.as_str())?
763                };
764                if value.is_some() {
765                    return Ok(true);
766                }
767            )
768        });
769        quote!(
770            #(#checks)*
771            Ok(false)
772        )
773    };
774
775    let joined_model_impl = quote!(
776        impl #from_row_impl_generics ::dbkit::JoinedModel for #model_ident #struct_type_args {
777            fn joined_columns() -> &'static [::dbkit::ColumnRef] {
778                Self::COLUMNS
779            }
780
781            fn joined_primary_keys() -> &'static [::dbkit::ColumnRef] {
782                Self::PRIMARY_KEYS
783            }
784
785            fn joined_from_row_prefixed(
786                row: &::dbkit::sqlx::postgres::PgRow,
787                prefix: &str,
788            ) -> Result<Self, ::dbkit::sqlx::Error> {
789                Ok(Self {
790                    #(#joined_from_row_fields,)*
791                })
792            }
793
794            fn joined_row_has_pk(
795                row: &::dbkit::sqlx::postgres::PgRow,
796                prefix: &str,
797            ) -> Result<bool, ::dbkit::sqlx::Error> {
798                #joined_pk_checks
799            }
800        }
801    );
802
803    let set_relation_impls = relation_fields.iter().map(|rel| {
804        let field_ident = rel.field.ident.as_ref().expect("field ident");
805        let child_type = &rel.child_type;
806        let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
807        let (value_ty, rel_ty) = match rel.kind {
808            RelationKind::HasMany => (quote!(Vec<#item_ident>), quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>)),
809            RelationKind::ManyToMany => {
810                let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
811                (
812                    quote!(Vec<#item_ident>),
813                    quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>),
814                )
815            }
816            RelationKind::BelongsTo => (
817                quote!(Option<#item_ident>),
818                quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
819            ),
820        };
821
822        let mut other_params = Vec::new();
823        let mut type_params = Vec::new();
824        for other in &relation_fields {
825            if other.field.ident == rel.field.ident {
826                type_params.push(value_ty.clone());
827            } else {
828                let ident = &other.param_ident;
829                let state_mod = &other.state_mod_ident;
830                other_params.push(quote!(#ident: #state_mod::State));
831                type_params.push(quote!(#ident));
832            }
833        }
834
835        let mut impl_params = Vec::new();
836        impl_params.push(quote!(#item_ident));
837        impl_params.extend(other_params);
838
839        let impl_generics = if impl_params.is_empty() {
840            quote!()
841        } else {
842            quote!(<#(#impl_params),*>)
843        };
844        let type_args = if type_params.is_empty() {
845            quote!()
846        } else {
847            quote!(<#(#type_params),*>)
848        };
849
850        quote!(
851            impl #impl_generics ::dbkit::SetRelation<#rel_ty, #value_ty> for #model_ident #type_args {
852                fn set_relation(&mut self, _rel: #rel_ty, value: #value_ty) -> Result<(), ::dbkit::Error> {
853                    self.#field_ident = value;
854                    Ok(())
855                }
856            }
857        )
858    });
859
860    let get_relation_impls = relation_fields.iter().map(|rel| {
861        let field_ident = rel.field.ident.as_ref().expect("field ident");
862        let child_type = &rel.child_type;
863        let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
864        let (value_ty, rel_ty) = match rel.kind {
865            RelationKind::HasMany => (quote!(Vec<#item_ident>), quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>)),
866            RelationKind::ManyToMany => {
867                let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
868                (
869                    quote!(Vec<#item_ident>),
870                    quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>),
871                )
872            }
873            RelationKind::BelongsTo => (
874                quote!(Option<#item_ident>),
875                quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
876            ),
877        };
878
879        let mut other_params = Vec::new();
880        let mut type_params = Vec::new();
881        for other in &relation_fields {
882            if other.field.ident == rel.field.ident {
883                type_params.push(value_ty.clone());
884            } else {
885                let ident = &other.param_ident;
886                let state_mod = &other.state_mod_ident;
887                other_params.push(quote!(#ident: #state_mod::State));
888                type_params.push(quote!(#ident));
889            }
890        }
891
892        let mut impl_params = Vec::new();
893        impl_params.push(quote!(#item_ident));
894        impl_params.extend(other_params);
895
896        let impl_generics = if impl_params.is_empty() {
897            quote!()
898        } else {
899            quote!(<#(#impl_params),*>)
900        };
901        let type_args = if type_params.is_empty() {
902            quote!()
903        } else {
904            quote!(<#(#type_params),*>)
905        };
906
907        quote!(
908            impl #impl_generics ::dbkit::GetRelation<#rel_ty, #value_ty> for #model_ident #type_args {
909                fn get_relation(&self, _rel: #rel_ty) -> Option<&#value_ty> {
910                    Some(&self.#field_ident)
911                }
912
913                fn get_relation_mut(&mut self, _rel: #rel_ty) -> Option<&mut #value_ty> {
914                    Some(&mut self.#field_ident)
915                }
916            }
917        )
918    });
919
920    let load_method = quote!(
921        pub async fn load<Rel>(
922            self,
923            rel: Rel,
924            ex: &(impl ::dbkit::Executor + Send + Sync),
925        ) -> Result<<Self as ::dbkit::LoadRelation<Rel>>::Out, ::dbkit::Error>
926        where
927            Self: ::dbkit::LoadRelation<Rel>,
928        {
929            ::dbkit::LoadRelation::load_relation(self, rel, ex).await
930        }
931    );
932
933    let load_relation_impls = relation_fields.iter().map(|rel| {
934        let field_ident = rel.field.ident.as_ref().expect("field ident");
935        let child_type = &rel.child_type;
936        let rel_type = match rel.kind {
937            RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
938            RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
939            RelationKind::ManyToMany => {
940                let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
941                quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
942            }
943        };
944        let loaded_type = match rel.kind {
945            RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#child_type>),
946            RelationKind::BelongsTo => quote!(Option<#child_type>),
947        };
948        let loader_fn = match rel.kind {
949            RelationKind::HasMany => quote!(::dbkit::runtime::load_selectin_has_many),
950            RelationKind::ManyToMany => quote!(::dbkit::runtime::load_selectin_many_to_many),
951            RelationKind::BelongsTo => quote!(::dbkit::runtime::load_selectin_belongs_to),
952        };
953
954        let mut other_params = Vec::new();
955        let mut type_params = Vec::new();
956        let mut out_params = Vec::new();
957        for other in &relation_fields {
958            if other.field.ident == rel.field.ident {
959                type_params.push(quote!(::dbkit::NotLoaded));
960                out_params.push(loaded_type.clone());
961            } else {
962                let ident = &other.param_ident;
963                let state_mod = &other.state_mod_ident;
964                other_params.push(quote!(#ident: #state_mod::State + Send + 'static));
965                type_params.push(quote!(#ident));
966                out_params.push(quote!(#ident));
967            }
968        }
969
970        let impl_generics = if other_params.is_empty() {
971            quote!()
972        } else {
973            quote!(<#(#other_params),*>)
974        };
975        let type_args = if type_params.is_empty() {
976            quote!()
977        } else {
978            quote!(<#(#type_params),*>)
979        };
980        let out_type = if out_params.is_empty() {
981            quote!(#model_ident)
982        } else {
983            quote!(#model_ident<#(#out_params),*>)
984        };
985        let out_construct = if out_params.is_empty() {
986            quote!(#model_ident)
987        } else {
988            quote!(#model_ident::<#(#out_params),*>)
989        };
990
991        let destructure_fields = output_fields.iter().map(|field| {
992            let ident = field.ident.as_ref().expect("field ident");
993            if ident == field_ident {
994                quote!(#ident: _)
995            } else {
996                quote!(#ident)
997            }
998        });
999
1000        let build_fields = output_fields.iter().map(|field| {
1001            let ident = field.ident.as_ref().expect("field ident");
1002            if ident == field_ident {
1003                quote!(#ident: Default::default())
1004            } else {
1005                quote!(#ident)
1006            }
1007        });
1008
1009        quote!(
1010            impl #impl_generics ::dbkit::LoadRelation<#rel_type> for #model_ident #type_args {
1011                type Out = #out_type;
1012
1013                fn load_relation<'e, E>(
1014                    self,
1015                    rel: #rel_type,
1016                    ex: &'e E,
1017                ) -> ::dbkit::executor::BoxFuture<'e, Result<Self::Out, ::dbkit::Error>>
1018                where
1019                    E: ::dbkit::Executor + Send + Sync + 'e,
1020                {
1021                    Box::pin(async move {
1022                        let Self { #(#destructure_fields,)* } = self;
1023                        let mut out = #out_construct {
1024                            #(#build_fields,)*
1025                        };
1026                        let mut rows = vec![out];
1027                        #loader_fn(ex, &mut rows, rel, &::dbkit::load::NoLoad).await?;
1028                        Ok(rows.pop().expect("loaded row"))
1029                    })
1030                }
1031            }
1032        )
1033    });
1034
1035    let relation_consts = relation_fields.iter().filter_map(|rel| {
1036        let field_ident = rel.field.ident.as_ref().expect("field ident");
1037        let child_type = &rel.child_type;
1038        match rel.kind {
1039            RelationKind::HasMany => Some(quote!(
1040                pub const #field_ident: ::dbkit::rel::HasMany<#struct_ident, #child_type> =
1041                    ::dbkit::rel::HasMany::new(
1042                        <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::PARENT_TABLE,
1043                        <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::CHILD_TABLE,
1044                        <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::PARENT_KEY,
1045                        <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::CHILD_KEY,
1046                    );
1047            )),
1048            RelationKind::BelongsTo => {
1049                let key = rel.belongs_to_key.as_ref().expect("belongs_to key");
1050                let references = rel.belongs_to_ref.as_ref().expect("belongs_to references");
1051                Some(quote!(
1052                    pub const #field_ident: ::dbkit::rel::BelongsTo<#struct_ident, #child_type> =
1053                        ::dbkit::rel::BelongsTo::new(
1054                            Self::TABLE,
1055                            #child_type::TABLE,
1056                            Self::#key.as_ref(),
1057                            #child_type::#references.as_ref(),
1058                        );
1059                ))
1060            }
1061            RelationKind::ManyToMany => {
1062                let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
1063                let left_key = rel.many_to_many_left_key.as_ref().expect("many-to-many left_key");
1064                let right_key = rel.many_to_many_right_key.as_ref().expect("many-to-many right_key");
1065                let parent_pk = primary_keys.first().map(|(ident, _)| ident).expect("many-to-many parent pk");
1066                Some(quote!(
1067                    pub const #field_ident: ::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through> =
1068                        ::dbkit::rel::ManyToMany::new(
1069                            Self::TABLE,
1070                            #child_type::TABLE,
1071                            #through::TABLE,
1072                            Self::#parent_pk.as_ref(),
1073                            #child_type::PRIMARY_KEY.as_ref(),
1074                            #through::#left_key.as_ref(),
1075                            #through::#right_key.as_ref(),
1076                        );
1077                ))
1078            }
1079        }
1080    });
1081
1082    let belongs_to_specs = relation_fields.iter().filter_map(|rel| {
1083        if rel.kind != RelationKind::BelongsTo {
1084            return None;
1085        }
1086        let parent_type = &rel.child_type;
1087        let key = rel.belongs_to_key.as_ref().expect("belongs_to key");
1088        let references = rel.belongs_to_ref.as_ref().expect("belongs_to references");
1089        Some(quote!(
1090            impl #impl_generics ::dbkit::rel::BelongsToSpec<#parent_type> for #model_ident #struct_type_args {
1091                const CHILD_TABLE: ::dbkit::Table = Self::TABLE;
1092                const PARENT_TABLE: ::dbkit::Table = #parent_type::TABLE;
1093                const CHILD_KEY: ::dbkit::ColumnRef = Self::#key.as_ref();
1094                const PARENT_KEY: ::dbkit::ColumnRef = #parent_type::#references.as_ref();
1095            }
1096        ))
1097    });
1098
1099    let apply_load_impls = relation_fields.iter().flat_map(|rel| {
1100        let child_type = &rel.child_type;
1101        let rel_type = match rel.kind {
1102            RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
1103            RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
1104            RelationKind::ManyToMany => {
1105                let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
1106                quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
1107            }
1108        };
1109
1110        let loaded_child = quote!(<Nested as ::dbkit::load::ApplyLoad<#child_type>>::Out2);
1111        let loaded_param = match rel.kind {
1112            RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#loaded_child>),
1113            RelationKind::BelongsTo => quote!(Option<#loaded_child>),
1114        };
1115
1116        let mut out_params = Vec::new();
1117        for other in &relation_fields {
1118            if other.field.ident == rel.field.ident {
1119                out_params.push(loaded_param.clone());
1120            } else {
1121                let ident = &other.param_ident;
1122                out_params.push(quote!(#ident));
1123            }
1124        }
1125
1126        let model_type = if generic_idents.is_empty() {
1127            quote!(#model_ident)
1128        } else {
1129            quote!(#model_ident<#(#generic_idents),*>)
1130        };
1131        let out_type = if out_params.is_empty() {
1132            quote!(#model_ident)
1133        } else {
1134            quote!(#model_ident<#(#out_params),*>)
1135        };
1136
1137        let mut apply_generics = Vec::new();
1138        apply_generics.push(quote!(Nested));
1139        apply_generics.extend(impl_generics_params.iter().cloned());
1140        let apply_generics = if apply_generics.is_empty() {
1141            quote!()
1142        } else {
1143            quote!(<#(#apply_generics),*>)
1144        };
1145
1146        let mut items = Vec::new();
1147        for strategy in ["SelectIn", "Joined"] {
1148            let load_ty = if strategy == "SelectIn" {
1149                quote!(::dbkit::load::SelectIn<#rel_type, Nested>)
1150            } else {
1151                quote!(::dbkit::load::Joined<#rel_type, Nested>)
1152            };
1153            items.push(quote!(
1154                impl #apply_generics ::dbkit::load::ApplyLoad<#model_type> for #load_ty
1155                where
1156                    Nested: ::dbkit::load::ApplyLoad<#child_type>,
1157                {
1158                    type Out2 = #out_type;
1159                }
1160            ));
1161        }
1162        items.into_iter()
1163    });
1164
1165    let run_load_impls = relation_fields.iter().flat_map(|rel| {
1166        let child_type = &rel.child_type;
1167        let through = rel.many_to_many_through.as_ref();
1168        let rel_type = match rel.kind {
1169            RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
1170            RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
1171            RelationKind::ManyToMany => {
1172                let through = through.expect("many-to-many through");
1173                quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
1174            }
1175        };
1176
1177        let loaded_child = quote!(<Nested as ::dbkit::load::ApplyLoad<#child_type>>::Out2);
1178        let loaded_param = match rel.kind {
1179            RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#loaded_child>),
1180            RelationKind::BelongsTo => quote!(Option<#loaded_child>),
1181        };
1182
1183        let mut out_params = Vec::new();
1184        for other in &relation_fields {
1185            if other.field.ident == rel.field.ident {
1186                out_params.push(loaded_param.clone());
1187            } else {
1188                let ident = &other.param_ident;
1189                out_params.push(quote!(#ident));
1190            }
1191        }
1192
1193        let out_type = if out_params.is_empty() {
1194            quote!(#model_ident)
1195        } else {
1196            quote!(#model_ident<#(#out_params),*>)
1197        };
1198
1199        let mut apply_generics = Vec::new();
1200        apply_generics.push(quote!(Nested));
1201        for other in &relation_fields {
1202            if other.field.ident == rel.field.ident {
1203                continue;
1204            }
1205            let ident = &other.param_ident;
1206            let state_mod = &other.state_mod_ident;
1207            apply_generics.push(quote!(#ident: #state_mod::State + Send + 'static));
1208        }
1209        let apply_generics = if apply_generics.is_empty() {
1210            quote!()
1211        } else {
1212            quote!(<#(#apply_generics),*>)
1213        };
1214
1215        let (child_bounds, loader_fn) = match rel.kind {
1216            RelationKind::HasMany => (
1217                quote!(#loaded_child: ::dbkit::ModelValue + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,),
1218                quote!(::dbkit::runtime::load_selectin_has_many),
1219            ),
1220            RelationKind::ManyToMany => {
1221                let through = through.expect("many-to-many through");
1222                (
1223                    quote!(
1224                        #loaded_child: ::dbkit::ModelValue + Clone + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,
1225                        #through: ::dbkit::ModelValue + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,
1226                    ),
1227                    quote!(::dbkit::runtime::load_selectin_many_to_many),
1228                )
1229            }
1230            RelationKind::BelongsTo => (
1231                quote!(#loaded_child: ::dbkit::ModelValue + Clone + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,),
1232                quote!(::dbkit::runtime::load_selectin_belongs_to),
1233            ),
1234        };
1235
1236        let joined_loader_fn = match rel.kind {
1237            RelationKind::HasMany => quote!(::dbkit::runtime::load_joined_has_many),
1238            RelationKind::ManyToMany => quote!(::dbkit::runtime::load_joined_many_to_many),
1239            RelationKind::BelongsTo => quote!(::dbkit::runtime::load_joined_belongs_to),
1240        };
1241
1242        let mut items = Vec::new();
1243        for (strategy, loader) in [
1244            ("SelectIn", loader_fn),
1245            ("Joined", joined_loader_fn),
1246        ] {
1247            let load_ty = if strategy == "SelectIn" {
1248                quote!(::dbkit::load::SelectIn<#rel_type, Nested>)
1249            } else {
1250                quote!(::dbkit::load::Joined<#rel_type, Nested>)
1251            };
1252            let out_bound = if strategy == "SelectIn" {
1253                quote!(::dbkit::ModelValue + ::dbkit::SetRelation<#rel_type, #loaded_param>)
1254            } else {
1255                quote!(::dbkit::GetRelation<#rel_type, #loaded_param>)
1256            };
1257
1258            items.push(quote!(
1259                impl #apply_generics ::dbkit::runtime::RunLoad<#out_type> for #load_ty
1260                where
1261                    Nested: ::dbkit::load::ApplyLoad<#child_type> + ::dbkit::runtime::RunLoads<#loaded_child> + Sync,
1262                    #out_type: #out_bound,
1263                    #child_bounds
1264                {
1265                    fn run<'e, E>(
1266                        &'e self,
1267                        ex: &'e E,
1268                        rows: &'e mut [#out_type],
1269                    ) -> ::dbkit::executor::BoxFuture<'e, Result<(), ::dbkit::Error>>
1270                    where
1271                        E: ::dbkit::Executor + Send + Sync + 'e,
1272                    {
1273                        #loader(ex, rows, self.rel.clone(), &self.nested)
1274                    }
1275                }
1276            ));
1277        }
1278        items.into_iter()
1279    });
1280
1281    let output = quote! {
1282        #(#struct_attrs)*
1283        #[derive(Debug, Clone)]
1284        #vis struct #model_ident #struct_generics {
1285            #(#output_fields,)*
1286        }
1287
1288        #vis type #struct_ident = #model_ident;
1289
1290        #(#relation_state_modules)*
1291
1292        #vis trait #any_state_ident {}
1293        impl #impl_generics #any_state_ident for #model_ident #struct_type_args {}
1294
1295        impl #impl_generics #model_ident #struct_type_args {
1296            pub const TABLE: ::dbkit::Table = #table_expr;
1297            #(#columns)*
1298            #columns_const
1299            #primary_key_const
1300            #primary_keys_const
1301            #(#relation_consts)*
1302
1303            pub fn query() -> ::dbkit::Select<#struct_ident> {
1304                ::dbkit::Select::new(Self::TABLE)
1305            }
1306
1307            #by_id_fn
1308
1309            pub fn insert(values: #insert_ident) -> ::dbkit::Insert<#struct_ident> {
1310                let mut insert = ::dbkit::Insert::new(Self::TABLE);
1311                #(#insert_values)*
1312                insert
1313            }
1314
1315            pub fn insert_many(values: Vec<#insert_ident>) -> ::dbkit::Insert<#struct_ident> {
1316                let mut insert = ::dbkit::Insert::new(Self::TABLE);
1317                for value in values {
1318                    insert = insert.row(|row| {
1319                        let mut row = row;
1320                        #(
1321                            row = row.value(Self::#insert_field_idents, value.#insert_field_idents);
1322                        )*
1323                        row
1324                    });
1325                }
1326                insert
1327            }
1328
1329            pub fn update() -> ::dbkit::Update<#struct_ident> {
1330                ::dbkit::Update::new(Self::TABLE)
1331            }
1332
1333            pub fn delete() -> ::dbkit::Delete {
1334                ::dbkit::Delete::new(Self::TABLE)
1335            }
1336
1337            pub fn new_active() -> #active_ident {
1338                #active_ident::new()
1339            }
1340
1341            #into_active_fn
1342            #load_method
1343        }
1344
1345        #[derive(Debug, Clone)]
1346        #vis struct #insert_ident {
1347            #(#insert_fields,)*
1348        }
1349
1350        #[derive(Debug, Clone, Default)]
1351        #vis struct #active_ident {
1352            #(#active_fields,)*
1353        }
1354
1355        impl #active_ident {
1356            pub fn new() -> Self {
1357                Self::default()
1358            }
1359
1360            #active_insert_fn
1361            #active_update_fn
1362            #active_delete_fn
1363            #active_save_fn
1364        }
1365
1366        #(#relation_methods)*
1367        #model_value_impl
1368        #from_row_impl
1369        #joined_model_impl
1370        #(#set_relation_impls)*
1371        #(#get_relation_impls)*
1372        #(#load_relation_impls)*
1373        #(#belongs_to_specs)*
1374        #(#apply_load_impls)*
1375        #(#run_load_impls)*
1376        #model_delete_impl
1377    };
1378
1379    Ok(output.into())
1380}
1381
1382fn parse_model_args(args: syn::punctuated::Punctuated<Meta, syn::Token![,]>) -> ModelArgs {
1383    let mut out = ModelArgs::default();
1384    for meta in args {
1385        if let Meta::NameValue(nv) = meta {
1386            if nv.path.is_ident("table") {
1387                if let Some(value) = extract_lit_str(&nv.value) {
1388                    out.table = Some(value);
1389                }
1390            } else if nv.path.is_ident("schema") {
1391                if let Some(value) = extract_lit_str(&nv.value) {
1392                    out.schema = Some(value);
1393                }
1394            }
1395        }
1396    }
1397    out
1398}
1399
1400fn parse_belongs_to_args(attrs: &[Attribute]) -> syn::Result<(Ident, Ident)> {
1401    for attr in attrs {
1402        if !attr.path().is_ident("belongs_to") {
1403            continue;
1404        }
1405        let args = attr.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)?;
1406        let mut key = None;
1407        let mut references = None;
1408        for meta in args {
1409            if let Meta::NameValue(nv) = meta {
1410                if nv.path.is_ident("key") {
1411                    key = extract_ident(&nv.value);
1412                } else if nv.path.is_ident("references") {
1413                    references = extract_ident(&nv.value);
1414                }
1415            }
1416        }
1417        if let (Some(key), Some(references)) = (key, references) {
1418            return Ok((key, references));
1419        }
1420    }
1421    Err(syn::Error::new(
1422        proc_macro2::Span::call_site(),
1423        "dbkit: #[belongs_to] requires key = <field> and references = <field>",
1424    ))
1425}
1426
1427fn parse_many_to_many_args(attrs: &[Attribute]) -> syn::Result<(Ident, Ident, Ident)> {
1428    for attr in attrs {
1429        if !attr.path().is_ident("many_to_many") {
1430            continue;
1431        }
1432        let args = attr.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)?;
1433        let mut through = None;
1434        let mut left_key = None;
1435        let mut right_key = None;
1436        for meta in args {
1437            if let Meta::NameValue(nv) = meta {
1438                if nv.path.is_ident("through") {
1439                    through = extract_ident(&nv.value);
1440                } else if nv.path.is_ident("left_key") {
1441                    left_key = extract_ident(&nv.value);
1442                } else if nv.path.is_ident("right_key") {
1443                    right_key = extract_ident(&nv.value);
1444                }
1445            }
1446        }
1447        if let (Some(through), Some(left_key), Some(right_key)) = (through, left_key, right_key) {
1448            return Ok((through, left_key, right_key));
1449        }
1450    }
1451    Err(syn::Error::new(
1452        proc_macro2::Span::call_site(),
1453        "dbkit: #[many_to_many] requires through = <Model>, left_key = <field>, right_key = <field>",
1454    ))
1455}
1456
1457fn extract_lit_str(expr: &syn::Expr) -> Option<String> {
1458    if let syn::Expr::Lit(syn::ExprLit {
1459        lit: syn::Lit::Str(lit), ..
1460    }) = expr
1461    {
1462        Some(lit.value())
1463    } else {
1464        None
1465    }
1466}
1467
1468fn extract_ident(expr: &syn::Expr) -> Option<Ident> {
1469    if let syn::Expr::Path(path) = expr {
1470        path.path.get_ident().cloned()
1471    } else {
1472        None
1473    }
1474}
1475
1476fn option_inner_type(ty: &Type) -> Option<Type> {
1477    let path = match ty {
1478        Type::Path(path) => path,
1479        _ => return None,
1480    };
1481    let segment = path.path.segments.last()?;
1482    if segment.ident != "Option" {
1483        return None;
1484    }
1485    let args = match &segment.arguments {
1486        syn::PathArguments::AngleBracketed(args) => args,
1487        _ => return None,
1488    };
1489    let inner = args.args.first()?;
1490    match inner {
1491        syn::GenericArgument::Type(inner_ty) => Some(inner_ty.clone()),
1492        _ => None,
1493    }
1494}
1495
1496fn has_attr(attrs: &[Attribute], name: &str) -> bool {
1497    attrs.iter().any(|attr| attr.path().is_ident(name))
1498}
1499
1500fn filter_struct_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
1501    let mut kept = Vec::new();
1502    for attr in attrs {
1503        if is_model_attr(attr) {
1504            continue;
1505        }
1506        if attr.path().is_ident("derive") {
1507            if let Ok(mut paths) = attr.parse_args_with(syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated) {
1508                paths = paths
1509                    .into_iter()
1510                    .filter(|path| !path.segments.last().map(|seg| seg.ident == "Model").unwrap_or(false))
1511                    .collect();
1512                if paths.is_empty() {
1513                    continue;
1514                }
1515                let new_attr = quote!(#[derive(#paths)]);
1516                let parsed = syn::Attribute::parse_outer.parse2(new_attr).expect("derive attr");
1517                kept.extend(parsed);
1518                continue;
1519            }
1520        }
1521        kept.push(attr.clone());
1522    }
1523    kept
1524}
1525
1526fn filter_field_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
1527    attrs.iter().filter(|attr| !is_field_orm_attr(attr)).cloned().collect()
1528}
1529
1530fn is_field_orm_attr(attr: &Attribute) -> bool {
1531    let name = attr.path().get_ident().map(|ident| ident.to_string());
1532    matches!(
1533        name.as_deref(),
1534        Some("key") | Some("autoincrement") | Some("unique") | Some("index") | Some("has_many") | Some("belongs_to") | Some("many_to_many")
1535    )
1536}
1537
1538fn is_model_attr(attr: &Attribute) -> bool {
1539    attr.path().is_ident("model")
1540}
1541
1542fn relation_type(field: &Field) -> syn::Result<(RelationKind, Type)> {
1543    let kind = if has_attr(&field.attrs, "has_many") {
1544        RelationKind::HasMany
1545    } else if has_attr(&field.attrs, "belongs_to") {
1546        RelationKind::BelongsTo
1547    } else if has_attr(&field.attrs, "many_to_many") {
1548        RelationKind::ManyToMany
1549    } else {
1550        return Err(syn::Error::new_spanned(field, "dbkit: missing relation attribute"));
1551    };
1552
1553    let child_type = match &field.ty {
1554        Type::Path(path) => {
1555            let segment = path
1556                .path
1557                .segments
1558                .last()
1559                .ok_or_else(|| syn::Error::new_spanned(&field.ty, "dbkit: invalid type"))?;
1560            let expected = match kind {
1561                RelationKind::HasMany => "HasMany",
1562                RelationKind::BelongsTo => "BelongsTo",
1563                RelationKind::ManyToMany => "ManyToMany",
1564            };
1565            if segment.ident != expected {
1566                return Err(syn::Error::new_spanned(
1567                    &segment.ident,
1568                    format!("dbkit: expected {} marker type", expected),
1569                ));
1570            }
1571            match &segment.arguments {
1572                syn::PathArguments::AngleBracketed(args) => {
1573                    let ty = args.args.iter().find_map(|arg| match arg {
1574                        syn::GenericArgument::Type(ty) => Some(ty.clone()),
1575                        _ => None,
1576                    });
1577                    ty.ok_or_else(|| syn::Error::new_spanned(&segment, "dbkit: missing type"))?
1578                }
1579                _ => return Err(syn::Error::new_spanned(&segment.arguments, "dbkit: expected generic argument")),
1580            }
1581        }
1582        _ => return Err(syn::Error::new_spanned(&field.ty, "dbkit: relation marker must be a type path")),
1583    };
1584
1585    Ok((kind, child_type))
1586}
1587
1588fn is_relation_field(field: &Field, rels: &[RelationInfo]) -> bool {
1589    rels.iter().any(|rel| rel.field.ident == field.ident)
1590}
1591
1592fn to_snake_case(name: &str) -> String {
1593    let chars: Vec<char> = name.chars().collect();
1594    let mut out = String::with_capacity(name.len() + (name.len() / 4));
1595
1596    for (idx, &ch) in chars.iter().enumerate() {
1597        let prev = idx.checked_sub(1).and_then(|i| chars.get(i)).copied();
1598        let next = chars.get(idx + 1).copied();
1599
1600        if ch.is_uppercase() {
1601            let prev_is_lower_or_digit = prev.map(|p| p.is_lowercase() || p.is_ascii_digit()).unwrap_or(false);
1602            let prev_is_upper = prev.map(|p| p.is_uppercase()).unwrap_or(false);
1603            let next_is_lower = next.map(|n| n.is_lowercase()).unwrap_or(false);
1604            let leading_upper_pair = idx == 1 && prev_is_upper && next_is_lower;
1605            let needs_separator = idx > 0 && (prev_is_lower_or_digit || (prev_is_upper && next_is_lower && !leading_upper_pair));
1606
1607            if needs_separator && !out.ends_with('_') {
1608                out.push('_');
1609            }
1610            for lower in ch.to_lowercase() {
1611                out.push(lower);
1612            }
1613            continue;
1614        }
1615
1616        out.push(ch);
1617    }
1618
1619    out
1620}
1621
1622fn to_camel_case(name: &str) -> String {
1623    let mut out = String::new();
1624    let mut uppercase_next = true;
1625    for ch in name.chars() {
1626        if ch == '_' {
1627            uppercase_next = true;
1628            continue;
1629        }
1630        if uppercase_next {
1631            for up in ch.to_uppercase() {
1632                out.push(up);
1633            }
1634            uppercase_next = false;
1635        } else {
1636            out.push(ch);
1637        }
1638    }
1639    out
1640}
1641
1642// (unused helper removed)
1643
1644// (intentionally removed unused AnyState helpers)
1645
1646#[derive(Default)]
1647struct DbEnumArgs {
1648    type_name: Option<String>,
1649    rename_all: Option<String>,
1650}
1651
1652#[derive(Clone, Copy)]
1653enum DbEnumRenameAll {
1654    AsIs,
1655    SnakeCase,
1656    LowerCase,
1657    UpperCase,
1658    ScreamingSnakeCase,
1659}
1660
1661fn expand_db_enum(input: syn::ItemEnum) -> syn::Result<TokenStream> {
1662    if !input.generics.params.is_empty() {
1663        return Err(syn::Error::new_spanned(
1664            input.generics,
1665            "dbkit: #[derive(DbEnum)] does not support generics",
1666        ));
1667    }
1668
1669    let args = parse_db_enum_args(&input.attrs)?;
1670    let type_name = args
1671        .type_name
1672        .ok_or_else(|| syn::Error::new_spanned(&input.ident, "dbkit: DbEnum requires #[dbkit(type_name = \"...\")]"))?;
1673    let rename_rule = parse_db_enum_rename_all(args.rename_all.as_deref())?;
1674
1675    let enum_ident = input.ident.clone();
1676
1677    let mut as_db_arms = Vec::new();
1678    let mut from_db_arms = Vec::new();
1679    let mut expected_values = Vec::new();
1680    let mut seen_db_names: std::collections::BTreeMap<String, syn::Ident> = std::collections::BTreeMap::new();
1681
1682    for variant in input.variants.iter() {
1683        if !matches!(variant.fields, syn::Fields::Unit) {
1684            return Err(syn::Error::new_spanned(
1685                &variant.fields,
1686                "dbkit: DbEnum only supports unit variants",
1687            ));
1688        }
1689
1690        let variant_ident = &variant.ident;
1691        let explicit = parse_db_enum_variant_rename(&variant.attrs)?;
1692        let db_name = match explicit {
1693            Some(value) => value,
1694            None => apply_db_enum_rename_rule(&variant.ident.to_string(), rename_rule),
1695        };
1696        if let Some(first_variant) = seen_db_names.get(&db_name) {
1697            return Err(syn::Error::new_spanned(
1698                variant_ident,
1699                format!(
1700                    "dbkit: duplicate DbEnum wire name `{}` for variants `{}` and `{}`",
1701                    db_name, first_variant, variant_ident
1702                ),
1703            ));
1704        }
1705        seen_db_names.insert(db_name.clone(), variant_ident.clone());
1706        let db_name_lit = syn::LitStr::new(&db_name, variant.ident.span());
1707        expected_values.push(db_name);
1708
1709        as_db_arms.push(quote!(Self::#variant_ident => #db_name_lit,));
1710        from_db_arms.push(quote!(#db_name_lit => Ok(Self::#variant_ident),));
1711    }
1712
1713    if as_db_arms.is_empty() {
1714        return Err(syn::Error::new_spanned(enum_ident, "dbkit: DbEnum requires at least one variant"));
1715    }
1716
1717    let type_name_lit = syn::LitStr::new(&type_name, proc_macro2::Span::call_site());
1718    let expected_lit = syn::LitStr::new(&expected_values.join(", "), proc_macro2::Span::call_site());
1719
1720    let tokens = quote! {
1721        impl #enum_ident {
1722            pub const DB_TYPE_NAME: &'static str = #type_name_lit;
1723
1724            pub fn as_db_str(&self) -> &'static str {
1725                match self {
1726                    #(#as_db_arms)*
1727                }
1728            }
1729        }
1730
1731        impl ::std::str::FromStr for #enum_ident {
1732            type Err = String;
1733
1734            fn from_str(value: &str) -> Result<Self, Self::Err> {
1735                match value {
1736                    #(#from_db_arms)*
1737                    _ => Err(format!(
1738                        "dbkit: invalid value `{}` for enum {} (expected one of: {})",
1739                        value,
1740                        stringify!(#enum_ident),
1741                        #expected_lit
1742                    )),
1743                }
1744            }
1745        }
1746
1747        impl From<#enum_ident> for ::dbkit::Value {
1748            fn from(value: #enum_ident) -> Self {
1749                ::dbkit::Value::Enum {
1750                    type_name: #type_name_lit,
1751                    value: value.as_db_str().to_string(),
1752                }
1753            }
1754        }
1755
1756        impl ::dbkit::sqlx::Type<::dbkit::sqlx::Postgres> for #enum_ident {
1757            fn type_info() -> ::dbkit::sqlx::postgres::PgTypeInfo {
1758                ::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit)
1759            }
1760
1761            fn compatible(ty: &::dbkit::sqlx::postgres::PgTypeInfo) -> bool {
1762                *ty == ::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit)
1763                    || <&str as ::dbkit::sqlx::Type<::dbkit::sqlx::Postgres>>::compatible(ty)
1764            }
1765        }
1766
1767        impl<'q> ::dbkit::sqlx::Encode<'q, ::dbkit::sqlx::Postgres> for #enum_ident {
1768            fn encode_by_ref(
1769                &self,
1770                buf: &mut ::dbkit::sqlx::postgres::PgArgumentBuffer,
1771            ) -> Result<::dbkit::sqlx::encode::IsNull, ::dbkit::sqlx::error::BoxDynError> {
1772                <&str as ::dbkit::sqlx::Encode<'q, ::dbkit::sqlx::Postgres>>::encode(self.as_db_str(), buf)
1773            }
1774
1775            fn produces(&self) -> Option<::dbkit::sqlx::postgres::PgTypeInfo> {
1776                Some(::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit))
1777            }
1778
1779            fn size_hint(&self) -> usize {
1780                self.as_db_str().len()
1781            }
1782        }
1783
1784        impl<'r> ::dbkit::sqlx::Decode<'r, ::dbkit::sqlx::Postgres> for #enum_ident {
1785            fn decode(value: ::dbkit::sqlx::postgres::PgValueRef<'r>) -> Result<Self, ::dbkit::sqlx::error::BoxDynError> {
1786                let value = <&str as ::dbkit::sqlx::Decode<'r, ::dbkit::sqlx::Postgres>>::decode(value)?;
1787                <Self as ::std::str::FromStr>::from_str(value).map_err(|err| err.into())
1788            }
1789        }
1790    };
1791
1792    Ok(TokenStream::from(tokens))
1793}
1794
1795fn parse_db_enum_args(attrs: &[Attribute]) -> syn::Result<DbEnumArgs> {
1796    let mut args = DbEnumArgs::default();
1797
1798    for attr in attrs {
1799        if !attr.path().is_ident("dbkit") {
1800            continue;
1801        }
1802        attr.parse_nested_meta(|meta| {
1803            if meta.path.is_ident("type_name") {
1804                let lit: syn::LitStr = meta.value()?.parse()?;
1805                args.type_name = Some(lit.value());
1806                return Ok(());
1807            }
1808            if meta.path.is_ident("rename_all") {
1809                let lit: syn::LitStr = meta.value()?.parse()?;
1810                args.rename_all = Some(lit.value());
1811                return Ok(());
1812            }
1813            Err(meta.error("dbkit: unsupported DbEnum option; expected `type_name` or `rename_all`"))
1814        })?;
1815    }
1816
1817    Ok(args)
1818}
1819
1820fn parse_db_enum_variant_rename(attrs: &[Attribute]) -> syn::Result<Option<String>> {
1821    let mut rename = None;
1822
1823    for attr in attrs {
1824        if !attr.path().is_ident("dbkit") {
1825            continue;
1826        }
1827        attr.parse_nested_meta(|meta| {
1828            if meta.path.is_ident("rename") {
1829                let lit: syn::LitStr = meta.value()?.parse()?;
1830                rename = Some(lit.value());
1831                return Ok(());
1832            }
1833            Err(meta.error("dbkit: unsupported DbEnum variant option; expected `rename`"))
1834        })?;
1835    }
1836
1837    Ok(rename)
1838}
1839
1840fn parse_db_enum_rename_all(value: Option<&str>) -> syn::Result<DbEnumRenameAll> {
1841    match value {
1842        None => Ok(DbEnumRenameAll::AsIs),
1843        Some("snake_case") => Ok(DbEnumRenameAll::SnakeCase),
1844        Some("lowercase") => Ok(DbEnumRenameAll::LowerCase),
1845        Some("UPPERCASE") => Ok(DbEnumRenameAll::UpperCase),
1846        Some("SCREAMING_SNAKE_CASE") => Ok(DbEnumRenameAll::ScreamingSnakeCase),
1847        Some(other) => Err(syn::Error::new(
1848            proc_macro2::Span::call_site(),
1849            format!(
1850                "dbkit: unsupported rename_all strategy `{}` for DbEnum; supported values: snake_case, lowercase, UPPERCASE, SCREAMING_SNAKE_CASE",
1851                other
1852            ),
1853        )),
1854    }
1855}
1856
1857fn apply_db_enum_rename_rule(value: &str, rule: DbEnumRenameAll) -> String {
1858    match rule {
1859        DbEnumRenameAll::AsIs => value.to_string(),
1860        DbEnumRenameAll::SnakeCase => to_snake_case(value),
1861        DbEnumRenameAll::LowerCase => value.to_lowercase(),
1862        DbEnumRenameAll::UpperCase => value.to_uppercase(),
1863        DbEnumRenameAll::ScreamingSnakeCase => to_snake_case(value).to_uppercase(),
1864    }
1865}
1866
1867#[cfg(test)]
1868mod tests {
1869    use super::{apply_db_enum_rename_rule, to_snake_case, DbEnumRenameAll};
1870
1871    #[test]
1872    fn snake_case_respects_acronym_word_boundaries() {
1873        assert_eq!(to_snake_case("HTTPWebhook"), "http_webhook");
1874        assert_eq!(to_snake_case("OAuthToken"), "oauth_token");
1875        assert_eq!(to_snake_case("XMLHttpRequest"), "xml_http_request");
1876        assert_eq!(to_snake_case("WebhookHTTP"), "webhook_http");
1877    }
1878
1879    #[test]
1880    fn screaming_snake_case_respects_acronym_word_boundaries() {
1881        assert_eq!(
1882            apply_db_enum_rename_rule("HTTPWebhook", DbEnumRenameAll::ScreamingSnakeCase),
1883            "HTTP_WEBHOOK"
1884        );
1885        assert_eq!(
1886            apply_db_enum_rename_rule("XMLHttpRequest", DbEnumRenameAll::ScreamingSnakeCase),
1887            "XML_HTTP_REQUEST"
1888        );
1889    }
1890}