Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27    let data_struct = gen_data_struct(model, &scalar_fields);
28    let filter_module = gen_filter_module(model, &scalar_fields);
29    let data_module = gen_data_module(model, &scalar_fields);
30    let order_module = gen_order_module(model, &scalar_fields);
31    let actions_struct = gen_actions(model, &scalar_fields);
32    let query_builders = gen_query_builders(model, &scalar_fields);
33    let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34    let select_types = gen_select_types(model, &scalar_fields);
35
36    quote! {
37        #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
38
39        use serde::{Deserialize, Serialize};
40        use ferriorm_runtime::prelude::*;
41        use ferriorm_runtime::prelude::sqlx;
42        use ferriorm_runtime::prelude::chrono;
43        use ferriorm_runtime::prelude::uuid;
44
45        #data_struct
46        #filter_module
47        #data_module
48        #order_module
49        #actions_struct
50        #query_builders
51        #aggregate_types
52        #select_types
53    }
54}
55
56// ─── Data Struct ──────────────────────────────────────────────
57
58fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
59    let struct_name = format_ident!("{}", model.name);
60    let table_name = &model.db_name;
61
62    let fields: Vec<TokenStream> = scalar_fields
63        .iter()
64        .map(|f| {
65            let name = format_ident!("{}", to_snake_case(&f.name));
66            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
67            let db_name = &f.db_name;
68            if db_name == &to_snake_case(&f.name) {
69                quote! { pub #name: #ty }
70            } else {
71                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
72            }
73        })
74        .collect();
75
76    quote! {
77        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
78        #[sqlx(rename_all = "snake_case")]
79        pub struct #struct_name {
80            #(#fields),*
81        }
82
83        impl #struct_name {
84            pub const TABLE_NAME: &'static str = #table_name;
85        }
86    }
87}
88
89// ─── Filter Module ────────────────────────────────────────────
90
91fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
92    let where_input = format_ident!("{}WhereInput", model.name);
93    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
94
95    let where_fields: Vec<TokenStream> = scalar_fields
96        .iter()
97        .filter_map(|f| {
98            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
99            let name = format_ident!("{}", to_snake_case(&f.name));
100            Some(quote! { pub #name: Option<#filter_ty> })
101        })
102        .collect();
103
104    let unique_variants: Vec<TokenStream> = scalar_fields
105        .iter()
106        .filter(|f| f.is_id || f.is_unique)
107        .map(|f| {
108            let variant = format_ident!("{}", to_pascal_case(&f.name));
109            let ty = rust_type_tokens(f, ModuleDepth::Nested);
110            quote! { #variant(#ty) }
111        })
112        .collect();
113
114    // Generate build_where for WhereInput
115    let db_bounds = collect_db_bounds(scalar_fields);
116    let where_arms = gen_where_arms(scalar_fields);
117    let unique_arms = gen_unique_where_arms(scalar_fields);
118
119    quote! {
120        pub mod filter {
121            use ferriorm_runtime::prelude::*;
122
123            #[derive(Debug, Clone, Default)]
124            pub struct #where_input {
125                #(#where_fields,)*
126                pub and: Option<Vec<#where_input>>,
127                pub or: Option<Vec<#where_input>>,
128                pub not: Option<Box<#where_input>>,
129            }
130
131            #[derive(Debug, Clone)]
132            pub enum #where_unique {
133                #(#unique_variants),*
134            }
135
136            impl #where_input {
137                pub(crate) fn build_where<'args, DB: sqlx::Database>(
138                    &self,
139                    qb: &mut sqlx::QueryBuilder<'args, DB>,
140                )
141                where
142                    #(#db_bounds,)*
143                {
144                    #(#where_arms)*
145
146                    if let Some(conditions) = &self.and {
147                        for c in conditions {
148                            c.build_where(qb);
149                        }
150                    }
151                    if let Some(conditions) = &self.or {
152                        if !conditions.is_empty() {
153                            qb.push(" AND (");
154                            for (i, c) in conditions.iter().enumerate() {
155                                if i > 0 { qb.push(" OR "); }
156                                qb.push("(1=1");
157                                c.build_where(qb);
158                                qb.push(")");
159                            }
160                            qb.push(")");
161                        }
162                    }
163                    if let Some(c) = &self.not {
164                        qb.push(" AND NOT (1=1");
165                        c.build_where(qb);
166                        qb.push(")");
167                    }
168                }
169            }
170
171            impl #where_unique {
172                pub(crate) fn build_where<'args, DB: sqlx::Database>(
173                    &self,
174                    qb: &mut sqlx::QueryBuilder<'args, DB>,
175                )
176                where
177                    #(#db_bounds,)*
178                {
179                    match self {
180                        #(#unique_arms)*
181                    }
182                }
183            }
184        }
185    }
186}
187
188/// Collect the sqlx type bounds needed for all scalar types used by the model.
189fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
190    let mut seen = std::collections::HashSet::new();
191    let mut bounds = Vec::new();
192
193    // Always need i64 for LIMIT/OFFSET
194    seen.insert("i64");
195    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
196
197    for f in scalar_fields {
198        match &f.field_type {
199            FieldKind::Scalar(scalar) => {
200                let key = scalar.rust_type();
201                if seen.insert(key)
202                    && let Some(ty) = scalar_bound_tokens(scalar)
203                {
204                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
205                    // Also add Option<T> bound for nullable field support
206                    bounds.push(
207                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
208                    );
209                }
210            }
211            FieldKind::Enum(_) | FieldKind::Model(_) => {}
212        }
213    }
214
215    bounds
216}
217
218fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
219    match scalar {
220        ScalarType::String => Some(quote! { String }),
221        ScalarType::Int => Some(quote! { i32 }),
222        ScalarType::BigInt => Some(quote! { i64 }),
223        ScalarType::Float => Some(quote! { f64 }),
224        ScalarType::Boolean => Some(quote! { bool }),
225        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
226        ScalarType::Bytes => Some(quote! { Vec<u8> }),
227        ScalarType::Json | ScalarType::Decimal => None,
228    }
229}
230
231/// Generate where-clause arms for each filterable scalar field.
232fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
233    scalar_fields
234        .iter()
235        .filter_map(|f| {
236            // Only generate filter arms for scalar types (skip enums for now)
237            if !matches!(&f.field_type, FieldKind::Scalar(_)) {
238                return None;
239            }
240            let field_ident = format_ident!("{}", to_snake_case(&f.name));
241            let db_name = &f.db_name;
242            let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
243            let is_comparable = matches!(
244                &f.field_type,
245                FieldKind::Scalar(
246                    ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
247                )
248            );
249
250            let mut arms = vec![];
251
252            arms.push(quote! {
253                if let Some(v) = &filter.equals {
254                    qb.push(concat!(" AND \"", #db_name, "\" = "));
255                    qb.push_bind(v.clone());
256                }
257                if let Some(v) = &filter.not {
258                    qb.push(concat!(" AND \"", #db_name, "\" != "));
259                    qb.push_bind(v.clone());
260                }
261            });
262
263            if is_string {
264                arms.push(quote! {
265                    if let Some(v) = &filter.contains {
266                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
267                        qb.push_bind(format!("%{}%", v));
268                    }
269                    if let Some(v) = &filter.starts_with {
270                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
271                        qb.push_bind(format!("{}%", v));
272                    }
273                    if let Some(v) = &filter.ends_with {
274                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
275                        qb.push_bind(format!("%{}", v));
276                    }
277                });
278            }
279
280            if is_comparable {
281                arms.push(quote! {
282                    if let Some(v) = &filter.gt {
283                        qb.push(concat!(" AND \"", #db_name, "\" > "));
284                        qb.push_bind(v.clone());
285                    }
286                    if let Some(v) = &filter.gte {
287                        qb.push(concat!(" AND \"", #db_name, "\" >= "));
288                        qb.push_bind(v.clone());
289                    }
290                    if let Some(v) = &filter.lt {
291                        qb.push(concat!(" AND \"", #db_name, "\" < "));
292                        qb.push_bind(v.clone());
293                    }
294                    if let Some(v) = &filter.lte {
295                        qb.push(concat!(" AND \"", #db_name, "\" <= "));
296                        qb.push_bind(v.clone());
297                    }
298                });
299            }
300
301            Some(quote! {
302                if let Some(filter) = &self.#field_ident {
303                    #(#arms)*
304                }
305            })
306        })
307        .collect()
308}
309
310fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
311    let _where_unique = format_ident!(
312        "{}WhereUniqueInput",
313        "" // placeholder, we use Self:: instead
314    );
315    scalar_fields
316        .iter()
317        .filter(|f| f.is_id || f.is_unique)
318        .map(|f| {
319            let variant = format_ident!("{}", to_pascal_case(&f.name));
320            let db_name = &f.db_name;
321            quote! {
322                Self::#variant(v) => {
323                    qb.push(concat!(" AND \"", #db_name, "\" = "));
324                    qb.push_bind(v.clone());
325                }
326            }
327        })
328        .collect()
329}
330
331// ─── Data Module ──────────────────────────────────────────────
332
333fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
334    let create_name = format_ident!("{}CreateInput", model.name);
335    let update_name = format_ident!("{}UpdateInput", model.name);
336
337    let required_fields: Vec<TokenStream> = scalar_fields
338        .iter()
339        .filter(|f| !f.has_default() && !f.is_updated_at)
340        .map(|f| {
341            let name = format_ident!("{}", to_snake_case(&f.name));
342            let ty = rust_type_tokens(f, ModuleDepth::Nested);
343            quote! { pub #name: #ty }
344        })
345        .collect();
346
347    let optional_fields: Vec<TokenStream> = scalar_fields
348        .iter()
349        .filter(|f| f.has_default() && !f.is_updated_at)
350        .map(|f| {
351            let name = format_ident!("{}", to_snake_case(&f.name));
352            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
353            quote! { pub #name: Option<#base_ty> }
354        })
355        .collect();
356
357    let update_fields: Vec<TokenStream> = scalar_fields
358        .iter()
359        .filter(|f| !f.is_id && !f.is_updated_at)
360        .map(|f| {
361            let name = format_ident!("{}", to_snake_case(&f.name));
362            let ty = rust_type_tokens(f, ModuleDepth::Nested);
363            quote! { pub #name: Option<SetValue<#ty>> }
364        })
365        .collect();
366
367    quote! {
368        pub mod data {
369            use ferriorm_runtime::prelude::*;
370
371            #[derive(Debug, Clone)]
372            pub struct #create_name {
373                #(#required_fields,)*
374                #(#optional_fields,)*
375            }
376
377            #[derive(Debug, Clone, Default)]
378            pub struct #update_name {
379                #(#update_fields,)*
380            }
381        }
382    }
383}
384
385// ─── Order Module ─────────────────────────────────────────────
386
387fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
388    let order_name = format_ident!("{}OrderByInput", model.name);
389
390    let variants: Vec<TokenStream> = scalar_fields
391        .iter()
392        .map(|f| {
393            let variant = format_ident!("{}", to_pascal_case(&f.name));
394            quote! { #variant(SortOrder) }
395        })
396        .collect();
397
398    let order_arms: Vec<TokenStream> = scalar_fields
399        .iter()
400        .map(|f| {
401            let variant = format_ident!("{}", to_pascal_case(&f.name));
402            let db_name = &f.db_name;
403            quote! {
404                Self::#variant(order) => {
405                    qb.push(concat!("\"", #db_name, "\" "));
406                    qb.push(order.as_sql());
407                }
408            }
409        })
410        .collect();
411
412    quote! {
413        pub mod order {
414            use ferriorm_runtime::prelude::*;
415
416            #[derive(Debug, Clone)]
417            pub enum #order_name {
418                #(#variants),*
419            }
420
421            impl #order_name {
422                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
423                    &self,
424                    qb: &mut sqlx::QueryBuilder<'args, DB>,
425                ) {
426                    match self {
427                        #(#order_arms)*
428                    }
429                }
430            }
431        }
432    }
433}
434
435// ─── Actions ──────────────────────────────────────────────────
436
437fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
438    let _model_ident = format_ident!("{}", model.name);
439    let actions_name = format_ident!("{}Actions", model.name);
440    let where_input = format_ident!("{}WhereInput", model.name);
441    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
442    let create_input = format_ident!("{}CreateInput", model.name);
443    let update_input = format_ident!("{}UpdateInput", model.name);
444    let _order_by = format_ident!("{}OrderByInput", model.name);
445
446    // Only generate aggregate() if there are aggregatable fields
447    let has_agg_fields = scalar_fields.iter().any(|f| {
448        matches!(
449            &f.field_type,
450            FieldKind::Scalar(
451                ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
452            )
453        )
454    });
455    let aggregate_method = if has_agg_fields {
456        quote! {
457            pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
458                AggregateQuery { client: self.client, r#where, ops: vec![] }
459            }
460        }
461    } else {
462        quote! {}
463    };
464
465    quote! {
466        pub struct #actions_name<'a> {
467            client: &'a DatabaseClient,
468        }
469
470        impl<'a> #actions_name<'a> {
471            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
472
473            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
474                FindUniqueQuery { client: self.client, r#where }
475            }
476
477            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
478                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
479            }
480
481            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
482                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
483            }
484
485            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
486                CreateQuery { client: self.client, data }
487            }
488
489            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
490                UpdateQuery { client: self.client, r#where, data }
491            }
492
493            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
494                DeleteQuery { client: self.client, r#where }
495            }
496
497            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
498                CountQuery { client: self.client, r#where }
499            }
500
501            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
502                CreateManyQuery { client: self.client, data }
503            }
504
505            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
506                UpdateManyQuery { client: self.client, r#where, data }
507            }
508
509            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
510                DeleteManyQuery { client: self.client, r#where }
511            }
512
513            pub fn upsert(
514                &self,
515                r#where: filter::#where_unique,
516                create: data::#create_input,
517                update: data::#update_input,
518            ) -> UpsertQuery<'a> {
519                UpsertQuery { client: self.client, r#where, create, update }
520            }
521
522            #aggregate_method
523        }
524    }
525}
526
527// ─── Query Builders with exec() ──────────────────────────────
528
529#[allow(clippy::too_many_lines)]
530fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
531    let model_ident = format_ident!("{}", model.name);
532    let table_name = &model.db_name;
533    let _where_input = format_ident!("{}WhereInput", model.name);
534    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
535    let _create_input = format_ident!("{}CreateInput", model.name);
536    let _update_input = format_ident!("{}UpdateInput", model.name);
537    let order_by = format_ident!("{}OrderByInput", model.name);
538    let _select_struct = format_ident!("{}Select", model.name);
539    let _partial_struct = format_ident!("{}Partial", model.name);
540    let _aggregate_result = format_ident!("{}AggregateResult", model.name);
541    let _aggregate_field = format_ident!("{}AggregateField", model.name);
542    let db_bounds = collect_db_bounds(scalar_fields);
543
544    let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
545    let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
546    let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
547
548    let insert_code = gen_insert_code(model, scalar_fields, table_name);
549    let update_code = gen_update_code(model, scalar_fields, table_name);
550    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
551    let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
552
553    quote! {
554        // ── Generic helper: build ORDER BY clause ──────────────
555        fn build_order_by<'args, DB: sqlx::Database>(
556            orders: &[order::#order_by],
557            qb: &mut sqlx::QueryBuilder<'args, DB>,
558        ) {
559            if !orders.is_empty() {
560                qb.push(" ORDER BY ");
561                for (i, ob) in orders.iter().enumerate() {
562                    if i > 0 { qb.push(", "); }
563                    ob.build_order_by(qb);
564                }
565            }
566        }
567
568        // ── Generic helper: build a SELECT query ───────────────
569        fn build_select_query<'args, DB: sqlx::Database>(
570            base_sql: &str,
571            where_input: &filter::#_where_input,
572            orders: &[order::#order_by],
573            take: Option<i64>,
574            skip: Option<i64>,
575        ) -> sqlx::QueryBuilder<'args, DB>
576        where
577            #(#db_bounds,)*
578        {
579            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
580            where_input.build_where(&mut qb);
581            build_order_by(orders, &mut qb);
582            if let Some(take) = take {
583                qb.push(" LIMIT ");
584                qb.push_bind(take);
585            }
586            if let Some(skip) = skip {
587                qb.push(" OFFSET ");
588                qb.push_bind(skip);
589            }
590            qb
591        }
592
593        // ── Generic helper: build a SELECT query for unique lookup ──
594        fn build_unique_select_query<'args, DB: sqlx::Database>(
595            base_sql: &str,
596            where_unique: &filter::#_where_unique,
597        ) -> sqlx::QueryBuilder<'args, DB>
598        where
599            #(#db_bounds,)*
600        {
601            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
602            where_unique.build_where(&mut qb);
603            qb.push(" LIMIT 1");
604            qb
605        }
606
607        // ── Generic helper: build a DELETE-returning query ─────
608        fn build_delete_query<'args, DB: sqlx::Database>(
609            base_sql: &str,
610            where_unique: &filter::#_where_unique,
611        ) -> sqlx::QueryBuilder<'args, DB>
612        where
613            #(#db_bounds,)*
614        {
615            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
616            where_unique.build_where(&mut qb);
617            qb.push(" RETURNING *");
618            qb
619        }
620
621        // ── Generic helper: build a COUNT query ────────────────
622        fn build_count_query<'args, DB: sqlx::Database>(
623            base_sql: &str,
624            where_input: &filter::#_where_input,
625        ) -> sqlx::QueryBuilder<'args, DB>
626        where
627            #(#db_bounds,)*
628        {
629            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
630            where_input.build_where(&mut qb);
631            qb
632        }
633
634        // ── Generic helper: build a DELETE-many query ──────────
635        fn build_delete_many_query<'args, DB: sqlx::Database>(
636            base_sql: &str,
637            where_input: &filter::#_where_input,
638        ) -> sqlx::QueryBuilder<'args, DB>
639        where
640            #(#db_bounds,)*
641        {
642            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
643            where_input.build_where(&mut qb);
644            qb
645        }
646
647        pub struct FindUniqueQuery<'a> {
648            client: &'a DatabaseClient,
649            r#where: filter::#_where_unique,
650        }
651
652        impl<'a> FindUniqueQuery<'a> {
653            pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
654                FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
655            }
656
657            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
658                match self.client {
659                    DatabaseClient::Postgres(_) => {
660                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
661                        self.client.fetch_optional_pg(qb).await
662                    }
663                    DatabaseClient::Sqlite(_) => {
664                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
665                        self.client.fetch_optional_sqlite(qb).await
666                    }
667                }
668            }
669        }
670
671        pub struct FindFirstQuery<'a> {
672            client: &'a DatabaseClient,
673            r#where: filter::#_where_input,
674            order_by: Vec<order::#order_by>,
675        }
676
677        impl<'a> FindFirstQuery<'a> {
678            pub fn order_by(mut self, order: order::#order_by) -> Self {
679                self.order_by.push(order);
680                self
681            }
682
683            pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
684                FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
685            }
686
687            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
688                match self.client {
689                    DatabaseClient::Postgres(_) => {
690                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
691                        self.client.fetch_optional_pg(qb).await
692                    }
693                    DatabaseClient::Sqlite(_) => {
694                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
695                        self.client.fetch_optional_sqlite(qb).await
696                    }
697                }
698            }
699        }
700
701        pub struct FindManyQuery<'a> {
702            client: &'a DatabaseClient,
703            r#where: filter::#_where_input,
704            order_by: Vec<order::#order_by>,
705            skip: Option<i64>,
706            take: Option<i64>,
707        }
708
709        impl<'a> FindManyQuery<'a> {
710            pub fn order_by(mut self, order: order::#order_by) -> Self {
711                self.order_by.push(order);
712                self
713            }
714
715            pub fn skip(mut self, n: i64) -> Self {
716                self.skip = Some(n);
717                self
718            }
719
720            pub fn take(mut self, n: i64) -> Self {
721                self.take = Some(n);
722                self
723            }
724
725            pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
726                FindManySelectQuery {
727                    client: self.client,
728                    r#where: self.r#where,
729                    order_by: self.order_by,
730                    skip: self.skip,
731                    take: self.take,
732                    select,
733                }
734            }
735
736            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
737                match self.client {
738                    DatabaseClient::Postgres(_) => {
739                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
740                        self.client.fetch_all_pg(qb).await
741                    }
742                    DatabaseClient::Sqlite(_) => {
743                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
744                        self.client.fetch_all_sqlite(qb).await
745                    }
746                }
747            }
748        }
749
750        pub struct CreateQuery<'a> {
751            client: &'a DatabaseClient,
752            data: data::#_create_input,
753        }
754
755        impl<'a> CreateQuery<'a> {
756            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
757                let client = self.client;
758                #insert_code
759            }
760        }
761
762        pub struct UpdateQuery<'a> {
763            client: &'a DatabaseClient,
764            r#where: filter::#_where_unique,
765            data: data::#_update_input,
766        }
767
768        impl<'a> UpdateQuery<'a> {
769            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
770                let client = self.client;
771                #update_code
772            }
773        }
774
775        pub struct DeleteQuery<'a> {
776            client: &'a DatabaseClient,
777            r#where: filter::#_where_unique,
778        }
779
780        impl<'a> DeleteQuery<'a> {
781            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
782                match self.client {
783                    DatabaseClient::Postgres(_) => {
784                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
785                        self.client.fetch_one_pg(qb).await
786                    }
787                    DatabaseClient::Sqlite(_) => {
788                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
789                        self.client.fetch_one_sqlite(qb).await
790                    }
791                }
792            }
793        }
794
795        #[derive(sqlx::FromRow)]
796        struct CountResult { count: i64 }
797
798        pub struct CountQuery<'a> {
799            client: &'a DatabaseClient,
800            r#where: filter::#_where_input,
801        }
802
803        impl<'a> CountQuery<'a> {
804            pub async fn exec(self) -> Result<i64, FerriormError> {
805                let row: CountResult = match self.client {
806                    DatabaseClient::Postgres(_) => {
807                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
808                        self.client.fetch_one_pg(qb).await?
809                    }
810                    DatabaseClient::Sqlite(_) => {
811                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
812                        self.client.fetch_one_sqlite(qb).await?
813                    }
814                };
815                Ok(row.count)
816            }
817        }
818
819        pub struct CreateManyQuery<'a> {
820            client: &'a DatabaseClient,
821            data: Vec<data::#_create_input>,
822        }
823
824        impl<'a> CreateManyQuery<'a> {
825            pub async fn exec(self) -> Result<u64, FerriormError> {
826                if self.data.is_empty() { return Ok(0); }
827                let count = self.data.len() as u64;
828                for item in self.data {
829                    CreateQuery { client: self.client, data: item }.exec().await?;
830                }
831                Ok(count)
832            }
833        }
834
835        pub struct UpdateManyQuery<'a> {
836            client: &'a DatabaseClient,
837            r#where: filter::#_where_input,
838            data: data::#_update_input,
839        }
840
841        impl<'a> UpdateManyQuery<'a> {
842            pub async fn exec(self) -> Result<u64, FerriormError> {
843                let client = self.client;
844                #update_many_code
845            }
846        }
847
848        pub struct UpsertQuery<'a> {
849            client: &'a DatabaseClient,
850            r#where: filter::#_where_unique,
851            create: data::#_create_input,
852            update: data::#_update_input,
853        }
854
855        impl<'a> UpsertQuery<'a> {
856            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
857                let client = self.client;
858                #upsert_code
859            }
860        }
861
862        pub struct DeleteManyQuery<'a> {
863            client: &'a DatabaseClient,
864            r#where: filter::#_where_input,
865        }
866
867        impl<'a> DeleteManyQuery<'a> {
868            pub async fn exec(self) -> Result<u64, FerriormError> {
869                match self.client {
870                    DatabaseClient::Postgres(_) => {
871                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
872                        self.client.execute_pg(qb).await
873                    }
874                    DatabaseClient::Sqlite(_) => {
875                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
876                        self.client.execute_sqlite(qb).await
877                    }
878                }
879            }
880        }
881    }
882}
883
884// ─── INSERT code generation ───────────────────────────────────
885
886fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
887    let _model_ident = format_ident!("{}", model.name);
888
889    // Required columns: scalar, no default, not @updatedAt
890    let required: Vec<&Field> = scalar_fields
891        .iter()
892        .copied()
893        .filter(|f| !f.has_default() && !f.is_updated_at)
894        .collect();
895
896    // Optional columns: have default (can be overridden), not @updatedAt
897    let optional: Vec<&Field> = scalar_fields
898        .iter()
899        .copied()
900        .filter(|f| f.has_default() && !f.is_updated_at)
901        .collect();
902
903    // @updatedAt columns: always set to now()
904    let updated_at: Vec<&Field> = scalar_fields
905        .iter()
906        .copied()
907        .filter(|f| f.is_updated_at)
908        .collect();
909
910    // Build column names and bind values
911    let mut col_pushes = vec![];
912    let mut val_pushes = vec![];
913
914    // Required fields — always included
915    for f in &required {
916        let db_name = &f.db_name;
917        let field_ident = format_ident!("{}", to_snake_case(&f.name));
918        col_pushes.push(quote! { cols.push(#db_name); });
919        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
920    }
921
922    // Optional fields — resolve defaults in Rust
923    for f in &optional {
924        let db_name = &f.db_name;
925        let field_ident = format_ident!("{}", to_snake_case(&f.name));
926        let default_expr = gen_default_expr(f, &f.field_type);
927
928        col_pushes.push(quote! { cols.push(#db_name); });
929        val_pushes.push(quote! {
930            let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
931            sep.push_bind(val);
932        });
933    }
934
935    // @updatedAt fields
936    for f in &updated_at {
937        let db_name = &f.db_name;
938        col_pushes.push(quote! { cols.push(#db_name); });
939        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
940    }
941
942    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
943
944    // The insert_body macro avoids duplicating the column/value building logic
945    // for each database backend. It captures `self` by reference.
946    quote! {
947        // Helper to build the INSERT query for any DB backend
948        macro_rules! build_insert {
949            ($qb_type:ty) => {{
950                let mut cols: Vec<&str> = Vec::new();
951                #(#col_pushes)*
952
953                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
954                qb.push(" (");
955                for (i, col) in cols.iter().enumerate() {
956                    if i > 0 { qb.push(", "); }
957                    qb.push("\"");
958                    qb.push(*col);
959                    qb.push("\"");
960                }
961                qb.push(") VALUES (");
962                {
963                    let mut sep = qb.separated(", ");
964                    #(#val_pushes)*
965                }
966                qb.push(") RETURNING *");
967                qb
968            }};
969        }
970
971        match client {
972            DatabaseClient::Postgres(_) => {
973                let qb = build_insert!(sqlx::Postgres);
974                client.fetch_one_pg(qb).await
975            }
976            DatabaseClient::Sqlite(_) => {
977                let qb = build_insert!(sqlx::Sqlite);
978                client.fetch_one_sqlite(qb).await
979            }
980        }
981    }
982}
983
984/// Generate a Rust expression for a field's @default value.
985fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
986    use ferriorm_core::ast::DefaultValue;
987
988    match &field.default {
989        Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
990            quote! { uuid::Uuid::new_v4().to_string() }
991        }
992        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
993        Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, // DB handles this
994        Some(DefaultValue::Literal(lit)) => {
995            use ferriorm_core::ast::LiteralValue;
996            match lit {
997                LiteralValue::String(s) => quote! { #s.to_string() },
998                LiteralValue::Int(i) => {
999                    // Cast the integer literal to the correct Rust type based on the field's scalar type.
1000                    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1001                    match field_type {
1002                        FieldKind::Scalar(ScalarType::Float) => {
1003                            let val = *i as f64;
1004                            quote! { #val }
1005                        }
1006                        FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1007                        _ => {
1008                            // Default to i32 for Int and other types
1009                            let val = *i as i32;
1010                            quote! { #val }
1011                        }
1012                    }
1013                }
1014                LiteralValue::Float(f) => quote! { #f },
1015                LiteralValue::Bool(b) => quote! { #b },
1016            }
1017        }
1018        Some(DefaultValue::EnumVariant(v)) => {
1019            // Reference the enum variant — insert code runs at model module level
1020            let variant = format_ident!("{}", v);
1021            if let FieldKind::Enum(enum_name) = &field.field_type {
1022                let enum_ident = format_ident!("{}", enum_name);
1023                quote! { super::enums::#enum_ident::#variant }
1024            } else {
1025                quote! { Default::default() }
1026            }
1027        }
1028        None => quote! { Default::default() },
1029    }
1030}
1031
1032// ─── UPDATE code generation ───────────────────────────────────
1033
1034fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1035    let _model_ident = format_ident!("{}", model.name);
1036
1037    // Updatable fields: non-id, non-updatedAt scalar fields
1038    let updatable: Vec<&Field> = scalar_fields
1039        .iter()
1040        .copied()
1041        .filter(|f| !f.is_id && !f.is_updated_at)
1042        .collect();
1043
1044    let updated_at: Vec<&Field> = scalar_fields
1045        .iter()
1046        .copied()
1047        .filter(|f| f.is_updated_at)
1048        .collect();
1049
1050    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1051
1052    // Generate SET clause arms
1053    let set_arms: Vec<TokenStream> = updatable
1054        .iter()
1055        .map(|f| {
1056            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1057            let db_name = &f.db_name;
1058            quote! {
1059                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1060                    if !first_set { qb.push(", "); }
1061                    first_set = false;
1062                    qb.push(concat!("\"", #db_name, "\" = "));
1063                    qb.push_bind(v);
1064                }
1065            }
1066        })
1067        .collect();
1068
1069    let updated_at_arms: Vec<TokenStream> = updated_at
1070        .iter()
1071        .map(|f| {
1072            let db_name = &f.db_name;
1073            quote! {
1074                if !first_set { qb.push(", "); }
1075                first_set = false;
1076                qb.push(concat!("\"", #db_name, "\" = "));
1077                qb.push_bind(chrono::Utc::now());
1078            }
1079        })
1080        .collect();
1081
1082    // The build_update macro avoids duplicating the SET clause building logic
1083    // for each database backend.
1084    quote! {
1085        macro_rules! build_update {
1086            ($qb_type:ty) => {{
1087                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1088                let mut first_set = true;
1089                #(#set_arms)*
1090                #(#updated_at_arms)*
1091
1092                if first_set {
1093                    return Err(FerriormError::Query("No fields to update".into()));
1094                }
1095
1096                qb.push(" WHERE 1=1");
1097                self.r#where.build_where(&mut qb);
1098                qb.push(" RETURNING *");
1099                qb
1100            }};
1101        }
1102
1103        match client {
1104            DatabaseClient::Postgres(_) => {
1105                let qb = build_update!(sqlx::Postgres);
1106                client.fetch_one_pg(qb).await
1107            }
1108            DatabaseClient::Sqlite(_) => {
1109                let qb = build_update!(sqlx::Sqlite);
1110                client.fetch_one_sqlite(qb).await
1111            }
1112        }
1113    }
1114}
1115
1116// ─── UPDATE MANY code generation ──────────────────────────────
1117
1118fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1119    // Updatable fields: non-id, non-updatedAt scalar fields
1120    let updatable: Vec<&Field> = scalar_fields
1121        .iter()
1122        .copied()
1123        .filter(|f| !f.is_id && !f.is_updated_at)
1124        .collect();
1125
1126    let updated_at: Vec<&Field> = scalar_fields
1127        .iter()
1128        .copied()
1129        .filter(|f| f.is_updated_at)
1130        .collect();
1131
1132    let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1133
1134    // Generate SET clause arms
1135    let set_arms: Vec<TokenStream> = updatable
1136        .iter()
1137        .map(|f| {
1138            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1139            let db_name = &f.db_name;
1140            quote! {
1141                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1142                    if !first_set { qb.push(", "); }
1143                    first_set = false;
1144                    qb.push(concat!("\"", #db_name, "\" = "));
1145                    qb.push_bind(v);
1146                }
1147            }
1148        })
1149        .collect();
1150
1151    let updated_at_arms: Vec<TokenStream> = updated_at
1152        .iter()
1153        .map(|f| {
1154            let db_name = &f.db_name;
1155            quote! {
1156                if !first_set { qb.push(", "); }
1157                first_set = false;
1158                qb.push(concat!("\"", #db_name, "\" = "));
1159                qb.push_bind(chrono::Utc::now());
1160            }
1161        })
1162        .collect();
1163
1164    quote! {
1165        macro_rules! build_update_many {
1166            ($qb_type:ty) => {{
1167                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1168                let mut first_set = true;
1169                #(#set_arms)*
1170                #(#updated_at_arms)*
1171
1172                if first_set {
1173                    return Ok(0);
1174                }
1175
1176                qb.push(" WHERE 1=1");
1177                self.r#where.build_where(&mut qb);
1178                qb
1179            }};
1180        }
1181
1182        match client {
1183            DatabaseClient::Postgres(_) => {
1184                let qb = build_update_many!(sqlx::Postgres);
1185                client.execute_pg(qb).await
1186            }
1187            DatabaseClient::Sqlite(_) => {
1188                let qb = build_update_many!(sqlx::Sqlite);
1189                client.execute_sqlite(qb).await
1190            }
1191        }
1192    }
1193}
1194
1195// ─── Aggregate Types ──────────────────────────────────────────
1196
1197/// Identifies which fields are aggregatable and what operations they support.
1198enum AggregateKind {
1199    /// Numeric fields: avg, sum, min, max
1200    Numeric,
1201    /// `DateTime` fields: min, max only
1202    DateTime,
1203}
1204
1205// ─── UPSERT code generation ──────────────────────────────────
1206
1207#[allow(clippy::too_many_lines)]
1208fn gen_upsert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1209    // Collect primary key db_names for ON CONFLICT clause
1210    let pk_db_names: Vec<String> = model
1211        .primary_key
1212        .fields
1213        .iter()
1214        .filter_map(|pk| {
1215            model
1216                .fields
1217                .iter()
1218                .find(|f| f.name == *pk || to_snake_case(&f.name) == *pk)
1219                .map(|f| f.db_name.clone())
1220        })
1221        .collect();
1222    let pk_conflict_cols = pk_db_names
1223        .iter()
1224        .map(|c| format!("\"{c}\""))
1225        .collect::<Vec<_>>()
1226        .join(", ");
1227
1228    // Required + optional + updatedAt fields for the INSERT part (same as gen_insert_code)
1229    let required: Vec<&Field> = scalar_fields
1230        .iter()
1231        .copied()
1232        .filter(|f| !f.has_default() && !f.is_updated_at)
1233        .collect();
1234    let optional: Vec<&Field> = scalar_fields
1235        .iter()
1236        .copied()
1237        .filter(|f| f.has_default() && !f.is_updated_at)
1238        .collect();
1239    let updated_at: Vec<&Field> = scalar_fields
1240        .iter()
1241        .copied()
1242        .filter(|f| f.is_updated_at)
1243        .collect();
1244
1245    let mut col_pushes = vec![];
1246    let mut val_pushes = vec![];
1247
1248    for f in &required {
1249        let db_name = &f.db_name;
1250        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1251        col_pushes.push(quote! { cols.push(#db_name); });
1252        val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1253    }
1254    for f in &optional {
1255        let db_name = &f.db_name;
1256        let field_ident = format_ident!("{}", to_snake_case(&f.name));
1257        let default_expr = gen_default_expr(f, &f.field_type);
1258        col_pushes.push(quote! { cols.push(#db_name); });
1259        val_pushes.push(quote! {
1260            let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1261            sep.push_bind(val);
1262        });
1263    }
1264    for f in &updated_at {
1265        let db_name = &f.db_name;
1266        col_pushes.push(quote! { cols.push(#db_name); });
1267        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1268    }
1269
1270    // Updatable fields for the DO UPDATE SET part
1271    let updatable: Vec<&Field> = scalar_fields
1272        .iter()
1273        .copied()
1274        .filter(|f| !f.is_id && !f.is_updated_at)
1275        .collect();
1276
1277    let set_arms: Vec<TokenStream> = updatable
1278        .iter()
1279        .map(|f| {
1280            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1281            let db_name = &f.db_name;
1282            quote! {
1283                if let Some(SetValue::Set(v)) = self.update.#field_ident {
1284                    if !first_set { qb.push(", "); }
1285                    first_set = false;
1286                    qb.push(concat!("\"", #db_name, "\" = "));
1287                    qb.push_bind(v);
1288                }
1289            }
1290        })
1291        .collect();
1292
1293    let updated_at_set: Vec<TokenStream> = updated_at
1294        .iter()
1295        .map(|f| {
1296            let db_name = &f.db_name;
1297            quote! {
1298                if !first_set { qb.push(", "); }
1299                first_set = false;
1300                qb.push(concat!("\"", #db_name, "\" = "));
1301                qb.push_bind(chrono::Utc::now());
1302            }
1303        })
1304        .collect();
1305
1306    let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1307    let conflict_clause = format!(" ON CONFLICT ({pk_conflict_cols}) DO UPDATE SET ");
1308    let noop_set = format!(
1309        r#""{}" = "{}""#,
1310        pk_db_names.first().unwrap_or(&"id".to_string()),
1311        pk_db_names.first().unwrap_or(&"id".to_string()),
1312    );
1313
1314    quote! {
1315        macro_rules! build_upsert {
1316            ($qb_type:ty) => {{
1317                let mut cols: Vec<&str> = Vec::new();
1318                #(#col_pushes)*
1319
1320                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1321                qb.push(" (");
1322                for (i, col) in cols.iter().enumerate() {
1323                    if i > 0 { qb.push(", "); }
1324                    qb.push("\"");
1325                    qb.push(*col);
1326                    qb.push("\"");
1327                }
1328                qb.push(") VALUES (");
1329                {
1330                    let mut sep = qb.separated(", ");
1331                    #(#val_pushes)*
1332                }
1333                qb.push(")");
1334                qb.push(#conflict_clause);
1335
1336                let mut first_set = true;
1337                #(#set_arms)*
1338                #(#updated_at_set)*
1339
1340                if first_set {
1341                    // No update fields specified — use a no-op update on the PK
1342                    qb.push(#noop_set);
1343                }
1344
1345                qb.push(" RETURNING *");
1346                qb
1347            }};
1348        }
1349
1350        match client {
1351            DatabaseClient::Postgres(_) => {
1352                let qb = build_upsert!(sqlx::Postgres);
1353                client.fetch_one_pg(qb).await
1354            }
1355            DatabaseClient::Sqlite(_) => {
1356                let qb = build_upsert!(sqlx::Sqlite);
1357                client.fetch_one_sqlite(qb).await
1358            }
1359        }
1360    }
1361}
1362
1363#[allow(clippy::too_many_lines)]
1364fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1365    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1366    let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1367    let _where_input = format_ident!("{}WhereInput", model.name);
1368    let table_name = &model.db_name;
1369
1370    // Collect aggregatable fields with their kind
1371    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1372        .iter()
1373        .filter_map(|f| match &f.field_type {
1374            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1375                Some((*f, AggregateKind::Numeric))
1376            }
1377            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1378            _ => None,
1379        })
1380        .collect();
1381
1382    if agg_fields.is_empty() {
1383        return quote! {};
1384    }
1385
1386    // Generate enum variants
1387    let enum_variants: Vec<TokenStream> = agg_fields
1388        .iter()
1389        .map(|(f, _)| {
1390            let variant = format_ident!("{}", to_pascal_case(&f.name));
1391            quote! { #variant }
1392        })
1393        .collect();
1394
1395    // Generate db_name match arms
1396    let db_name_arms: Vec<TokenStream> = agg_fields
1397        .iter()
1398        .map(|(f, _)| {
1399            let variant = format_ident!("{}", to_pascal_case(&f.name));
1400            let db_name = &f.db_name;
1401            quote! { Self::#variant => #db_name }
1402        })
1403        .collect();
1404
1405    // Generate AggregateResult fields
1406    let mut result_fields = Vec::new();
1407    for (f, kind) in &agg_fields {
1408        let snake = to_snake_case(&f.name);
1409        let orig_ty = rust_type_tokens(
1410            &Field {
1411                is_optional: false,
1412                ..(*f).clone()
1413            },
1414            ModuleDepth::TopLevel,
1415        );
1416
1417        match kind {
1418            AggregateKind::Numeric => {
1419                let avg_name = format_ident!("avg_{}", snake);
1420                let sum_name = format_ident!("sum_{}", snake);
1421                let min_name = format_ident!("min_{}", snake);
1422                let max_name = format_ident!("max_{}", snake);
1423                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1424                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1425                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1426                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1427            }
1428            AggregateKind::DateTime => {
1429                let min_name = format_ident!("min_{}", snake);
1430                let max_name = format_ident!("max_{}", snake);
1431                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1432                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1433            }
1434        }
1435    }
1436
1437    // Generate the is_numeric check for avg/sum validation
1438    let numeric_arms: Vec<TokenStream> = agg_fields
1439        .iter()
1440        .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1441        .map(|(f, _)| {
1442            let variant = format_ident!("{}", to_pascal_case(&f.name));
1443            quote! { Self::#variant => true }
1444        })
1445        .collect();
1446
1447    let has_numeric = !numeric_arms.is_empty();
1448    let is_numeric_method = if has_numeric {
1449        quote! {
1450            fn is_numeric(&self) -> bool {
1451                match self {
1452                    #(#numeric_arms,)*
1453                    #[allow(unreachable_patterns)]
1454                    _ => false,
1455                }
1456            }
1457        }
1458    } else {
1459        quote! {
1460            fn is_numeric(&self) -> bool { false }
1461        }
1462    };
1463
1464    // Generate alias match arms for each (prefix, field) combination
1465    let mut alias_arms = Vec::new();
1466    for (f, kind) in &agg_fields {
1467        let variant = format_ident!("{}", to_pascal_case(&f.name));
1468        let snake = to_snake_case(&f.name);
1469        let prefixes = match kind {
1470            AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1471            AggregateKind::DateTime => vec!["min", "max"],
1472        };
1473        for prefix in prefixes {
1474            let alias_str = format!("{prefix}_{snake}");
1475            alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1476        }
1477    }
1478
1479    let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1480
1481    quote! {
1482        #[derive(Debug, Clone, Copy)]
1483        pub enum #aggregate_field_name {
1484            #(#enum_variants),*
1485        }
1486
1487        impl #aggregate_field_name {
1488            pub fn db_name(&self) -> &'static str {
1489                match self {
1490                    #(#db_name_arms,)*
1491                }
1492            }
1493
1494            fn alias(&self, prefix: &'static str) -> &'static str {
1495                match (prefix, self) {
1496                    #(#alias_arms,)*
1497                    _ => unreachable!(),
1498                }
1499            }
1500
1501            #is_numeric_method
1502        }
1503
1504        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1505        pub struct #aggregate_result_name {
1506            #(#result_fields,)*
1507        }
1508
1509        pub struct AggregateQuery<'a> {
1510            client: &'a DatabaseClient,
1511            r#where: filter::#_where_input,
1512            ops: Vec<(&'static str, &'static str, &'static str)>,
1513        }
1514
1515        impl<'a> AggregateQuery<'a> {
1516            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1517                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1518                let db_name = field.db_name();
1519                let alias = field.alias("avg");
1520                self.ops.push(("AVG", db_name, alias));
1521                self
1522            }
1523
1524            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1525                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1526                let db_name = field.db_name();
1527                let alias = field.alias("sum");
1528                self.ops.push(("SUM", db_name, alias));
1529                self
1530            }
1531
1532            pub fn min(mut self, field: #aggregate_field_name) -> Self {
1533                let db_name = field.db_name();
1534                let alias = field.alias("min");
1535                self.ops.push(("MIN", db_name, alias));
1536                self
1537            }
1538
1539            pub fn max(mut self, field: #aggregate_field_name) -> Self {
1540                let db_name = field.db_name();
1541                let alias = field.alias("max");
1542                self.ops.push(("MAX", db_name, alias));
1543                self
1544            }
1545
1546            pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1547                if self.ops.is_empty() {
1548                    return Err(FerriormError::Query("No aggregate operations specified".into()));
1549                }
1550
1551                let selections: Vec<String> = self.ops.iter()
1552                    .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1553                    .collect();
1554                let select_clause = selections.join(", ");
1555                let base_sql = format!(#agg_select_base, select_clause);
1556
1557                match self.client {
1558                    DatabaseClient::Postgres(_) => {
1559                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1560                        self.r#where.build_where(&mut qb);
1561                        self.client.fetch_one_pg(qb).await
1562                    }
1563                    DatabaseClient::Sqlite(_) => {
1564                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1565                        self.r#where.build_where(&mut qb);
1566                        self.client.fetch_one_sqlite(qb).await
1567                    }
1568                }
1569            }
1570        }
1571    }
1572}
1573
1574// ─── Select Types ─────────────────────────────────────────────
1575
1576#[allow(clippy::too_many_lines)]
1577fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1578    let select_name = format_ident!("{}Select", model.name);
1579    let partial_name = format_ident!("{}Partial", model.name);
1580    let _where_input = format_ident!("{}WhereInput", model.name);
1581    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1582    let order_by_name = format_ident!("{}OrderByInput", model.name);
1583    let table_name = &model.db_name;
1584
1585    // Select struct fields: all bool, default false
1586    let select_fields: Vec<TokenStream> = scalar_fields
1587        .iter()
1588        .map(|f| {
1589            let name = format_ident!("{}", to_snake_case(&f.name));
1590            quote! { pub #name: bool }
1591        })
1592        .collect();
1593
1594    // Partial struct fields: all Option<T> with #[sqlx(default)]
1595    // For already-optional fields, don't double-wrap in Option
1596    let partial_fields: Vec<TokenStream> = scalar_fields
1597        .iter()
1598        .map(|f| {
1599            let name = format_ident!("{}", to_snake_case(&f.name));
1600            let db_name = &f.db_name;
1601            // Get the base type (non-optional version)
1602            let base_ty = rust_type_tokens(
1603                &Field {
1604                    is_optional: false,
1605                    ..(*f).clone()
1606                },
1607                ModuleDepth::TopLevel,
1608            );
1609            let rename = if db_name == &to_snake_case(&f.name) {
1610                quote! {}
1611            } else {
1612                quote! { #[sqlx(rename = #db_name)] }
1613            };
1614            // Always wrap in Option<T>, regardless of whether field was originally optional
1615            quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1616        })
1617        .collect();
1618
1619    // build_select_columns: maps Select bools to column names
1620    let select_col_arms: Vec<TokenStream> = scalar_fields
1621        .iter()
1622        .map(|f| {
1623            let name = format_ident!("{}", to_snake_case(&f.name));
1624            let db_name = &f.db_name;
1625            let col_expr = format!(r#""{db_name}""#);
1626            quote! {
1627                if select.#name { cols.push(#col_expr); }
1628            }
1629        })
1630        .collect();
1631
1632    let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1633
1634    quote! {
1635        #[derive(Debug, Clone, Default)]
1636        pub struct #select_name {
1637            #(#select_fields,)*
1638        }
1639
1640        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
1641        #[sqlx(rename_all = "snake_case")]
1642        pub struct #partial_name {
1643            #(#partial_fields,)*
1644        }
1645
1646        fn build_select_columns(select: &#select_name) -> String {
1647            let mut cols = Vec::new();
1648            #(#select_col_arms)*
1649            if cols.is_empty() {
1650                "*".to_string()
1651            } else {
1652                cols.join(", ")
1653            }
1654        }
1655
1656        // ── FindManySelectQuery ──────────────────────────────────
1657
1658        pub struct FindManySelectQuery<'a> {
1659            client: &'a DatabaseClient,
1660            r#where: filter::#_where_input,
1661            order_by: Vec<order::#order_by_name>,
1662            skip: Option<i64>,
1663            take: Option<i64>,
1664            select: #select_name,
1665        }
1666
1667        impl<'a> FindManySelectQuery<'a> {
1668            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1669                self.order_by.push(order);
1670                self
1671            }
1672
1673            pub fn skip(mut self, n: i64) -> Self {
1674                self.skip = Some(n);
1675                self
1676            }
1677
1678            pub fn take(mut self, n: i64) -> Self {
1679                self.take = Some(n);
1680                self
1681            }
1682
1683            pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
1684                let cols = build_select_columns(&self.select);
1685                let base_sql = format!(#select_sql_prefix, cols);
1686
1687                match self.client {
1688                    DatabaseClient::Postgres(_) => {
1689                        let qb = build_select_query::<sqlx::Postgres>(
1690                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1691                        );
1692                        self.client.fetch_all_pg(qb).await
1693                    }
1694                    DatabaseClient::Sqlite(_) => {
1695                        let qb = build_select_query::<sqlx::Sqlite>(
1696                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1697                        );
1698                        self.client.fetch_all_sqlite(qb).await
1699                    }
1700                }
1701            }
1702        }
1703
1704        // ── FindUniqueSelectQuery ────────────────────────────────
1705
1706        pub struct FindUniqueSelectQuery<'a> {
1707            client: &'a DatabaseClient,
1708            r#where: filter::#_where_unique,
1709            select: #select_name,
1710        }
1711
1712        impl<'a> FindUniqueSelectQuery<'a> {
1713            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1714                let cols = build_select_columns(&self.select);
1715                let base_sql = format!(#select_sql_prefix, cols);
1716
1717                match self.client {
1718                    DatabaseClient::Postgres(_) => {
1719                        let qb = build_unique_select_query::<sqlx::Postgres>(
1720                            &base_sql, &self.r#where,
1721                        );
1722                        self.client.fetch_optional_pg(qb).await
1723                    }
1724                    DatabaseClient::Sqlite(_) => {
1725                        let qb = build_unique_select_query::<sqlx::Sqlite>(
1726                            &base_sql, &self.r#where,
1727                        );
1728                        self.client.fetch_optional_sqlite(qb).await
1729                    }
1730                }
1731            }
1732        }
1733
1734        // ── FindFirstSelectQuery ─────────────────────────────────
1735
1736        pub struct FindFirstSelectQuery<'a> {
1737            client: &'a DatabaseClient,
1738            r#where: filter::#_where_input,
1739            order_by: Vec<order::#order_by_name>,
1740            select: #select_name,
1741        }
1742
1743        impl<'a> FindFirstSelectQuery<'a> {
1744            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1745                self.order_by.push(order);
1746                self
1747            }
1748
1749            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1750                let cols = build_select_columns(&self.select);
1751                let base_sql = format!(#select_sql_prefix, cols);
1752
1753                match self.client {
1754                    DatabaseClient::Postgres(_) => {
1755                        let qb = build_select_query::<sqlx::Postgres>(
1756                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1757                        );
1758                        self.client.fetch_optional_pg(qb).await
1759                    }
1760                    DatabaseClient::Sqlite(_) => {
1761                        let qb = build_select_query::<sqlx::Sqlite>(
1762                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1763                        );
1764                        self.client.fetch_optional_sqlite(qb).await
1765                    }
1766                }
1767            }
1768        }
1769    }
1770}