Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, enum_path, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27    let data_struct = gen_data_struct(model, &scalar_fields);
28    let filter_module = gen_filter_module(model, &scalar_fields);
29    let data_module = gen_data_module(model, &scalar_fields);
30    let order_module = gen_order_module(model, &scalar_fields);
31    let actions_struct = gen_actions(model, &scalar_fields);
32    let query_builders = gen_query_builders(model, &scalar_fields);
33    let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34    let groupby_types = gen_groupby_types(model, &scalar_fields);
35    let select_types = gen_select_types(model, &scalar_fields);
36
37    quote! {
38        #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
39
40        use serde::{Deserialize, Serialize};
41        use ferriorm_runtime::prelude::*;
42        use ferriorm_runtime::prelude::sqlx;
43        use ferriorm_runtime::prelude::chrono;
44        use ferriorm_runtime::prelude::uuid;
45
46        #data_struct
47        #filter_module
48        #data_module
49        #order_module
50        #actions_struct
51        #query_builders
52        #aggregate_types
53        #groupby_types
54        #select_types
55    }
56}
57
58// ─── Data Struct ──────────────────────────────────────────────
59
60fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
61    let struct_name = format_ident!("{}", model.name);
62    let table_name = &model.db_name;
63
64    let fields: Vec<TokenStream> = scalar_fields
65        .iter()
66        .map(|f| {
67            let name = format_ident!("{}", to_snake_case(&f.name));
68            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
69            let db_name = &f.db_name;
70            if db_name == &to_snake_case(&f.name) {
71                quote! { pub #name: #ty }
72            } else {
73                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
74            }
75        })
76        .collect();
77
78    quote! {
79        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
80        #[sqlx(rename_all = "snake_case")]
81        pub struct #struct_name {
82            #(#fields),*
83        }
84
85        impl #struct_name {
86            pub const TABLE_NAME: &'static str = #table_name;
87        }
88    }
89}
90
91// ─── Filter Module ────────────────────────────────────────────
92
93#[allow(clippy::too_many_lines)]
94fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
95    let where_input = format_ident!("{}WhereInput", model.name);
96    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
97
98    let where_fields: Vec<TokenStream> = scalar_fields
99        .iter()
100        .filter_map(|f| {
101            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
102            let name = format_ident!("{}", to_snake_case(&f.name));
103            Some(quote! { pub #name: Option<#filter_ty> })
104        })
105        .collect();
106
107    let single_unique_variants: Vec<TokenStream> = scalar_fields
108        .iter()
109        .filter(|f| f.is_id || f.is_unique)
110        .map(|f| {
111            let variant = format_ident!("{}", to_pascal_case(&f.name));
112            let ty = rust_type_tokens(f, ModuleDepth::Nested);
113            quote! { #variant(#ty) }
114        })
115        .collect();
116
117    let compound_unique_variants: Vec<TokenStream> = model
118        .unique_constraints
119        .iter()
120        .map(|uc| {
121            let variant = format_ident!("{}", compound_variant_name(&uc.fields));
122            let struct_fields = compound_variant_fields(model, &uc.fields);
123            quote! { #variant { #(#struct_fields),* } }
124        })
125        .collect();
126
127    let unique_variants: Vec<TokenStream> = single_unique_variants
128        .into_iter()
129        .chain(compound_unique_variants)
130        .collect();
131
132    // Generate build_where for WhereInput
133    let db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::Nested);
134    let where_arms = gen_where_arms(scalar_fields);
135    let unique_arms = gen_unique_where_arms(model, scalar_fields);
136    let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
137    let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
138
139    quote! {
140        pub mod filter {
141            use ferriorm_runtime::prelude::*;
142
143            #[derive(Debug, Clone, Default)]
144            pub struct #where_input {
145                #(#where_fields,)*
146                pub and: Option<Vec<#where_input>>,
147                pub or: Option<Vec<#where_input>>,
148                pub not: Option<Box<#where_input>>,
149            }
150
151            #[derive(Debug, Clone)]
152            pub enum #where_unique {
153                #(#unique_variants),*
154            }
155
156            impl #where_input {
157                pub(crate) fn build_where<'args, DB: sqlx::Database>(
158                    &self,
159                    qb: &mut sqlx::QueryBuilder<'args, DB>,
160                )
161                where
162                    #(#db_bounds,)*
163                {
164                    #(#where_arms)*
165
166                    if let Some(conditions) = &self.and {
167                        for c in conditions {
168                            c.build_where(qb);
169                        }
170                    }
171                    if let Some(conditions) = &self.or {
172                        if !conditions.is_empty() {
173                            qb.push(" AND (");
174                            for (i, c) in conditions.iter().enumerate() {
175                                if i > 0 { qb.push(" OR "); }
176                                qb.push("(1=1");
177                                c.build_where(qb);
178                                qb.push(")");
179                            }
180                            qb.push(")");
181                        }
182                    }
183                    if let Some(c) = &self.not {
184                        qb.push(" AND NOT (1=1");
185                        c.build_where(qb);
186                        qb.push(")");
187                    }
188                }
189            }
190
191            impl #where_unique {
192                pub(crate) fn build_where<'args, DB: sqlx::Database>(
193                    &self,
194                    qb: &mut sqlx::QueryBuilder<'args, DB>,
195                )
196                where
197                    #(#db_bounds,)*
198                {
199                    match self {
200                        #(#unique_arms)*
201                    }
202                }
203            }
204
205            impl #where_unique {
206                #[allow(dead_code)]
207                pub(crate) fn conflict_target(&self) -> &'static str {
208                    match self {
209                        #(#conflict_target_arms)*
210                    }
211                }
212
213                #[allow(dead_code)]
214                pub(crate) fn first_conflict_col(&self) -> &'static str {
215                    match self {
216                        #(#first_conflict_col_arms)*
217                    }
218                }
219            }
220        }
221    }
222}
223
224/// Collect the sqlx type bounds needed for all scalar types used by the model.
225/// `depth` controls the module-path resolution for enum types (the only
226/// bounds that depend on where the `impl` block lives).
227fn collect_db_bounds(scalar_fields: &[&Field], depth: ModuleDepth) -> Vec<TokenStream> {
228    let mut seen = std::collections::HashSet::new();
229    let mut seen_enums = std::collections::HashSet::new();
230    let mut bounds = Vec::new();
231
232    // Always need i64 for LIMIT/OFFSET
233    seen.insert("i64");
234    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
235
236    for f in scalar_fields {
237        match &f.field_type {
238            FieldKind::Scalar(scalar) => {
239                let key = scalar.rust_type();
240                if seen.insert(key)
241                    && let Some(ty) = scalar_bound_tokens(scalar)
242                {
243                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
244                    // Also add Option<T> bound for nullable field support
245                    bounds.push(
246                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
247                    );
248                }
249            }
250            FieldKind::Enum(name) => {
251                // Enum WHERE arms call qb.push_bind on enum values, so the
252                // generated impl needs `EnumTy: sqlx::Type<DB> + Encode<DB>`.
253                if seen_enums.insert(name.clone()) {
254                    let enum_ty = enum_path(name, depth);
255                    bounds.push(quote! { #enum_ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
256                }
257            }
258            FieldKind::Model(_) => {}
259        }
260    }
261
262    bounds
263}
264
265fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
266    match scalar {
267        ScalarType::String => Some(quote! { String }),
268        ScalarType::Int => Some(quote! { i32 }),
269        ScalarType::BigInt => Some(quote! { i64 }),
270        ScalarType::Float => Some(quote! { f64 }),
271        ScalarType::Boolean => Some(quote! { bool }),
272        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
273        ScalarType::Bytes => Some(quote! { Vec<u8> }),
274        ScalarType::Json | ScalarType::Decimal => None,
275    }
276}
277
278/// Generate `IN (...)` / `NOT IN (...)` arms for `filter.r#in` and
279/// `filter.not_in`. `lhs` is the SQL expression on the left-hand side --
280/// a quoted column for `WhereInput` or an aggregate expression like
281/// `AVG("col")` for `HavingInput`. Empty `r#in` emits the portable
282/// `1 = 0` form (Postgres rejects bare `IN ()`); empty `not_in` is
283/// dropped (vacuously true).
284fn gen_in_arms_lhs(lhs: &str) -> TokenStream {
285    let in_prefix = format!(" AND {lhs} IN (");
286    let not_in_prefix = format!(" AND {lhs} NOT IN (");
287    quote! {
288        if let Some(values) = &filter.r#in {
289            if values.is_empty() {
290                qb.push(" AND 1 = 0");
291            } else {
292                qb.push(#in_prefix);
293                {
294                    let mut sep = qb.separated(", ");
295                    for v in values {
296                        sep.push_bind(v.clone());
297                    }
298                }
299                qb.push(")");
300            }
301        }
302        if let Some(values) = &filter.not_in {
303            if !values.is_empty() {
304                qb.push(#not_in_prefix);
305                {
306                    let mut sep = qb.separated(", ");
307                    for v in values {
308                        sep.push_bind(v.clone());
309                    }
310                }
311                qb.push(")");
312            }
313        }
314    }
315}
316
317/// Generate where-clause arms for each filterable scalar or enum field.
318#[allow(clippy::too_many_lines)]
319fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
320    scalar_fields
321        .iter()
322        .filter_map(|f| {
323            let field_ident = format_ident!("{}", to_snake_case(&f.name));
324            let db_name = &f.db_name;
325            let column_lhs = format!("\"{db_name}\"");
326
327            match &f.field_type {
328                FieldKind::Scalar(scalar) => {
329                    // Skip non-filterable scalars (Json/Bytes/Decimal); they
330                    // have no filter struct field, so there's nothing to bind.
331                    if matches!(
332                        scalar,
333                        ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal
334                    ) {
335                        return None;
336                    }
337                    let is_string = matches!(scalar, ScalarType::String);
338                    let is_comparable = matches!(
339                        scalar,
340                        ScalarType::Int
341                            | ScalarType::BigInt
342                            | ScalarType::Float
343                            | ScalarType::DateTime
344                    );
345                    // BoolFilter has no `r#in`/`not_in` -- IN over booleans is
346                    // not useful.
347                    let supports_in = !matches!(scalar, ScalarType::Boolean);
348
349                    let mut arms: Vec<TokenStream> = Vec::new();
350
351                    if f.is_optional {
352                        // Nullable filter: `equals`/`not` are `Option<Option<T>>`.
353                        // `Some(None)` means IS NULL / IS NOT NULL; `Some(Some(v))`
354                        // is the ordinary `= ?` / `!= ?` comparison.
355                        arms.push(quote! {
356                            if let Some(v) = &filter.equals {
357                                match v {
358                                    None => {
359                                        qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
360                                    }
361                                    Some(inner) => {
362                                        qb.push(concat!(" AND \"", #db_name, "\" = "));
363                                        qb.push_bind(inner.clone());
364                                    }
365                                }
366                            }
367                            if let Some(v) = &filter.not {
368                                match v {
369                                    None => {
370                                        qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
371                                    }
372                                    Some(inner) => {
373                                        qb.push(concat!(" AND \"", #db_name, "\" != "));
374                                        qb.push_bind(inner.clone());
375                                    }
376                                }
377                            }
378                        });
379                    } else {
380                        arms.push(quote! {
381                            if let Some(v) = &filter.equals {
382                                qb.push(concat!(" AND \"", #db_name, "\" = "));
383                                qb.push_bind(v.clone());
384                            }
385                            if let Some(v) = &filter.not {
386                                qb.push(concat!(" AND \"", #db_name, "\" != "));
387                                qb.push_bind(v.clone());
388                            }
389                        });
390                    }
391
392                    if is_string {
393                        // `like_escape` quotes %, _, and \ in user input so they
394                        // match themselves; `ESCAPE '\\'` tells the DB to treat
395                        // backslash as the escape character. Without this the
396                        // query `contains: "100%_safe"` would also match
397                        // arbitrary strings like `100Xsafe`.
398                        arms.push(quote! {
399                            if let Some(v) = &filter.contains {
400                                qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
401                                qb.push_bind(format!("%{}%", ferriorm_runtime::filter::like_escape(v)));
402                                qb.push(" ESCAPE '\\'");
403                            }
404                            if let Some(v) = &filter.starts_with {
405                                qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
406                                qb.push_bind(format!("{}%", ferriorm_runtime::filter::like_escape(v)));
407                                qb.push(" ESCAPE '\\'");
408                            }
409                            if let Some(v) = &filter.ends_with {
410                                qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
411                                qb.push_bind(format!("%{}", ferriorm_runtime::filter::like_escape(v)));
412                                qb.push(" ESCAPE '\\'");
413                            }
414                        });
415                    }
416
417                    if is_comparable {
418                        arms.push(quote! {
419                            if let Some(v) = &filter.gt {
420                                qb.push(concat!(" AND \"", #db_name, "\" > "));
421                                qb.push_bind(v.clone());
422                            }
423                            if let Some(v) = &filter.gte {
424                                qb.push(concat!(" AND \"", #db_name, "\" >= "));
425                                qb.push_bind(v.clone());
426                            }
427                            if let Some(v) = &filter.lt {
428                                qb.push(concat!(" AND \"", #db_name, "\" < "));
429                                qb.push_bind(v.clone());
430                            }
431                            if let Some(v) = &filter.lte {
432                                qb.push(concat!(" AND \"", #db_name, "\" <= "));
433                                qb.push_bind(v.clone());
434                            }
435                        });
436                    }
437
438                    if supports_in {
439                        arms.push(gen_in_arms_lhs(&column_lhs));
440                    }
441
442                    Some(quote! {
443                        if let Some(filter) = &self.#field_ident {
444                            #(#arms)*
445                        }
446                    })
447                }
448                FieldKind::Enum(_) => {
449                    // EnumFilter exposes equals/not/in/not_in. There is no
450                    // NullableEnumFilter, so IS NULL handling on optional
451                    // enums is not yet supported here.
452                    let in_arms = gen_in_arms_lhs(&column_lhs);
453                    Some(quote! {
454                        if let Some(filter) = &self.#field_ident {
455                            if let Some(v) = &filter.equals {
456                                qb.push(concat!(" AND \"", #db_name, "\" = "));
457                                qb.push_bind(v.clone());
458                            }
459                            if let Some(v) = &filter.not {
460                                qb.push(concat!(" AND \"", #db_name, "\" != "));
461                                qb.push_bind(v.clone());
462                            }
463                            #in_arms
464                        }
465                    })
466                }
467                FieldKind::Model(_) => None,
468            }
469        })
470        .collect()
471}
472
473fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
474    let mut arms: Vec<TokenStream> = scalar_fields
475        .iter()
476        .filter(|f| f.is_id || f.is_unique)
477        .map(|f| {
478            let variant = format_ident!("{}", to_pascal_case(&f.name));
479            let db_name = &f.db_name;
480            quote! {
481                Self::#variant(v) => {
482                    qb.push(concat!(" AND \"", #db_name, "\" = "));
483                    qb.push_bind(v.clone());
484                }
485            }
486        })
487        .collect();
488
489    for uc in &model.unique_constraints {
490        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
491        let idents: Vec<_> = uc
492            .fields
493            .iter()
494            .map(|name| format_ident!("{}", to_snake_case(name)))
495            .collect();
496        let binds: Vec<TokenStream> = uc
497            .fields
498            .iter()
499            .map(|name| {
500                let ident = format_ident!("{}", to_snake_case(name));
501                let db_name = resolve_db_name(model, name);
502                quote! {
503                    qb.push(concat!(" AND \"", #db_name, "\" = "));
504                    qb.push_bind(#ident.clone());
505                }
506            })
507            .collect();
508        arms.push(quote! {
509            Self::#variant { #(#idents),* } => {
510                #(#binds)*
511            }
512        });
513    }
514
515    arms
516}
517
518fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
519    let mut arms: Vec<TokenStream> = scalar_fields
520        .iter()
521        .filter(|f| f.is_id || f.is_unique)
522        .map(|f| {
523            let variant = format_ident!("{}", to_pascal_case(&f.name));
524            let target = format!("(\"{}\")", f.db_name);
525            quote! { Self::#variant(_) => #target, }
526        })
527        .collect();
528
529    for uc in &model.unique_constraints {
530        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
531        let cols: Vec<String> = uc
532            .fields
533            .iter()
534            .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
535            .collect();
536        let target = format!("({})", cols.join(", "));
537        arms.push(quote! { Self::#variant { .. } => #target, });
538    }
539
540    arms
541}
542
543fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
544    let mut arms: Vec<TokenStream> = scalar_fields
545        .iter()
546        .filter(|f| f.is_id || f.is_unique)
547        .map(|f| {
548            let variant = format_ident!("{}", to_pascal_case(&f.name));
549            let col = format!("\"{}\"", f.db_name);
550            quote! { Self::#variant(_) => #col, }
551        })
552        .collect();
553
554    for uc in &model.unique_constraints {
555        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
556        let first = uc
557            .fields
558            .first()
559            .map_or_else(String::new, |n| resolve_db_name(model, n));
560        let col = format!("\"{first}\"");
561        arms.push(quote! { Self::#variant { .. } => #col, });
562    }
563
564    arms
565}
566
567/// `PascalCase` concatenation of the field names.
568fn compound_variant_name(fields: &[String]) -> String {
569    fields.iter().map(|f| to_pascal_case(f)).collect()
570}
571
572/// Struct-field tokens for a compound variant: `ident: Ty` per field.
573/// Enum struct-variant fields inherit the enum's visibility; no `pub` here.
574fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
575    fields
576        .iter()
577        .filter_map(|field_name| {
578            let field = model.fields.iter().find(|f| f.name == *field_name)?;
579            let ident = format_ident!("{}", to_snake_case(field_name));
580            let ty = rust_type_tokens(field, ModuleDepth::Nested);
581            Some(quote! { #ident: #ty })
582        })
583        .collect()
584}
585
586/// Resolve a schema field name to its `db_name`, falling back to `snake_case`.
587fn resolve_db_name(model: &Model, field_name: &str) -> String {
588    model
589        .fields
590        .iter()
591        .find(|f| f.name == field_name)
592        .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
593}
594
595// ─── Data Module ──────────────────────────────────────────────
596
597fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
598    let create_name = format_ident!("{}CreateInput", model.name);
599    let update_name = format_ident!("{}UpdateInput", model.name);
600
601    let required_fields: Vec<TokenStream> = scalar_fields
602        .iter()
603        .filter(|f| !f.has_default() && !f.is_updated_at)
604        .map(|f| {
605            let name = format_ident!("{}", to_snake_case(&f.name));
606            let ty = rust_type_tokens(f, ModuleDepth::Nested);
607            quote! { pub #name: #ty }
608        })
609        .collect();
610
611    let optional_fields: Vec<TokenStream> = scalar_fields
612        .iter()
613        .filter(|f| f.has_default() && !f.is_updated_at)
614        .map(|f| {
615            let name = format_ident!("{}", to_snake_case(&f.name));
616            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
617            quote! { pub #name: Option<#base_ty> }
618        })
619        .collect();
620
621    let update_fields: Vec<TokenStream> = scalar_fields
622        .iter()
623        .filter(|f| !f.is_id && !f.is_updated_at)
624        .map(|f| {
625            let name = format_ident!("{}", to_snake_case(&f.name));
626            let ty = rust_type_tokens(f, ModuleDepth::Nested);
627            quote! { pub #name: Option<SetValue<#ty>> }
628        })
629        .collect();
630
631    quote! {
632        pub mod data {
633            use ferriorm_runtime::prelude::*;
634
635            #[derive(Debug, Clone)]
636            pub struct #create_name {
637                #(#required_fields,)*
638                #(#optional_fields,)*
639            }
640
641            /// Update payload. Each field is `Option<SetValue<T>>`:
642            /// `None` leaves the column untouched (omitted from the SET clause),
643            /// `Some(SetValue::Set(v))` writes `v`.
644            #[derive(Debug, Clone, Default)]
645            pub struct #update_name {
646                #(#update_fields,)*
647            }
648        }
649    }
650}
651
652// ─── Order Module ─────────────────────────────────────────────
653
654fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
655    let order_name = format_ident!("{}OrderByInput", model.name);
656
657    let variants: Vec<TokenStream> = scalar_fields
658        .iter()
659        .map(|f| {
660            let variant = format_ident!("{}", to_pascal_case(&f.name));
661            quote! { #variant(SortOrder) }
662        })
663        .collect();
664
665    let order_arms: Vec<TokenStream> = scalar_fields
666        .iter()
667        .map(|f| {
668            let variant = format_ident!("{}", to_pascal_case(&f.name));
669            let db_name = &f.db_name;
670            quote! {
671                Self::#variant(order) => {
672                    qb.push(concat!("\"", #db_name, "\" "));
673                    qb.push(order.as_sql());
674                }
675            }
676        })
677        .collect();
678
679    quote! {
680        pub mod order {
681            use ferriorm_runtime::prelude::*;
682
683            #[derive(Debug, Clone)]
684            pub enum #order_name {
685                #(#variants),*
686            }
687
688            impl #order_name {
689                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
690                    &self,
691                    qb: &mut sqlx::QueryBuilder<'args, DB>,
692                ) {
693                    match self {
694                        #(#order_arms)*
695                    }
696                }
697            }
698        }
699    }
700}
701
702// ─── Actions ──────────────────────────────────────────────────
703
704fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
705    let _model_ident = format_ident!("{}", model.name);
706    let actions_name = format_ident!("{}Actions", model.name);
707    let where_input = format_ident!("{}WhereInput", model.name);
708    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
709    let create_input = format_ident!("{}CreateInput", model.name);
710    let update_input = format_ident!("{}UpdateInput", model.name);
711    let _order_by = format_ident!("{}OrderByInput", model.name);
712
713    // Only generate aggregate() if there are aggregatable fields
714    let has_agg_fields = scalar_fields.iter().any(|f| {
715        matches!(
716            &f.field_type,
717            FieldKind::Scalar(
718                ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
719            )
720        )
721    });
722    let aggregate_method = if has_agg_fields {
723        quote! {
724            pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
725                AggregateQuery { client: self.client, r#where, ops: vec![] }
726            }
727        }
728    } else {
729        quote! {}
730    };
731
732    // Only generate group_by() if there are groupable fields
733    let has_group_fields = scalar_fields.iter().any(|f| is_groupable(f));
734    let groupby_field_name = format_ident!("{}GroupByField", model.name);
735    let group_by_method = if has_group_fields {
736        quote! {
737            pub fn group_by(&self, keys: Vec<#groupby_field_name>) -> GroupByQuery<'a> {
738                GroupByQuery {
739                    client: self.client,
740                    r#where: filter::#where_input::default(),
741                    group_keys: keys,
742                    agg_ops: vec![],
743                    count: false,
744                    having: None,
745                }
746            }
747        }
748    } else {
749        quote! {}
750    };
751
752    quote! {
753        pub struct #actions_name<'a> {
754            client: &'a DatabaseClient,
755        }
756
757        impl<'a> #actions_name<'a> {
758            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
759
760            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
761                FindUniqueQuery { client: self.client, r#where }
762            }
763
764            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
765                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
766            }
767
768            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
769                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
770            }
771
772            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
773                CreateQuery { client: self.client, data }
774            }
775
776            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
777                UpdateQuery { client: self.client, r#where, data }
778            }
779
780            /// Like [`update`], but accepts a full `WhereInput` so additional
781            /// predicates (e.g., `status = 'pending'`) can be used for
782            /// compare-and-swap updates. Returns `Ok(None)` if no row matched.
783            pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
784                UpdateFirstQuery { client: self.client, r#where, data }
785            }
786
787            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
788                DeleteQuery { client: self.client, r#where }
789            }
790
791            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
792                CountQuery { client: self.client, r#where }
793            }
794
795            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
796                CreateManyQuery { client: self.client, data }
797            }
798
799            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
800                UpdateManyQuery { client: self.client, r#where, data }
801            }
802
803            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
804                DeleteManyQuery { client: self.client, r#where }
805            }
806
807            pub fn upsert(
808                &self,
809                r#where: filter::#where_unique,
810                create: data::#create_input,
811                update: data::#update_input,
812            ) -> UpsertQuery<'a> {
813                UpsertQuery { client: self.client, r#where, create, update }
814            }
815
816            #aggregate_method
817
818            #group_by_method
819        }
820    }
821}
822
823// ─── Query Builders with exec() ──────────────────────────────
824
825#[allow(clippy::too_many_lines)]
826fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
827    let model_ident = format_ident!("{}", model.name);
828    let table_name = &model.db_name;
829    let _where_input = format_ident!("{}WhereInput", model.name);
830    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
831    let _create_input = format_ident!("{}CreateInput", model.name);
832    let _update_input = format_ident!("{}UpdateInput", model.name);
833    let order_by = format_ident!("{}OrderByInput", model.name);
834    let _select_struct = format_ident!("{}Select", model.name);
835    let _partial_struct = format_ident!("{}Partial", model.name);
836    let _aggregate_result = format_ident!("{}AggregateResult", model.name);
837    let _aggregate_field = format_ident!("{}AggregateField", model.name);
838    let db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::TopLevel);
839
840    let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
841    let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
842    let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
843
844    let insert_code = gen_insert_code(model, scalar_fields, table_name);
845    let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
846    let update_code = gen_update_code(model, scalar_fields, table_name);
847    let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
848    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
849    let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
850
851    quote! {
852        // ── Generic helper: build ORDER BY clause ──────────────
853        fn build_order_by<'args, DB: sqlx::Database>(
854            orders: &[order::#order_by],
855            qb: &mut sqlx::QueryBuilder<'args, DB>,
856        ) {
857            if !orders.is_empty() {
858                qb.push(" ORDER BY ");
859                for (i, ob) in orders.iter().enumerate() {
860                    if i > 0 { qb.push(", "); }
861                    ob.build_order_by(qb);
862                }
863            }
864        }
865
866        // ── Generic helper: build a SELECT query ───────────────
867        fn build_select_query<'args, DB: sqlx::Database>(
868            base_sql: &str,
869            where_input: &filter::#_where_input,
870            orders: &[order::#order_by],
871            take: Option<i64>,
872            skip: Option<i64>,
873        ) -> sqlx::QueryBuilder<'args, DB>
874        where
875            #(#db_bounds,)*
876        {
877            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
878            where_input.build_where(&mut qb);
879            build_order_by(orders, &mut qb);
880            if let Some(take) = take {
881                qb.push(" LIMIT ");
882                qb.push_bind(take);
883            }
884            if let Some(skip) = skip {
885                qb.push(" OFFSET ");
886                qb.push_bind(skip);
887            }
888            qb
889        }
890
891        // ── Generic helper: build a SELECT query for unique lookup ──
892        fn build_unique_select_query<'args, DB: sqlx::Database>(
893            base_sql: &str,
894            where_unique: &filter::#_where_unique,
895        ) -> sqlx::QueryBuilder<'args, DB>
896        where
897            #(#db_bounds,)*
898        {
899            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
900            where_unique.build_where(&mut qb);
901            qb.push(" LIMIT 1");
902            qb
903        }
904
905        // ── Generic helper: build a DELETE-returning query ─────
906        fn build_delete_query<'args, DB: sqlx::Database>(
907            base_sql: &str,
908            where_unique: &filter::#_where_unique,
909        ) -> sqlx::QueryBuilder<'args, DB>
910        where
911            #(#db_bounds,)*
912        {
913            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
914            where_unique.build_where(&mut qb);
915            qb.push(" RETURNING *");
916            qb
917        }
918
919        // ── Generic helper: build a COUNT query ────────────────
920        fn build_count_query<'args, DB: sqlx::Database>(
921            base_sql: &str,
922            where_input: &filter::#_where_input,
923        ) -> sqlx::QueryBuilder<'args, DB>
924        where
925            #(#db_bounds,)*
926        {
927            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
928            where_input.build_where(&mut qb);
929            qb
930        }
931
932        // ── Generic helper: build a DELETE-many query ──────────
933        fn build_delete_many_query<'args, DB: sqlx::Database>(
934            base_sql: &str,
935            where_input: &filter::#_where_input,
936        ) -> sqlx::QueryBuilder<'args, DB>
937        where
938            #(#db_bounds,)*
939        {
940            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
941            where_input.build_where(&mut qb);
942            qb
943        }
944
945        pub struct FindUniqueQuery<'a> {
946            client: &'a DatabaseClient,
947            r#where: filter::#_where_unique,
948        }
949
950        impl<'a> FindUniqueQuery<'a> {
951            pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
952                FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
953            }
954
955            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
956                match self.client {
957                    DatabaseClient::Postgres(_) => {
958                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
959                        self.client.fetch_optional_pg(qb).await
960                    }
961                    DatabaseClient::Sqlite(_) => {
962                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
963                        self.client.fetch_optional_sqlite(qb).await
964                    }
965                }
966            }
967        }
968
969        pub struct FindFirstQuery<'a> {
970            client: &'a DatabaseClient,
971            r#where: filter::#_where_input,
972            order_by: Vec<order::#order_by>,
973        }
974
975        impl<'a> FindFirstQuery<'a> {
976            pub fn order_by(mut self, order: order::#order_by) -> Self {
977                self.order_by.push(order);
978                self
979            }
980
981            pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
982                FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
983            }
984
985            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
986                match self.client {
987                    DatabaseClient::Postgres(_) => {
988                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
989                        self.client.fetch_optional_pg(qb).await
990                    }
991                    DatabaseClient::Sqlite(_) => {
992                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
993                        self.client.fetch_optional_sqlite(qb).await
994                    }
995                }
996            }
997        }
998
999        pub struct FindManyQuery<'a> {
1000            client: &'a DatabaseClient,
1001            r#where: filter::#_where_input,
1002            order_by: Vec<order::#order_by>,
1003            skip: Option<i64>,
1004            take: Option<i64>,
1005        }
1006
1007        impl<'a> FindManyQuery<'a> {
1008            pub fn order_by(mut self, order: order::#order_by) -> Self {
1009                self.order_by.push(order);
1010                self
1011            }
1012
1013            pub fn skip(mut self, n: i64) -> Self {
1014                self.skip = Some(n);
1015                self
1016            }
1017
1018            pub fn take(mut self, n: i64) -> Self {
1019                self.take = Some(n);
1020                self
1021            }
1022
1023            pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
1024                FindManySelectQuery {
1025                    client: self.client,
1026                    r#where: self.r#where,
1027                    order_by: self.order_by,
1028                    skip: self.skip,
1029                    take: self.take,
1030                    select,
1031                }
1032            }
1033
1034            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
1035                match self.client {
1036                    DatabaseClient::Postgres(_) => {
1037                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
1038                        self.client.fetch_all_pg(qb).await
1039                    }
1040                    DatabaseClient::Sqlite(_) => {
1041                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
1042                        self.client.fetch_all_sqlite(qb).await
1043                    }
1044                }
1045            }
1046        }
1047
1048        pub struct CreateQuery<'a> {
1049            client: &'a DatabaseClient,
1050            data: data::#_create_input,
1051        }
1052
1053        impl<'a> CreateQuery<'a> {
1054            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1055                let client = self.client;
1056                #insert_code
1057            }
1058
1059            /// Switch the insert into "ignore on conflict" mode:
1060            /// PostgreSQL uses `ON CONFLICT DO NOTHING`, SQLite uses `INSERT OR IGNORE`.
1061            /// Returns `Ok(None)` when a conflict suppressed the insert.
1062            pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
1063                CreateIgnoreQuery { client: self.client, data: self.data }
1064            }
1065        }
1066
1067        pub struct CreateIgnoreQuery<'a> {
1068            client: &'a DatabaseClient,
1069            data: data::#_create_input,
1070        }
1071
1072        impl<'a> CreateIgnoreQuery<'a> {
1073            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1074                let client = self.client;
1075                #insert_ignore_code
1076            }
1077        }
1078
1079        pub struct UpdateQuery<'a> {
1080            client: &'a DatabaseClient,
1081            r#where: filter::#_where_unique,
1082            data: data::#_update_input,
1083        }
1084
1085        impl<'a> UpdateQuery<'a> {
1086            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1087                let client = self.client;
1088                #update_code
1089            }
1090        }
1091
1092        pub struct UpdateFirstQuery<'a> {
1093            client: &'a DatabaseClient,
1094            r#where: filter::#_where_input,
1095            data: data::#_update_input,
1096        }
1097
1098        impl<'a> UpdateFirstQuery<'a> {
1099            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1100                let client = self.client;
1101                #update_first_code
1102            }
1103        }
1104
1105        pub struct DeleteQuery<'a> {
1106            client: &'a DatabaseClient,
1107            r#where: filter::#_where_unique,
1108        }
1109
1110        impl<'a> DeleteQuery<'a> {
1111            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1112                match self.client {
1113                    DatabaseClient::Postgres(_) => {
1114                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1115                        self.client.fetch_one_pg(qb).await
1116                    }
1117                    DatabaseClient::Sqlite(_) => {
1118                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1119                        self.client.fetch_one_sqlite(qb).await
1120                    }
1121                }
1122            }
1123        }
1124
1125        #[derive(sqlx::FromRow)]
1126        struct CountResult { count: i64 }
1127
1128        pub struct CountQuery<'a> {
1129            client: &'a DatabaseClient,
1130            r#where: filter::#_where_input,
1131        }
1132
1133        impl<'a> CountQuery<'a> {
1134            pub async fn exec(self) -> Result<i64, FerriormError> {
1135                let row: CountResult = match self.client {
1136                    DatabaseClient::Postgres(_) => {
1137                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1138                        self.client.fetch_one_pg(qb).await?
1139                    }
1140                    DatabaseClient::Sqlite(_) => {
1141                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1142                        self.client.fetch_one_sqlite(qb).await?
1143                    }
1144                };
1145                Ok(row.count)
1146            }
1147        }
1148
1149        pub struct CreateManyQuery<'a> {
1150            client: &'a DatabaseClient,
1151            data: Vec<data::#_create_input>,
1152        }
1153
1154        impl<'a> CreateManyQuery<'a> {
1155            pub async fn exec(self) -> Result<u64, FerriormError> {
1156                if self.data.is_empty() { return Ok(0); }
1157                let count = self.data.len() as u64;
1158                for item in self.data {
1159                    CreateQuery { client: self.client, data: item }.exec().await?;
1160                }
1161                Ok(count)
1162            }
1163        }
1164
1165        pub struct UpdateManyQuery<'a> {
1166            client: &'a DatabaseClient,
1167            r#where: filter::#_where_input,
1168            data: data::#_update_input,
1169        }
1170
1171        impl<'a> UpdateManyQuery<'a> {
1172            pub async fn exec(self) -> Result<u64, FerriormError> {
1173                let client = self.client;
1174                #update_many_code
1175            }
1176        }
1177
1178        pub struct UpsertQuery<'a> {
1179            client: &'a DatabaseClient,
1180            r#where: filter::#_where_unique,
1181            create: data::#_create_input,
1182            update: data::#_update_input,
1183        }
1184
1185        impl<'a> UpsertQuery<'a> {
1186            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1187                let client = self.client;
1188                #upsert_code
1189            }
1190        }
1191
1192        pub struct DeleteManyQuery<'a> {
1193            client: &'a DatabaseClient,
1194            r#where: filter::#_where_input,
1195        }
1196
1197        impl<'a> DeleteManyQuery<'a> {
1198            pub async fn exec(self) -> Result<u64, FerriormError> {
1199                match self.client {
1200                    DatabaseClient::Postgres(_) => {
1201                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1202                        self.client.execute_pg(qb).await
1203                    }
1204                    DatabaseClient::Sqlite(_) => {
1205                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1206                        self.client.execute_sqlite(qb).await
1207                    }
1208                }
1209            }
1210        }
1211    }
1212}
1213
1214// ─── INSERT code generation ───────────────────────────────────
1215
1216fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1217    let _model_ident = format_ident!("{}", model.name);
1218
1219    // Required columns: scalar, no default, not @updatedAt
1220    let required: Vec<&Field> = scalar_fields
1221        .iter()
1222        .copied()
1223        .filter(|f| !f.has_default() && !f.is_updated_at)
1224        .collect();
1225
1226    // Optional columns: have default (can be overridden), not @updatedAt
1227    let optional: Vec<&Field> = scalar_fields
1228        .iter()
1229        .copied()
1230        .filter(|f| f.has_default() && !f.is_updated_at)
1231        .collect();
1232
1233    // @updatedAt columns: always set to now()
1234    let updated_at: Vec<&Field> = scalar_fields
1235        .iter()
1236        .copied()
1237        .filter(|f| f.is_updated_at)
1238        .collect();
1239
1240    // Build column names and bind values
1241    let mut col_pushes = vec![];
1242    let mut val_pushes = vec![];
1243
1244    // Required fields — always included
1245    for f in &required {
1246        let db_name = &f.db_name;
1247        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1248        col_pushes.push(quote! { cols.push(#db_name); });
1249        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1250    }
1251
1252    // Optional fields — resolve defaults in Rust
1253    for f in &optional {
1254        let db_name = &f.db_name;
1255        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1256        if is_autoincrement(f) {
1257            // Autoincrement: if caller passed None, omit the column entirely
1258            // so the DB assigns the next sequence value. Binding a literal 0
1259            // would collide on the second insert.
1260            col_pushes.push(quote! {
1261                if self.data.#field_ident.is_some() { cols.push(#db_name); }
1262            });
1263            val_pushes.push(quote! {
1264                if let Some(val) = self.data.#field_ident {
1265                    sep.push_bind(val);
1266                }
1267            });
1268        } else {
1269            let default_expr = gen_default_expr(f, &f.field_type);
1270            col_pushes.push(quote! { cols.push(#db_name); });
1271            val_pushes.push(quote! {
1272                let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1273                sep.push_bind(val);
1274            });
1275        }
1276    }
1277
1278    // @updatedAt fields
1279    for f in &updated_at {
1280        let db_name = &f.db_name;
1281        col_pushes.push(quote! { cols.push(#db_name); });
1282        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1283    }
1284
1285    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1286
1287    // The insert_body macro avoids duplicating the column/value building logic
1288    // for each database backend. It captures `self` by reference.
1289    quote! {
1290        // Helper to build the INSERT query for any DB backend
1291        macro_rules! build_insert {
1292            ($qb_type:ty) => {{
1293                let mut cols: Vec<&str> = Vec::new();
1294                #(#col_pushes)*
1295
1296                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1297                qb.push(" (");
1298                for (i, col) in cols.iter().enumerate() {
1299                    if i > 0 { qb.push(", "); }
1300                    qb.push("\"");
1301                    qb.push(*col);
1302                    qb.push("\"");
1303                }
1304                qb.push(") VALUES (");
1305                {
1306                    let mut sep = qb.separated(", ");
1307                    #(#val_pushes)*
1308                }
1309                qb.push(") RETURNING *");
1310                qb
1311            }};
1312        }
1313
1314        match client {
1315            DatabaseClient::Postgres(_) => {
1316                let qb = build_insert!(sqlx::Postgres);
1317                client.fetch_one_pg(qb).await
1318            }
1319            DatabaseClient::Sqlite(_) => {
1320                let qb = build_insert!(sqlx::Sqlite);
1321                client.fetch_one_sqlite(qb).await
1322            }
1323        }
1324    }
1325}
1326
1327fn gen_insert_ignore_code(
1328    _model: &Model,
1329    scalar_fields: &[&Field],
1330    table_name: &str,
1331) -> TokenStream {
1332    let required: Vec<&Field> = scalar_fields
1333        .iter()
1334        .copied()
1335        .filter(|f| !f.has_default() && !f.is_updated_at)
1336        .collect();
1337    let optional: Vec<&Field> = scalar_fields
1338        .iter()
1339        .copied()
1340        .filter(|f| f.has_default() && !f.is_updated_at)
1341        .collect();
1342    let updated_at: Vec<&Field> = scalar_fields
1343        .iter()
1344        .copied()
1345        .filter(|f| f.is_updated_at)
1346        .collect();
1347
1348    let mut col_pushes = vec![];
1349    let mut val_pushes = vec![];
1350
1351    for f in &required {
1352        let db_name = &f.db_name;
1353        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1354        col_pushes.push(quote! { cols.push(#db_name); });
1355        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1356    }
1357    for f in &optional {
1358        let db_name = &f.db_name;
1359        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1360        if is_autoincrement(f) {
1361            col_pushes.push(quote! {
1362                if self.data.#field_ident.is_some() { cols.push(#db_name); }
1363            });
1364            val_pushes.push(quote! {
1365                if let Some(val) = self.data.#field_ident {
1366                    sep.push_bind(val);
1367                }
1368            });
1369        } else {
1370            let default_expr = gen_default_expr(f, &f.field_type);
1371            col_pushes.push(quote! { cols.push(#db_name); });
1372            val_pushes.push(quote! {
1373                let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1374                sep.push_bind(val);
1375            });
1376        }
1377    }
1378    for f in &updated_at {
1379        let db_name = &f.db_name;
1380        col_pushes.push(quote! { cols.push(#db_name); });
1381        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1382    }
1383
1384    let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1385    let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1386
1387    quote! {
1388        macro_rules! build_insert_ignore {
1389            ($qb_type:ty, $head:expr, $tail:expr) => {{
1390                let mut cols: Vec<&str> = Vec::new();
1391                #(#col_pushes)*
1392
1393                let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1394                qb.push(" (");
1395                for (i, col) in cols.iter().enumerate() {
1396                    if i > 0 { qb.push(", "); }
1397                    qb.push("\"");
1398                    qb.push(*col);
1399                    qb.push("\"");
1400                }
1401                qb.push(") VALUES (");
1402                {
1403                    let mut sep = qb.separated(", ");
1404                    #(#val_pushes)*
1405                }
1406                qb.push(")");
1407                qb.push($tail);
1408                qb.push(" RETURNING *");
1409                qb
1410            }};
1411        }
1412
1413        match client {
1414            DatabaseClient::Postgres(_) => {
1415                let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1416                client.fetch_optional_pg(qb).await
1417            }
1418            DatabaseClient::Sqlite(_) => {
1419                let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1420                client.fetch_optional_sqlite(qb).await
1421            }
1422        }
1423    }
1424}
1425
1426/// True if the field is declared with `@default(autoincrement())`.
1427/// Such columns must be omitted from the INSERT when the caller passes `None`,
1428/// otherwise we'd bind a literal 0 and collide on the second insert.
1429fn is_autoincrement(field: &Field) -> bool {
1430    matches!(
1431        field.default,
1432        Some(ferriorm_core::ast::DefaultValue::AutoIncrement)
1433    )
1434}
1435
1436/// Generate a Rust expression for a field's @default value.
1437fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1438    use ferriorm_core::ast::DefaultValue;
1439
1440    match &field.default {
1441        Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1442            quote! { uuid::Uuid::new_v4().to_string() }
1443        }
1444        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1445        // Autoincrement is handled specially by the INSERT/UPSERT generators
1446        // (see `is_autoincrement`): the column is omitted when the caller passes
1447        // `None`, so this arm is unreachable in practice. It stays as a safe
1448        // fallback only for match exhaustiveness.
1449        Some(DefaultValue::AutoIncrement) => quote! { 0i32 },
1450        Some(DefaultValue::Literal(lit)) => {
1451            use ferriorm_core::ast::LiteralValue;
1452            match lit {
1453                LiteralValue::String(s) => quote! { #s.to_string() },
1454                LiteralValue::Int(i) => {
1455                    // Cast the integer literal to the correct Rust type based on the field's scalar type.
1456                    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1457                    match field_type {
1458                        FieldKind::Scalar(ScalarType::Float) => {
1459                            let val = *i as f64;
1460                            quote! { #val }
1461                        }
1462                        FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1463                        // `@db.BigInt` on an `Int` widens the literal to i64 too.
1464                        FieldKind::Scalar(ScalarType::Int)
1465                            if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1466                        {
1467                            quote! { #i }
1468                        }
1469                        _ => {
1470                            // Default to i32 for Int and other types
1471                            let val = *i as i32;
1472                            quote! { #val }
1473                        }
1474                    }
1475                }
1476                LiteralValue::Float(f) => quote! { #f },
1477                LiteralValue::Bool(b) => quote! { #b },
1478            }
1479        }
1480        Some(DefaultValue::EnumVariant(v)) => {
1481            // Reference the enum variant — insert code runs at model module level
1482            let variant = format_ident!("{}", v);
1483            if let FieldKind::Enum(enum_name) = &field.field_type {
1484                let enum_ident = format_ident!("{}", enum_name);
1485                quote! { super::enums::#enum_ident::#variant }
1486            } else {
1487                quote! { Default::default() }
1488            }
1489        }
1490        None => quote! { Default::default() },
1491    }
1492}
1493
1494// ─── UPDATE code generation ───────────────────────────────────
1495
1496fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1497    let _model_ident = format_ident!("{}", model.name);
1498
1499    // Updatable fields: non-id, non-updatedAt scalar fields
1500    let updatable: Vec<&Field> = scalar_fields
1501        .iter()
1502        .copied()
1503        .filter(|f| !f.is_id && !f.is_updated_at)
1504        .collect();
1505
1506    let updated_at: Vec<&Field> = scalar_fields
1507        .iter()
1508        .copied()
1509        .filter(|f| f.is_updated_at)
1510        .collect();
1511
1512    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1513
1514    // Generate SET clause arms
1515    let set_arms: Vec<TokenStream> = updatable
1516        .iter()
1517        .map(|f| {
1518            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1519            let db_name = &f.db_name;
1520            quote! {
1521                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1522                    if !first_set { qb.push(", "); }
1523                    first_set = false;
1524                    qb.push(concat!("\"", #db_name, "\" = "));
1525                    qb.push_bind(v);
1526                }
1527            }
1528        })
1529        .collect();
1530
1531    let updated_at_arms: Vec<TokenStream> = updated_at
1532        .iter()
1533        .map(|f| {
1534            let db_name = &f.db_name;
1535            quote! {
1536                if !first_set { qb.push(", "); }
1537                first_set = false;
1538                qb.push(concat!("\"", #db_name, "\" = "));
1539                qb.push_bind(chrono::Utc::now());
1540            }
1541        })
1542        .collect();
1543
1544    // The build_update macro avoids duplicating the SET clause building logic
1545    // for each database backend.
1546    quote! {
1547        macro_rules! build_update {
1548            ($qb_type:ty) => {{
1549                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1550                let mut first_set = true;
1551                #(#set_arms)*
1552                #(#updated_at_arms)*
1553
1554                if first_set {
1555                    return Err(FerriormError::Query("No fields to update".into()));
1556                }
1557
1558                qb.push(" WHERE 1=1");
1559                self.r#where.build_where(&mut qb);
1560                qb.push(" RETURNING *");
1561                qb
1562            }};
1563        }
1564
1565        match client {
1566            DatabaseClient::Postgres(_) => {
1567                let qb = build_update!(sqlx::Postgres);
1568                client.fetch_one_pg(qb).await
1569            }
1570            DatabaseClient::Sqlite(_) => {
1571                let qb = build_update!(sqlx::Sqlite);
1572                client.fetch_one_sqlite(qb).await
1573            }
1574        }
1575    }
1576}
1577
1578// ─── UPDATE FIRST (CAS) code generation ──────────────────────
1579
1580fn gen_update_first_code(
1581    _model: &Model,
1582    scalar_fields: &[&Field],
1583    table_name: &str,
1584) -> TokenStream {
1585    let updatable: Vec<&Field> = scalar_fields
1586        .iter()
1587        .copied()
1588        .filter(|f| !f.is_id && !f.is_updated_at)
1589        .collect();
1590
1591    let updated_at: Vec<&Field> = scalar_fields
1592        .iter()
1593        .copied()
1594        .filter(|f| f.is_updated_at)
1595        .collect();
1596
1597    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1598
1599    let set_arms: Vec<TokenStream> = updatable
1600        .iter()
1601        .map(|f| {
1602            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1603            let db_name = &f.db_name;
1604            quote! {
1605                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1606                    if !first_set { qb.push(", "); }
1607                    first_set = false;
1608                    qb.push(concat!("\"", #db_name, "\" = "));
1609                    qb.push_bind(v);
1610                }
1611            }
1612        })
1613        .collect();
1614
1615    let updated_at_arms: Vec<TokenStream> = updated_at
1616        .iter()
1617        .map(|f| {
1618            let db_name = &f.db_name;
1619            quote! {
1620                if !first_set { qb.push(", "); }
1621                first_set = false;
1622                qb.push(concat!("\"", #db_name, "\" = "));
1623                qb.push_bind(chrono::Utc::now());
1624            }
1625        })
1626        .collect();
1627
1628    quote! {
1629        macro_rules! build_update_first {
1630            ($qb_type:ty) => {{
1631                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1632                let mut first_set = true;
1633                #(#set_arms)*
1634                #(#updated_at_arms)*
1635
1636                if first_set {
1637                    return Err(FerriormError::Query("No fields to update".into()));
1638                }
1639
1640                qb.push(" WHERE 1=1");
1641                self.r#where.build_where(&mut qb);
1642                qb.push(" RETURNING *");
1643                qb
1644            }};
1645        }
1646
1647        match client {
1648            DatabaseClient::Postgres(_) => {
1649                let qb = build_update_first!(sqlx::Postgres);
1650                client.fetch_optional_pg(qb).await
1651            }
1652            DatabaseClient::Sqlite(_) => {
1653                let qb = build_update_first!(sqlx::Sqlite);
1654                client.fetch_optional_sqlite(qb).await
1655            }
1656        }
1657    }
1658}
1659
1660// ─── UPDATE MANY code generation ──────────────────────────────
1661
1662fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1663    // Updatable fields: non-id, non-updatedAt scalar fields
1664    let updatable: Vec<&Field> = scalar_fields
1665        .iter()
1666        .copied()
1667        .filter(|f| !f.is_id && !f.is_updated_at)
1668        .collect();
1669
1670    let updated_at: Vec<&Field> = scalar_fields
1671        .iter()
1672        .copied()
1673        .filter(|f| f.is_updated_at)
1674        .collect();
1675
1676    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1677
1678    // Generate SET clause arms
1679    let set_arms: Vec<TokenStream> = updatable
1680        .iter()
1681        .map(|f| {
1682            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1683            let db_name = &f.db_name;
1684            quote! {
1685                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1686                    if !first_set { qb.push(", "); }
1687                    first_set = false;
1688                    qb.push(concat!("\"", #db_name, "\" = "));
1689                    qb.push_bind(v);
1690                }
1691            }
1692        })
1693        .collect();
1694
1695    let updated_at_arms: Vec<TokenStream> = updated_at
1696        .iter()
1697        .map(|f| {
1698            let db_name = &f.db_name;
1699            quote! {
1700                if !first_set { qb.push(", "); }
1701                first_set = false;
1702                qb.push(concat!("\"", #db_name, "\" = "));
1703                qb.push_bind(chrono::Utc::now());
1704            }
1705        })
1706        .collect();
1707
1708    quote! {
1709        macro_rules! build_update_many {
1710            ($qb_type:ty) => {{
1711                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1712                let mut first_set = true;
1713                #(#set_arms)*
1714                #(#updated_at_arms)*
1715
1716                if first_set {
1717                    return Ok(0);
1718                }
1719
1720                qb.push(" WHERE 1=1");
1721                self.r#where.build_where(&mut qb);
1722                qb
1723            }};
1724        }
1725
1726        match client {
1727            DatabaseClient::Postgres(_) => {
1728                let qb = build_update_many!(sqlx::Postgres);
1729                client.execute_pg(qb).await
1730            }
1731            DatabaseClient::Sqlite(_) => {
1732                let qb = build_update_many!(sqlx::Sqlite);
1733                client.execute_sqlite(qb).await
1734            }
1735        }
1736    }
1737}
1738
1739// ─── Aggregate Types ──────────────────────────────────────────
1740
1741/// Identifies which fields are aggregatable and what operations they support.
1742enum AggregateKind {
1743    /// Numeric fields: avg, sum, min, max
1744    Numeric,
1745    /// `DateTime` fields: min, max only
1746    DateTime,
1747}
1748
1749// ─── UPSERT code generation ──────────────────────────────────
1750
1751#[allow(clippy::too_many_lines)]
1752fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1753    // Required + optional + updatedAt fields for the INSERT part (same as gen_insert_code)
1754    let required: Vec<&Field> = scalar_fields
1755        .iter()
1756        .copied()
1757        .filter(|f| !f.has_default() && !f.is_updated_at)
1758        .collect();
1759    let optional: Vec<&Field> = scalar_fields
1760        .iter()
1761        .copied()
1762        .filter(|f| f.has_default() && !f.is_updated_at)
1763        .collect();
1764    let updated_at: Vec<&Field> = scalar_fields
1765        .iter()
1766        .copied()
1767        .filter(|f| f.is_updated_at)
1768        .collect();
1769
1770    let mut col_pushes = vec![];
1771    let mut val_pushes = vec![];
1772
1773    for f in &required {
1774        let db_name = &f.db_name;
1775        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1776        col_pushes.push(quote! { cols.push(#db_name); });
1777        val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1778    }
1779    for f in &optional {
1780        let db_name = &f.db_name;
1781        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1782        if is_autoincrement(f) {
1783            col_pushes.push(quote! {
1784                if self.create.#field_ident.is_some() { cols.push(#db_name); }
1785            });
1786            val_pushes.push(quote! {
1787                if let Some(val) = self.create.#field_ident {
1788                    sep.push_bind(val);
1789                }
1790            });
1791        } else {
1792            let default_expr = gen_default_expr(f, &f.field_type);
1793            col_pushes.push(quote! { cols.push(#db_name); });
1794            val_pushes.push(quote! {
1795                let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1796                sep.push_bind(val);
1797            });
1798        }
1799    }
1800    for f in &updated_at {
1801        let db_name = &f.db_name;
1802        col_pushes.push(quote! { cols.push(#db_name); });
1803        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1804    }
1805
1806    // Updatable fields for the DO UPDATE SET part
1807    let updatable: Vec<&Field> = scalar_fields
1808        .iter()
1809        .copied()
1810        .filter(|f| !f.is_id && !f.is_updated_at)
1811        .collect();
1812
1813    let set_arms: Vec<TokenStream> = updatable
1814        .iter()
1815        .map(|f| {
1816            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1817            let db_name = &f.db_name;
1818            quote! {
1819                if let Some(SetValue::Set(v)) = self.update.#field_ident {
1820                    if !first_set { qb.push(", "); }
1821                    first_set = false;
1822                    qb.push(concat!("\"", #db_name, "\" = "));
1823                    qb.push_bind(v);
1824                }
1825            }
1826        })
1827        .collect();
1828
1829    let updated_at_set: Vec<TokenStream> = updated_at
1830        .iter()
1831        .map(|f| {
1832            let db_name = &f.db_name;
1833            quote! {
1834                if !first_set { qb.push(", "); }
1835                first_set = false;
1836                qb.push(concat!("\"", #db_name, "\" = "));
1837                qb.push_bind(chrono::Utc::now());
1838            }
1839        })
1840        .collect();
1841
1842    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1843
1844    quote! {
1845        let conflict_target = self.r#where.conflict_target();
1846        let first_conflict_col = self.r#where.first_conflict_col();
1847
1848        macro_rules! build_upsert {
1849            ($qb_type:ty) => {{
1850                let mut cols: Vec<&str> = Vec::new();
1851                #(#col_pushes)*
1852
1853                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1854                qb.push(" (");
1855                for (i, col) in cols.iter().enumerate() {
1856                    if i > 0 { qb.push(", "); }
1857                    qb.push("\"");
1858                    qb.push(*col);
1859                    qb.push("\"");
1860                }
1861                qb.push(") VALUES (");
1862                {
1863                    let mut sep = qb.separated(", ");
1864                    #(#val_pushes)*
1865                }
1866                qb.push(")");
1867                qb.push(" ON CONFLICT ");
1868                qb.push(conflict_target);
1869                qb.push(" DO UPDATE SET ");
1870
1871                let mut first_set = true;
1872                #(#set_arms)*
1873                #(#updated_at_set)*
1874
1875                if first_set {
1876                    // No update fields specified — use a no-op update on the first
1877                    // conflict-target column so RETURNING * still yields the row.
1878                    qb.push(first_conflict_col);
1879                    qb.push(" = ");
1880                    qb.push(first_conflict_col);
1881                }
1882
1883                qb.push(" RETURNING *");
1884                qb
1885            }};
1886        }
1887
1888        match client {
1889            DatabaseClient::Postgres(_) => {
1890                let qb = build_upsert!(sqlx::Postgres);
1891                client.fetch_one_pg(qb).await
1892            }
1893            DatabaseClient::Sqlite(_) => {
1894                let qb = build_upsert!(sqlx::Sqlite);
1895                client.fetch_one_sqlite(qb).await
1896            }
1897        }
1898    }
1899}
1900
1901#[allow(clippy::too_many_lines)]
1902fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1903    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1904    let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1905    let _where_input = format_ident!("{}WhereInput", model.name);
1906    let table_name = &model.db_name;
1907
1908    // Collect aggregatable fields with their kind
1909    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1910        .iter()
1911        .filter_map(|f| match &f.field_type {
1912            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1913                Some((*f, AggregateKind::Numeric))
1914            }
1915            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1916            _ => None,
1917        })
1918        .collect();
1919
1920    if agg_fields.is_empty() {
1921        return quote! {};
1922    }
1923
1924    // Generate enum variants
1925    let enum_variants: Vec<TokenStream> = agg_fields
1926        .iter()
1927        .map(|(f, _)| {
1928            let variant = format_ident!("{}", to_pascal_case(&f.name));
1929            quote! { #variant }
1930        })
1931        .collect();
1932
1933    // Generate db_name match arms
1934    let db_name_arms: Vec<TokenStream> = agg_fields
1935        .iter()
1936        .map(|(f, _)| {
1937            let variant = format_ident!("{}", to_pascal_case(&f.name));
1938            let db_name = &f.db_name;
1939            quote! { Self::#variant => #db_name }
1940        })
1941        .collect();
1942
1943    // Generate AggregateResult fields
1944    let mut result_fields = Vec::new();
1945    for (f, kind) in &agg_fields {
1946        let snake = to_snake_case(&f.name);
1947        let orig_ty = rust_type_tokens(
1948            &Field {
1949                is_optional: false,
1950                ..(*f).clone()
1951            },
1952            ModuleDepth::TopLevel,
1953        );
1954
1955        match kind {
1956            AggregateKind::Numeric => {
1957                let avg_name = format_ident!("avg_{}", snake);
1958                let sum_name = format_ident!("sum_{}", snake);
1959                let min_name = format_ident!("min_{}", snake);
1960                let max_name = format_ident!("max_{}", snake);
1961                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1962                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1963                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1964                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1965            }
1966            AggregateKind::DateTime => {
1967                let min_name = format_ident!("min_{}", snake);
1968                let max_name = format_ident!("max_{}", snake);
1969                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1970                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1971            }
1972        }
1973    }
1974
1975    // Generate the is_numeric check for avg/sum validation
1976    let numeric_arms: Vec<TokenStream> = agg_fields
1977        .iter()
1978        .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1979        .map(|(f, _)| {
1980            let variant = format_ident!("{}", to_pascal_case(&f.name));
1981            quote! { Self::#variant => true }
1982        })
1983        .collect();
1984
1985    let has_numeric = !numeric_arms.is_empty();
1986    let is_numeric_method = if has_numeric {
1987        quote! {
1988            fn is_numeric(&self) -> bool {
1989                match self {
1990                    #(#numeric_arms,)*
1991                    #[allow(unreachable_patterns)]
1992                    _ => false,
1993                }
1994            }
1995        }
1996    } else {
1997        quote! {
1998            fn is_numeric(&self) -> bool { false }
1999        }
2000    };
2001
2002    // Generate alias match arms for each (prefix, field) combination
2003    let mut alias_arms = Vec::new();
2004    for (f, kind) in &agg_fields {
2005        let variant = format_ident!("{}", to_pascal_case(&f.name));
2006        let snake = to_snake_case(&f.name);
2007        let prefixes = match kind {
2008            AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
2009            AggregateKind::DateTime => vec!["min", "max"],
2010        };
2011        for prefix in prefixes {
2012            let alias_str = format!("{prefix}_{snake}");
2013            alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
2014        }
2015    }
2016
2017    let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2018
2019    quote! {
2020        #[derive(Debug, Clone, Copy)]
2021        pub enum #aggregate_field_name {
2022            #(#enum_variants),*
2023        }
2024
2025        impl #aggregate_field_name {
2026            pub fn db_name(&self) -> &'static str {
2027                match self {
2028                    #(#db_name_arms,)*
2029                }
2030            }
2031
2032            fn alias(&self, prefix: &'static str) -> &'static str {
2033                match (prefix, self) {
2034                    #(#alias_arms,)*
2035                    _ => unreachable!(),
2036                }
2037            }
2038
2039            #is_numeric_method
2040        }
2041
2042        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2043        pub struct #aggregate_result_name {
2044            #(#result_fields,)*
2045        }
2046
2047        pub struct AggregateQuery<'a> {
2048            client: &'a DatabaseClient,
2049            r#where: filter::#_where_input,
2050            ops: Vec<(&'static str, &'static str, &'static str)>,
2051        }
2052
2053        impl<'a> AggregateQuery<'a> {
2054            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2055                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2056                let db_name = field.db_name();
2057                let alias = field.alias("avg");
2058                self.ops.push(("AVG", db_name, alias));
2059                self
2060            }
2061
2062            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2063                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2064                let db_name = field.db_name();
2065                let alias = field.alias("sum");
2066                self.ops.push(("SUM", db_name, alias));
2067                self
2068            }
2069
2070            pub fn min(mut self, field: #aggregate_field_name) -> Self {
2071                let db_name = field.db_name();
2072                let alias = field.alias("min");
2073                self.ops.push(("MIN", db_name, alias));
2074                self
2075            }
2076
2077            pub fn max(mut self, field: #aggregate_field_name) -> Self {
2078                let db_name = field.db_name();
2079                let alias = field.alias("max");
2080                self.ops.push(("MAX", db_name, alias));
2081                self
2082            }
2083
2084            pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
2085                if self.ops.is_empty() {
2086                    return Err(FerriormError::Query("No aggregate operations specified".into()));
2087                }
2088
2089                let selections: Vec<String> = self.ops.iter()
2090                    .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
2091                    .collect();
2092                let select_clause = selections.join(", ");
2093                let base_sql = format!(#agg_select_base, select_clause);
2094
2095                match self.client {
2096                    DatabaseClient::Postgres(_) => {
2097                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2098                        self.r#where.build_where(&mut qb);
2099                        self.client.fetch_one_pg(qb).await
2100                    }
2101                    DatabaseClient::Sqlite(_) => {
2102                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2103                        self.r#where.build_where(&mut qb);
2104                        self.client.fetch_one_sqlite(qb).await
2105                    }
2106                }
2107            }
2108        }
2109    }
2110}
2111
2112// ─── GroupBy Types ────────────────────────────────────────────
2113
2114/// Generate the standard-comparable HAVING arms (`equals`/`not`/`gt`/
2115/// `gte`/`lt`/`lte`/`in`/`not_in`) for one aggregate field. `lhs` is the
2116/// SQL expression on the left-hand side of the operator (e.g. `AVG("age")`,
2117/// `MIN("created_at")`).
2118fn gen_having_comparable_arms(field_ident: &proc_macro2::Ident, lhs: &str) -> TokenStream {
2119    let eq = format!(" AND {lhs} = ");
2120    let ne = format!(" AND {lhs} != ");
2121    let gt = format!(" AND {lhs} > ");
2122    let gte = format!(" AND {lhs} >= ");
2123    let lt = format!(" AND {lhs} < ");
2124    let lte = format!(" AND {lhs} <= ");
2125    let in_arms = gen_in_arms_lhs(lhs);
2126    quote! {
2127        if let Some(filter) = &self.#field_ident {
2128            if let Some(v) = &filter.equals { qb.push(#eq); qb.push_bind(v.clone()); }
2129            if let Some(v) = &filter.not    { qb.push(#ne); qb.push_bind(v.clone()); }
2130            if let Some(v) = &filter.gt     { qb.push(#gt); qb.push_bind(v.clone()); }
2131            if let Some(v) = &filter.gte    { qb.push(#gte); qb.push_bind(v.clone()); }
2132            if let Some(v) = &filter.lt     { qb.push(#lt); qb.push_bind(v.clone()); }
2133            if let Some(v) = &filter.lte    { qb.push(#lte); qb.push_bind(v.clone()); }
2134            #in_arms
2135        }
2136    }
2137}
2138
2139/// True for fields that can appear in a `GROUP BY` clause: any scalar except
2140/// `Json`/`Bytes`/`Decimal` (which are not orderable / hashable in SQL), plus
2141/// enums. Optional fields are still groupable -- `NULL` becomes its own
2142/// bucket.
2143fn is_groupable(field: &Field) -> bool {
2144    match &field.field_type {
2145        FieldKind::Scalar(
2146            ScalarType::String
2147            | ScalarType::Int
2148            | ScalarType::BigInt
2149            | ScalarType::Float
2150            | ScalarType::Boolean
2151            | ScalarType::DateTime,
2152        )
2153        | FieldKind::Enum(_) => true,
2154        FieldKind::Scalar(ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal)
2155        | FieldKind::Model(_) => false,
2156    }
2157}
2158
2159#[allow(clippy::too_many_lines)]
2160fn gen_groupby_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2161    let groupby_field_name = format_ident!("{}GroupByField", model.name);
2162    let groupby_result_name = format_ident!("{}GroupByResult", model.name);
2163    let having_input_name = format_ident!("{}HavingInput", model.name);
2164    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
2165    let where_input = format_ident!("{}WhereInput", model.name);
2166    let table_name = &model.db_name;
2167
2168    // Reuse the same aggregate-field collection as gen_aggregate_types so the
2169    // result struct columns and HAVING surface stay consistent.
2170    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
2171        .iter()
2172        .filter_map(|f| match &f.field_type {
2173            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
2174                Some((*f, AggregateKind::Numeric))
2175            }
2176            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
2177            _ => None,
2178        })
2179        .collect();
2180
2181    let group_fields: Vec<&Field> = scalar_fields
2182        .iter()
2183        .filter(|f| is_groupable(f))
2184        .copied()
2185        .collect();
2186
2187    if group_fields.is_empty() {
2188        return quote! {};
2189    }
2190
2191    // ── <Model>GroupByField enum ──────────────────────────────
2192    let groupby_variants: Vec<TokenStream> = group_fields
2193        .iter()
2194        .map(|f| {
2195            let variant = format_ident!("{}", to_pascal_case(&f.name));
2196            quote! { #variant }
2197        })
2198        .collect();
2199
2200    let groupby_db_arms: Vec<TokenStream> = group_fields
2201        .iter()
2202        .map(|f| {
2203            let variant = format_ident!("{}", to_pascal_case(&f.name));
2204            let db_name = &f.db_name;
2205            quote! { Self::#variant => #db_name }
2206        })
2207        .collect();
2208
2209    let groupby_alias_arms: Vec<TokenStream> = group_fields
2210        .iter()
2211        .map(|f| {
2212            let variant = format_ident!("{}", to_pascal_case(&f.name));
2213            let alias = to_snake_case(&f.name);
2214            quote! { Self::#variant => #alias }
2215        })
2216        .collect();
2217
2218    // ── <Model>GroupByResult fields ───────────────────────────
2219    // One Option<T> per groupable field (only filled when that field is in
2220    // the active group key set), then count, then the same avg/sum/min/max
2221    // columns that gen_aggregate_types emits.
2222    let mut result_fields: Vec<TokenStream> = Vec::new();
2223    for f in &group_fields {
2224        let snake = to_snake_case(&f.name);
2225        let name = format_ident!("{}", snake);
2226        // Always wrap in Option so the same struct serves every group_by call.
2227        let base_ty = rust_type_tokens(
2228            &Field {
2229                is_optional: false,
2230                ..(*f).clone()
2231            },
2232            ModuleDepth::TopLevel,
2233        );
2234        result_fields.push(quote! { #[sqlx(default)] pub #name: Option<#base_ty> });
2235    }
2236    result_fields.push(quote! { #[sqlx(default)] pub count: Option<i64> });
2237    for (f, kind) in &agg_fields {
2238        let snake = to_snake_case(&f.name);
2239        let orig_ty = rust_type_tokens(
2240            &Field {
2241                is_optional: false,
2242                ..(*f).clone()
2243            },
2244            ModuleDepth::TopLevel,
2245        );
2246        match kind {
2247            AggregateKind::Numeric => {
2248                let avg_name = format_ident!("avg_{}", snake);
2249                let sum_name = format_ident!("sum_{}", snake);
2250                let min_name = format_ident!("min_{}", snake);
2251                let max_name = format_ident!("max_{}", snake);
2252                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
2253                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
2254                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2255                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2256            }
2257            AggregateKind::DateTime => {
2258                let min_name = format_ident!("min_{}", snake);
2259                let max_name = format_ident!("max_{}", snake);
2260                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2261                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2262            }
2263        }
2264    }
2265
2266    // ── <Model>HavingInput fields ─────────────────────────────
2267    // Filtering on aggregate expressions: COUNT(*), AVG/SUM/MIN/MAX of each
2268    // aggregatable column. RHS reuses the same scalar filter types as
2269    // WhereInput.
2270    let mut having_fields: Vec<TokenStream> = Vec::new();
2271    // COUNT(*) returns BIGINT in both Postgres and SQLite -> BigIntFilter.
2272    having_fields.push(quote! { pub count: Option<ferriorm_runtime::filter::BigIntFilter> });
2273    for (f, kind) in &agg_fields {
2274        let snake = to_snake_case(&f.name);
2275        let avg_name = format_ident!("avg_{}", snake);
2276        let sum_name = format_ident!("sum_{}", snake);
2277        let min_name = format_ident!("min_{}", snake);
2278        let max_name = format_ident!("max_{}", snake);
2279        let column_filter = filter_type_tokens(
2280            &Field {
2281                is_optional: false,
2282                ..(*f).clone()
2283            },
2284            ModuleDepth::TopLevel,
2285        )
2286        .unwrap_or_else(|| quote! { ferriorm_runtime::filter::BigIntFilter });
2287        match kind {
2288            AggregateKind::Numeric => {
2289                having_fields
2290                    .push(quote! { pub #avg_name: Option<ferriorm_runtime::filter::FloatFilter> });
2291                having_fields
2292                    .push(quote! { pub #sum_name: Option<ferriorm_runtime::filter::FloatFilter> });
2293                having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2294                having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2295            }
2296            AggregateKind::DateTime => {
2297                having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2298                having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2299            }
2300        }
2301    }
2302
2303    // ── build_having arms ─────────────────────────────────────
2304    // Mirrors gen_where_arms but the LHS is the aggregate expression
2305    // (`AVG("col")`, `COUNT(*)`, ...) instead of a bare column reference.
2306    // Aggregate results are never NULL semantically except for empty inputs,
2307    // so we don't need IS NULL handling here.
2308    let mut having_arms: Vec<TokenStream> = Vec::new();
2309    // count: BigIntFilter on COUNT(*) -- mirrors the comparable HAVING arms,
2310    // including `r#in`/`not_in` for `WHERE COUNT(*) IN (...)` semantics.
2311    let count_in_arms = gen_in_arms_lhs("COUNT(*)");
2312    having_arms.push(quote! {
2313        if let Some(filter) = &self.count {
2314            if let Some(v) = &filter.equals { qb.push(" AND COUNT(*) = "); qb.push_bind(*v); }
2315            if let Some(v) = &filter.not    { qb.push(" AND COUNT(*) != "); qb.push_bind(*v); }
2316            if let Some(v) = &filter.gt     { qb.push(" AND COUNT(*) > "); qb.push_bind(*v); }
2317            if let Some(v) = &filter.gte    { qb.push(" AND COUNT(*) >= "); qb.push_bind(*v); }
2318            if let Some(v) = &filter.lt     { qb.push(" AND COUNT(*) < "); qb.push_bind(*v); }
2319            if let Some(v) = &filter.lte    { qb.push(" AND COUNT(*) <= "); qb.push_bind(*v); }
2320            #count_in_arms
2321        }
2322    });
2323
2324    for (f, kind) in &agg_fields {
2325        let snake = to_snake_case(&f.name);
2326        let db_name = &f.db_name;
2327        let avg_ident = format_ident!("avg_{}", snake);
2328        let sum_ident = format_ident!("sum_{}", snake);
2329        let min_ident = format_ident!("min_{}", snake);
2330        let max_ident = format_ident!("max_{}", snake);
2331        match kind {
2332            AggregateKind::Numeric => {
2333                let avg_lhs = format!(r#"AVG("{db_name}")"#);
2334                let sum_lhs = format!(r#"SUM("{db_name}")"#);
2335                let min_lhs = format!(r#"MIN("{db_name}")"#);
2336                let max_lhs = format!(r#"MAX("{db_name}")"#);
2337                having_arms.push(gen_having_comparable_arms(&avg_ident, &avg_lhs));
2338                having_arms.push(gen_having_comparable_arms(&sum_ident, &sum_lhs));
2339                having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2340                having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2341            }
2342            AggregateKind::DateTime => {
2343                let min_lhs = format!(r#"MIN("{db_name}")"#);
2344                let max_lhs = format!(r#"MAX("{db_name}")"#);
2345                having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2346                having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2347            }
2348        }
2349    }
2350
2351    // build_having binds the RHS of `AVG(col) op ?` (always f64) and
2352    // `COUNT(*) op ?` (always i64) regardless of which scalar types appear
2353    // in the model. Reuse collect_db_bounds for the column-type bounds
2354    // (needed by min/max filters), then top up with f64.
2355    let mut db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::TopLevel);
2356    if !scalar_fields
2357        .iter()
2358        .any(|f| matches!(&f.field_type, FieldKind::Scalar(ScalarType::Float)))
2359    {
2360        db_bounds.push(quote! { f64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
2361    }
2362
2363    // ── is_numeric reuse: AggregateField is the canonical enum, but if no
2364    //    aggregatable fields exist we still want a typed group_by query --
2365    //    just without the agg ops. In that case AggregateField may not have
2366    //    been emitted at all, so skip avg/sum/min/max methods.
2367    let has_agg_fields = !agg_fields.is_empty();
2368
2369    let agg_methods = if has_agg_fields {
2370        quote! {
2371            pub fn count(mut self) -> Self {
2372                self.count = true;
2373                self
2374            }
2375
2376            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2377                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2378                let db_name = field.db_name();
2379                let alias = field.alias("avg");
2380                self.agg_ops.push(("AVG", db_name, alias));
2381                self
2382            }
2383
2384            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2385                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2386                let db_name = field.db_name();
2387                let alias = field.alias("sum");
2388                self.agg_ops.push(("SUM", db_name, alias));
2389                self
2390            }
2391
2392            pub fn min(mut self, field: #aggregate_field_name) -> Self {
2393                let db_name = field.db_name();
2394                let alias = field.alias("min");
2395                self.agg_ops.push(("MIN", db_name, alias));
2396                self
2397            }
2398
2399            pub fn max(mut self, field: #aggregate_field_name) -> Self {
2400                let db_name = field.db_name();
2401                let alias = field.alias("max");
2402                self.agg_ops.push(("MAX", db_name, alias));
2403                self
2404            }
2405        }
2406    } else {
2407        quote! {
2408            pub fn count(mut self) -> Self {
2409                self.count = true;
2410                self
2411            }
2412        }
2413    };
2414
2415    quote! {
2416        #[derive(Debug, Clone, Copy)]
2417        pub enum #groupby_field_name {
2418            #(#groupby_variants),*
2419        }
2420
2421        impl #groupby_field_name {
2422            pub fn db_name(&self) -> &'static str {
2423                match self {
2424                    #(#groupby_db_arms,)*
2425                }
2426            }
2427
2428            fn alias(&self) -> &'static str {
2429                match self {
2430                    #(#groupby_alias_arms,)*
2431                }
2432            }
2433        }
2434
2435        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2436        pub struct #groupby_result_name {
2437            #(#result_fields,)*
2438        }
2439
2440        #[derive(Debug, Clone, Default)]
2441        pub struct #having_input_name {
2442            #(#having_fields,)*
2443            pub and: Option<Vec<#having_input_name>>,
2444            pub or: Option<Vec<#having_input_name>>,
2445            pub not: Option<Box<#having_input_name>>,
2446        }
2447
2448        impl #having_input_name {
2449            pub(crate) fn build_having<'args, DB: sqlx::Database>(
2450                &self,
2451                qb: &mut sqlx::QueryBuilder<'args, DB>,
2452            )
2453            where
2454                #(#db_bounds,)*
2455            {
2456                #(#having_arms)*
2457
2458                if let Some(conditions) = &self.and {
2459                    for c in conditions {
2460                        c.build_having(qb);
2461                    }
2462                }
2463                if let Some(conditions) = &self.or {
2464                    if !conditions.is_empty() {
2465                        qb.push(" AND (");
2466                        for (i, c) in conditions.iter().enumerate() {
2467                            if i > 0 { qb.push(" OR "); }
2468                            qb.push("(1=1");
2469                            c.build_having(qb);
2470                            qb.push(")");
2471                        }
2472                        qb.push(")");
2473                    }
2474                }
2475                if let Some(c) = &self.not {
2476                    qb.push(" AND NOT (1=1");
2477                    c.build_having(qb);
2478                    qb.push(")");
2479                }
2480            }
2481        }
2482
2483        pub struct GroupByQuery<'a> {
2484            client: &'a DatabaseClient,
2485            r#where: filter::#where_input,
2486            group_keys: Vec<#groupby_field_name>,
2487            agg_ops: Vec<(&'static str, &'static str, &'static str)>,
2488            count: bool,
2489            having: Option<#having_input_name>,
2490        }
2491
2492        impl<'a> GroupByQuery<'a> {
2493            pub fn r#where(mut self, r#where: filter::#where_input) -> Self {
2494                self.r#where = r#where;
2495                self
2496            }
2497
2498            #agg_methods
2499
2500            pub fn having(mut self, having: #having_input_name) -> Self {
2501                self.having = Some(having);
2502                self
2503            }
2504
2505            pub async fn exec(self) -> Result<Vec<#groupby_result_name>, FerriormError> {
2506                if self.group_keys.is_empty() {
2507                    return Err(FerriormError::Query(
2508                        "group_by() requires at least one group key".into(),
2509                    ));
2510                }
2511
2512                let mut selections: Vec<String> = self.group_keys
2513                    .iter()
2514                    .map(|k| format!(r#""{}" as "{}""#, k.db_name(), k.alias()))
2515                    .collect();
2516                if self.count {
2517                    selections.push(r#"COUNT(*) as "count""#.to_string());
2518                }
2519                for (func, col, alias) in &self.agg_ops {
2520                    selections.push(format!(r#"{}("{}") as "{}""#, func, col, alias));
2521                }
2522
2523                let group_by_clause: Vec<String> = self.group_keys
2524                    .iter()
2525                    .map(|k| format!(r#""{}""#, k.db_name()))
2526                    .collect();
2527
2528                let base_sql = format!(
2529                    r#"SELECT {} FROM "{}" WHERE 1=1"#,
2530                    selections.join(", "),
2531                    #table_name,
2532                );
2533
2534                match self.client {
2535                    DatabaseClient::Postgres(_) => {
2536                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2537                        self.r#where.build_where(&mut qb);
2538                        qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2539                        if let Some(h) = &self.having {
2540                            qb.push(" HAVING 1=1");
2541                            h.build_having(&mut qb);
2542                        }
2543                        self.client.fetch_all_pg(qb).await
2544                    }
2545                    DatabaseClient::Sqlite(_) => {
2546                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2547                        self.r#where.build_where(&mut qb);
2548                        qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2549                        if let Some(h) = &self.having {
2550                            qb.push(" HAVING 1=1");
2551                            h.build_having(&mut qb);
2552                        }
2553                        self.client.fetch_all_sqlite(qb).await
2554                    }
2555                }
2556            }
2557        }
2558    }
2559}
2560
2561// ─── Select Types ─────────────────────────────────────────────
2562
2563#[allow(clippy::too_many_lines)]
2564fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2565    let select_name = format_ident!("{}Select", model.name);
2566    let partial_name = format_ident!("{}Partial", model.name);
2567    let _where_input = format_ident!("{}WhereInput", model.name);
2568    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
2569    let order_by_name = format_ident!("{}OrderByInput", model.name);
2570    let table_name = &model.db_name;
2571
2572    // Select struct fields: all bool, default false
2573    let select_fields: Vec<TokenStream> = scalar_fields
2574        .iter()
2575        .map(|f| {
2576            let name = format_ident!("{}", to_snake_case(&f.name));
2577            quote! { pub #name: bool }
2578        })
2579        .collect();
2580
2581    // Partial struct fields: all Option<T> with #[sqlx(default)]
2582    // For already-optional fields, don't double-wrap in Option
2583    let partial_fields: Vec<TokenStream> = scalar_fields
2584        .iter()
2585        .map(|f| {
2586            let name = format_ident!("{}", to_snake_case(&f.name));
2587            let db_name = &f.db_name;
2588            // Get the base type (non-optional version)
2589            let base_ty = rust_type_tokens(
2590                &Field {
2591                    is_optional: false,
2592                    ..(*f).clone()
2593                },
2594                ModuleDepth::TopLevel,
2595            );
2596            let rename = if db_name == &to_snake_case(&f.name) {
2597                quote! {}
2598            } else {
2599                quote! { #[sqlx(rename = #db_name)] }
2600            };
2601            // Always wrap in Option<T>, regardless of whether field was originally optional
2602            quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
2603        })
2604        .collect();
2605
2606    // build_select_columns: maps Select bools to column names
2607    let select_col_arms: Vec<TokenStream> = scalar_fields
2608        .iter()
2609        .map(|f| {
2610            let name = format_ident!("{}", to_snake_case(&f.name));
2611            let db_name = &f.db_name;
2612            let col_expr = format!(r#""{db_name}""#);
2613            quote! {
2614                if select.#name { cols.push(#col_expr); }
2615            }
2616        })
2617        .collect();
2618
2619    let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2620
2621    quote! {
2622        #[derive(Debug, Clone, Default)]
2623        pub struct #select_name {
2624            #(#select_fields,)*
2625        }
2626
2627        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2628        #[sqlx(rename_all = "snake_case")]
2629        pub struct #partial_name {
2630            #(#partial_fields,)*
2631        }
2632
2633        fn build_select_columns(select: &#select_name) -> String {
2634            let mut cols = Vec::new();
2635            #(#select_col_arms)*
2636            if cols.is_empty() {
2637                "*".to_string()
2638            } else {
2639                cols.join(", ")
2640            }
2641        }
2642
2643        // ── FindManySelectQuery ──────────────────────────────────
2644
2645        pub struct FindManySelectQuery<'a> {
2646            client: &'a DatabaseClient,
2647            r#where: filter::#_where_input,
2648            order_by: Vec<order::#order_by_name>,
2649            skip: Option<i64>,
2650            take: Option<i64>,
2651            select: #select_name,
2652        }
2653
2654        impl<'a> FindManySelectQuery<'a> {
2655            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2656                self.order_by.push(order);
2657                self
2658            }
2659
2660            pub fn skip(mut self, n: i64) -> Self {
2661                self.skip = Some(n);
2662                self
2663            }
2664
2665            pub fn take(mut self, n: i64) -> Self {
2666                self.take = Some(n);
2667                self
2668            }
2669
2670            pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2671                let cols = build_select_columns(&self.select);
2672                let base_sql = format!(#select_sql_prefix, cols);
2673
2674                match self.client {
2675                    DatabaseClient::Postgres(_) => {
2676                        let qb = build_select_query::<sqlx::Postgres>(
2677                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2678                        );
2679                        self.client.fetch_all_pg(qb).await
2680                    }
2681                    DatabaseClient::Sqlite(_) => {
2682                        let qb = build_select_query::<sqlx::Sqlite>(
2683                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2684                        );
2685                        self.client.fetch_all_sqlite(qb).await
2686                    }
2687                }
2688            }
2689        }
2690
2691        // ── FindUniqueSelectQuery ────────────────────────────────
2692
2693        pub struct FindUniqueSelectQuery<'a> {
2694            client: &'a DatabaseClient,
2695            r#where: filter::#_where_unique,
2696            select: #select_name,
2697        }
2698
2699        impl<'a> FindUniqueSelectQuery<'a> {
2700            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2701                let cols = build_select_columns(&self.select);
2702                let base_sql = format!(#select_sql_prefix, cols);
2703
2704                match self.client {
2705                    DatabaseClient::Postgres(_) => {
2706                        let qb = build_unique_select_query::<sqlx::Postgres>(
2707                            &base_sql, &self.r#where,
2708                        );
2709                        self.client.fetch_optional_pg(qb).await
2710                    }
2711                    DatabaseClient::Sqlite(_) => {
2712                        let qb = build_unique_select_query::<sqlx::Sqlite>(
2713                            &base_sql, &self.r#where,
2714                        );
2715                        self.client.fetch_optional_sqlite(qb).await
2716                    }
2717                }
2718            }
2719        }
2720
2721        // ── FindFirstSelectQuery ─────────────────────────────────
2722
2723        pub struct FindFirstSelectQuery<'a> {
2724            client: &'a DatabaseClient,
2725            r#where: filter::#_where_input,
2726            order_by: Vec<order::#order_by_name>,
2727            select: #select_name,
2728        }
2729
2730        impl<'a> FindFirstSelectQuery<'a> {
2731            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2732                self.order_by.push(order);
2733                self
2734            }
2735
2736            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2737                let cols = build_select_columns(&self.select);
2738                let base_sql = format!(#select_sql_prefix, cols);
2739
2740                match self.client {
2741                    DatabaseClient::Postgres(_) => {
2742                        let qb = build_select_query::<sqlx::Postgres>(
2743                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
2744                        );
2745                        self.client.fetch_optional_pg(qb).await
2746                    }
2747                    DatabaseClient::Sqlite(_) => {
2748                        let qb = build_select_query::<sqlx::Sqlite>(
2749                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
2750                        );
2751                        self.client.fetch_optional_sqlite(qb).await
2752                    }
2753                }
2754            }
2755        }
2756    }
2757}