Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23pub fn generate_model_module(model: &Model) -> TokenStream {
24    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
25
26    let data_struct = gen_data_struct(model, &scalar_fields);
27    let filter_module = gen_filter_module(model, &scalar_fields);
28    let data_module = gen_data_module(model, &scalar_fields);
29    let order_module = gen_order_module(model, &scalar_fields);
30    let actions_struct = gen_actions(model, &scalar_fields);
31    let query_builders = gen_query_builders(model, &scalar_fields);
32    let aggregate_types = gen_aggregate_types(model, &scalar_fields);
33    let select_types = gen_select_types(model, &scalar_fields);
34
35    quote! {
36        #![allow(unused_imports, dead_code, clippy::all, unused_variables)]
37
38        use serde::{Deserialize, Serialize};
39        use ferriorm_runtime::prelude::*;
40        use ferriorm_runtime::prelude::sqlx;
41        use ferriorm_runtime::prelude::chrono;
42        use ferriorm_runtime::prelude::uuid;
43
44        #data_struct
45        #filter_module
46        #data_module
47        #order_module
48        #actions_struct
49        #query_builders
50        #aggregate_types
51        #select_types
52    }
53}
54
55// ─── Data Struct ──────────────────────────────────────────────
56
57fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
58    let struct_name = format_ident!("{}", model.name);
59    let table_name = &model.db_name;
60
61    let fields: Vec<TokenStream> = scalar_fields
62        .iter()
63        .map(|f| {
64            let name = format_ident!("{}", to_snake_case(&f.name));
65            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
66            let db_name = &f.db_name;
67            if db_name != &to_snake_case(&f.name) {
68                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
69            } else {
70                quote! { pub #name: #ty }
71            }
72        })
73        .collect();
74
75    quote! {
76        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
77        #[sqlx(rename_all = "snake_case")]
78        pub struct #struct_name {
79            #(#fields),*
80        }
81
82        impl #struct_name {
83            pub const TABLE_NAME: &'static str = #table_name;
84        }
85    }
86}
87
88// ─── Filter Module ────────────────────────────────────────────
89
90fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
91    let where_input = format_ident!("{}WhereInput", model.name);
92    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
93
94    let where_fields: Vec<TokenStream> = scalar_fields
95        .iter()
96        .filter_map(|f| {
97            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
98            let name = format_ident!("{}", to_snake_case(&f.name));
99            Some(quote! { pub #name: Option<#filter_ty> })
100        })
101        .collect();
102
103    let unique_variants: Vec<TokenStream> = scalar_fields
104        .iter()
105        .filter(|f| f.is_id || f.is_unique)
106        .map(|f| {
107            let variant = format_ident!("{}", to_pascal_case(&f.name));
108            let ty = rust_type_tokens(f, ModuleDepth::Nested);
109            quote! { #variant(#ty) }
110        })
111        .collect();
112
113    // Generate build_where for WhereInput
114    let db_bounds = collect_db_bounds(scalar_fields);
115    let where_arms = gen_where_arms(scalar_fields);
116    let unique_arms = gen_unique_where_arms(scalar_fields);
117
118    quote! {
119        pub mod filter {
120            use ferriorm_runtime::prelude::*;
121
122            #[derive(Debug, Clone, Default)]
123            pub struct #where_input {
124                #(#where_fields,)*
125                pub and: Option<Vec<#where_input>>,
126                pub or: Option<Vec<#where_input>>,
127                pub not: Option<Box<#where_input>>,
128            }
129
130            #[derive(Debug, Clone)]
131            pub enum #where_unique {
132                #(#unique_variants),*
133            }
134
135            impl #where_input {
136                pub(crate) fn build_where<'args, DB: sqlx::Database>(
137                    &self,
138                    qb: &mut sqlx::QueryBuilder<'args, DB>,
139                )
140                where
141                    #(#db_bounds,)*
142                {
143                    #(#where_arms)*
144
145                    if let Some(conditions) = &self.and {
146                        for c in conditions {
147                            c.build_where(qb);
148                        }
149                    }
150                    if let Some(conditions) = &self.or {
151                        if !conditions.is_empty() {
152                            qb.push(" AND (");
153                            for (i, c) in conditions.iter().enumerate() {
154                                if i > 0 { qb.push(" OR "); }
155                                qb.push("(1=1");
156                                c.build_where(qb);
157                                qb.push(")");
158                            }
159                            qb.push(")");
160                        }
161                    }
162                    if let Some(c) = &self.not {
163                        qb.push(" AND NOT (1=1");
164                        c.build_where(qb);
165                        qb.push(")");
166                    }
167                }
168            }
169
170            impl #where_unique {
171                pub(crate) fn build_where<'args, DB: sqlx::Database>(
172                    &self,
173                    qb: &mut sqlx::QueryBuilder<'args, DB>,
174                )
175                where
176                    #(#db_bounds,)*
177                {
178                    match self {
179                        #(#unique_arms)*
180                    }
181                }
182            }
183        }
184    }
185}
186
187/// Collect the sqlx type bounds needed for all scalar types used by the model.
188fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
189    let mut seen = std::collections::HashSet::new();
190    let mut bounds = Vec::new();
191
192    // Always need i64 for LIMIT/OFFSET
193    seen.insert("i64");
194    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
195
196    for f in scalar_fields {
197        match &f.field_type {
198            FieldKind::Scalar(scalar) => {
199                let key = scalar.rust_type();
200                if seen.insert(key)
201                    && let Some(ty) = scalar_bound_tokens(scalar)
202                {
203                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
204                    // Also add Option<T> bound for nullable field support
205                    bounds.push(
206                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
207                    );
208                }
209            }
210            FieldKind::Enum(_) => {}
211            _ => {}
212        }
213    }
214
215    bounds
216}
217
218fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
219    match scalar {
220        ScalarType::String => Some(quote! { String }),
221        ScalarType::Int => Some(quote! { i32 }),
222        ScalarType::BigInt => Some(quote! { i64 }),
223        ScalarType::Float => Some(quote! { f64 }),
224        ScalarType::Boolean => Some(quote! { bool }),
225        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
226        ScalarType::Bytes => Some(quote! { Vec<u8> }),
227        ScalarType::Json | ScalarType::Decimal => None,
228    }
229}
230
231/// Generate where-clause arms for each filterable scalar field.
232fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
233    scalar_fields
234        .iter()
235        .filter_map(|f| {
236            // Only generate filter arms for scalar types (skip enums for now)
237            if !matches!(&f.field_type, FieldKind::Scalar(_)) {
238                return None;
239            }
240            let field_ident = format_ident!("{}", to_snake_case(&f.name));
241            let db_name = &f.db_name;
242            let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
243            let is_comparable = matches!(
244                &f.field_type,
245                FieldKind::Scalar(
246                    ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
247                )
248            );
249
250            let mut arms = vec![];
251
252            arms.push(quote! {
253                if let Some(v) = &filter.equals {
254                    qb.push(concat!(" AND \"", #db_name, "\" = "));
255                    qb.push_bind(v.clone());
256                }
257                if let Some(v) = &filter.not {
258                    qb.push(concat!(" AND \"", #db_name, "\" != "));
259                    qb.push_bind(v.clone());
260                }
261            });
262
263            if is_string {
264                arms.push(quote! {
265                    if let Some(v) = &filter.contains {
266                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
267                        qb.push_bind(format!("%{}%", v));
268                    }
269                    if let Some(v) = &filter.starts_with {
270                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
271                        qb.push_bind(format!("{}%", v));
272                    }
273                    if let Some(v) = &filter.ends_with {
274                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
275                        qb.push_bind(format!("%{}", v));
276                    }
277                });
278            }
279
280            if is_comparable {
281                arms.push(quote! {
282                    if let Some(v) = &filter.gt {
283                        qb.push(concat!(" AND \"", #db_name, "\" > "));
284                        qb.push_bind(v.clone());
285                    }
286                    if let Some(v) = &filter.gte {
287                        qb.push(concat!(" AND \"", #db_name, "\" >= "));
288                        qb.push_bind(v.clone());
289                    }
290                    if let Some(v) = &filter.lt {
291                        qb.push(concat!(" AND \"", #db_name, "\" < "));
292                        qb.push_bind(v.clone());
293                    }
294                    if let Some(v) = &filter.lte {
295                        qb.push(concat!(" AND \"", #db_name, "\" <= "));
296                        qb.push_bind(v.clone());
297                    }
298                });
299            }
300
301            Some(quote! {
302                if let Some(filter) = &self.#field_ident {
303                    #(#arms)*
304                }
305            })
306        })
307        .collect()
308}
309
310fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
311    let _where_unique = format_ident!(
312        "{}WhereUniqueInput",
313        "" // placeholder, we use Self:: instead
314    );
315    scalar_fields
316        .iter()
317        .filter(|f| f.is_id || f.is_unique)
318        .map(|f| {
319            let variant = format_ident!("{}", to_pascal_case(&f.name));
320            let db_name = &f.db_name;
321            quote! {
322                Self::#variant(v) => {
323                    qb.push(concat!(" AND \"", #db_name, "\" = "));
324                    qb.push_bind(v.clone());
325                }
326            }
327        })
328        .collect()
329}
330
331// ─── Data Module ──────────────────────────────────────────────
332
333fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
334    let create_name = format_ident!("{}CreateInput", model.name);
335    let update_name = format_ident!("{}UpdateInput", model.name);
336
337    let required_fields: Vec<TokenStream> = scalar_fields
338        .iter()
339        .filter(|f| !f.has_default() && !f.is_updated_at)
340        .map(|f| {
341            let name = format_ident!("{}", to_snake_case(&f.name));
342            let ty = rust_type_tokens(f, ModuleDepth::Nested);
343            quote! { pub #name: #ty }
344        })
345        .collect();
346
347    let optional_fields: Vec<TokenStream> = scalar_fields
348        .iter()
349        .filter(|f| f.has_default() && !f.is_updated_at)
350        .map(|f| {
351            let name = format_ident!("{}", to_snake_case(&f.name));
352            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
353            quote! { pub #name: Option<#base_ty> }
354        })
355        .collect();
356
357    let update_fields: Vec<TokenStream> = scalar_fields
358        .iter()
359        .filter(|f| !f.is_id && !f.is_updated_at)
360        .map(|f| {
361            let name = format_ident!("{}", to_snake_case(&f.name));
362            let ty = rust_type_tokens(f, ModuleDepth::Nested);
363            quote! { pub #name: Option<SetValue<#ty>> }
364        })
365        .collect();
366
367    quote! {
368        pub mod data {
369            use ferriorm_runtime::prelude::*;
370
371            #[derive(Debug, Clone)]
372            pub struct #create_name {
373                #(#required_fields,)*
374                #(#optional_fields,)*
375            }
376
377            #[derive(Debug, Clone, Default)]
378            pub struct #update_name {
379                #(#update_fields,)*
380            }
381        }
382    }
383}
384
385// ─── Order Module ─────────────────────────────────────────────
386
387fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
388    let order_name = format_ident!("{}OrderByInput", model.name);
389
390    let variants: Vec<TokenStream> = scalar_fields
391        .iter()
392        .map(|f| {
393            let variant = format_ident!("{}", to_pascal_case(&f.name));
394            quote! { #variant(SortOrder) }
395        })
396        .collect();
397
398    let order_arms: Vec<TokenStream> = scalar_fields
399        .iter()
400        .map(|f| {
401            let variant = format_ident!("{}", to_pascal_case(&f.name));
402            let db_name = &f.db_name;
403            quote! {
404                Self::#variant(order) => {
405                    qb.push(concat!("\"", #db_name, "\" "));
406                    qb.push(order.as_sql());
407                }
408            }
409        })
410        .collect();
411
412    quote! {
413        pub mod order {
414            use ferriorm_runtime::prelude::*;
415
416            #[derive(Debug, Clone)]
417            pub enum #order_name {
418                #(#variants),*
419            }
420
421            impl #order_name {
422                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
423                    &self,
424                    qb: &mut sqlx::QueryBuilder<'args, DB>,
425                ) {
426                    match self {
427                        #(#order_arms)*
428                    }
429                }
430            }
431        }
432    }
433}
434
435// ─── Actions ──────────────────────────────────────────────────
436
437fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
438    let _model_ident = format_ident!("{}", model.name);
439    let actions_name = format_ident!("{}Actions", model.name);
440    let where_input = format_ident!("{}WhereInput", model.name);
441    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
442    let create_input = format_ident!("{}CreateInput", model.name);
443    let update_input = format_ident!("{}UpdateInput", model.name);
444    let _order_by = format_ident!("{}OrderByInput", model.name);
445
446    // Only generate aggregate() if there are aggregatable fields
447    let has_agg_fields = scalar_fields.iter().any(|f| {
448        matches!(
449            &f.field_type,
450            FieldKind::Scalar(
451                ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
452            )
453        )
454    });
455    let aggregate_method = if has_agg_fields {
456        quote! {
457            pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
458                AggregateQuery { client: self.client, r#where, ops: vec![] }
459            }
460        }
461    } else {
462        quote! {}
463    };
464
465    quote! {
466        pub struct #actions_name<'a> {
467            client: &'a DatabaseClient,
468        }
469
470        impl<'a> #actions_name<'a> {
471            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
472
473            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
474                FindUniqueQuery { client: self.client, r#where }
475            }
476
477            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
478                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
479            }
480
481            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
482                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
483            }
484
485            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
486                CreateQuery { client: self.client, data }
487            }
488
489            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
490                UpdateQuery { client: self.client, r#where, data }
491            }
492
493            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
494                DeleteQuery { client: self.client, r#where }
495            }
496
497            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
498                CountQuery { client: self.client, r#where }
499            }
500
501            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
502                CreateManyQuery { client: self.client, data }
503            }
504
505            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
506                UpdateManyQuery { client: self.client, r#where, data }
507            }
508
509            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
510                DeleteManyQuery { client: self.client, r#where }
511            }
512
513            #aggregate_method
514        }
515    }
516}
517
518// ─── Query Builders with exec() ──────────────────────────────
519
520fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
521    let model_ident = format_ident!("{}", model.name);
522    let table_name = &model.db_name;
523    let _where_input = format_ident!("{}WhereInput", model.name);
524    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
525    let _create_input = format_ident!("{}CreateInput", model.name);
526    let _update_input = format_ident!("{}UpdateInput", model.name);
527    let order_by = format_ident!("{}OrderByInput", model.name);
528    let _select_struct = format_ident!("{}Select", model.name);
529    let _partial_struct = format_ident!("{}Partial", model.name);
530    let _aggregate_result = format_ident!("{}AggregateResult", model.name);
531    let _aggregate_field = format_ident!("{}AggregateField", model.name);
532    let db_bounds = collect_db_bounds(scalar_fields);
533
534    let select_sql = format!(r#"SELECT * FROM "{}" WHERE 1=1"#, table_name);
535    let count_sql = format!(
536        r#"SELECT COUNT(*) as "count" FROM "{}" WHERE 1=1"#,
537        table_name
538    );
539    let delete_sql = format!(r#"DELETE FROM "{}" WHERE 1=1"#, table_name);
540
541    let insert_code = gen_insert_code(model, scalar_fields, table_name);
542    let update_code = gen_update_code(model, scalar_fields, table_name);
543    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
544
545    quote! {
546        // ── Generic helper: build ORDER BY clause ──────────────
547        fn build_order_by<'args, DB: sqlx::Database>(
548            orders: &[order::#order_by],
549            qb: &mut sqlx::QueryBuilder<'args, DB>,
550        ) {
551            if !orders.is_empty() {
552                qb.push(" ORDER BY ");
553                for (i, ob) in orders.iter().enumerate() {
554                    if i > 0 { qb.push(", "); }
555                    ob.build_order_by(qb);
556                }
557            }
558        }
559
560        // ── Generic helper: build a SELECT query ───────────────
561        fn build_select_query<'args, DB: sqlx::Database>(
562            base_sql: &str,
563            where_input: &filter::#_where_input,
564            orders: &[order::#order_by],
565            take: Option<i64>,
566            skip: Option<i64>,
567        ) -> sqlx::QueryBuilder<'args, DB>
568        where
569            #(#db_bounds,)*
570        {
571            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
572            where_input.build_where(&mut qb);
573            build_order_by(orders, &mut qb);
574            if let Some(take) = take {
575                qb.push(" LIMIT ");
576                qb.push_bind(take);
577            }
578            if let Some(skip) = skip {
579                qb.push(" OFFSET ");
580                qb.push_bind(skip);
581            }
582            qb
583        }
584
585        // ── Generic helper: build a SELECT query for unique lookup ──
586        fn build_unique_select_query<'args, DB: sqlx::Database>(
587            base_sql: &str,
588            where_unique: &filter::#_where_unique,
589        ) -> sqlx::QueryBuilder<'args, DB>
590        where
591            #(#db_bounds,)*
592        {
593            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
594            where_unique.build_where(&mut qb);
595            qb.push(" LIMIT 1");
596            qb
597        }
598
599        // ── Generic helper: build a DELETE-returning query ─────
600        fn build_delete_query<'args, DB: sqlx::Database>(
601            base_sql: &str,
602            where_unique: &filter::#_where_unique,
603        ) -> sqlx::QueryBuilder<'args, DB>
604        where
605            #(#db_bounds,)*
606        {
607            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
608            where_unique.build_where(&mut qb);
609            qb.push(" RETURNING *");
610            qb
611        }
612
613        // ── Generic helper: build a COUNT query ────────────────
614        fn build_count_query<'args, DB: sqlx::Database>(
615            base_sql: &str,
616            where_input: &filter::#_where_input,
617        ) -> sqlx::QueryBuilder<'args, DB>
618        where
619            #(#db_bounds,)*
620        {
621            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
622            where_input.build_where(&mut qb);
623            qb
624        }
625
626        // ── Generic helper: build a DELETE-many query ──────────
627        fn build_delete_many_query<'args, DB: sqlx::Database>(
628            base_sql: &str,
629            where_input: &filter::#_where_input,
630        ) -> sqlx::QueryBuilder<'args, DB>
631        where
632            #(#db_bounds,)*
633        {
634            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
635            where_input.build_where(&mut qb);
636            qb
637        }
638
639        pub struct FindUniqueQuery<'a> {
640            client: &'a DatabaseClient,
641            r#where: filter::#_where_unique,
642        }
643
644        impl<'a> FindUniqueQuery<'a> {
645            pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
646                FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
647            }
648
649            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
650                match self.client {
651                    DatabaseClient::Postgres(_) => {
652                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
653                        self.client.fetch_optional_pg(qb).await
654                    }
655                    DatabaseClient::Sqlite(_) => {
656                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
657                        self.client.fetch_optional_sqlite(qb).await
658                    }
659                }
660            }
661        }
662
663        pub struct FindFirstQuery<'a> {
664            client: &'a DatabaseClient,
665            r#where: filter::#_where_input,
666            order_by: Vec<order::#order_by>,
667        }
668
669        impl<'a> FindFirstQuery<'a> {
670            pub fn order_by(mut self, order: order::#order_by) -> Self {
671                self.order_by.push(order);
672                self
673            }
674
675            pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
676                FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
677            }
678
679            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
680                match self.client {
681                    DatabaseClient::Postgres(_) => {
682                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
683                        self.client.fetch_optional_pg(qb).await
684                    }
685                    DatabaseClient::Sqlite(_) => {
686                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
687                        self.client.fetch_optional_sqlite(qb).await
688                    }
689                }
690            }
691        }
692
693        pub struct FindManyQuery<'a> {
694            client: &'a DatabaseClient,
695            r#where: filter::#_where_input,
696            order_by: Vec<order::#order_by>,
697            skip: Option<i64>,
698            take: Option<i64>,
699        }
700
701        impl<'a> FindManyQuery<'a> {
702            pub fn order_by(mut self, order: order::#order_by) -> Self {
703                self.order_by.push(order);
704                self
705            }
706
707            pub fn skip(mut self, n: i64) -> Self {
708                self.skip = Some(n);
709                self
710            }
711
712            pub fn take(mut self, n: i64) -> Self {
713                self.take = Some(n);
714                self
715            }
716
717            pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
718                FindManySelectQuery {
719                    client: self.client,
720                    r#where: self.r#where,
721                    order_by: self.order_by,
722                    skip: self.skip,
723                    take: self.take,
724                    select,
725                }
726            }
727
728            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
729                match self.client {
730                    DatabaseClient::Postgres(_) => {
731                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
732                        self.client.fetch_all_pg(qb).await
733                    }
734                    DatabaseClient::Sqlite(_) => {
735                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
736                        self.client.fetch_all_sqlite(qb).await
737                    }
738                }
739            }
740        }
741
742        pub struct CreateQuery<'a> {
743            client: &'a DatabaseClient,
744            data: data::#_create_input,
745        }
746
747        impl<'a> CreateQuery<'a> {
748            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
749                let client = self.client;
750                #insert_code
751            }
752        }
753
754        pub struct UpdateQuery<'a> {
755            client: &'a DatabaseClient,
756            r#where: filter::#_where_unique,
757            data: data::#_update_input,
758        }
759
760        impl<'a> UpdateQuery<'a> {
761            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
762                let client = self.client;
763                #update_code
764            }
765        }
766
767        pub struct DeleteQuery<'a> {
768            client: &'a DatabaseClient,
769            r#where: filter::#_where_unique,
770        }
771
772        impl<'a> DeleteQuery<'a> {
773            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
774                match self.client {
775                    DatabaseClient::Postgres(_) => {
776                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
777                        self.client.fetch_one_pg(qb).await
778                    }
779                    DatabaseClient::Sqlite(_) => {
780                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
781                        self.client.fetch_one_sqlite(qb).await
782                    }
783                }
784            }
785        }
786
787        #[derive(sqlx::FromRow)]
788        struct CountResult { count: i64 }
789
790        pub struct CountQuery<'a> {
791            client: &'a DatabaseClient,
792            r#where: filter::#_where_input,
793        }
794
795        impl<'a> CountQuery<'a> {
796            pub async fn exec(self) -> Result<i64, FerriormError> {
797                let row: CountResult = match self.client {
798                    DatabaseClient::Postgres(_) => {
799                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
800                        self.client.fetch_one_pg(qb).await?
801                    }
802                    DatabaseClient::Sqlite(_) => {
803                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
804                        self.client.fetch_one_sqlite(qb).await?
805                    }
806                };
807                Ok(row.count)
808            }
809        }
810
811        pub struct CreateManyQuery<'a> {
812            client: &'a DatabaseClient,
813            data: Vec<data::#_create_input>,
814        }
815
816        impl<'a> CreateManyQuery<'a> {
817            pub async fn exec(self) -> Result<u64, FerriormError> {
818                if self.data.is_empty() { return Ok(0); }
819                let count = self.data.len() as u64;
820                for item in self.data {
821                    CreateQuery { client: self.client, data: item }.exec().await?;
822                }
823                Ok(count)
824            }
825        }
826
827        pub struct UpdateManyQuery<'a> {
828            client: &'a DatabaseClient,
829            r#where: filter::#_where_input,
830            data: data::#_update_input,
831        }
832
833        impl<'a> UpdateManyQuery<'a> {
834            pub async fn exec(self) -> Result<u64, FerriormError> {
835                let client = self.client;
836                #update_many_code
837            }
838        }
839
840        pub struct DeleteManyQuery<'a> {
841            client: &'a DatabaseClient,
842            r#where: filter::#_where_input,
843        }
844
845        impl<'a> DeleteManyQuery<'a> {
846            pub async fn exec(self) -> Result<u64, FerriormError> {
847                match self.client {
848                    DatabaseClient::Postgres(_) => {
849                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
850                        self.client.execute_pg(qb).await
851                    }
852                    DatabaseClient::Sqlite(_) => {
853                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
854                        self.client.execute_sqlite(qb).await
855                    }
856                }
857            }
858        }
859    }
860}
861
862// ─── INSERT code generation ───────────────────────────────────
863
864fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
865    let _model_ident = format_ident!("{}", model.name);
866
867    // Required columns: scalar, no default, not @updatedAt
868    let required: Vec<&Field> = scalar_fields
869        .iter()
870        .copied()
871        .filter(|f| !f.has_default() && !f.is_updated_at)
872        .collect();
873
874    // Optional columns: have default (can be overridden), not @updatedAt
875    let optional: Vec<&Field> = scalar_fields
876        .iter()
877        .copied()
878        .filter(|f| f.has_default() && !f.is_updated_at)
879        .collect();
880
881    // @updatedAt columns: always set to now()
882    let updated_at: Vec<&Field> = scalar_fields
883        .iter()
884        .copied()
885        .filter(|f| f.is_updated_at)
886        .collect();
887
888    // Build column names and bind values
889    let mut col_pushes = vec![];
890    let mut val_pushes = vec![];
891
892    // Required fields — always included
893    for f in &required {
894        let db_name = &f.db_name;
895        let field_ident = format_ident!("{}", to_snake_case(&f.name));
896        col_pushes.push(quote! { cols.push(#db_name); });
897        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
898    }
899
900    // Optional fields — resolve defaults in Rust
901    for f in &optional {
902        let db_name = &f.db_name;
903        let field_ident = format_ident!("{}", to_snake_case(&f.name));
904        let default_expr = gen_default_expr(f, &f.field_type);
905
906        col_pushes.push(quote! { cols.push(#db_name); });
907        val_pushes.push(quote! {
908            let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
909            sep.push_bind(val);
910        });
911    }
912
913    // @updatedAt fields
914    for f in &updated_at {
915        let db_name = &f.db_name;
916        col_pushes.push(quote! { cols.push(#db_name); });
917        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
918    }
919
920    let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
921
922    // The insert_body macro avoids duplicating the column/value building logic
923    // for each database backend. It captures `self` by reference.
924    quote! {
925        // Helper to build the INSERT query for any DB backend
926        macro_rules! build_insert {
927            ($qb_type:ty) => {{
928                let mut cols: Vec<&str> = Vec::new();
929                #(#col_pushes)*
930
931                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
932                qb.push(" (");
933                for (i, col) in cols.iter().enumerate() {
934                    if i > 0 { qb.push(", "); }
935                    qb.push("\"");
936                    qb.push(*col);
937                    qb.push("\"");
938                }
939                qb.push(") VALUES (");
940                {
941                    let mut sep = qb.separated(", ");
942                    #(#val_pushes)*
943                }
944                qb.push(") RETURNING *");
945                qb
946            }};
947        }
948
949        match client {
950            DatabaseClient::Postgres(_) => {
951                let qb = build_insert!(sqlx::Postgres);
952                client.fetch_one_pg(qb).await
953            }
954            DatabaseClient::Sqlite(_) => {
955                let qb = build_insert!(sqlx::Sqlite);
956                client.fetch_one_sqlite(qb).await
957            }
958        }
959    }
960}
961
962/// Generate a Rust expression for a field's @default value.
963fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
964    use ferriorm_core::ast::DefaultValue;
965
966    match &field.default {
967        Some(DefaultValue::Uuid) => quote! { uuid::Uuid::new_v4().to_string() },
968        Some(DefaultValue::Cuid) => quote! { uuid::Uuid::new_v4().to_string() }, // fallback
969        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
970        Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, // DB handles this
971        Some(DefaultValue::Literal(lit)) => {
972            use ferriorm_core::ast::LiteralValue;
973            match lit {
974                LiteralValue::String(s) => quote! { #s.to_string() },
975                LiteralValue::Int(i) => {
976                    // Cast the integer literal to the correct Rust type based on the field's scalar type.
977                    match field_type {
978                        FieldKind::Scalar(ScalarType::Float) => {
979                            let val = *i as f64;
980                            quote! { #val }
981                        }
982                        FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
983                        _ => {
984                            // Default to i32 for Int and other types
985                            let val = *i as i32;
986                            quote! { #val }
987                        }
988                    }
989                }
990                LiteralValue::Float(f) => quote! { #f },
991                LiteralValue::Bool(b) => quote! { #b },
992            }
993        }
994        Some(DefaultValue::EnumVariant(v)) => {
995            // Reference the enum variant — insert code runs at model module level
996            let variant = format_ident!("{}", v);
997            if let FieldKind::Enum(enum_name) = &field.field_type {
998                let enum_ident = format_ident!("{}", enum_name);
999                quote! { super::enums::#enum_ident::#variant }
1000            } else {
1001                quote! { Default::default() }
1002            }
1003        }
1004        None => quote! { Default::default() },
1005    }
1006}
1007
1008// ─── UPDATE code generation ───────────────────────────────────
1009
1010fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1011    let _model_ident = format_ident!("{}", model.name);
1012
1013    // Updatable fields: non-id, non-updatedAt scalar fields
1014    let updatable: Vec<&Field> = scalar_fields
1015        .iter()
1016        .copied()
1017        .filter(|f| !f.is_id && !f.is_updated_at)
1018        .collect();
1019
1020    let updated_at: Vec<&Field> = scalar_fields
1021        .iter()
1022        .copied()
1023        .filter(|f| f.is_updated_at)
1024        .collect();
1025
1026    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1027
1028    // Generate SET clause arms
1029    let set_arms: Vec<TokenStream> = updatable
1030        .iter()
1031        .map(|f| {
1032            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1033            let db_name = &f.db_name;
1034            quote! {
1035                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1036                    if !first_set { qb.push(", "); }
1037                    first_set = false;
1038                    qb.push(concat!("\"", #db_name, "\" = "));
1039                    qb.push_bind(v);
1040                }
1041            }
1042        })
1043        .collect();
1044
1045    let updated_at_arms: Vec<TokenStream> = updated_at
1046        .iter()
1047        .map(|f| {
1048            let db_name = &f.db_name;
1049            quote! {
1050                if !first_set { qb.push(", "); }
1051                first_set = false;
1052                qb.push(concat!("\"", #db_name, "\" = "));
1053                qb.push_bind(chrono::Utc::now());
1054            }
1055        })
1056        .collect();
1057
1058    // The build_update macro avoids duplicating the SET clause building logic
1059    // for each database backend.
1060    quote! {
1061        macro_rules! build_update {
1062            ($qb_type:ty) => {{
1063                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1064                let mut first_set = true;
1065                #(#set_arms)*
1066                #(#updated_at_arms)*
1067
1068                if first_set {
1069                    return Err(FerriormError::Query("No fields to update".into()));
1070                }
1071
1072                qb.push(" WHERE 1=1");
1073                self.r#where.build_where(&mut qb);
1074                qb.push(" RETURNING *");
1075                qb
1076            }};
1077        }
1078
1079        match client {
1080            DatabaseClient::Postgres(_) => {
1081                let qb = build_update!(sqlx::Postgres);
1082                client.fetch_one_pg(qb).await
1083            }
1084            DatabaseClient::Sqlite(_) => {
1085                let qb = build_update!(sqlx::Sqlite);
1086                client.fetch_one_sqlite(qb).await
1087            }
1088        }
1089    }
1090}
1091
1092// ─── UPDATE MANY code generation ──────────────────────────────
1093
1094fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1095    // Updatable fields: non-id, non-updatedAt scalar fields
1096    let updatable: Vec<&Field> = scalar_fields
1097        .iter()
1098        .copied()
1099        .filter(|f| !f.is_id && !f.is_updated_at)
1100        .collect();
1101
1102    let updated_at: Vec<&Field> = scalar_fields
1103        .iter()
1104        .copied()
1105        .filter(|f| f.is_updated_at)
1106        .collect();
1107
1108    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1109
1110    // Generate SET clause arms
1111    let set_arms: Vec<TokenStream> = updatable
1112        .iter()
1113        .map(|f| {
1114            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1115            let db_name = &f.db_name;
1116            quote! {
1117                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1118                    if !first_set { qb.push(", "); }
1119                    first_set = false;
1120                    qb.push(concat!("\"", #db_name, "\" = "));
1121                    qb.push_bind(v);
1122                }
1123            }
1124        })
1125        .collect();
1126
1127    let updated_at_arms: Vec<TokenStream> = updated_at
1128        .iter()
1129        .map(|f| {
1130            let db_name = &f.db_name;
1131            quote! {
1132                if !first_set { qb.push(", "); }
1133                first_set = false;
1134                qb.push(concat!("\"", #db_name, "\" = "));
1135                qb.push_bind(chrono::Utc::now());
1136            }
1137        })
1138        .collect();
1139
1140    quote! {
1141        macro_rules! build_update_many {
1142            ($qb_type:ty) => {{
1143                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1144                let mut first_set = true;
1145                #(#set_arms)*
1146                #(#updated_at_arms)*
1147
1148                if first_set {
1149                    return Ok(0);
1150                }
1151
1152                qb.push(" WHERE 1=1");
1153                self.r#where.build_where(&mut qb);
1154                qb
1155            }};
1156        }
1157
1158        match client {
1159            DatabaseClient::Postgres(_) => {
1160                let qb = build_update_many!(sqlx::Postgres);
1161                client.execute_pg(qb).await
1162            }
1163            DatabaseClient::Sqlite(_) => {
1164                let qb = build_update_many!(sqlx::Sqlite);
1165                client.execute_sqlite(qb).await
1166            }
1167        }
1168    }
1169}
1170
1171// ─── Aggregate Types ──────────────────────────────────────────
1172
1173/// Identifies which fields are aggregatable and what operations they support.
1174enum AggregateKind {
1175    /// Numeric fields: avg, sum, min, max
1176    Numeric,
1177    /// DateTime fields: min, max only
1178    DateTime,
1179}
1180
1181fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1182    let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1183    let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1184    let _where_input = format_ident!("{}WhereInput", model.name);
1185    let table_name = &model.db_name;
1186
1187    // Collect aggregatable fields with their kind
1188    let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1189        .iter()
1190        .filter_map(|f| match &f.field_type {
1191            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1192                Some((*f, AggregateKind::Numeric))
1193            }
1194            FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1195            _ => None,
1196        })
1197        .collect();
1198
1199    if agg_fields.is_empty() {
1200        return quote! {};
1201    }
1202
1203    // Generate enum variants
1204    let enum_variants: Vec<TokenStream> = agg_fields
1205        .iter()
1206        .map(|(f, _)| {
1207            let variant = format_ident!("{}", to_pascal_case(&f.name));
1208            quote! { #variant }
1209        })
1210        .collect();
1211
1212    // Generate db_name match arms
1213    let db_name_arms: Vec<TokenStream> = agg_fields
1214        .iter()
1215        .map(|(f, _)| {
1216            let variant = format_ident!("{}", to_pascal_case(&f.name));
1217            let db_name = &f.db_name;
1218            quote! { Self::#variant => #db_name }
1219        })
1220        .collect();
1221
1222    // Generate AggregateResult fields
1223    let mut result_fields = Vec::new();
1224    for (f, kind) in &agg_fields {
1225        let snake = to_snake_case(&f.name);
1226        let orig_ty = rust_type_tokens(
1227            &Field {
1228                is_optional: false,
1229                ..(*f).clone()
1230            },
1231            ModuleDepth::TopLevel,
1232        );
1233
1234        match kind {
1235            AggregateKind::Numeric => {
1236                let avg_name = format_ident!("avg_{}", snake);
1237                let sum_name = format_ident!("sum_{}", snake);
1238                let min_name = format_ident!("min_{}", snake);
1239                let max_name = format_ident!("max_{}", snake);
1240                result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1241                result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1242                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1243                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1244            }
1245            AggregateKind::DateTime => {
1246                let min_name = format_ident!("min_{}", snake);
1247                let max_name = format_ident!("max_{}", snake);
1248                result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1249                result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1250            }
1251        }
1252    }
1253
1254    // Generate the is_numeric check for avg/sum validation
1255    let numeric_arms: Vec<TokenStream> = agg_fields
1256        .iter()
1257        .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1258        .map(|(f, _)| {
1259            let variant = format_ident!("{}", to_pascal_case(&f.name));
1260            quote! { Self::#variant => true }
1261        })
1262        .collect();
1263
1264    let has_numeric = !numeric_arms.is_empty();
1265    let is_numeric_method = if has_numeric {
1266        quote! {
1267            fn is_numeric(&self) -> bool {
1268                match self {
1269                    #(#numeric_arms,)*
1270                    #[allow(unreachable_patterns)]
1271                    _ => false,
1272                }
1273            }
1274        }
1275    } else {
1276        quote! {
1277            fn is_numeric(&self) -> bool { false }
1278        }
1279    };
1280
1281    // Generate alias match arms for each (prefix, field) combination
1282    let mut alias_arms = Vec::new();
1283    for (f, kind) in &agg_fields {
1284        let variant = format_ident!("{}", to_pascal_case(&f.name));
1285        let snake = to_snake_case(&f.name);
1286        let prefixes = match kind {
1287            AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1288            AggregateKind::DateTime => vec!["min", "max"],
1289        };
1290        for prefix in prefixes {
1291            let alias_str = format!("{}_{}", prefix, snake);
1292            alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1293        }
1294    }
1295
1296    let agg_select_base = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1297
1298    quote! {
1299        #[derive(Debug, Clone, Copy)]
1300        pub enum #aggregate_field_name {
1301            #(#enum_variants),*
1302        }
1303
1304        impl #aggregate_field_name {
1305            pub fn db_name(&self) -> &'static str {
1306                match self {
1307                    #(#db_name_arms,)*
1308                }
1309            }
1310
1311            fn alias(&self, prefix: &'static str) -> &'static str {
1312                match (prefix, self) {
1313                    #(#alias_arms,)*
1314                    _ => unreachable!(),
1315                }
1316            }
1317
1318            #is_numeric_method
1319        }
1320
1321        #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1322        pub struct #aggregate_result_name {
1323            #(#result_fields,)*
1324        }
1325
1326        pub struct AggregateQuery<'a> {
1327            client: &'a DatabaseClient,
1328            r#where: filter::#_where_input,
1329            ops: Vec<(&'static str, &'static str, &'static str)>,
1330        }
1331
1332        impl<'a> AggregateQuery<'a> {
1333            pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1334                assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1335                let db_name = field.db_name();
1336                let alias = field.alias("avg");
1337                self.ops.push(("AVG", db_name, alias));
1338                self
1339            }
1340
1341            pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1342                assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1343                let db_name = field.db_name();
1344                let alias = field.alias("sum");
1345                self.ops.push(("SUM", db_name, alias));
1346                self
1347            }
1348
1349            pub fn min(mut self, field: #aggregate_field_name) -> Self {
1350                let db_name = field.db_name();
1351                let alias = field.alias("min");
1352                self.ops.push(("MIN", db_name, alias));
1353                self
1354            }
1355
1356            pub fn max(mut self, field: #aggregate_field_name) -> Self {
1357                let db_name = field.db_name();
1358                let alias = field.alias("max");
1359                self.ops.push(("MAX", db_name, alias));
1360                self
1361            }
1362
1363            pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1364                if self.ops.is_empty() {
1365                    return Err(FerriormError::Query("No aggregate operations specified".into()));
1366                }
1367
1368                let selections: Vec<String> = self.ops.iter()
1369                    .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1370                    .collect();
1371                let select_clause = selections.join(", ");
1372                let base_sql = format!(#agg_select_base, select_clause);
1373
1374                match self.client {
1375                    DatabaseClient::Postgres(_) => {
1376                        let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1377                        self.r#where.build_where(&mut qb);
1378                        self.client.fetch_one_pg(qb).await
1379                    }
1380                    DatabaseClient::Sqlite(_) => {
1381                        let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1382                        self.r#where.build_where(&mut qb);
1383                        self.client.fetch_one_sqlite(qb).await
1384                    }
1385                }
1386            }
1387        }
1388    }
1389}
1390
1391// ─── Select Types ─────────────────────────────────────────────
1392
1393fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1394    let select_name = format_ident!("{}Select", model.name);
1395    let partial_name = format_ident!("{}Partial", model.name);
1396    let _where_input = format_ident!("{}WhereInput", model.name);
1397    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1398    let order_by_name = format_ident!("{}OrderByInput", model.name);
1399    let table_name = &model.db_name;
1400
1401    // Select struct fields: all bool, default false
1402    let select_fields: Vec<TokenStream> = scalar_fields
1403        .iter()
1404        .map(|f| {
1405            let name = format_ident!("{}", to_snake_case(&f.name));
1406            quote! { pub #name: bool }
1407        })
1408        .collect();
1409
1410    // Partial struct fields: all Option<T> with #[sqlx(default)]
1411    // For already-optional fields, don't double-wrap in Option
1412    let partial_fields: Vec<TokenStream> = scalar_fields
1413        .iter()
1414        .map(|f| {
1415            let name = format_ident!("{}", to_snake_case(&f.name));
1416            let db_name = &f.db_name;
1417            // Get the base type (non-optional version)
1418            let base_ty = rust_type_tokens(
1419                &Field {
1420                    is_optional: false,
1421                    ..(*f).clone()
1422                },
1423                ModuleDepth::TopLevel,
1424            );
1425            let rename = if db_name != &to_snake_case(&f.name) {
1426                quote! { #[sqlx(rename = #db_name)] }
1427            } else {
1428                quote! {}
1429            };
1430            // Always wrap in Option<T>, regardless of whether field was originally optional
1431            quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1432        })
1433        .collect();
1434
1435    // build_select_columns: maps Select bools to column names
1436    let select_col_arms: Vec<TokenStream> = scalar_fields
1437        .iter()
1438        .map(|f| {
1439            let name = format_ident!("{}", to_snake_case(&f.name));
1440            let db_name = &f.db_name;
1441            let col_expr = format!(r#""{}""#, db_name);
1442            quote! {
1443                if select.#name { cols.push(#col_expr); }
1444            }
1445        })
1446        .collect();
1447
1448    let select_sql_prefix = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1449
1450    quote! {
1451        #[derive(Debug, Clone, Default)]
1452        pub struct #select_name {
1453            #(#select_fields,)*
1454        }
1455
1456        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
1457        #[sqlx(rename_all = "snake_case")]
1458        pub struct #partial_name {
1459            #(#partial_fields,)*
1460        }
1461
1462        fn build_select_columns(select: &#select_name) -> String {
1463            let mut cols = Vec::new();
1464            #(#select_col_arms)*
1465            if cols.is_empty() {
1466                "*".to_string()
1467            } else {
1468                cols.join(", ")
1469            }
1470        }
1471
1472        // ── FindManySelectQuery ──────────────────────────────────
1473
1474        pub struct FindManySelectQuery<'a> {
1475            client: &'a DatabaseClient,
1476            r#where: filter::#_where_input,
1477            order_by: Vec<order::#order_by_name>,
1478            skip: Option<i64>,
1479            take: Option<i64>,
1480            select: #select_name,
1481        }
1482
1483        impl<'a> FindManySelectQuery<'a> {
1484            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1485                self.order_by.push(order);
1486                self
1487            }
1488
1489            pub fn skip(mut self, n: i64) -> Self {
1490                self.skip = Some(n);
1491                self
1492            }
1493
1494            pub fn take(mut self, n: i64) -> Self {
1495                self.take = Some(n);
1496                self
1497            }
1498
1499            pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
1500                let cols = build_select_columns(&self.select);
1501                let base_sql = format!(#select_sql_prefix, cols);
1502
1503                match self.client {
1504                    DatabaseClient::Postgres(_) => {
1505                        let qb = build_select_query::<sqlx::Postgres>(
1506                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1507                        );
1508                        self.client.fetch_all_pg(qb).await
1509                    }
1510                    DatabaseClient::Sqlite(_) => {
1511                        let qb = build_select_query::<sqlx::Sqlite>(
1512                            &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1513                        );
1514                        self.client.fetch_all_sqlite(qb).await
1515                    }
1516                }
1517            }
1518        }
1519
1520        // ── FindUniqueSelectQuery ────────────────────────────────
1521
1522        pub struct FindUniqueSelectQuery<'a> {
1523            client: &'a DatabaseClient,
1524            r#where: filter::#_where_unique,
1525            select: #select_name,
1526        }
1527
1528        impl<'a> FindUniqueSelectQuery<'a> {
1529            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1530                let cols = build_select_columns(&self.select);
1531                let base_sql = format!(#select_sql_prefix, cols);
1532
1533                match self.client {
1534                    DatabaseClient::Postgres(_) => {
1535                        let qb = build_unique_select_query::<sqlx::Postgres>(
1536                            &base_sql, &self.r#where,
1537                        );
1538                        self.client.fetch_optional_pg(qb).await
1539                    }
1540                    DatabaseClient::Sqlite(_) => {
1541                        let qb = build_unique_select_query::<sqlx::Sqlite>(
1542                            &base_sql, &self.r#where,
1543                        );
1544                        self.client.fetch_optional_sqlite(qb).await
1545                    }
1546                }
1547            }
1548        }
1549
1550        // ── FindFirstSelectQuery ─────────────────────────────────
1551
1552        pub struct FindFirstSelectQuery<'a> {
1553            client: &'a DatabaseClient,
1554            r#where: filter::#_where_input,
1555            order_by: Vec<order::#order_by_name>,
1556            select: #select_name,
1557        }
1558
1559        impl<'a> FindFirstSelectQuery<'a> {
1560            pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1561                self.order_by.push(order);
1562                self
1563            }
1564
1565            pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1566                let cols = build_select_columns(&self.select);
1567                let base_sql = format!(#select_sql_prefix, cols);
1568
1569                match self.client {
1570                    DatabaseClient::Postgres(_) => {
1571                        let qb = build_select_query::<sqlx::Postgres>(
1572                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1573                        );
1574                        self.client.fetch_optional_pg(qb).await
1575                    }
1576                    DatabaseClient::Sqlite(_) => {
1577                        let qb = build_select_query::<sqlx::Sqlite>(
1578                            &base_sql, &self.r#where, &self.order_by, Some(1), None,
1579                        );
1580                        self.client.fetch_optional_sqlite(qb).await
1581                    }
1582                }
1583            }
1584        }
1585    }
1586}