Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23pub fn generate_model_module(model: &Model) -> TokenStream {
24    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
25
26    let data_struct = gen_data_struct(model, &scalar_fields);
27    let filter_module = gen_filter_module(model, &scalar_fields);
28    let data_module = gen_data_module(model, &scalar_fields);
29    let order_module = gen_order_module(model, &scalar_fields);
30    let actions_struct = gen_actions(model, &scalar_fields);
31    let query_builders = gen_query_builders(model, &scalar_fields);
32    let aggregate_types = gen_aggregate_types(model, &scalar_fields);
33    let select_types = gen_select_types(model, &scalar_fields);
34
35    quote! {
36        #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
37
38        use serde::{Deserialize, Serialize};
39        use ferriorm_runtime::prelude::*;
40        use ferriorm_runtime::prelude::sqlx;
41        use ferriorm_runtime::prelude::chrono;
42        use ferriorm_runtime::prelude::uuid;
43
44        #data_struct
45        #filter_module
46        #data_module
47        #order_module
48        #actions_struct
49        #query_builders
50        #aggregate_types
51        #select_types
52    }
53}
54
55// ─── Data Struct ──────────────────────────────────────────────
56
57fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
58    let struct_name = format_ident!("{}", model.name);
59    let table_name = &model.db_name;
60
61    let fields: Vec<TokenStream> = scalar_fields
62        .iter()
63        .map(|f| {
64            let name = format_ident!("{}", to_snake_case(&f.name));
65            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
66            let db_name = &f.db_name;
67            if db_name != &to_snake_case(&f.name) {
68                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
69            } else {
70                quote! { pub #name: #ty }
71            }
72        })
73        .collect();
74
75    quote! {
76        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
77        #[sqlx(rename_all = "snake_case")]
78        pub struct #struct_name {
79            #(#fields),*
80        }
81
82        impl #struct_name {
83            pub const TABLE_NAME: &'static str = #table_name;
84        }
85    }
86}
87
88// ─── Filter Module ────────────────────────────────────────────
89
90fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
91    let where_input = format_ident!("{}WhereInput", model.name);
92    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
93
94    let where_fields: Vec<TokenStream> = scalar_fields
95        .iter()
96        .filter_map(|f| {
97            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
98            let name = format_ident!("{}", to_snake_case(&f.name));
99            Some(quote! { pub #name: Option<#filter_ty> })
100        })
101        .collect();
102
103    let unique_variants: Vec<TokenStream> = scalar_fields
104        .iter()
105        .filter(|f| f.is_id || f.is_unique)
106        .map(|f| {
107            let variant = format_ident!("{}", to_pascal_case(&f.name));
108            let ty = rust_type_tokens(f, ModuleDepth::Nested);
109            quote! { #variant(#ty) }
110        })
111        .collect();
112
113    // Generate build_where for WhereInput
114    let db_bounds = collect_db_bounds(scalar_fields);
115    let where_arms = gen_where_arms(scalar_fields);
116    let unique_arms = gen_unique_where_arms(scalar_fields);
117
118    quote! {
119        pub mod filter {
120            use ferriorm_runtime::prelude::*;
121
122            #[derive(Debug, Clone, Default)]
123            pub struct #where_input {
124                #(#where_fields,)*
125                pub and: Option<Vec<#where_input>>,
126                pub or: Option<Vec<#where_input>>,
127                pub not: Option<Box<#where_input>>,
128            }
129
130            #[derive(Debug, Clone)]
131            pub enum #where_unique {
132                #(#unique_variants),*
133            }
134
135            impl #where_input {
136                pub(crate) fn build_where<'args, DB: sqlx::Database>(
137                    &self,
138                    qb: &mut sqlx::QueryBuilder<'args, DB>,
139                )
140                where
141                    #(#db_bounds,)*
142                {
143                    #(#where_arms)*
144
145                    if let Some(conditions) = &self.and {
146                        for c in conditions {
147                            c.build_where(qb);
148                        }
149                    }
150                    if let Some(conditions) = &self.or {
151                        if !conditions.is_empty() {
152                            qb.push(" AND (");
153                            for (i, c) in conditions.iter().enumerate() {
154                                if i > 0 { qb.push(" OR "); }
155                                qb.push("(1=1");
156                                c.build_where(qb);
157                                qb.push(")");
158                            }
159                            qb.push(")");
160                        }
161                    }
162                    if let Some(c) = &self.not {
163                        qb.push(" AND NOT (1=1");
164                        c.build_where(qb);
165                        qb.push(")");
166                    }
167                }
168            }
169
170            impl #where_unique {
171                pub(crate) fn build_where<'args, DB: sqlx::Database>(
172                    &self,
173                    qb: &mut sqlx::QueryBuilder<'args, DB>,
174                )
175                where
176                    #(#db_bounds,)*
177                {
178                    match self {
179                        #(#unique_arms)*
180                    }
181                }
182            }
183        }
184    }
185}
186
187/// Collect the sqlx type bounds needed for all scalar types used by the model.
188fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
189    let mut seen = std::collections::HashSet::new();
190    let mut bounds = Vec::new();
191
192    // Always need i64 for LIMIT/OFFSET
193    seen.insert("i64");
194    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
195
196    for f in scalar_fields {
197        match &f.field_type {
198            FieldKind::Scalar(scalar) => {
199                let key = scalar.rust_type();
200                if seen.insert(key)
201                    && let Some(ty) = scalar_bound_tokens(scalar)
202                {
203                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
204                    // Also add Option<T> bound for nullable field support
205                    bounds.push(
206                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
207                    );
208                }
209            }
210            FieldKind::Enum(_) => {}
211            _ => {}
212        }
213    }
214
215    bounds
216}
217
218fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
219    match scalar {
220        ScalarType::String => Some(quote! { String }),
221        ScalarType::Int => Some(quote! { i32 }),
222        ScalarType::BigInt => Some(quote! { i64 }),
223        ScalarType::Float => Some(quote! { f64 }),
224        ScalarType::Boolean => Some(quote! { bool }),
225        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
226        ScalarType::Bytes => Some(quote! { Vec<u8> }),
227        ScalarType::Json | ScalarType::Decimal => None,
228    }
229}
230
231/// Generate where-clause arms for each filterable scalar field.
232fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
233    scalar_fields
234        .iter()
235        .filter_map(|f| {
236            // Only generate filter arms for scalar types (skip enums for now)
237            if !matches!(&f.field_type, FieldKind::Scalar(_)) {
238                return None;
239            }
240            let field_ident = format_ident!("{}", to_snake_case(&f.name));
241            let db_name = &f.db_name;
242            let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
243            let is_comparable = matches!(
244                &f.field_type,
245                FieldKind::Scalar(
246                    ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
247                )
248            );
249
250            let mut arms = vec![];
251
252            arms.push(quote! {
253                if let Some(v) = &filter.equals {
254                    qb.push(concat!(" AND \"", #db_name, "\" = "));
255                    qb.push_bind(v.clone());
256                }
257                if let Some(v) = &filter.not {
258                    qb.push(concat!(" AND \"", #db_name, "\" != "));
259                    qb.push_bind(v.clone());
260                }
261            });
262
263            if is_string {
264                arms.push(quote! {
265                    if let Some(v) = &filter.contains {
266                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
267                        qb.push_bind(format!("%{}%", v));
268                    }
269                    if let Some(v) = &filter.starts_with {
270                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
271                        qb.push_bind(format!("{}%", v));
272                    }
273                    if let Some(v) = &filter.ends_with {
274                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
275                        qb.push_bind(format!("%{}", v));
276                    }
277                });
278            }
279
280            if is_comparable {
281                arms.push(quote! {
282                    if let Some(v) = &filter.gt {
283                        qb.push(concat!(" AND \"", #db_name, "\" > "));
284                        qb.push_bind(v.clone());
285                    }
286                    if let Some(v) = &filter.gte {
287                        qb.push(concat!(" AND \"", #db_name, "\" >= "));
288                        qb.push_bind(v.clone());
289                    }
290                    if let Some(v) = &filter.lt {
291                        qb.push(concat!(" AND \"", #db_name, "\" < "));
292                        qb.push_bind(v.clone());
293                    }
294                    if let Some(v) = &filter.lte {
295                        qb.push(concat!(" AND \"", #db_name, "\" <= "));
296                        qb.push_bind(v.clone());
297                    }
298                });
299            }
300
301            Some(quote! {
302                if let Some(filter) = &self.#field_ident {
303                    #(#arms)*
304                }
305            })
306        })
307        .collect()
308}
309
310fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
311    let _where_unique = format_ident!(
312        "{}WhereUniqueInput",
313        "" // placeholder, we use Self:: instead
314    );
315    scalar_fields
316        .iter()
317        .filter(|f| f.is_id || f.is_unique)
318        .map(|f| {
319            let variant = format_ident!("{}", to_pascal_case(&f.name));
320            let db_name = &f.db_name;
321            quote! {
322                Self::#variant(v) => {
323                    qb.push(concat!(" AND \"", #db_name, "\" = "));
324                    qb.push_bind(v.clone());
325                }
326            }
327        })
328        .collect()
329}
330
331// ─── Data Module ──────────────────────────────────────────────
332
333fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
334    let create_name = format_ident!("{}CreateInput", model.name);
335    let update_name = format_ident!("{}UpdateInput", model.name);
336
337    let required_fields: Vec<TokenStream> = scalar_fields
338        .iter()
339        .filter(|f| !f.has_default() && !f.is_updated_at)
340        .map(|f| {
341            let name = format_ident!("{}", to_snake_case(&f.name));
342            let ty = rust_type_tokens(f, ModuleDepth::Nested);
343            quote! { pub #name: #ty }
344        })
345        .collect();
346
347    let optional_fields: Vec<TokenStream> = scalar_fields
348        .iter()
349        .filter(|f| f.has_default() && !f.is_updated_at)
350        .map(|f| {
351            let name = format_ident!("{}", to_snake_case(&f.name));
352            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
353            quote! { pub #name: Option<#base_ty> }
354        })
355        .collect();
356
357    let update_fields: Vec<TokenStream> = scalar_fields
358        .iter()
359        .filter(|f| !f.is_id && !f.is_updated_at)
360        .map(|f| {
361            let name = format_ident!("{}", to_snake_case(&f.name));
362            let ty = rust_type_tokens(f, ModuleDepth::Nested);
363            quote! { pub #name: Option<SetValue<#ty>> }
364        })
365        .collect();
366
367    quote! {
368        pub mod data {
369            use ferriorm_runtime::prelude::*;
370
371            #[derive(Debug, Clone)]
372            pub struct #create_name {
373                #(#required_fields,)*
374                #(#optional_fields,)*
375            }
376
377            #[derive(Debug, Clone, Default)]
378            pub struct #update_name {
379                #(#update_fields,)*
380            }
381        }
382    }
383}
384
385// ─── Order Module ─────────────────────────────────────────────
386
387fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
388    let order_name = format_ident!("{}OrderByInput", model.name);
389
390    let variants: Vec<TokenStream> = scalar_fields
391        .iter()
392        .map(|f| {
393            let variant = format_ident!("{}", to_pascal_case(&f.name));
394            quote! { #variant(SortOrder) }
395        })
396        .collect();
397
398    let order_arms: Vec<TokenStream> = scalar_fields
399        .iter()
400        .map(|f| {
401            let variant = format_ident!("{}", to_pascal_case(&f.name));
402            let db_name = &f.db_name;
403            quote! {
404                Self::#variant(order) => {
405                    qb.push(concat!("\"", #db_name, "\" "));
406                    qb.push(order.as_sql());
407                }
408            }
409        })
410        .collect();
411
412    quote! {
413        pub mod order {
414            use ferriorm_runtime::prelude::*;
415
416            #[derive(Debug, Clone)]
417            pub enum #order_name {
418                #(#variants),*
419            }
420
421            impl #order_name {
422                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
423                    &self,
424                    qb: &mut sqlx::QueryBuilder<'args, DB>,
425                ) {
426                    match self {
427                        #(#order_arms)*
428                    }
429                }
430            }
431        }
432    }
433}
434
435// ─── Actions ──────────────────────────────────────────────────
436
437fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
438    let _model_ident = format_ident!("{}", model.name);
439    let actions_name = format_ident!("{}Actions", model.name);
440    let where_input = format_ident!("{}WhereInput", model.name);
441    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
442    let create_input = format_ident!("{}CreateInput", model.name);
443    let update_input = format_ident!("{}UpdateInput", model.name);
444    let _order_by = format_ident!("{}OrderByInput", model.name);
445
446    // Only generate aggregate() if there are aggregatable fields
447    let has_agg_fields = scalar_fields.iter().any(|f| {
448        matches!(
449            &f.field_type,
450            FieldKind::Scalar(
451                ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
452            )
453        )
454    });
455    let aggregate_method = if has_agg_fields {
456        quote! {
457            pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
458                AggregateQuery { client: self.client, r#where, ops: vec![] }
459            }
460        }
461    } else {
462        quote! {}
463    };
464
465    quote! {
466        pub struct #actions_name<'a> {
467            client: &'a DatabaseClient,
468        }
469
470        impl<'a> #actions_name<'a> {
471            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
472
473            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
474                FindUniqueQuery { client: self.client, r#where }
475            }
476
477            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
478                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
479            }
480
481            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
482                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
483            }
484
485            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
486                CreateQuery { client: self.client, data }
487            }
488
489            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
490                UpdateQuery { client: self.client, r#where, data }
491            }
492
493            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
494                DeleteQuery { client: self.client, r#where }
495            }
496
497            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
498                CountQuery { client: self.client, r#where }
499            }
500
501            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
502                CreateManyQuery { client: self.client, data }
503            }
504
505            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
506                UpdateManyQuery { client: self.client, r#where, data }
507            }
508
509            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
510                DeleteManyQuery { client: self.client, r#where }
511            }
512
513            pub fn upsert(
514                &self,
515                r#where: filter::#where_unique,
516                create: data::#create_input,
517                update: data::#update_input,
518            ) -> UpsertQuery<'a> {
519                UpsertQuery { client: self.client, r#where, create, update }
520            }
521
522            #aggregate_method
523        }
524    }
525}
526
527// ─── Query Builders with exec() ──────────────────────────────
528
529fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
530    let model_ident = format_ident!("{}", model.name);
531    let table_name = &model.db_name;
532    let _where_input = format_ident!("{}WhereInput", model.name);
533    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
534    let _create_input = format_ident!("{}CreateInput", model.name);
535    let _update_input = format_ident!("{}UpdateInput", model.name);
536    let order_by = format_ident!("{}OrderByInput", model.name);
537    let _select_struct = format_ident!("{}Select", model.name);
538    let _partial_struct = format_ident!("{}Partial", model.name);
539    let _aggregate_result = format_ident!("{}AggregateResult", model.name);
540    let _aggregate_field = format_ident!("{}AggregateField", model.name);
541    let db_bounds = collect_db_bounds(scalar_fields);
542
543    let select_sql = format!(r#"SELECT * FROM "{}" WHERE 1=1"#, table_name);
544    let count_sql = format!(
545        r#"SELECT COUNT(*) as "count" FROM "{}" WHERE 1=1"#,
546        table_name
547    );
548    let delete_sql = format!(r#"DELETE FROM "{}" WHERE 1=1"#, table_name);
549
550    let insert_code = gen_insert_code(model, scalar_fields, table_name);
551    let update_code = gen_update_code(model, scalar_fields, table_name);
552    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
553    let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
554
555    quote! {
556        // ── Generic helper: build ORDER BY clause ──────────────
557        fn build_order_by<'args, DB: sqlx::Database>(
558            orders: &[order::#order_by],
559            qb: &mut sqlx::QueryBuilder<'args, DB>,
560        ) {
561            if !orders.is_empty() {
562                qb.push(" ORDER BY ");
563                for (i, ob) in orders.iter().enumerate() {
564                    if i > 0 { qb.push(", "); }
565                    ob.build_order_by(qb);
566                }
567            }
568        }
569
570        // ── Generic helper: build a SELECT query ───────────────
571        fn build_select_query<'args, DB: sqlx::Database>(
572            base_sql: &str,
573            where_input: &filter::#_where_input,
574            orders: &[order::#order_by],
575            take: Option<i64>,
576            skip: Option<i64>,
577        ) -> sqlx::QueryBuilder<'args, DB>
578        where
579            #(#db_bounds,)*
580        {
581            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
582            where_input.build_where(&mut qb);
583            build_order_by(orders, &mut qb);
584            if let Some(take) = take {
585                qb.push(" LIMIT ");
586                qb.push_bind(take);
587            }
588            if let Some(skip) = skip {
589                qb.push(" OFFSET ");
590                qb.push_bind(skip);
591            }
592            qb
593        }
594
595        // ── Generic helper: build a SELECT query for unique lookup ──
596        fn build_unique_select_query<'args, DB: sqlx::Database>(
597            base_sql: &str,
598            where_unique: &filter::#_where_unique,
599        ) -> sqlx::QueryBuilder<'args, DB>
600        where
601            #(#db_bounds,)*
602        {
603            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
604            where_unique.build_where(&mut qb);
605            qb.push(" LIMIT 1");
606            qb
607        }
608
609        // ── Generic helper: build a DELETE-returning query ─────
610        fn build_delete_query<'args, DB: sqlx::Database>(
611            base_sql: &str,
612            where_unique: &filter::#_where_unique,
613        ) -> sqlx::QueryBuilder<'args, DB>
614        where
615            #(#db_bounds,)*
616        {
617            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
618            where_unique.build_where(&mut qb);
619            qb.push(" RETURNING *");
620            qb
621        }
622
623        // ── Generic helper: build a COUNT query ────────────────
624        fn build_count_query<'args, DB: sqlx::Database>(
625            base_sql: &str,
626            where_input: &filter::#_where_input,
627        ) -> sqlx::QueryBuilder<'args, DB>
628        where
629            #(#db_bounds,)*
630        {
631            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
632            where_input.build_where(&mut qb);
633            qb
634        }
635
636        // ── Generic helper: build a DELETE-many query ──────────
637        fn build_delete_many_query<'args, DB: sqlx::Database>(
638            base_sql: &str,
639            where_input: &filter::#_where_input,
640        ) -> sqlx::QueryBuilder<'args, DB>
641        where
642            #(#db_bounds,)*
643        {
644            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
645            where_input.build_where(&mut qb);
646            qb
647        }
648
649        pub struct FindUniqueQuery<'a> {
650            client: &'a DatabaseClient,
651            r#where: filter::#_where_unique,
652        }
653
654        impl<'a> FindUniqueQuery<'a> {
655            pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
656                FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
657            }
658
659            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
660                match self.client {
661                    DatabaseClient::Postgres(_) => {
662                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
663                        self.client.fetch_optional_pg(qb).await
664                    }
665                    DatabaseClient::Sqlite(_) => {
666                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
667                        self.client.fetch_optional_sqlite(qb).await
668                    }
669                }
670            }
671        }
672
673        pub struct FindFirstQuery<'a> {
674            client: &'a DatabaseClient,
675            r#where: filter::#_where_input,
676            order_by: Vec<order::#order_by>,
677        }
678
679        impl<'a> FindFirstQuery<'a> {
680            pub fn order_by(mut self, order: order::#order_by) -> Self {
681                self.order_by.push(order);
682                self
683            }
684
685            pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
686                FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
687            }
688
689            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
690                match self.client {
691                    DatabaseClient::Postgres(_) => {
692                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
693                        self.client.fetch_optional_pg(qb).await
694                    }
695                    DatabaseClient::Sqlite(_) => {
696                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
697                        self.client.fetch_optional_sqlite(qb).await
698                    }
699                }
700            }
701        }
702
703        pub struct FindManyQuery<'a> {
704            client: &'a DatabaseClient,
705            r#where: filter::#_where_input,
706            order_by: Vec<order::#order_by>,
707            skip: Option<i64>,
708            take: Option<i64>,
709        }
710
711        impl<'a> FindManyQuery<'a> {
712            pub fn order_by(mut self, order: order::#order_by) -> Self {
713                self.order_by.push(order);
714                self
715            }
716
717            pub fn skip(mut self, n: i64) -> Self {
718                self.skip = Some(n);
719                self
720            }
721
722            pub fn take(mut self, n: i64) -> Self {
723                self.take = Some(n);
724                self
725            }
726
727            pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
728                FindManySelectQuery {
729                    client: self.client,
730                    r#where: self.r#where,
731                    order_by: self.order_by,
732                    skip: self.skip,
733                    take: self.take,
734                    select,
735                }
736            }
737
738            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
739                match self.client {
740                    DatabaseClient::Postgres(_) => {
741                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
742                        self.client.fetch_all_pg(qb).await
743                    }
744                    DatabaseClient::Sqlite(_) => {
745                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
746                        self.client.fetch_all_sqlite(qb).await
747                    }
748                }
749            }
750        }
751
752        pub struct CreateQuery<'a> {
753            client: &'a DatabaseClient,
754            data: data::#_create_input,
755        }
756
757        impl<'a> CreateQuery<'a> {
758            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
759                let client = self.client;
760                #insert_code
761            }
762        }
763
764        pub struct UpdateQuery<'a> {
765            client: &'a DatabaseClient,
766            r#where: filter::#_where_unique,
767            data: data::#_update_input,
768        }
769
770        impl<'a> UpdateQuery<'a> {
771            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
772                let client = self.client;
773                #update_code
774            }
775        }
776
777        pub struct DeleteQuery<'a> {
778            client: &'a DatabaseClient,
779            r#where: filter::#_where_unique,
780        }
781
782        impl<'a> DeleteQuery<'a> {
783            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
784                match self.client {
785                    DatabaseClient::Postgres(_) => {
786                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
787                        self.client.fetch_one_pg(qb).await
788                    }
789                    DatabaseClient::Sqlite(_) => {
790                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
791                        self.client.fetch_one_sqlite(qb).await
792                    }
793                }
794            }
795        }
796
797        #[derive(sqlx::FromRow)]
798        struct CountResult { count: i64 }
799
800        pub struct CountQuery<'a> {
801            client: &'a DatabaseClient,
802            r#where: filter::#_where_input,
803        }
804
805        impl<'a> CountQuery<'a> {
806            pub async fn exec(self) -> Result<i64, FerriormError> {
807                let row: CountResult = match self.client {
808                    DatabaseClient::Postgres(_) => {
809                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
810                        self.client.fetch_one_pg(qb).await?
811                    }
812                    DatabaseClient::Sqlite(_) => {
813                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
814                        self.client.fetch_one_sqlite(qb).await?
815                    }
816                };
817                Ok(row.count)
818            }
819        }
820
821        pub struct CreateManyQuery<'a> {
822            client: &'a DatabaseClient,
823            data: Vec<data::#_create_input>,
824        }
825
826        impl<'a> CreateManyQuery<'a> {
827            pub async fn exec(self) -> Result<u64, FerriormError> {
828                if self.data.is_empty() { return Ok(0); }
829                let count = self.data.len() as u64;
830                for item in self.data {
831                    CreateQuery { client: self.client, data: item }.exec().await?;
832                }
833                Ok(count)
834            }
835        }
836
837        pub struct UpdateManyQuery<'a> {
838            client: &'a DatabaseClient,
839            r#where: filter::#_where_input,
840            data: data::#_update_input,
841        }
842
843        impl<'a> UpdateManyQuery<'a> {
844            pub async fn exec(self) -> Result<u64, FerriormError> {
845                let client = self.client;
846                #update_many_code
847            }
848        }
849
850        pub struct UpsertQuery<'a> {
851            client: &'a DatabaseClient,
852            r#where: filter::#_where_unique,
853            create: data::#_create_input,
854            update: data::#_update_input,
855        }
856
857        impl<'a> UpsertQuery<'a> {
858            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
859                let client = self.client;
860                #upsert_code
861            }
862        }
863
864        pub struct DeleteManyQuery<'a> {
865            client: &'a DatabaseClient,
866            r#where: filter::#_where_input,
867        }
868
869        impl<'a> DeleteManyQuery<'a> {
870            pub async fn exec(self) -> Result<u64, FerriormError> {
871                match self.client {
872                    DatabaseClient::Postgres(_) => {
873                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
874                        self.client.execute_pg(qb).await
875                    }
876                    DatabaseClient::Sqlite(_) => {
877                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
878                        self.client.execute_sqlite(qb).await
879                    }
880                }
881            }
882        }
883    }
884}
885
886// ─── INSERT code generation ───────────────────────────────────
887
888fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
889    let _model_ident = format_ident!("{}", model.name);
890
891    // Required columns: scalar, no default, not @updatedAt
892    let required: Vec<&Field> = scalar_fields
893        .iter()
894        .copied()
895        .filter(|f| !f.has_default() && !f.is_updated_at)
896        .collect();
897
898    // Optional columns: have default (can be overridden), not @updatedAt
899    let optional: Vec<&Field> = scalar_fields
900        .iter()
901        .copied()
902        .filter(|f| f.has_default() && !f.is_updated_at)
903        .collect();
904
905    // @updatedAt columns: always set to now()
906    let updated_at: Vec<&Field> = scalar_fields
907        .iter()
908        .copied()
909        .filter(|f| f.is_updated_at)
910        .collect();
911
912    // Build column names and bind values
913    let mut col_pushes = vec![];
914    let mut val_pushes = vec![];
915
916    // Required fields — always included
917    for f in &required {
918        let db_name = &f.db_name;
919        let field_ident = format_ident!("{}", to_snake_case(&f.name));
920        col_pushes.push(quote! { cols.push(#db_name); });
921        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
922    }
923
924    // Optional fields — resolve defaults in Rust
925    for f in &optional {
926        let db_name = &f.db_name;
927        let field_ident = format_ident!("{}", to_snake_case(&f.name));
928        let default_expr = gen_default_expr(f, &f.field_type);
929
930        col_pushes.push(quote! { cols.push(#db_name); });
931        val_pushes.push(quote! {
932            let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
933            sep.push_bind(val);
934        });
935    }
936
937    // @updatedAt fields
938    for f in &updated_at {
939        let db_name = &f.db_name;
940        col_pushes.push(quote! { cols.push(#db_name); });
941        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
942    }
943
944    let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
945
946    // The insert_body macro avoids duplicating the column/value building logic
947    // for each database backend. It captures `self` by reference.
948    quote! {
949        // Helper to build the INSERT query for any DB backend
950        macro_rules! build_insert {
951            ($qb_type:ty) => {{
952                let mut cols: Vec<&str> = Vec::new();
953                #(#col_pushes)*
954
955                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
956                qb.push(" (");
957                for (i, col) in cols.iter().enumerate() {
958                    if i > 0 { qb.push(", "); }
959                    qb.push("\"");
960                    qb.push(*col);
961                    qb.push("\"");
962                }
963                qb.push(") VALUES (");
964                {
965                    let mut sep = qb.separated(", ");
966                    #(#val_pushes)*
967                }
968                qb.push(") RETURNING *");
969                qb
970            }};
971        }
972
973        match client {
974            DatabaseClient::Postgres(_) => {
975                let qb = build_insert!(sqlx::Postgres);
976                client.fetch_one_pg(qb).await
977            }
978            DatabaseClient::Sqlite(_) => {
979                let qb = build_insert!(sqlx::Sqlite);
980                client.fetch_one_sqlite(qb).await
981            }
982        }
983    }
984}
985
986/// Generate a Rust expression for a field's @default value.
987fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
988    use ferriorm_core::ast::DefaultValue;
989
990    match &field.default {
991        Some(DefaultValue::Uuid) => quote! { uuid::Uuid::new_v4().to_string() },
992        Some(DefaultValue::Cuid) => quote! { uuid::Uuid::new_v4().to_string() }, // fallback
993        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
994        Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, // DB handles this
995        Some(DefaultValue::Literal(lit)) => {
996            use ferriorm_core::ast::LiteralValue;
997            match lit {
998                LiteralValue::String(s) => quote! { #s.to_string() },
999                LiteralValue::Int(i) => {
1000                    // Cast the integer literal to the correct Rust type based on the field's scalar type.
1001                    match field_type {
1002                        FieldKind::Scalar(ScalarType::Float) => {
1003                            let val = *i as f64;
1004                            quote! { #val }
1005                        }
1006                        FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1007                        _ => {
1008                            // Default to i32 for Int and other types
1009                            let val = *i as i32;
1010                            quote! { #val }
1011                        }
1012                    }
1013                }
1014                LiteralValue::Float(f) => quote! { #f },
1015                LiteralValue::Bool(b) => quote! { #b },
1016            }
1017        }
1018        Some(DefaultValue::EnumVariant(v)) => {
1019            // Reference the enum variant — insert code runs at model module level
1020            let variant = format_ident!("{}", v);
1021            if let FieldKind::Enum(enum_name) = &field.field_type {
1022                let enum_ident = format_ident!("{}", enum_name);
1023                quote! { super::enums::#enum_ident::#variant }
1024            } else {
1025                quote! { Default::default() }
1026            }
1027        }
1028        None => quote! { Default::default() },
1029    }
1030}
1031
1032// ─── UPDATE code generation ───────────────────────────────────
1033
1034fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1035    let _model_ident = format_ident!("{}", model.name);
1036
1037    // Updatable fields: non-id, non-updatedAt scalar fields
1038    let updatable: Vec<&Field> = scalar_fields
1039        .iter()
1040        .copied()
1041        .filter(|f| !f.is_id && !f.is_updated_at)
1042        .collect();
1043
1044    let updated_at: Vec<&Field> = scalar_fields
1045        .iter()
1046        .copied()
1047        .filter(|f| f.is_updated_at)
1048        .collect();
1049
1050    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1051
1052    // Generate SET clause arms
1053    let set_arms: Vec<TokenStream> = updatable
1054        .iter()
1055        .map(|f| {
1056            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1057            let db_name = &f.db_name;
1058            quote! {
1059                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1060                    if !first_set { qb.push(", "); }
1061                    first_set = false;
1062                    qb.push(concat!("\"", #db_name, "\" = "));
1063                    qb.push_bind(v);
1064                }
1065            }
1066        })
1067        .collect();
1068
1069    let updated_at_arms: Vec<TokenStream> = updated_at
1070        .iter()
1071        .map(|f| {
1072            let db_name = &f.db_name;
1073            quote! {
1074                if !first_set { qb.push(", "); }
1075                first_set = false;
1076                qb.push(concat!("\"", #db_name, "\" = "));
1077                qb.push_bind(chrono::Utc::now());
1078            }
1079        })
1080        .collect();
1081
1082    // The build_update macro avoids duplicating the SET clause building logic
1083    // for each database backend.
1084    quote! {
1085        macro_rules! build_update {
1086            ($qb_type:ty) => {{
1087                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1088                let mut first_set = true;
1089                #(#set_arms)*
1090                #(#updated_at_arms)*
1091
1092                if first_set {
1093                    return Err(FerriormError::Query("No fields to update".into()));
1094                }
1095
1096                qb.push(" WHERE 1=1");
1097                self.r#where.build_where(&mut qb);
1098                qb.push(" RETURNING *");
1099                qb
1100            }};
1101        }
1102
1103        match client {
1104            DatabaseClient::Postgres(_) => {
1105                let qb = build_update!(sqlx::Postgres);
1106                client.fetch_one_pg(qb).await
1107            }
1108            DatabaseClient::Sqlite(_) => {
1109                let qb = build_update!(sqlx::Sqlite);
1110                client.fetch_one_sqlite(qb).await
1111            }
1112        }
1113    }
1114}
1115
1116// ─── UPDATE MANY code generation ──────────────────────────────
1117
1118fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1119    // Updatable fields: non-id, non-updatedAt scalar fields
1120    let updatable: Vec<&Field> = scalar_fields
1121        .iter()
1122        .copied()
1123        .filter(|f| !f.is_id && !f.is_updated_at)
1124        .collect();
1125
1126    let updated_at: Vec<&Field> = scalar_fields
1127        .iter()
1128        .copied()
1129        .filter(|f| f.is_updated_at)
1130        .collect();
1131
1132    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1133
1134    // Generate SET clause arms
1135    let set_arms: Vec<TokenStream> = updatable
1136        .iter()
1137        .map(|f| {
1138            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1139            let db_name = &f.db_name;
1140            quote! {
1141                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1142                    if !first_set { qb.push(", "); }
1143                    first_set = false;
1144                    qb.push(concat!("\"", #db_name, "\" = "));
1145                    qb.push_bind(v);
1146                }
1147            }
1148        })
1149        .collect();
1150
1151    let updated_at_arms: Vec<TokenStream> = updated_at
1152        .iter()
1153        .map(|f| {
1154            let db_name = &f.db_name;
1155            quote! {
1156                if !first_set { qb.push(", "); }
1157                first_set = false;
1158                qb.push(concat!("\"", #db_name, "\" = "));
1159                qb.push_bind(chrono::Utc::now());
1160            }
1161        })
1162        .collect();
1163
1164    quote! {
1165        macro_rules! build_update_many {
1166            ($qb_type:ty) => {{
1167                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1168                let mut first_set = true;
1169                #(#set_arms)*
1170                #(#updated_at_arms)*
1171
1172                if first_set {
1173                    return Ok(0);
1174                }
1175
1176                qb.push(" WHERE 1=1");
1177                self.r#where.build_where(&mut qb);
1178                qb
1179            }};
1180        }
1181
1182        match client {
1183            DatabaseClient::Postgres(_) => {
1184                let qb = build_update_many!(sqlx::Postgres);
1185                client.execute_pg(qb).await
1186            }
1187            DatabaseClient::Sqlite(_) => {
1188                let qb = build_update_many!(sqlx::Sqlite);
1189                client.execute_sqlite(qb).await
1190            }
1191        }
1192    }
1193}
1194
1195// ─── Aggregate Types ──────────────────────────────────────────
1196
1197/// Identifies which fields are aggregatable and what operations they support.
1198enum AggregateKind {
1199    /// Numeric fields: avg, sum, min, max
1200    Numeric,
1201    /// DateTime fields: min, max only
1202    DateTime,
1203}
1204
1205// ─── UPSERT code generation ──────────────────────────────────
1206
1207fn gen_upsert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1208    // Collect primary key db_names for ON CONFLICT clause
1209    let pk_db_names: Vec<String> = model
1210        .primary_key
1211        .fields
1212        .iter()
1213        .filter_map(|pk| {
1214            model
1215                .fields
1216                .iter()
1217                .find(|f| f.name == *pk || to_snake_case(&f.name) == *pk)
1218                .map(|f| f.db_name.clone())
1219        })
1220        .collect();
1221    let pk_conflict_cols = pk_db_names
1222        .iter()
1223        .map(|c| format!("\"{}\"", c))
1224        .collect::<Vec<_>>()
1225        .join(", ");
1226
1227    // Required + optional + updatedAt fields for the INSERT part (same as gen_insert_code)
1228    let required: Vec<&Field> = scalar_fields
1229        .iter()
1230        .copied()
1231        .filter(|f| !f.has_default() && !f.is_updated_at)
1232        .collect();
1233    let optional: Vec<&Field> = scalar_fields
1234        .iter()
1235        .copied()
1236        .filter(|f| f.has_default() && !f.is_updated_at)
1237        .collect();
1238    let updated_at: Vec<&Field> = scalar_fields
1239        .iter()
1240        .copied()
1241        .filter(|f| f.is_updated_at)
1242        .collect();
1243
1244    let mut col_pushes = vec![];
1245    let mut val_pushes = vec![];
1246
1247    for f in &required {
1248        let db_name = &f.db_name;
1249        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1250        col_pushes.push(quote! { cols.push(#db_name); });
1251        val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1252    }
1253    for f in &optional {
1254        let db_name = &f.db_name;
1255        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1256        let default_expr = gen_default_expr(f, &f.field_type);
1257        col_pushes.push(quote! { cols.push(#db_name); });
1258        val_pushes.push(quote! {
1259            let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1260            sep.push_bind(val);
1261        });
1262    }
1263    for f in &updated_at {
1264        let db_name = &f.db_name;
1265        col_pushes.push(quote! { cols.push(#db_name); });
1266        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1267    }
1268
1269    // Updatable fields for the DO UPDATE SET part
1270    let updatable: Vec<&Field> = scalar_fields
1271        .iter()
1272        .copied()
1273        .filter(|f| !f.is_id && !f.is_updated_at)
1274        .collect();
1275
1276    let set_arms: Vec<TokenStream> = updatable
1277        .iter()
1278        .map(|f| {
1279            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1280            let db_name = &f.db_name;
1281            quote! {
1282                if let Some(SetValue::Set(v)) = self.update.#field_ident {
1283                    if !first_set { qb.push(", "); }
1284                    first_set = false;
1285                    qb.push(concat!("\"", #db_name, "\" = "));
1286                    qb.push_bind(v);
1287                }
1288            }
1289        })
1290        .collect();
1291
1292    let updated_at_set: Vec<TokenStream> = updated_at
1293        .iter()
1294        .map(|f| {
1295            let db_name = &f.db_name;
1296            quote! {
1297                if !first_set { qb.push(", "); }
1298                first_set = false;
1299                qb.push(concat!("\"", #db_name, "\" = "));
1300                qb.push_bind(chrono::Utc::now());
1301            }
1302        })
1303        .collect();
1304
1305    let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
1306    let conflict_clause = format!(" ON CONFLICT ({}) DO UPDATE SET ", pk_conflict_cols);
1307    let noop_set = format!(
1308        r#""{}" = "{}""#,
1309        pk_db_names.first().unwrap_or(&"id".to_string()),
1310        pk_db_names.first().unwrap_or(&"id".to_string()),
1311    );
1312
1313    quote! {
1314        macro_rules! build_upsert {
1315            ($qb_type:ty) => {{
1316                let mut cols: Vec<&str> = Vec::new();
1317                #(#col_pushes)*
1318
1319                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1320                qb.push(" (");
1321                for (i, col) in cols.iter().enumerate() {
1322                    if i > 0 { qb.push(", "); }
1323                    qb.push("\"");
1324                    qb.push(*col);
1325                    qb.push("\"");
1326                }
1327                qb.push(") VALUES (");
1328                {
1329                    let mut sep = qb.separated(", ");
1330                    #(#val_pushes)*
1331                }
1332                qb.push(")");
1333                qb.push(#conflict_clause);
1334
1335                let mut first_set = true;
1336                #(#set_arms)*
1337                #(#updated_at_set)*
1338
1339                if first_set {
1340                    // No update fields specified — use a no-op update on the PK
1341                    qb.push(#noop_set);
1342                }
1343
1344                qb.push(" RETURNING *");
1345                qb
1346            }};
1347        }
1348
1349        match client {
1350            DatabaseClient::Postgres(_) => {
1351                let qb = build_upsert!(sqlx::Postgres);
1352                client.fetch_one_pg(qb).await
1353            }
1354            DatabaseClient::Sqlite(_) => {
1355                let qb = build_upsert!(sqlx::Sqlite);
1356                client.fetch_one_sqlite(qb).await
1357            }
1358        }
1359    }
1360}
1361
1362fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1363    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1364    let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1365    let _where_input = format_ident!("{}WhereInput", model.name);
1366    let table_name = &model.db_name;
1367
1368    // Collect aggregatable fields with their kind
1369    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1370        .iter()
1371        .filter_map(|f| match &f.field_type {
1372            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1373                Some((*f, AggregateKind::Numeric))
1374            }
1375            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1376            _ => None,
1377        })
1378        .collect();
1379
1380    if agg_fields.is_empty() {
1381        return quote! {};
1382    }
1383
1384    // Generate enum variants
1385    let enum_variants: Vec<TokenStream> = agg_fields
1386        .iter()
1387        .map(|(f, _)| {
1388            let variant = format_ident!("{}", to_pascal_case(&f.name));
1389            quote! { #variant }
1390        })
1391        .collect();
1392
1393    // Generate db_name match arms
1394    let db_name_arms: Vec<TokenStream> = agg_fields
1395        .iter()
1396        .map(|(f, _)| {
1397            let variant = format_ident!("{}", to_pascal_case(&f.name));
1398            let db_name = &f.db_name;
1399            quote! { Self::#variant => #db_name }
1400        })
1401        .collect();
1402
1403    // Generate AggregateResult fields
1404    let mut result_fields = Vec::new();
1405    for (f, kind) in &agg_fields {
1406        let snake = to_snake_case(&f.name);
1407        let orig_ty = rust_type_tokens(
1408            &Field {
1409                is_optional: false,
1410                ..(*f).clone()
1411            },
1412            ModuleDepth::TopLevel,
1413        );
1414
1415        match kind {
1416            AggregateKind::Numeric => {
1417                let avg_name = format_ident!("avg_{}", snake);
1418                let sum_name = format_ident!("sum_{}", snake);
1419                let min_name = format_ident!("min_{}", snake);
1420                let max_name = format_ident!("max_{}", snake);
1421                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1422                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1423                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1424                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1425            }
1426            AggregateKind::DateTime => {
1427                let min_name = format_ident!("min_{}", snake);
1428                let max_name = format_ident!("max_{}", snake);
1429                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1430                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1431            }
1432        }
1433    }
1434
1435    // Generate the is_numeric check for avg/sum validation
1436    let numeric_arms: Vec<TokenStream> = agg_fields
1437        .iter()
1438        .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1439        .map(|(f, _)| {
1440            let variant = format_ident!("{}", to_pascal_case(&f.name));
1441            quote! { Self::#variant => true }
1442        })
1443        .collect();
1444
1445    let has_numeric = !numeric_arms.is_empty();
1446    let is_numeric_method = if has_numeric {
1447        quote! {
1448            fn is_numeric(&self) -> bool {
1449                match self {
1450                    #(#numeric_arms,)*
1451                    #[allow(unreachable_patterns)]
1452                    _ => false,
1453                }
1454            }
1455        }
1456    } else {
1457        quote! {
1458            fn is_numeric(&self) -> bool { false }
1459        }
1460    };
1461
1462    // Generate alias match arms for each (prefix, field) combination
1463    let mut alias_arms = Vec::new();
1464    for (f, kind) in &agg_fields {
1465        let variant = format_ident!("{}", to_pascal_case(&f.name));
1466        let snake = to_snake_case(&f.name);
1467        let prefixes = match kind {
1468            AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1469            AggregateKind::DateTime => vec!["min", "max"],
1470        };
1471        for prefix in prefixes {
1472            let alias_str = format!("{}_{}", prefix, snake);
1473            alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1474        }
1475    }
1476
1477    let agg_select_base = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1478
1479    quote! {
1480        #[derive(Debug, Clone, Copy)]
1481        pub enum #aggregate_field_name {
1482            #(#enum_variants),*
1483        }
1484
1485        impl #aggregate_field_name {
1486            pub fn db_name(&self) -> &'static str {
1487                match self {
1488                    #(#db_name_arms,)*
1489                }
1490            }
1491
1492            fn alias(&self, prefix: &'static str) -> &'static str {
1493                match (prefix, self) {
1494                    #(#alias_arms,)*
1495                    _ => unreachable!(),
1496                }
1497            }
1498
1499            #is_numeric_method
1500        }
1501
1502        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1503        pub struct #aggregate_result_name {
1504            #(#result_fields,)*
1505        }
1506
1507        pub struct AggregateQuery<'a> {
1508            client: &'a DatabaseClient,
1509            r#where: filter::#_where_input,
1510            ops: Vec<(&'static str, &'static str, &'static str)>,
1511        }
1512
1513        impl<'a> AggregateQuery<'a> {
1514            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1515                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1516                let db_name = field.db_name();
1517                let alias = field.alias("avg");
1518                self.ops.push(("AVG", db_name, alias));
1519                self
1520            }
1521
1522            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1523                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1524                let db_name = field.db_name();
1525                let alias = field.alias("sum");
1526                self.ops.push(("SUM", db_name, alias));
1527                self
1528            }
1529
1530            pub fn min(mut self, field: #aggregate_field_name) -> Self {
1531                let db_name = field.db_name();
1532                let alias = field.alias("min");
1533                self.ops.push(("MIN", db_name, alias));
1534                self
1535            }
1536
1537            pub fn max(mut self, field: #aggregate_field_name) -> Self {
1538                let db_name = field.db_name();
1539                let alias = field.alias("max");
1540                self.ops.push(("MAX", db_name, alias));
1541                self
1542            }
1543
1544            pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1545                if self.ops.is_empty() {
1546                    return Err(FerriormError::Query("No aggregate operations specified".into()));
1547                }
1548
1549                let selections: Vec<String> = self.ops.iter()
1550                    .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1551                    .collect();
1552                let select_clause = selections.join(", ");
1553                let base_sql = format!(#agg_select_base, select_clause);
1554
1555                match self.client {
1556                    DatabaseClient::Postgres(_) => {
1557                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1558                        self.r#where.build_where(&mut qb);
1559                        self.client.fetch_one_pg(qb).await
1560                    }
1561                    DatabaseClient::Sqlite(_) => {
1562                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1563                        self.r#where.build_where(&mut qb);
1564                        self.client.fetch_one_sqlite(qb).await
1565                    }
1566                }
1567            }
1568        }
1569    }
1570}
1571
1572// ─── Select Types ─────────────────────────────────────────────
1573
1574fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1575    let select_name = format_ident!("{}Select", model.name);
1576    let partial_name = format_ident!("{}Partial", model.name);
1577    let _where_input = format_ident!("{}WhereInput", model.name);
1578    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1579    let order_by_name = format_ident!("{}OrderByInput", model.name);
1580    let table_name = &model.db_name;
1581
1582    // Select struct fields: all bool, default false
1583    let select_fields: Vec<TokenStream> = scalar_fields
1584        .iter()
1585        .map(|f| {
1586            let name = format_ident!("{}", to_snake_case(&f.name));
1587            quote! { pub #name: bool }
1588        })
1589        .collect();
1590
1591    // Partial struct fields: all Option<T> with #[sqlx(default)]
1592    // For already-optional fields, don't double-wrap in Option
1593    let partial_fields: Vec<TokenStream> = scalar_fields
1594        .iter()
1595        .map(|f| {
1596            let name = format_ident!("{}", to_snake_case(&f.name));
1597            let db_name = &f.db_name;
1598            // Get the base type (non-optional version)
1599            let base_ty = rust_type_tokens(
1600                &Field {
1601                    is_optional: false,
1602                    ..(*f).clone()
1603                },
1604                ModuleDepth::TopLevel,
1605            );
1606            let rename = if db_name != &to_snake_case(&f.name) {
1607                quote! { #[sqlx(rename = #db_name)] }
1608            } else {
1609                quote! {}
1610            };
1611            // Always wrap in Option<T>, regardless of whether field was originally optional
1612            quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1613        })
1614        .collect();
1615
1616    // build_select_columns: maps Select bools to column names
1617    let select_col_arms: Vec<TokenStream> = scalar_fields
1618        .iter()
1619        .map(|f| {
1620            let name = format_ident!("{}", to_snake_case(&f.name));
1621            let db_name = &f.db_name;
1622            let col_expr = format!(r#""{}""#, db_name);
1623            quote! {
1624                if select.#name { cols.push(#col_expr); }
1625            }
1626        })
1627        .collect();
1628
1629    let select_sql_prefix = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1630
1631    quote! {
1632        #[derive(Debug, Clone, Default)]
1633        pub struct #select_name {
1634            #(#select_fields,)*
1635        }
1636
1637        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
1638        #[sqlx(rename_all = "snake_case")]
1639        pub struct #partial_name {
1640            #(#partial_fields,)*
1641        }
1642
1643        fn build_select_columns(select: &#select_name) -> String {
1644            let mut cols = Vec::new();
1645            #(#select_col_arms)*
1646            if cols.is_empty() {
1647                "*".to_string()
1648            } else {
1649                cols.join(", ")
1650            }
1651        }
1652
1653        // ── FindManySelectQuery ──────────────────────────────────
1654
1655        pub struct FindManySelectQuery<'a> {
1656            client: &'a DatabaseClient,
1657            r#where: filter::#_where_input,
1658            order_by: Vec<order::#order_by_name>,
1659            skip: Option<i64>,
1660            take: Option<i64>,
1661            select: #select_name,
1662        }
1663
1664        impl<'a> FindManySelectQuery<'a> {
1665            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1666                self.order_by.push(order);
1667                self
1668            }
1669
1670            pub fn skip(mut self, n: i64) -> Self {
1671                self.skip = Some(n);
1672                self
1673            }
1674
1675            pub fn take(mut self, n: i64) -> Self {
1676                self.take = Some(n);
1677                self
1678            }
1679
1680            pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
1681                let cols = build_select_columns(&self.select);
1682                let base_sql = format!(#select_sql_prefix, cols);
1683
1684                match self.client {
1685                    DatabaseClient::Postgres(_) => {
1686                        let qb = build_select_query::<sqlx::Postgres>(
1687                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1688                        );
1689                        self.client.fetch_all_pg(qb).await
1690                    }
1691                    DatabaseClient::Sqlite(_) => {
1692                        let qb = build_select_query::<sqlx::Sqlite>(
1693                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1694                        );
1695                        self.client.fetch_all_sqlite(qb).await
1696                    }
1697                }
1698            }
1699        }
1700
1701        // ── FindUniqueSelectQuery ────────────────────────────────
1702
1703        pub struct FindUniqueSelectQuery<'a> {
1704            client: &'a DatabaseClient,
1705            r#where: filter::#_where_unique,
1706            select: #select_name,
1707        }
1708
1709        impl<'a> FindUniqueSelectQuery<'a> {
1710            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1711                let cols = build_select_columns(&self.select);
1712                let base_sql = format!(#select_sql_prefix, cols);
1713
1714                match self.client {
1715                    DatabaseClient::Postgres(_) => {
1716                        let qb = build_unique_select_query::<sqlx::Postgres>(
1717                            &base_sql, &self.r#where,
1718                        );
1719                        self.client.fetch_optional_pg(qb).await
1720                    }
1721                    DatabaseClient::Sqlite(_) => {
1722                        let qb = build_unique_select_query::<sqlx::Sqlite>(
1723                            &base_sql, &self.r#where,
1724                        );
1725                        self.client.fetch_optional_sqlite(qb).await
1726                    }
1727                }
1728            }
1729        }
1730
1731        // ── FindFirstSelectQuery ─────────────────────────────────
1732
1733        pub struct FindFirstSelectQuery<'a> {
1734            client: &'a DatabaseClient,
1735            r#where: filter::#_where_input,
1736            order_by: Vec<order::#order_by_name>,
1737            select: #select_name,
1738        }
1739
1740        impl<'a> FindFirstSelectQuery<'a> {
1741            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1742                self.order_by.push(order);
1743                self
1744            }
1745
1746            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1747                let cols = build_select_columns(&self.select);
1748                let base_sql = format!(#select_sql_prefix, cols);
1749
1750                match self.client {
1751                    DatabaseClient::Postgres(_) => {
1752                        let qb = build_select_query::<sqlx::Postgres>(
1753                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1754                        );
1755                        self.client.fetch_optional_pg(qb).await
1756                    }
1757                    DatabaseClient::Sqlite(_) => {
1758                        let qb = build_select_query::<sqlx::Sqlite>(
1759                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1760                        );
1761                        self.client.fetch_optional_sqlite(qb).await
1762                    }
1763                }
1764            }
1765        }
1766    }
1767}