Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27    let data_struct = gen_data_struct(model, &scalar_fields);
28    let filter_module = gen_filter_module(model, &scalar_fields);
29    let data_module = gen_data_module(model, &scalar_fields);
30    let order_module = gen_order_module(model, &scalar_fields);
31    let actions_struct = gen_actions(model, &scalar_fields);
32    let query_builders = gen_query_builders(model, &scalar_fields);
33    let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34    let groupby_types = gen_groupby_types(model, &scalar_fields);
35    let select_types = gen_select_types(model, &scalar_fields);
36
37    quote! {
38        #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
39
40        use serde::{Deserialize, Serialize};
41        use ferriorm_runtime::prelude::*;
42        use ferriorm_runtime::prelude::sqlx;
43        use ferriorm_runtime::prelude::chrono;
44        use ferriorm_runtime::prelude::uuid;
45
46        #data_struct
47        #filter_module
48        #data_module
49        #order_module
50        #actions_struct
51        #query_builders
52        #aggregate_types
53        #groupby_types
54        #select_types
55    }
56}
57
58// ─── Data Struct ──────────────────────────────────────────────
59
60fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
61    let struct_name = format_ident!("{}", model.name);
62    let table_name = &model.db_name;
63
64    let fields: Vec<TokenStream> = scalar_fields
65        .iter()
66        .map(|f| {
67            let name = format_ident!("{}", to_snake_case(&f.name));
68            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
69            let db_name = &f.db_name;
70            if db_name == &to_snake_case(&f.name) {
71                quote! { pub #name: #ty }
72            } else {
73                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
74            }
75        })
76        .collect();
77
78    quote! {
79        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
80        #[sqlx(rename_all = "snake_case")]
81        pub struct #struct_name {
82            #(#fields),*
83        }
84
85        impl #struct_name {
86            pub const TABLE_NAME: &'static str = #table_name;
87        }
88    }
89}
90
91// ─── Filter Module ────────────────────────────────────────────
92
93#[allow(clippy::too_many_lines)]
94fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
95    let where_input = format_ident!("{}WhereInput", model.name);
96    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
97
98    let where_fields: Vec<TokenStream> = scalar_fields
99        .iter()
100        .filter_map(|f| {
101            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
102            let name = format_ident!("{}", to_snake_case(&f.name));
103            Some(quote! { pub #name: Option<#filter_ty> })
104        })
105        .collect();
106
107    let single_unique_variants: Vec<TokenStream> = scalar_fields
108        .iter()
109        .filter(|f| f.is_id || f.is_unique)
110        .map(|f| {
111            let variant = format_ident!("{}", to_pascal_case(&f.name));
112            let ty = rust_type_tokens(f, ModuleDepth::Nested);
113            quote! { #variant(#ty) }
114        })
115        .collect();
116
117    let compound_unique_variants: Vec<TokenStream> = model
118        .unique_constraints
119        .iter()
120        .map(|uc| {
121            let variant = format_ident!("{}", compound_variant_name(&uc.fields));
122            let struct_fields = compound_variant_fields(model, &uc.fields);
123            quote! { #variant { #(#struct_fields),* } }
124        })
125        .collect();
126
127    let unique_variants: Vec<TokenStream> = single_unique_variants
128        .into_iter()
129        .chain(compound_unique_variants)
130        .collect();
131
132    // Generate build_where for WhereInput
133    let db_bounds = collect_db_bounds(scalar_fields);
134    let where_arms = gen_where_arms(scalar_fields);
135    let unique_arms = gen_unique_where_arms(model, scalar_fields);
136    let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
137    let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
138
139    quote! {
140        pub mod filter {
141            use ferriorm_runtime::prelude::*;
142
143            #[derive(Debug, Clone, Default)]
144            pub struct #where_input {
145                #(#where_fields,)*
146                pub and: Option<Vec<#where_input>>,
147                pub or: Option<Vec<#where_input>>,
148                pub not: Option<Box<#where_input>>,
149            }
150
151            #[derive(Debug, Clone)]
152            pub enum #where_unique {
153                #(#unique_variants),*
154            }
155
156            impl #where_input {
157                pub(crate) fn build_where<'args, DB: sqlx::Database>(
158                    &self,
159                    qb: &mut sqlx::QueryBuilder<'args, DB>,
160                )
161                where
162                    #(#db_bounds,)*
163                {
164                    #(#where_arms)*
165
166                    if let Some(conditions) = &self.and {
167                        for c in conditions {
168                            c.build_where(qb);
169                        }
170                    }
171                    if let Some(conditions) = &self.or {
172                        if !conditions.is_empty() {
173                            qb.push(" AND (");
174                            for (i, c) in conditions.iter().enumerate() {
175                                if i > 0 { qb.push(" OR "); }
176                                qb.push("(1=1");
177                                c.build_where(qb);
178                                qb.push(")");
179                            }
180                            qb.push(")");
181                        }
182                    }
183                    if let Some(c) = &self.not {
184                        qb.push(" AND NOT (1=1");
185                        c.build_where(qb);
186                        qb.push(")");
187                    }
188                }
189            }
190
191            impl #where_unique {
192                pub(crate) fn build_where<'args, DB: sqlx::Database>(
193                    &self,
194                    qb: &mut sqlx::QueryBuilder<'args, DB>,
195                )
196                where
197                    #(#db_bounds,)*
198                {
199                    match self {
200                        #(#unique_arms)*
201                    }
202                }
203            }
204
205            impl #where_unique {
206                #[allow(dead_code)]
207                pub(crate) fn conflict_target(&self) -> &'static str {
208                    match self {
209                        #(#conflict_target_arms)*
210                    }
211                }
212
213                #[allow(dead_code)]
214                pub(crate) fn first_conflict_col(&self) -> &'static str {
215                    match self {
216                        #(#first_conflict_col_arms)*
217                    }
218                }
219            }
220        }
221    }
222}
223
224/// Collect the sqlx type bounds needed for all scalar types used by the model.
225fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
226    let mut seen = std::collections::HashSet::new();
227    let mut bounds = Vec::new();
228
229    // Always need i64 for LIMIT/OFFSET
230    seen.insert("i64");
231    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
232
233    for f in scalar_fields {
234        match &f.field_type {
235            FieldKind::Scalar(scalar) => {
236                let key = scalar.rust_type();
237                if seen.insert(key)
238                    && let Some(ty) = scalar_bound_tokens(scalar)
239                {
240                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
241                    // Also add Option<T> bound for nullable field support
242                    bounds.push(
243                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
244                    );
245                }
246            }
247            FieldKind::Enum(_) | FieldKind::Model(_) => {}
248        }
249    }
250
251    bounds
252}
253
254fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
255    match scalar {
256        ScalarType::String => Some(quote! { String }),
257        ScalarType::Int => Some(quote! { i32 }),
258        ScalarType::BigInt => Some(quote! { i64 }),
259        ScalarType::Float => Some(quote! { f64 }),
260        ScalarType::Boolean => Some(quote! { bool }),
261        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
262        ScalarType::Bytes => Some(quote! { Vec<u8> }),
263        ScalarType::Json | ScalarType::Decimal => None,
264    }
265}
266
267/// Generate where-clause arms for each filterable scalar field.
268fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
269    scalar_fields
270        .iter()
271        .filter_map(|f| {
272            // Only generate filter arms for scalar types (skip enums for now)
273            if !matches!(&f.field_type, FieldKind::Scalar(_)) {
274                return None;
275            }
276            let field_ident = format_ident!("{}", to_snake_case(&f.name));
277            let db_name = &f.db_name;
278            let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
279            let is_comparable = matches!(
280                &f.field_type,
281                FieldKind::Scalar(
282                    ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
283                )
284            );
285
286            let mut arms = vec![];
287
288            if f.is_optional {
289                // Nullable filter: `equals`/`not` are `Option<Option<T>>`.
290                // `Some(None)` means IS NULL / IS NOT NULL; `Some(Some(v))`
291                // is the ordinary `= ?` / `!= ?` comparison.
292                arms.push(quote! {
293                    if let Some(v) = &filter.equals {
294                        match v {
295                            None => {
296                                qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
297                            }
298                            Some(inner) => {
299                                qb.push(concat!(" AND \"", #db_name, "\" = "));
300                                qb.push_bind(inner.clone());
301                            }
302                        }
303                    }
304                    if let Some(v) = &filter.not {
305                        match v {
306                            None => {
307                                qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
308                            }
309                            Some(inner) => {
310                                qb.push(concat!(" AND \"", #db_name, "\" != "));
311                                qb.push_bind(inner.clone());
312                            }
313                        }
314                    }
315                });
316            } else {
317                arms.push(quote! {
318                    if let Some(v) = &filter.equals {
319                        qb.push(concat!(" AND \"", #db_name, "\" = "));
320                        qb.push_bind(v.clone());
321                    }
322                    if let Some(v) = &filter.not {
323                        qb.push(concat!(" AND \"", #db_name, "\" != "));
324                        qb.push_bind(v.clone());
325                    }
326                });
327            }
328
329            if is_string {
330                // `like_escape` quotes %, _, and \ in user input so they
331                // match themselves; `ESCAPE '\\'` tells the DB to treat
332                // backslash as the escape character. Without this the
333                // query `contains: "100%_safe"` would also match
334                // arbitrary strings like `100Xsafe`.
335                arms.push(quote! {
336                    if let Some(v) = &filter.contains {
337                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
338                        qb.push_bind(format!("%{}%", ferriorm_runtime::filter::like_escape(v)));
339                        qb.push(" ESCAPE '\\'");
340                    }
341                    if let Some(v) = &filter.starts_with {
342                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
343                        qb.push_bind(format!("{}%", ferriorm_runtime::filter::like_escape(v)));
344                        qb.push(" ESCAPE '\\'");
345                    }
346                    if let Some(v) = &filter.ends_with {
347                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
348                        qb.push_bind(format!("%{}", ferriorm_runtime::filter::like_escape(v)));
349                        qb.push(" ESCAPE '\\'");
350                    }
351                });
352            }
353
354            if is_comparable {
355                arms.push(quote! {
356                    if let Some(v) = &filter.gt {
357                        qb.push(concat!(" AND \"", #db_name, "\" > "));
358                        qb.push_bind(v.clone());
359                    }
360                    if let Some(v) = &filter.gte {
361                        qb.push(concat!(" AND \"", #db_name, "\" >= "));
362                        qb.push_bind(v.clone());
363                    }
364                    if let Some(v) = &filter.lt {
365                        qb.push(concat!(" AND \"", #db_name, "\" < "));
366                        qb.push_bind(v.clone());
367                    }
368                    if let Some(v) = &filter.lte {
369                        qb.push(concat!(" AND \"", #db_name, "\" <= "));
370                        qb.push_bind(v.clone());
371                    }
372                });
373            }
374
375            Some(quote! {
376                if let Some(filter) = &self.#field_ident {
377                    #(#arms)*
378                }
379            })
380        })
381        .collect()
382}
383
384fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
385    let mut arms: Vec<TokenStream> = scalar_fields
386        .iter()
387        .filter(|f| f.is_id || f.is_unique)
388        .map(|f| {
389            let variant = format_ident!("{}", to_pascal_case(&f.name));
390            let db_name = &f.db_name;
391            quote! {
392                Self::#variant(v) => {
393                    qb.push(concat!(" AND \"", #db_name, "\" = "));
394                    qb.push_bind(v.clone());
395                }
396            }
397        })
398        .collect();
399
400    for uc in &model.unique_constraints {
401        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
402        let idents: Vec<_> = uc
403            .fields
404            .iter()
405            .map(|name| format_ident!("{}", to_snake_case(name)))
406            .collect();
407        let binds: Vec<TokenStream> = uc
408            .fields
409            .iter()
410            .map(|name| {
411                let ident = format_ident!("{}", to_snake_case(name));
412                let db_name = resolve_db_name(model, name);
413                quote! {
414                    qb.push(concat!(" AND \"", #db_name, "\" = "));
415                    qb.push_bind(#ident.clone());
416                }
417            })
418            .collect();
419        arms.push(quote! {
420            Self::#variant { #(#idents),* } => {
421                #(#binds)*
422            }
423        });
424    }
425
426    arms
427}
428
429fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
430    let mut arms: Vec<TokenStream> = scalar_fields
431        .iter()
432        .filter(|f| f.is_id || f.is_unique)
433        .map(|f| {
434            let variant = format_ident!("{}", to_pascal_case(&f.name));
435            let target = format!("(\"{}\")", f.db_name);
436            quote! { Self::#variant(_) => #target, }
437        })
438        .collect();
439
440    for uc in &model.unique_constraints {
441        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
442        let cols: Vec<String> = uc
443            .fields
444            .iter()
445            .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
446            .collect();
447        let target = format!("({})", cols.join(", "));
448        arms.push(quote! { Self::#variant { .. } => #target, });
449    }
450
451    arms
452}
453
454fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
455    let mut arms: Vec<TokenStream> = scalar_fields
456        .iter()
457        .filter(|f| f.is_id || f.is_unique)
458        .map(|f| {
459            let variant = format_ident!("{}", to_pascal_case(&f.name));
460            let col = format!("\"{}\"", f.db_name);
461            quote! { Self::#variant(_) => #col, }
462        })
463        .collect();
464
465    for uc in &model.unique_constraints {
466        let variant = format_ident!("{}", compound_variant_name(&uc.fields));
467        let first = uc
468            .fields
469            .first()
470            .map_or_else(String::new, |n| resolve_db_name(model, n));
471        let col = format!("\"{first}\"");
472        arms.push(quote! { Self::#variant { .. } => #col, });
473    }
474
475    arms
476}
477
478/// `PascalCase` concatenation of the field names.
479fn compound_variant_name(fields: &[String]) -> String {
480    fields.iter().map(|f| to_pascal_case(f)).collect()
481}
482
483/// Struct-field tokens for a compound variant: `ident: Ty` per field.
484/// Enum struct-variant fields inherit the enum's visibility; no `pub` here.
485fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
486    fields
487        .iter()
488        .filter_map(|field_name| {
489            let field = model.fields.iter().find(|f| f.name == *field_name)?;
490            let ident = format_ident!("{}", to_snake_case(field_name));
491            let ty = rust_type_tokens(field, ModuleDepth::Nested);
492            Some(quote! { #ident: #ty })
493        })
494        .collect()
495}
496
497/// Resolve a schema field name to its `db_name`, falling back to `snake_case`.
498fn resolve_db_name(model: &Model, field_name: &str) -> String {
499    model
500        .fields
501        .iter()
502        .find(|f| f.name == field_name)
503        .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
504}
505
506// ─── Data Module ──────────────────────────────────────────────
507
508fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
509    let create_name = format_ident!("{}CreateInput", model.name);
510    let update_name = format_ident!("{}UpdateInput", model.name);
511
512    let required_fields: Vec<TokenStream> = scalar_fields
513        .iter()
514        .filter(|f| !f.has_default() && !f.is_updated_at)
515        .map(|f| {
516            let name = format_ident!("{}", to_snake_case(&f.name));
517            let ty = rust_type_tokens(f, ModuleDepth::Nested);
518            quote! { pub #name: #ty }
519        })
520        .collect();
521
522    let optional_fields: Vec<TokenStream> = scalar_fields
523        .iter()
524        .filter(|f| f.has_default() && !f.is_updated_at)
525        .map(|f| {
526            let name = format_ident!("{}", to_snake_case(&f.name));
527            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
528            quote! { pub #name: Option<#base_ty> }
529        })
530        .collect();
531
532    let update_fields: Vec<TokenStream> = scalar_fields
533        .iter()
534        .filter(|f| !f.is_id && !f.is_updated_at)
535        .map(|f| {
536            let name = format_ident!("{}", to_snake_case(&f.name));
537            let ty = rust_type_tokens(f, ModuleDepth::Nested);
538            quote! { pub #name: Option<SetValue<#ty>> }
539        })
540        .collect();
541
542    quote! {
543        pub mod data {
544            use ferriorm_runtime::prelude::*;
545
546            #[derive(Debug, Clone)]
547            pub struct #create_name {
548                #(#required_fields,)*
549                #(#optional_fields,)*
550            }
551
552            /// Update payload. Each field is `Option<SetValue<T>>`:
553            /// `None` leaves the column untouched (omitted from the SET clause),
554            /// `Some(SetValue::Set(v))` writes `v`.
555            #[derive(Debug, Clone, Default)]
556            pub struct #update_name {
557                #(#update_fields,)*
558            }
559        }
560    }
561}
562
563// ─── Order Module ─────────────────────────────────────────────
564
565fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
566    let order_name = format_ident!("{}OrderByInput", model.name);
567
568    let variants: Vec<TokenStream> = scalar_fields
569        .iter()
570        .map(|f| {
571            let variant = format_ident!("{}", to_pascal_case(&f.name));
572            quote! { #variant(SortOrder) }
573        })
574        .collect();
575
576    let order_arms: Vec<TokenStream> = scalar_fields
577        .iter()
578        .map(|f| {
579            let variant = format_ident!("{}", to_pascal_case(&f.name));
580            let db_name = &f.db_name;
581            quote! {
582                Self::#variant(order) => {
583                    qb.push(concat!("\"", #db_name, "\" "));
584                    qb.push(order.as_sql());
585                }
586            }
587        })
588        .collect();
589
590    quote! {
591        pub mod order {
592            use ferriorm_runtime::prelude::*;
593
594            #[derive(Debug, Clone)]
595            pub enum #order_name {
596                #(#variants),*
597            }
598
599            impl #order_name {
600                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
601                    &self,
602                    qb: &mut sqlx::QueryBuilder<'args, DB>,
603                ) {
604                    match self {
605                        #(#order_arms)*
606                    }
607                }
608            }
609        }
610    }
611}
612
613// ─── Actions ──────────────────────────────────────────────────
614
615fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
616    let _model_ident = format_ident!("{}", model.name);
617    let actions_name = format_ident!("{}Actions", model.name);
618    let where_input = format_ident!("{}WhereInput", model.name);
619    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
620    let create_input = format_ident!("{}CreateInput", model.name);
621    let update_input = format_ident!("{}UpdateInput", model.name);
622    let _order_by = format_ident!("{}OrderByInput", model.name);
623
624    // Only generate aggregate() if there are aggregatable fields
625    let has_agg_fields = scalar_fields.iter().any(|f| {
626        matches!(
627            &f.field_type,
628            FieldKind::Scalar(
629                ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
630            )
631        )
632    });
633    let aggregate_method = if has_agg_fields {
634        quote! {
635            pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
636                AggregateQuery { client: self.client, r#where, ops: vec![] }
637            }
638        }
639    } else {
640        quote! {}
641    };
642
643    // Only generate group_by() if there are groupable fields
644    let has_group_fields = scalar_fields.iter().any(|f| is_groupable(f));
645    let groupby_field_name = format_ident!("{}GroupByField", model.name);
646    let group_by_method = if has_group_fields {
647        quote! {
648            pub fn group_by(&self, keys: Vec<#groupby_field_name>) -> GroupByQuery<'a> {
649                GroupByQuery {
650                    client: self.client,
651                    r#where: filter::#where_input::default(),
652                    group_keys: keys,
653                    agg_ops: vec![],
654                    count: false,
655                    having: None,
656                }
657            }
658        }
659    } else {
660        quote! {}
661    };
662
663    quote! {
664        pub struct #actions_name<'a> {
665            client: &'a DatabaseClient,
666        }
667
668        impl<'a> #actions_name<'a> {
669            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
670
671            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
672                FindUniqueQuery { client: self.client, r#where }
673            }
674
675            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
676                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
677            }
678
679            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
680                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
681            }
682
683            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
684                CreateQuery { client: self.client, data }
685            }
686
687            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
688                UpdateQuery { client: self.client, r#where, data }
689            }
690
691            /// Like [`update`], but accepts a full `WhereInput` so additional
692            /// predicates (e.g., `status = 'pending'`) can be used for
693            /// compare-and-swap updates. Returns `Ok(None)` if no row matched.
694            pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
695                UpdateFirstQuery { client: self.client, r#where, data }
696            }
697
698            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
699                DeleteQuery { client: self.client, r#where }
700            }
701
702            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
703                CountQuery { client: self.client, r#where }
704            }
705
706            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
707                CreateManyQuery { client: self.client, data }
708            }
709
710            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
711                UpdateManyQuery { client: self.client, r#where, data }
712            }
713
714            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
715                DeleteManyQuery { client: self.client, r#where }
716            }
717
718            pub fn upsert(
719                &self,
720                r#where: filter::#where_unique,
721                create: data::#create_input,
722                update: data::#update_input,
723            ) -> UpsertQuery<'a> {
724                UpsertQuery { client: self.client, r#where, create, update }
725            }
726
727            #aggregate_method
728
729            #group_by_method
730        }
731    }
732}
733
734// ─── Query Builders with exec() ──────────────────────────────
735
736#[allow(clippy::too_many_lines)]
737fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
738    let model_ident = format_ident!("{}", model.name);
739    let table_name = &model.db_name;
740    let _where_input = format_ident!("{}WhereInput", model.name);
741    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
742    let _create_input = format_ident!("{}CreateInput", model.name);
743    let _update_input = format_ident!("{}UpdateInput", model.name);
744    let order_by = format_ident!("{}OrderByInput", model.name);
745    let _select_struct = format_ident!("{}Select", model.name);
746    let _partial_struct = format_ident!("{}Partial", model.name);
747    let _aggregate_result = format_ident!("{}AggregateResult", model.name);
748    let _aggregate_field = format_ident!("{}AggregateField", model.name);
749    let db_bounds = collect_db_bounds(scalar_fields);
750
751    let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
752    let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
753    let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
754
755    let insert_code = gen_insert_code(model, scalar_fields, table_name);
756    let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
757    let update_code = gen_update_code(model, scalar_fields, table_name);
758    let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
759    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
760    let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
761
762    quote! {
763        // ── Generic helper: build ORDER BY clause ──────────────
764        fn build_order_by<'args, DB: sqlx::Database>(
765            orders: &[order::#order_by],
766            qb: &mut sqlx::QueryBuilder<'args, DB>,
767        ) {
768            if !orders.is_empty() {
769                qb.push(" ORDER BY ");
770                for (i, ob) in orders.iter().enumerate() {
771                    if i > 0 { qb.push(", "); }
772                    ob.build_order_by(qb);
773                }
774            }
775        }
776
777        // ── Generic helper: build a SELECT query ───────────────
778        fn build_select_query<'args, DB: sqlx::Database>(
779            base_sql: &str,
780            where_input: &filter::#_where_input,
781            orders: &[order::#order_by],
782            take: Option<i64>,
783            skip: Option<i64>,
784        ) -> sqlx::QueryBuilder<'args, DB>
785        where
786            #(#db_bounds,)*
787        {
788            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
789            where_input.build_where(&mut qb);
790            build_order_by(orders, &mut qb);
791            if let Some(take) = take {
792                qb.push(" LIMIT ");
793                qb.push_bind(take);
794            }
795            if let Some(skip) = skip {
796                qb.push(" OFFSET ");
797                qb.push_bind(skip);
798            }
799            qb
800        }
801
802        // ── Generic helper: build a SELECT query for unique lookup ──
803        fn build_unique_select_query<'args, DB: sqlx::Database>(
804            base_sql: &str,
805            where_unique: &filter::#_where_unique,
806        ) -> sqlx::QueryBuilder<'args, DB>
807        where
808            #(#db_bounds,)*
809        {
810            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
811            where_unique.build_where(&mut qb);
812            qb.push(" LIMIT 1");
813            qb
814        }
815
816        // ── Generic helper: build a DELETE-returning query ─────
817        fn build_delete_query<'args, DB: sqlx::Database>(
818            base_sql: &str,
819            where_unique: &filter::#_where_unique,
820        ) -> sqlx::QueryBuilder<'args, DB>
821        where
822            #(#db_bounds,)*
823        {
824            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
825            where_unique.build_where(&mut qb);
826            qb.push(" RETURNING *");
827            qb
828        }
829
830        // ── Generic helper: build a COUNT query ────────────────
831        fn build_count_query<'args, DB: sqlx::Database>(
832            base_sql: &str,
833            where_input: &filter::#_where_input,
834        ) -> sqlx::QueryBuilder<'args, DB>
835        where
836            #(#db_bounds,)*
837        {
838            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
839            where_input.build_where(&mut qb);
840            qb
841        }
842
843        // ── Generic helper: build a DELETE-many query ──────────
844        fn build_delete_many_query<'args, DB: sqlx::Database>(
845            base_sql: &str,
846            where_input: &filter::#_where_input,
847        ) -> sqlx::QueryBuilder<'args, DB>
848        where
849            #(#db_bounds,)*
850        {
851            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
852            where_input.build_where(&mut qb);
853            qb
854        }
855
856        pub struct FindUniqueQuery<'a> {
857            client: &'a DatabaseClient,
858            r#where: filter::#_where_unique,
859        }
860
861        impl<'a> FindUniqueQuery<'a> {
862            pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
863                FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
864            }
865
866            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
867                match self.client {
868                    DatabaseClient::Postgres(_) => {
869                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
870                        self.client.fetch_optional_pg(qb).await
871                    }
872                    DatabaseClient::Sqlite(_) => {
873                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
874                        self.client.fetch_optional_sqlite(qb).await
875                    }
876                }
877            }
878        }
879
880        pub struct FindFirstQuery<'a> {
881            client: &'a DatabaseClient,
882            r#where: filter::#_where_input,
883            order_by: Vec<order::#order_by>,
884        }
885
886        impl<'a> FindFirstQuery<'a> {
887            pub fn order_by(mut self, order: order::#order_by) -> Self {
888                self.order_by.push(order);
889                self
890            }
891
892            pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
893                FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
894            }
895
896            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
897                match self.client {
898                    DatabaseClient::Postgres(_) => {
899                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
900                        self.client.fetch_optional_pg(qb).await
901                    }
902                    DatabaseClient::Sqlite(_) => {
903                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
904                        self.client.fetch_optional_sqlite(qb).await
905                    }
906                }
907            }
908        }
909
910        pub struct FindManyQuery<'a> {
911            client: &'a DatabaseClient,
912            r#where: filter::#_where_input,
913            order_by: Vec<order::#order_by>,
914            skip: Option<i64>,
915            take: Option<i64>,
916        }
917
918        impl<'a> FindManyQuery<'a> {
919            pub fn order_by(mut self, order: order::#order_by) -> Self {
920                self.order_by.push(order);
921                self
922            }
923
924            pub fn skip(mut self, n: i64) -> Self {
925                self.skip = Some(n);
926                self
927            }
928
929            pub fn take(mut self, n: i64) -> Self {
930                self.take = Some(n);
931                self
932            }
933
934            pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
935                FindManySelectQuery {
936                    client: self.client,
937                    r#where: self.r#where,
938                    order_by: self.order_by,
939                    skip: self.skip,
940                    take: self.take,
941                    select,
942                }
943            }
944
945            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
946                match self.client {
947                    DatabaseClient::Postgres(_) => {
948                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
949                        self.client.fetch_all_pg(qb).await
950                    }
951                    DatabaseClient::Sqlite(_) => {
952                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
953                        self.client.fetch_all_sqlite(qb).await
954                    }
955                }
956            }
957        }
958
959        pub struct CreateQuery<'a> {
960            client: &'a DatabaseClient,
961            data: data::#_create_input,
962        }
963
964        impl<'a> CreateQuery<'a> {
965            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
966                let client = self.client;
967                #insert_code
968            }
969
970            /// Switch the insert into "ignore on conflict" mode:
971            /// PostgreSQL uses `ON CONFLICT DO NOTHING`, SQLite uses `INSERT OR IGNORE`.
972            /// Returns `Ok(None)` when a conflict suppressed the insert.
973            pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
974                CreateIgnoreQuery { client: self.client, data: self.data }
975            }
976        }
977
978        pub struct CreateIgnoreQuery<'a> {
979            client: &'a DatabaseClient,
980            data: data::#_create_input,
981        }
982
983        impl<'a> CreateIgnoreQuery<'a> {
984            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
985                let client = self.client;
986                #insert_ignore_code
987            }
988        }
989
990        pub struct UpdateQuery<'a> {
991            client: &'a DatabaseClient,
992            r#where: filter::#_where_unique,
993            data: data::#_update_input,
994        }
995
996        impl<'a> UpdateQuery<'a> {
997            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
998                let client = self.client;
999                #update_code
1000            }
1001        }
1002
1003        pub struct UpdateFirstQuery<'a> {
1004            client: &'a DatabaseClient,
1005            r#where: filter::#_where_input,
1006            data: data::#_update_input,
1007        }
1008
1009        impl<'a> UpdateFirstQuery<'a> {
1010            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1011                let client = self.client;
1012                #update_first_code
1013            }
1014        }
1015
1016        pub struct DeleteQuery<'a> {
1017            client: &'a DatabaseClient,
1018            r#where: filter::#_where_unique,
1019        }
1020
1021        impl<'a> DeleteQuery<'a> {
1022            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1023                match self.client {
1024                    DatabaseClient::Postgres(_) => {
1025                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1026                        self.client.fetch_one_pg(qb).await
1027                    }
1028                    DatabaseClient::Sqlite(_) => {
1029                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1030                        self.client.fetch_one_sqlite(qb).await
1031                    }
1032                }
1033            }
1034        }
1035
1036        #[derive(sqlx::FromRow)]
1037        struct CountResult { count: i64 }
1038
1039        pub struct CountQuery<'a> {
1040            client: &'a DatabaseClient,
1041            r#where: filter::#_where_input,
1042        }
1043
1044        impl<'a> CountQuery<'a> {
1045            pub async fn exec(self) -> Result<i64, FerriormError> {
1046                let row: CountResult = match self.client {
1047                    DatabaseClient::Postgres(_) => {
1048                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1049                        self.client.fetch_one_pg(qb).await?
1050                    }
1051                    DatabaseClient::Sqlite(_) => {
1052                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1053                        self.client.fetch_one_sqlite(qb).await?
1054                    }
1055                };
1056                Ok(row.count)
1057            }
1058        }
1059
1060        pub struct CreateManyQuery<'a> {
1061            client: &'a DatabaseClient,
1062            data: Vec<data::#_create_input>,
1063        }
1064
1065        impl<'a> CreateManyQuery<'a> {
1066            pub async fn exec(self) -> Result<u64, FerriormError> {
1067                if self.data.is_empty() { return Ok(0); }
1068                let count = self.data.len() as u64;
1069                for item in self.data {
1070                    CreateQuery { client: self.client, data: item }.exec().await?;
1071                }
1072                Ok(count)
1073            }
1074        }
1075
1076        pub struct UpdateManyQuery<'a> {
1077            client: &'a DatabaseClient,
1078            r#where: filter::#_where_input,
1079            data: data::#_update_input,
1080        }
1081
1082        impl<'a> UpdateManyQuery<'a> {
1083            pub async fn exec(self) -> Result<u64, FerriormError> {
1084                let client = self.client;
1085                #update_many_code
1086            }
1087        }
1088
1089        pub struct UpsertQuery<'a> {
1090            client: &'a DatabaseClient,
1091            r#where: filter::#_where_unique,
1092            create: data::#_create_input,
1093            update: data::#_update_input,
1094        }
1095
1096        impl<'a> UpsertQuery<'a> {
1097            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1098                let client = self.client;
1099                #upsert_code
1100            }
1101        }
1102
1103        pub struct DeleteManyQuery<'a> {
1104            client: &'a DatabaseClient,
1105            r#where: filter::#_where_input,
1106        }
1107
1108        impl<'a> DeleteManyQuery<'a> {
1109            pub async fn exec(self) -> Result<u64, FerriormError> {
1110                match self.client {
1111                    DatabaseClient::Postgres(_) => {
1112                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1113                        self.client.execute_pg(qb).await
1114                    }
1115                    DatabaseClient::Sqlite(_) => {
1116                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1117                        self.client.execute_sqlite(qb).await
1118                    }
1119                }
1120            }
1121        }
1122    }
1123}
1124
1125// ─── INSERT code generation ───────────────────────────────────
1126
1127fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1128    let _model_ident = format_ident!("{}", model.name);
1129
1130    // Required columns: scalar, no default, not @updatedAt
1131    let required: Vec<&Field> = scalar_fields
1132        .iter()
1133        .copied()
1134        .filter(|f| !f.has_default() && !f.is_updated_at)
1135        .collect();
1136
1137    // Optional columns: have default (can be overridden), not @updatedAt
1138    let optional: Vec<&Field> = scalar_fields
1139        .iter()
1140        .copied()
1141        .filter(|f| f.has_default() && !f.is_updated_at)
1142        .collect();
1143
1144    // @updatedAt columns: always set to now()
1145    let updated_at: Vec<&Field> = scalar_fields
1146        .iter()
1147        .copied()
1148        .filter(|f| f.is_updated_at)
1149        .collect();
1150
1151    // Build column names and bind values
1152    let mut col_pushes = vec![];
1153    let mut val_pushes = vec![];
1154
1155    // Required fields — always included
1156    for f in &required {
1157        let db_name = &f.db_name;
1158        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1159        col_pushes.push(quote! { cols.push(#db_name); });
1160        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1161    }
1162
1163    // Optional fields — resolve defaults in Rust
1164    for f in &optional {
1165        let db_name = &f.db_name;
1166        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1167        if is_autoincrement(f) {
1168            // Autoincrement: if caller passed None, omit the column entirely
1169            // so the DB assigns the next sequence value. Binding a literal 0
1170            // would collide on the second insert.
1171            col_pushes.push(quote! {
1172                if self.data.#field_ident.is_some() { cols.push(#db_name); }
1173            });
1174            val_pushes.push(quote! {
1175                if let Some(val) = self.data.#field_ident {
1176                    sep.push_bind(val);
1177                }
1178            });
1179        } else {
1180            let default_expr = gen_default_expr(f, &f.field_type);
1181            col_pushes.push(quote! { cols.push(#db_name); });
1182            val_pushes.push(quote! {
1183                let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1184                sep.push_bind(val);
1185            });
1186        }
1187    }
1188
1189    // @updatedAt fields
1190    for f in &updated_at {
1191        let db_name = &f.db_name;
1192        col_pushes.push(quote! { cols.push(#db_name); });
1193        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1194    }
1195
1196    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1197
1198    // The insert_body macro avoids duplicating the column/value building logic
1199    // for each database backend. It captures `self` by reference.
1200    quote! {
1201        // Helper to build the INSERT query for any DB backend
1202        macro_rules! build_insert {
1203            ($qb_type:ty) => {{
1204                let mut cols: Vec<&str> = Vec::new();
1205                #(#col_pushes)*
1206
1207                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1208                qb.push(" (");
1209                for (i, col) in cols.iter().enumerate() {
1210                    if i > 0 { qb.push(", "); }
1211                    qb.push("\"");
1212                    qb.push(*col);
1213                    qb.push("\"");
1214                }
1215                qb.push(") VALUES (");
1216                {
1217                    let mut sep = qb.separated(", ");
1218                    #(#val_pushes)*
1219                }
1220                qb.push(") RETURNING *");
1221                qb
1222            }};
1223        }
1224
1225        match client {
1226            DatabaseClient::Postgres(_) => {
1227                let qb = build_insert!(sqlx::Postgres);
1228                client.fetch_one_pg(qb).await
1229            }
1230            DatabaseClient::Sqlite(_) => {
1231                let qb = build_insert!(sqlx::Sqlite);
1232                client.fetch_one_sqlite(qb).await
1233            }
1234        }
1235    }
1236}
1237
1238fn gen_insert_ignore_code(
1239    _model: &Model,
1240    scalar_fields: &[&Field],
1241    table_name: &str,
1242) -> TokenStream {
1243    let required: Vec<&Field> = scalar_fields
1244        .iter()
1245        .copied()
1246        .filter(|f| !f.has_default() && !f.is_updated_at)
1247        .collect();
1248    let optional: Vec<&Field> = scalar_fields
1249        .iter()
1250        .copied()
1251        .filter(|f| f.has_default() && !f.is_updated_at)
1252        .collect();
1253    let updated_at: Vec<&Field> = scalar_fields
1254        .iter()
1255        .copied()
1256        .filter(|f| f.is_updated_at)
1257        .collect();
1258
1259    let mut col_pushes = vec![];
1260    let mut val_pushes = vec![];
1261
1262    for f in &required {
1263        let db_name = &f.db_name;
1264        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1265        col_pushes.push(quote! { cols.push(#db_name); });
1266        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1267    }
1268    for f in &optional {
1269        let db_name = &f.db_name;
1270        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1271        if is_autoincrement(f) {
1272            col_pushes.push(quote! {
1273                if self.data.#field_ident.is_some() { cols.push(#db_name); }
1274            });
1275            val_pushes.push(quote! {
1276                if let Some(val) = self.data.#field_ident {
1277                    sep.push_bind(val);
1278                }
1279            });
1280        } else {
1281            let default_expr = gen_default_expr(f, &f.field_type);
1282            col_pushes.push(quote! { cols.push(#db_name); });
1283            val_pushes.push(quote! {
1284                let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1285                sep.push_bind(val);
1286            });
1287        }
1288    }
1289    for f in &updated_at {
1290        let db_name = &f.db_name;
1291        col_pushes.push(quote! { cols.push(#db_name); });
1292        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1293    }
1294
1295    let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1296    let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1297
1298    quote! {
1299        macro_rules! build_insert_ignore {
1300            ($qb_type:ty, $head:expr, $tail:expr) => {{
1301                let mut cols: Vec<&str> = Vec::new();
1302                #(#col_pushes)*
1303
1304                let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1305                qb.push(" (");
1306                for (i, col) in cols.iter().enumerate() {
1307                    if i > 0 { qb.push(", "); }
1308                    qb.push("\"");
1309                    qb.push(*col);
1310                    qb.push("\"");
1311                }
1312                qb.push(") VALUES (");
1313                {
1314                    let mut sep = qb.separated(", ");
1315                    #(#val_pushes)*
1316                }
1317                qb.push(")");
1318                qb.push($tail);
1319                qb.push(" RETURNING *");
1320                qb
1321            }};
1322        }
1323
1324        match client {
1325            DatabaseClient::Postgres(_) => {
1326                let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1327                client.fetch_optional_pg(qb).await
1328            }
1329            DatabaseClient::Sqlite(_) => {
1330                let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1331                client.fetch_optional_sqlite(qb).await
1332            }
1333        }
1334    }
1335}
1336
1337/// True if the field is declared with `@default(autoincrement())`.
1338/// Such columns must be omitted from the INSERT when the caller passes `None`,
1339/// otherwise we'd bind a literal 0 and collide on the second insert.
1340fn is_autoincrement(field: &Field) -> bool {
1341    matches!(
1342        field.default,
1343        Some(ferriorm_core::ast::DefaultValue::AutoIncrement)
1344    )
1345}
1346
1347/// Generate a Rust expression for a field's @default value.
1348fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1349    use ferriorm_core::ast::DefaultValue;
1350
1351    match &field.default {
1352        Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1353            quote! { uuid::Uuid::new_v4().to_string() }
1354        }
1355        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1356        // Autoincrement is handled specially by the INSERT/UPSERT generators
1357        // (see `is_autoincrement`): the column is omitted when the caller passes
1358        // `None`, so this arm is unreachable in practice. It stays as a safe
1359        // fallback only for match exhaustiveness.
1360        Some(DefaultValue::AutoIncrement) => quote! { 0i32 },
1361        Some(DefaultValue::Literal(lit)) => {
1362            use ferriorm_core::ast::LiteralValue;
1363            match lit {
1364                LiteralValue::String(s) => quote! { #s.to_string() },
1365                LiteralValue::Int(i) => {
1366                    // Cast the integer literal to the correct Rust type based on the field's scalar type.
1367                    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1368                    match field_type {
1369                        FieldKind::Scalar(ScalarType::Float) => {
1370                            let val = *i as f64;
1371                            quote! { #val }
1372                        }
1373                        FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1374                        // `@db.BigInt` on an `Int` widens the literal to i64 too.
1375                        FieldKind::Scalar(ScalarType::Int)
1376                            if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1377                        {
1378                            quote! { #i }
1379                        }
1380                        _ => {
1381                            // Default to i32 for Int and other types
1382                            let val = *i as i32;
1383                            quote! { #val }
1384                        }
1385                    }
1386                }
1387                LiteralValue::Float(f) => quote! { #f },
1388                LiteralValue::Bool(b) => quote! { #b },
1389            }
1390        }
1391        Some(DefaultValue::EnumVariant(v)) => {
1392            // Reference the enum variant — insert code runs at model module level
1393            let variant = format_ident!("{}", v);
1394            if let FieldKind::Enum(enum_name) = &field.field_type {
1395                let enum_ident = format_ident!("{}", enum_name);
1396                quote! { super::enums::#enum_ident::#variant }
1397            } else {
1398                quote! { Default::default() }
1399            }
1400        }
1401        None => quote! { Default::default() },
1402    }
1403}
1404
1405// ─── UPDATE code generation ───────────────────────────────────
1406
1407fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1408    let _model_ident = format_ident!("{}", model.name);
1409
1410    // Updatable fields: non-id, non-updatedAt scalar fields
1411    let updatable: Vec<&Field> = scalar_fields
1412        .iter()
1413        .copied()
1414        .filter(|f| !f.is_id && !f.is_updated_at)
1415        .collect();
1416
1417    let updated_at: Vec<&Field> = scalar_fields
1418        .iter()
1419        .copied()
1420        .filter(|f| f.is_updated_at)
1421        .collect();
1422
1423    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1424
1425    // Generate SET clause arms
1426    let set_arms: Vec<TokenStream> = updatable
1427        .iter()
1428        .map(|f| {
1429            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1430            let db_name = &f.db_name;
1431            quote! {
1432                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1433                    if !first_set { qb.push(", "); }
1434                    first_set = false;
1435                    qb.push(concat!("\"", #db_name, "\" = "));
1436                    qb.push_bind(v);
1437                }
1438            }
1439        })
1440        .collect();
1441
1442    let updated_at_arms: Vec<TokenStream> = updated_at
1443        .iter()
1444        .map(|f| {
1445            let db_name = &f.db_name;
1446            quote! {
1447                if !first_set { qb.push(", "); }
1448                first_set = false;
1449                qb.push(concat!("\"", #db_name, "\" = "));
1450                qb.push_bind(chrono::Utc::now());
1451            }
1452        })
1453        .collect();
1454
1455    // The build_update macro avoids duplicating the SET clause building logic
1456    // for each database backend.
1457    quote! {
1458        macro_rules! build_update {
1459            ($qb_type:ty) => {{
1460                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1461                let mut first_set = true;
1462                #(#set_arms)*
1463                #(#updated_at_arms)*
1464
1465                if first_set {
1466                    return Err(FerriormError::Query("No fields to update".into()));
1467                }
1468
1469                qb.push(" WHERE 1=1");
1470                self.r#where.build_where(&mut qb);
1471                qb.push(" RETURNING *");
1472                qb
1473            }};
1474        }
1475
1476        match client {
1477            DatabaseClient::Postgres(_) => {
1478                let qb = build_update!(sqlx::Postgres);
1479                client.fetch_one_pg(qb).await
1480            }
1481            DatabaseClient::Sqlite(_) => {
1482                let qb = build_update!(sqlx::Sqlite);
1483                client.fetch_one_sqlite(qb).await
1484            }
1485        }
1486    }
1487}
1488
1489// ─── UPDATE FIRST (CAS) code generation ──────────────────────
1490
1491fn gen_update_first_code(
1492    _model: &Model,
1493    scalar_fields: &[&Field],
1494    table_name: &str,
1495) -> TokenStream {
1496    let updatable: Vec<&Field> = scalar_fields
1497        .iter()
1498        .copied()
1499        .filter(|f| !f.is_id && !f.is_updated_at)
1500        .collect();
1501
1502    let updated_at: Vec<&Field> = scalar_fields
1503        .iter()
1504        .copied()
1505        .filter(|f| f.is_updated_at)
1506        .collect();
1507
1508    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1509
1510    let set_arms: Vec<TokenStream> = updatable
1511        .iter()
1512        .map(|f| {
1513            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1514            let db_name = &f.db_name;
1515            quote! {
1516                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1517                    if !first_set { qb.push(", "); }
1518                    first_set = false;
1519                    qb.push(concat!("\"", #db_name, "\" = "));
1520                    qb.push_bind(v);
1521                }
1522            }
1523        })
1524        .collect();
1525
1526    let updated_at_arms: Vec<TokenStream> = updated_at
1527        .iter()
1528        .map(|f| {
1529            let db_name = &f.db_name;
1530            quote! {
1531                if !first_set { qb.push(", "); }
1532                first_set = false;
1533                qb.push(concat!("\"", #db_name, "\" = "));
1534                qb.push_bind(chrono::Utc::now());
1535            }
1536        })
1537        .collect();
1538
1539    quote! {
1540        macro_rules! build_update_first {
1541            ($qb_type:ty) => {{
1542                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1543                let mut first_set = true;
1544                #(#set_arms)*
1545                #(#updated_at_arms)*
1546
1547                if first_set {
1548                    return Err(FerriormError::Query("No fields to update".into()));
1549                }
1550
1551                qb.push(" WHERE 1=1");
1552                self.r#where.build_where(&mut qb);
1553                qb.push(" RETURNING *");
1554                qb
1555            }};
1556        }
1557
1558        match client {
1559            DatabaseClient::Postgres(_) => {
1560                let qb = build_update_first!(sqlx::Postgres);
1561                client.fetch_optional_pg(qb).await
1562            }
1563            DatabaseClient::Sqlite(_) => {
1564                let qb = build_update_first!(sqlx::Sqlite);
1565                client.fetch_optional_sqlite(qb).await
1566            }
1567        }
1568    }
1569}
1570
1571// ─── UPDATE MANY code generation ──────────────────────────────
1572
1573fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1574    // Updatable fields: non-id, non-updatedAt scalar fields
1575    let updatable: Vec<&Field> = scalar_fields
1576        .iter()
1577        .copied()
1578        .filter(|f| !f.is_id && !f.is_updated_at)
1579        .collect();
1580
1581    let updated_at: Vec<&Field> = scalar_fields
1582        .iter()
1583        .copied()
1584        .filter(|f| f.is_updated_at)
1585        .collect();
1586
1587    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1588
1589    // Generate SET clause arms
1590    let set_arms: Vec<TokenStream> = updatable
1591        .iter()
1592        .map(|f| {
1593            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1594            let db_name = &f.db_name;
1595            quote! {
1596                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1597                    if !first_set { qb.push(", "); }
1598                    first_set = false;
1599                    qb.push(concat!("\"", #db_name, "\" = "));
1600                    qb.push_bind(v);
1601                }
1602            }
1603        })
1604        .collect();
1605
1606    let updated_at_arms: Vec<TokenStream> = updated_at
1607        .iter()
1608        .map(|f| {
1609            let db_name = &f.db_name;
1610            quote! {
1611                if !first_set { qb.push(", "); }
1612                first_set = false;
1613                qb.push(concat!("\"", #db_name, "\" = "));
1614                qb.push_bind(chrono::Utc::now());
1615            }
1616        })
1617        .collect();
1618
1619    quote! {
1620        macro_rules! build_update_many {
1621            ($qb_type:ty) => {{
1622                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1623                let mut first_set = true;
1624                #(#set_arms)*
1625                #(#updated_at_arms)*
1626
1627                if first_set {
1628                    return Ok(0);
1629                }
1630
1631                qb.push(" WHERE 1=1");
1632                self.r#where.build_where(&mut qb);
1633                qb
1634            }};
1635        }
1636
1637        match client {
1638            DatabaseClient::Postgres(_) => {
1639                let qb = build_update_many!(sqlx::Postgres);
1640                client.execute_pg(qb).await
1641            }
1642            DatabaseClient::Sqlite(_) => {
1643                let qb = build_update_many!(sqlx::Sqlite);
1644                client.execute_sqlite(qb).await
1645            }
1646        }
1647    }
1648}
1649
1650// ─── Aggregate Types ──────────────────────────────────────────
1651
1652/// Identifies which fields are aggregatable and what operations they support.
1653enum AggregateKind {
1654    /// Numeric fields: avg, sum, min, max
1655    Numeric,
1656    /// `DateTime` fields: min, max only
1657    DateTime,
1658}
1659
1660// ─── UPSERT code generation ──────────────────────────────────
1661
1662#[allow(clippy::too_many_lines)]
1663fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1664    // Required + optional + updatedAt fields for the INSERT part (same as gen_insert_code)
1665    let required: Vec<&Field> = scalar_fields
1666        .iter()
1667        .copied()
1668        .filter(|f| !f.has_default() && !f.is_updated_at)
1669        .collect();
1670    let optional: Vec<&Field> = scalar_fields
1671        .iter()
1672        .copied()
1673        .filter(|f| f.has_default() && !f.is_updated_at)
1674        .collect();
1675    let updated_at: Vec<&Field> = scalar_fields
1676        .iter()
1677        .copied()
1678        .filter(|f| f.is_updated_at)
1679        .collect();
1680
1681    let mut col_pushes = vec![];
1682    let mut val_pushes = vec![];
1683
1684    for f in &required {
1685        let db_name = &f.db_name;
1686        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1687        col_pushes.push(quote! { cols.push(#db_name); });
1688        val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1689    }
1690    for f in &optional {
1691        let db_name = &f.db_name;
1692        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1693        if is_autoincrement(f) {
1694            col_pushes.push(quote! {
1695                if self.create.#field_ident.is_some() { cols.push(#db_name); }
1696            });
1697            val_pushes.push(quote! {
1698                if let Some(val) = self.create.#field_ident {
1699                    sep.push_bind(val);
1700                }
1701            });
1702        } else {
1703            let default_expr = gen_default_expr(f, &f.field_type);
1704            col_pushes.push(quote! { cols.push(#db_name); });
1705            val_pushes.push(quote! {
1706                let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1707                sep.push_bind(val);
1708            });
1709        }
1710    }
1711    for f in &updated_at {
1712        let db_name = &f.db_name;
1713        col_pushes.push(quote! { cols.push(#db_name); });
1714        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1715    }
1716
1717    // Updatable fields for the DO UPDATE SET part
1718    let updatable: Vec<&Field> = scalar_fields
1719        .iter()
1720        .copied()
1721        .filter(|f| !f.is_id && !f.is_updated_at)
1722        .collect();
1723
1724    let set_arms: Vec<TokenStream> = updatable
1725        .iter()
1726        .map(|f| {
1727            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1728            let db_name = &f.db_name;
1729            quote! {
1730                if let Some(SetValue::Set(v)) = self.update.#field_ident {
1731                    if !first_set { qb.push(", "); }
1732                    first_set = false;
1733                    qb.push(concat!("\"", #db_name, "\" = "));
1734                    qb.push_bind(v);
1735                }
1736            }
1737        })
1738        .collect();
1739
1740    let updated_at_set: Vec<TokenStream> = updated_at
1741        .iter()
1742        .map(|f| {
1743            let db_name = &f.db_name;
1744            quote! {
1745                if !first_set { qb.push(", "); }
1746                first_set = false;
1747                qb.push(concat!("\"", #db_name, "\" = "));
1748                qb.push_bind(chrono::Utc::now());
1749            }
1750        })
1751        .collect();
1752
1753    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1754
1755    quote! {
1756        let conflict_target = self.r#where.conflict_target();
1757        let first_conflict_col = self.r#where.first_conflict_col();
1758
1759        macro_rules! build_upsert {
1760            ($qb_type:ty) => {{
1761                let mut cols: Vec<&str> = Vec::new();
1762                #(#col_pushes)*
1763
1764                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1765                qb.push(" (");
1766                for (i, col) in cols.iter().enumerate() {
1767                    if i > 0 { qb.push(", "); }
1768                    qb.push("\"");
1769                    qb.push(*col);
1770                    qb.push("\"");
1771                }
1772                qb.push(") VALUES (");
1773                {
1774                    let mut sep = qb.separated(", ");
1775                    #(#val_pushes)*
1776                }
1777                qb.push(")");
1778                qb.push(" ON CONFLICT ");
1779                qb.push(conflict_target);
1780                qb.push(" DO UPDATE SET ");
1781
1782                let mut first_set = true;
1783                #(#set_arms)*
1784                #(#updated_at_set)*
1785
1786                if first_set {
1787                    // No update fields specified — use a no-op update on the first
1788                    // conflict-target column so RETURNING * still yields the row.
1789                    qb.push(first_conflict_col);
1790                    qb.push(" = ");
1791                    qb.push(first_conflict_col);
1792                }
1793
1794                qb.push(" RETURNING *");
1795                qb
1796            }};
1797        }
1798
1799        match client {
1800            DatabaseClient::Postgres(_) => {
1801                let qb = build_upsert!(sqlx::Postgres);
1802                client.fetch_one_pg(qb).await
1803            }
1804            DatabaseClient::Sqlite(_) => {
1805                let qb = build_upsert!(sqlx::Sqlite);
1806                client.fetch_one_sqlite(qb).await
1807            }
1808        }
1809    }
1810}
1811
1812#[allow(clippy::too_many_lines)]
1813fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1814    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1815    let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1816    let _where_input = format_ident!("{}WhereInput", model.name);
1817    let table_name = &model.db_name;
1818
1819    // Collect aggregatable fields with their kind
1820    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1821        .iter()
1822        .filter_map(|f| match &f.field_type {
1823            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1824                Some((*f, AggregateKind::Numeric))
1825            }
1826            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1827            _ => None,
1828        })
1829        .collect();
1830
1831    if agg_fields.is_empty() {
1832        return quote! {};
1833    }
1834
1835    // Generate enum variants
1836    let enum_variants: Vec<TokenStream> = agg_fields
1837        .iter()
1838        .map(|(f, _)| {
1839            let variant = format_ident!("{}", to_pascal_case(&f.name));
1840            quote! { #variant }
1841        })
1842        .collect();
1843
1844    // Generate db_name match arms
1845    let db_name_arms: Vec<TokenStream> = agg_fields
1846        .iter()
1847        .map(|(f, _)| {
1848            let variant = format_ident!("{}", to_pascal_case(&f.name));
1849            let db_name = &f.db_name;
1850            quote! { Self::#variant => #db_name }
1851        })
1852        .collect();
1853
1854    // Generate AggregateResult fields
1855    let mut result_fields = Vec::new();
1856    for (f, kind) in &agg_fields {
1857        let snake = to_snake_case(&f.name);
1858        let orig_ty = rust_type_tokens(
1859            &Field {
1860                is_optional: false,
1861                ..(*f).clone()
1862            },
1863            ModuleDepth::TopLevel,
1864        );
1865
1866        match kind {
1867            AggregateKind::Numeric => {
1868                let avg_name = format_ident!("avg_{}", snake);
1869                let sum_name = format_ident!("sum_{}", snake);
1870                let min_name = format_ident!("min_{}", snake);
1871                let max_name = format_ident!("max_{}", snake);
1872                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1873                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1874                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1875                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1876            }
1877            AggregateKind::DateTime => {
1878                let min_name = format_ident!("min_{}", snake);
1879                let max_name = format_ident!("max_{}", snake);
1880                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1881                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1882            }
1883        }
1884    }
1885
1886    // Generate the is_numeric check for avg/sum validation
1887    let numeric_arms: Vec<TokenStream> = agg_fields
1888        .iter()
1889        .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1890        .map(|(f, _)| {
1891            let variant = format_ident!("{}", to_pascal_case(&f.name));
1892            quote! { Self::#variant => true }
1893        })
1894        .collect();
1895
1896    let has_numeric = !numeric_arms.is_empty();
1897    let is_numeric_method = if has_numeric {
1898        quote! {
1899            fn is_numeric(&self) -> bool {
1900                match self {
1901                    #(#numeric_arms,)*
1902                    #[allow(unreachable_patterns)]
1903                    _ => false,
1904                }
1905            }
1906        }
1907    } else {
1908        quote! {
1909            fn is_numeric(&self) -> bool { false }
1910        }
1911    };
1912
1913    // Generate alias match arms for each (prefix, field) combination
1914    let mut alias_arms = Vec::new();
1915    for (f, kind) in &agg_fields {
1916        let variant = format_ident!("{}", to_pascal_case(&f.name));
1917        let snake = to_snake_case(&f.name);
1918        let prefixes = match kind {
1919            AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1920            AggregateKind::DateTime => vec!["min", "max"],
1921        };
1922        for prefix in prefixes {
1923            let alias_str = format!("{prefix}_{snake}");
1924            alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1925        }
1926    }
1927
1928    let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1929
1930    quote! {
1931        #[derive(Debug, Clone, Copy)]
1932        pub enum #aggregate_field_name {
1933            #(#enum_variants),*
1934        }
1935
1936        impl #aggregate_field_name {
1937            pub fn db_name(&self) -> &'static str {
1938                match self {
1939                    #(#db_name_arms,)*
1940                }
1941            }
1942
1943            fn alias(&self, prefix: &'static str) -> &'static str {
1944                match (prefix, self) {
1945                    #(#alias_arms,)*
1946                    _ => unreachable!(),
1947                }
1948            }
1949
1950            #is_numeric_method
1951        }
1952
1953        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1954        pub struct #aggregate_result_name {
1955            #(#result_fields,)*
1956        }
1957
1958        pub struct AggregateQuery<'a> {
1959            client: &'a DatabaseClient,
1960            r#where: filter::#_where_input,
1961            ops: Vec<(&'static str, &'static str, &'static str)>,
1962        }
1963
1964        impl<'a> AggregateQuery<'a> {
1965            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1966                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1967                let db_name = field.db_name();
1968                let alias = field.alias("avg");
1969                self.ops.push(("AVG", db_name, alias));
1970                self
1971            }
1972
1973            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1974                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1975                let db_name = field.db_name();
1976                let alias = field.alias("sum");
1977                self.ops.push(("SUM", db_name, alias));
1978                self
1979            }
1980
1981            pub fn min(mut self, field: #aggregate_field_name) -> Self {
1982                let db_name = field.db_name();
1983                let alias = field.alias("min");
1984                self.ops.push(("MIN", db_name, alias));
1985                self
1986            }
1987
1988            pub fn max(mut self, field: #aggregate_field_name) -> Self {
1989                let db_name = field.db_name();
1990                let alias = field.alias("max");
1991                self.ops.push(("MAX", db_name, alias));
1992                self
1993            }
1994
1995            pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1996                if self.ops.is_empty() {
1997                    return Err(FerriormError::Query("No aggregate operations specified".into()));
1998                }
1999
2000                let selections: Vec<String> = self.ops.iter()
2001                    .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
2002                    .collect();
2003                let select_clause = selections.join(", ");
2004                let base_sql = format!(#agg_select_base, select_clause);
2005
2006                match self.client {
2007                    DatabaseClient::Postgres(_) => {
2008                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2009                        self.r#where.build_where(&mut qb);
2010                        self.client.fetch_one_pg(qb).await
2011                    }
2012                    DatabaseClient::Sqlite(_) => {
2013                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2014                        self.r#where.build_where(&mut qb);
2015                        self.client.fetch_one_sqlite(qb).await
2016                    }
2017                }
2018            }
2019        }
2020    }
2021}
2022
2023// ─── GroupBy Types ────────────────────────────────────────────
2024
2025/// Generate the six standard-comparable HAVING arms (`equals`/`not`/`gt`/
2026/// `gte`/`lt`/`lte`) for one aggregate field. `lhs` is the SQL expression on
2027/// the left-hand side of the operator (e.g. `AVG("age")`, `MIN("created_at")`).
2028fn gen_having_comparable_arms(field_ident: &proc_macro2::Ident, lhs: &str) -> TokenStream {
2029    let eq = format!(" AND {lhs} = ");
2030    let ne = format!(" AND {lhs} != ");
2031    let gt = format!(" AND {lhs} > ");
2032    let gte = format!(" AND {lhs} >= ");
2033    let lt = format!(" AND {lhs} < ");
2034    let lte = format!(" AND {lhs} <= ");
2035    quote! {
2036        if let Some(filter) = &self.#field_ident {
2037            if let Some(v) = &filter.equals { qb.push(#eq); qb.push_bind(v.clone()); }
2038            if let Some(v) = &filter.not    { qb.push(#ne); qb.push_bind(v.clone()); }
2039            if let Some(v) = &filter.gt     { qb.push(#gt); qb.push_bind(v.clone()); }
2040            if let Some(v) = &filter.gte    { qb.push(#gte); qb.push_bind(v.clone()); }
2041            if let Some(v) = &filter.lt     { qb.push(#lt); qb.push_bind(v.clone()); }
2042            if let Some(v) = &filter.lte    { qb.push(#lte); qb.push_bind(v.clone()); }
2043        }
2044    }
2045}
2046
2047/// True for fields that can appear in a `GROUP BY` clause: any scalar except
2048/// `Json`/`Bytes`/`Decimal` (which are not orderable / hashable in SQL), plus
2049/// enums. Optional fields are still groupable -- `NULL` becomes its own
2050/// bucket.
2051fn is_groupable(field: &Field) -> bool {
2052    match &field.field_type {
2053        FieldKind::Scalar(
2054            ScalarType::String
2055            | ScalarType::Int
2056            | ScalarType::BigInt
2057            | ScalarType::Float
2058            | ScalarType::Boolean
2059            | ScalarType::DateTime,
2060        )
2061        | FieldKind::Enum(_) => true,
2062        FieldKind::Scalar(ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal)
2063        | FieldKind::Model(_) => false,
2064    }
2065}
2066
2067#[allow(clippy::too_many_lines)]
2068fn gen_groupby_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2069    let groupby_field_name = format_ident!("{}GroupByField", model.name);
2070    let groupby_result_name = format_ident!("{}GroupByResult", model.name);
2071    let having_input_name = format_ident!("{}HavingInput", model.name);
2072    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
2073    let where_input = format_ident!("{}WhereInput", model.name);
2074    let table_name = &model.db_name;
2075
2076    // Reuse the same aggregate-field collection as gen_aggregate_types so the
2077    // result struct columns and HAVING surface stay consistent.
2078    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
2079        .iter()
2080        .filter_map(|f| match &f.field_type {
2081            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
2082                Some((*f, AggregateKind::Numeric))
2083            }
2084            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
2085            _ => None,
2086        })
2087        .collect();
2088
2089    let group_fields: Vec<&Field> = scalar_fields
2090        .iter()
2091        .filter(|f| is_groupable(f))
2092        .copied()
2093        .collect();
2094
2095    if group_fields.is_empty() {
2096        return quote! {};
2097    }
2098
2099    // ── <Model>GroupByField enum ──────────────────────────────
2100    let groupby_variants: Vec<TokenStream> = group_fields
2101        .iter()
2102        .map(|f| {
2103            let variant = format_ident!("{}", to_pascal_case(&f.name));
2104            quote! { #variant }
2105        })
2106        .collect();
2107
2108    let groupby_db_arms: Vec<TokenStream> = group_fields
2109        .iter()
2110        .map(|f| {
2111            let variant = format_ident!("{}", to_pascal_case(&f.name));
2112            let db_name = &f.db_name;
2113            quote! { Self::#variant => #db_name }
2114        })
2115        .collect();
2116
2117    let groupby_alias_arms: Vec<TokenStream> = group_fields
2118        .iter()
2119        .map(|f| {
2120            let variant = format_ident!("{}", to_pascal_case(&f.name));
2121            let alias = to_snake_case(&f.name);
2122            quote! { Self::#variant => #alias }
2123        })
2124        .collect();
2125
2126    // ── <Model>GroupByResult fields ───────────────────────────
2127    // One Option<T> per groupable field (only filled when that field is in
2128    // the active group key set), then count, then the same avg/sum/min/max
2129    // columns that gen_aggregate_types emits.
2130    let mut result_fields: Vec<TokenStream> = Vec::new();
2131    for f in &group_fields {
2132        let snake = to_snake_case(&f.name);
2133        let name = format_ident!("{}", snake);
2134        // Always wrap in Option so the same struct serves every group_by call.
2135        let base_ty = rust_type_tokens(
2136            &Field {
2137                is_optional: false,
2138                ..(*f).clone()
2139            },
2140            ModuleDepth::TopLevel,
2141        );
2142        result_fields.push(quote! { #[sqlx(default)] pub #name: Option<#base_ty> });
2143    }
2144    result_fields.push(quote! { #[sqlx(default)] pub count: Option<i64> });
2145    for (f, kind) in &agg_fields {
2146        let snake = to_snake_case(&f.name);
2147        let orig_ty = rust_type_tokens(
2148            &Field {
2149                is_optional: false,
2150                ..(*f).clone()
2151            },
2152            ModuleDepth::TopLevel,
2153        );
2154        match kind {
2155            AggregateKind::Numeric => {
2156                let avg_name = format_ident!("avg_{}", snake);
2157                let sum_name = format_ident!("sum_{}", snake);
2158                let min_name = format_ident!("min_{}", snake);
2159                let max_name = format_ident!("max_{}", snake);
2160                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
2161                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
2162                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2163                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2164            }
2165            AggregateKind::DateTime => {
2166                let min_name = format_ident!("min_{}", snake);
2167                let max_name = format_ident!("max_{}", snake);
2168                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2169                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2170            }
2171        }
2172    }
2173
2174    // ── <Model>HavingInput fields ─────────────────────────────
2175    // Filtering on aggregate expressions: COUNT(*), AVG/SUM/MIN/MAX of each
2176    // aggregatable column. RHS reuses the same scalar filter types as
2177    // WhereInput.
2178    let mut having_fields: Vec<TokenStream> = Vec::new();
2179    // COUNT(*) returns BIGINT in both Postgres and SQLite -> BigIntFilter.
2180    having_fields.push(quote! { pub count: Option<ferriorm_runtime::filter::BigIntFilter> });
2181    for (f, kind) in &agg_fields {
2182        let snake = to_snake_case(&f.name);
2183        let avg_name = format_ident!("avg_{}", snake);
2184        let sum_name = format_ident!("sum_{}", snake);
2185        let min_name = format_ident!("min_{}", snake);
2186        let max_name = format_ident!("max_{}", snake);
2187        let column_filter = filter_type_tokens(
2188            &Field {
2189                is_optional: false,
2190                ..(*f).clone()
2191            },
2192            ModuleDepth::TopLevel,
2193        )
2194        .unwrap_or_else(|| quote! { ferriorm_runtime::filter::BigIntFilter });
2195        match kind {
2196            AggregateKind::Numeric => {
2197                having_fields
2198                    .push(quote! { pub #avg_name: Option<ferriorm_runtime::filter::FloatFilter> });
2199                having_fields
2200                    .push(quote! { pub #sum_name: Option<ferriorm_runtime::filter::FloatFilter> });
2201                having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2202                having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2203            }
2204            AggregateKind::DateTime => {
2205                having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2206                having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2207            }
2208        }
2209    }
2210
2211    // ── build_having arms ─────────────────────────────────────
2212    // Mirrors gen_where_arms but the LHS is the aggregate expression
2213    // (`AVG("col")`, `COUNT(*)`, ...) instead of a bare column reference.
2214    // Aggregate results are never NULL semantically except for empty inputs,
2215    // so we don't need IS NULL handling here.
2216    let mut having_arms: Vec<TokenStream> = Vec::new();
2217    // count: BigIntFilter on COUNT(*)
2218    having_arms.push(quote! {
2219        if let Some(filter) = &self.count {
2220            if let Some(v) = &filter.equals { qb.push(" AND COUNT(*) = "); qb.push_bind(*v); }
2221            if let Some(v) = &filter.not    { qb.push(" AND COUNT(*) != "); qb.push_bind(*v); }
2222            if let Some(v) = &filter.gt     { qb.push(" AND COUNT(*) > "); qb.push_bind(*v); }
2223            if let Some(v) = &filter.gte    { qb.push(" AND COUNT(*) >= "); qb.push_bind(*v); }
2224            if let Some(v) = &filter.lt     { qb.push(" AND COUNT(*) < "); qb.push_bind(*v); }
2225            if let Some(v) = &filter.lte    { qb.push(" AND COUNT(*) <= "); qb.push_bind(*v); }
2226        }
2227    });
2228
2229    for (f, kind) in &agg_fields {
2230        let snake = to_snake_case(&f.name);
2231        let db_name = &f.db_name;
2232        let avg_ident = format_ident!("avg_{}", snake);
2233        let sum_ident = format_ident!("sum_{}", snake);
2234        let min_ident = format_ident!("min_{}", snake);
2235        let max_ident = format_ident!("max_{}", snake);
2236        match kind {
2237            AggregateKind::Numeric => {
2238                let avg_lhs = format!(r#"AVG("{db_name}")"#);
2239                let sum_lhs = format!(r#"SUM("{db_name}")"#);
2240                let min_lhs = format!(r#"MIN("{db_name}")"#);
2241                let max_lhs = format!(r#"MAX("{db_name}")"#);
2242                having_arms.push(gen_having_comparable_arms(&avg_ident, &avg_lhs));
2243                having_arms.push(gen_having_comparable_arms(&sum_ident, &sum_lhs));
2244                having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2245                having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2246            }
2247            AggregateKind::DateTime => {
2248                let min_lhs = format!(r#"MIN("{db_name}")"#);
2249                let max_lhs = format!(r#"MAX("{db_name}")"#);
2250                having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2251                having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2252            }
2253        }
2254    }
2255
2256    // build_having binds the RHS of `AVG(col) op ?` (always f64) and
2257    // `COUNT(*) op ?` (always i64) regardless of which scalar types appear
2258    // in the model. Reuse collect_db_bounds for the column-type bounds
2259    // (needed by min/max filters), then top up with f64.
2260    let mut db_bounds = collect_db_bounds(scalar_fields);
2261    if !scalar_fields
2262        .iter()
2263        .any(|f| matches!(&f.field_type, FieldKind::Scalar(ScalarType::Float)))
2264    {
2265        db_bounds.push(quote! { f64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
2266    }
2267
2268    // ── is_numeric reuse: AggregateField is the canonical enum, but if no
2269    //    aggregatable fields exist we still want a typed group_by query --
2270    //    just without the agg ops. In that case AggregateField may not have
2271    //    been emitted at all, so skip avg/sum/min/max methods.
2272    let has_agg_fields = !agg_fields.is_empty();
2273
2274    let agg_methods = if has_agg_fields {
2275        quote! {
2276            pub fn count(mut self) -> Self {
2277                self.count = true;
2278                self
2279            }
2280
2281            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2282                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2283                let db_name = field.db_name();
2284                let alias = field.alias("avg");
2285                self.agg_ops.push(("AVG", db_name, alias));
2286                self
2287            }
2288
2289            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2290                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2291                let db_name = field.db_name();
2292                let alias = field.alias("sum");
2293                self.agg_ops.push(("SUM", db_name, alias));
2294                self
2295            }
2296
2297            pub fn min(mut self, field: #aggregate_field_name) -> Self {
2298                let db_name = field.db_name();
2299                let alias = field.alias("min");
2300                self.agg_ops.push(("MIN", db_name, alias));
2301                self
2302            }
2303
2304            pub fn max(mut self, field: #aggregate_field_name) -> Self {
2305                let db_name = field.db_name();
2306                let alias = field.alias("max");
2307                self.agg_ops.push(("MAX", db_name, alias));
2308                self
2309            }
2310        }
2311    } else {
2312        quote! {
2313            pub fn count(mut self) -> Self {
2314                self.count = true;
2315                self
2316            }
2317        }
2318    };
2319
2320    quote! {
2321        #[derive(Debug, Clone, Copy)]
2322        pub enum #groupby_field_name {
2323            #(#groupby_variants),*
2324        }
2325
2326        impl #groupby_field_name {
2327            pub fn db_name(&self) -> &'static str {
2328                match self {
2329                    #(#groupby_db_arms,)*
2330                }
2331            }
2332
2333            fn alias(&self) -> &'static str {
2334                match self {
2335                    #(#groupby_alias_arms,)*
2336                }
2337            }
2338        }
2339
2340        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2341        pub struct #groupby_result_name {
2342            #(#result_fields,)*
2343        }
2344
2345        #[derive(Debug, Clone, Default)]
2346        pub struct #having_input_name {
2347            #(#having_fields,)*
2348            pub and: Option<Vec<#having_input_name>>,
2349            pub or: Option<Vec<#having_input_name>>,
2350            pub not: Option<Box<#having_input_name>>,
2351        }
2352
2353        impl #having_input_name {
2354            pub(crate) fn build_having<'args, DB: sqlx::Database>(
2355                &self,
2356                qb: &mut sqlx::QueryBuilder<'args, DB>,
2357            )
2358            where
2359                #(#db_bounds,)*
2360            {
2361                #(#having_arms)*
2362
2363                if let Some(conditions) = &self.and {
2364                    for c in conditions {
2365                        c.build_having(qb);
2366                    }
2367                }
2368                if let Some(conditions) = &self.or {
2369                    if !conditions.is_empty() {
2370                        qb.push(" AND (");
2371                        for (i, c) in conditions.iter().enumerate() {
2372                            if i > 0 { qb.push(" OR "); }
2373                            qb.push("(1=1");
2374                            c.build_having(qb);
2375                            qb.push(")");
2376                        }
2377                        qb.push(")");
2378                    }
2379                }
2380                if let Some(c) = &self.not {
2381                    qb.push(" AND NOT (1=1");
2382                    c.build_having(qb);
2383                    qb.push(")");
2384                }
2385            }
2386        }
2387
2388        pub struct GroupByQuery<'a> {
2389            client: &'a DatabaseClient,
2390            r#where: filter::#where_input,
2391            group_keys: Vec<#groupby_field_name>,
2392            agg_ops: Vec<(&'static str, &'static str, &'static str)>,
2393            count: bool,
2394            having: Option<#having_input_name>,
2395        }
2396
2397        impl<'a> GroupByQuery<'a> {
2398            pub fn r#where(mut self, r#where: filter::#where_input) -> Self {
2399                self.r#where = r#where;
2400                self
2401            }
2402
2403            #agg_methods
2404
2405            pub fn having(mut self, having: #having_input_name) -> Self {
2406                self.having = Some(having);
2407                self
2408            }
2409
2410            pub async fn exec(self) -> Result<Vec<#groupby_result_name>, FerriormError> {
2411                if self.group_keys.is_empty() {
2412                    return Err(FerriormError::Query(
2413                        "group_by() requires at least one group key".into(),
2414                    ));
2415                }
2416
2417                let mut selections: Vec<String> = self.group_keys
2418                    .iter()
2419                    .map(|k| format!(r#""{}" as "{}""#, k.db_name(), k.alias()))
2420                    .collect();
2421                if self.count {
2422                    selections.push(r#"COUNT(*) as "count""#.to_string());
2423                }
2424                for (func, col, alias) in &self.agg_ops {
2425                    selections.push(format!(r#"{}("{}") as "{}""#, func, col, alias));
2426                }
2427
2428                let group_by_clause: Vec<String> = self.group_keys
2429                    .iter()
2430                    .map(|k| format!(r#""{}""#, k.db_name()))
2431                    .collect();
2432
2433                let base_sql = format!(
2434                    r#"SELECT {} FROM "{}" WHERE 1=1"#,
2435                    selections.join(", "),
2436                    #table_name,
2437                );
2438
2439                match self.client {
2440                    DatabaseClient::Postgres(_) => {
2441                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2442                        self.r#where.build_where(&mut qb);
2443                        qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2444                        if let Some(h) = &self.having {
2445                            qb.push(" HAVING 1=1");
2446                            h.build_having(&mut qb);
2447                        }
2448                        self.client.fetch_all_pg(qb).await
2449                    }
2450                    DatabaseClient::Sqlite(_) => {
2451                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2452                        self.r#where.build_where(&mut qb);
2453                        qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2454                        if let Some(h) = &self.having {
2455                            qb.push(" HAVING 1=1");
2456                            h.build_having(&mut qb);
2457                        }
2458                        self.client.fetch_all_sqlite(qb).await
2459                    }
2460                }
2461            }
2462        }
2463    }
2464}
2465
2466// ─── Select Types ─────────────────────────────────────────────
2467
2468#[allow(clippy::too_many_lines)]
2469fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2470    let select_name = format_ident!("{}Select", model.name);
2471    let partial_name = format_ident!("{}Partial", model.name);
2472    let _where_input = format_ident!("{}WhereInput", model.name);
2473    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
2474    let order_by_name = format_ident!("{}OrderByInput", model.name);
2475    let table_name = &model.db_name;
2476
2477    // Select struct fields: all bool, default false
2478    let select_fields: Vec<TokenStream> = scalar_fields
2479        .iter()
2480        .map(|f| {
2481            let name = format_ident!("{}", to_snake_case(&f.name));
2482            quote! { pub #name: bool }
2483        })
2484        .collect();
2485
2486    // Partial struct fields: all Option<T> with #[sqlx(default)]
2487    // For already-optional fields, don't double-wrap in Option
2488    let partial_fields: Vec<TokenStream> = scalar_fields
2489        .iter()
2490        .map(|f| {
2491            let name = format_ident!("{}", to_snake_case(&f.name));
2492            let db_name = &f.db_name;
2493            // Get the base type (non-optional version)
2494            let base_ty = rust_type_tokens(
2495                &Field {
2496                    is_optional: false,
2497                    ..(*f).clone()
2498                },
2499                ModuleDepth::TopLevel,
2500            );
2501            let rename = if db_name == &to_snake_case(&f.name) {
2502                quote! {}
2503            } else {
2504                quote! { #[sqlx(rename = #db_name)] }
2505            };
2506            // Always wrap in Option<T>, regardless of whether field was originally optional
2507            quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
2508        })
2509        .collect();
2510
2511    // build_select_columns: maps Select bools to column names
2512    let select_col_arms: Vec<TokenStream> = scalar_fields
2513        .iter()
2514        .map(|f| {
2515            let name = format_ident!("{}", to_snake_case(&f.name));
2516            let db_name = &f.db_name;
2517            let col_expr = format!(r#""{db_name}""#);
2518            quote! {
2519                if select.#name { cols.push(#col_expr); }
2520            }
2521        })
2522        .collect();
2523
2524    let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2525
2526    quote! {
2527        #[derive(Debug, Clone, Default)]
2528        pub struct #select_name {
2529            #(#select_fields,)*
2530        }
2531
2532        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2533        #[sqlx(rename_all = "snake_case")]
2534        pub struct #partial_name {
2535            #(#partial_fields,)*
2536        }
2537
2538        fn build_select_columns(select: &#select_name) -> String {
2539            let mut cols = Vec::new();
2540            #(#select_col_arms)*
2541            if cols.is_empty() {
2542                "*".to_string()
2543            } else {
2544                cols.join(", ")
2545            }
2546        }
2547
2548        // ── FindManySelectQuery ──────────────────────────────────
2549
2550        pub struct FindManySelectQuery<'a> {
2551            client: &'a DatabaseClient,
2552            r#where: filter::#_where_input,
2553            order_by: Vec<order::#order_by_name>,
2554            skip: Option<i64>,
2555            take: Option<i64>,
2556            select: #select_name,
2557        }
2558
2559        impl<'a> FindManySelectQuery<'a> {
2560            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2561                self.order_by.push(order);
2562                self
2563            }
2564
2565            pub fn skip(mut self, n: i64) -> Self {
2566                self.skip = Some(n);
2567                self
2568            }
2569
2570            pub fn take(mut self, n: i64) -> Self {
2571                self.take = Some(n);
2572                self
2573            }
2574
2575            pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2576                let cols = build_select_columns(&self.select);
2577                let base_sql = format!(#select_sql_prefix, cols);
2578
2579                match self.client {
2580                    DatabaseClient::Postgres(_) => {
2581                        let qb = build_select_query::<sqlx::Postgres>(
2582                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2583                        );
2584                        self.client.fetch_all_pg(qb).await
2585                    }
2586                    DatabaseClient::Sqlite(_) => {
2587                        let qb = build_select_query::<sqlx::Sqlite>(
2588                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2589                        );
2590                        self.client.fetch_all_sqlite(qb).await
2591                    }
2592                }
2593            }
2594        }
2595
2596        // ── FindUniqueSelectQuery ────────────────────────────────
2597
2598        pub struct FindUniqueSelectQuery<'a> {
2599            client: &'a DatabaseClient,
2600            r#where: filter::#_where_unique,
2601            select: #select_name,
2602        }
2603
2604        impl<'a> FindUniqueSelectQuery<'a> {
2605            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2606                let cols = build_select_columns(&self.select);
2607                let base_sql = format!(#select_sql_prefix, cols);
2608
2609                match self.client {
2610                    DatabaseClient::Postgres(_) => {
2611                        let qb = build_unique_select_query::<sqlx::Postgres>(
2612                            &base_sql, &self.r#where,
2613                        );
2614                        self.client.fetch_optional_pg(qb).await
2615                    }
2616                    DatabaseClient::Sqlite(_) => {
2617                        let qb = build_unique_select_query::<sqlx::Sqlite>(
2618                            &base_sql, &self.r#where,
2619                        );
2620                        self.client.fetch_optional_sqlite(qb).await
2621                    }
2622                }
2623            }
2624        }
2625
2626        // ── FindFirstSelectQuery ─────────────────────────────────
2627
2628        pub struct FindFirstSelectQuery<'a> {
2629            client: &'a DatabaseClient,
2630            r#where: filter::#_where_input,
2631            order_by: Vec<order::#order_by_name>,
2632            select: #select_name,
2633        }
2634
2635        impl<'a> FindFirstSelectQuery<'a> {
2636            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2637                self.order_by.push(order);
2638                self
2639            }
2640
2641            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2642                let cols = build_select_columns(&self.select);
2643                let base_sql = format!(#select_sql_prefix, cols);
2644
2645                match self.client {
2646                    DatabaseClient::Postgres(_) => {
2647                        let qb = build_select_query::<sqlx::Postgres>(
2648                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
2649                        );
2650                        self.client.fetch_optional_pg(qb).await
2651                    }
2652                    DatabaseClient::Sqlite(_) => {
2653                        let qb = build_select_query::<sqlx::Sqlite>(
2654                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
2655                        );
2656                        self.client.fetch_optional_sqlite(qb).await
2657                    }
2658                }
2659            }
2660        }
2661    }
2662}