Skip to main content

ferriorm_codegen/
model.rs

1//! Generates per-model Rust modules (struct, filters, data inputs, ordering, CRUD).
2//!
3//! For each model in the schema, this module produces:
4//!
5//! - A **data struct** (e.g., `User`) with `sqlx::FromRow` and serde derives.
6//! - A **filter submodule** with `WhereInput` and `WhereUniqueInput` types.
7//! - A **data submodule** with `CreateInput` and `UpdateInput` types.
8//! - An **order submodule** with `OrderByInput`.
9//! - An **`Actions` struct** exposing `create`, `find_unique`, `find_many`,
10//!   `update`, `delete`, `upsert`, and batch operations.
11//! - **Query builder structs** that chain filters, ordering, pagination, and
12//!   include clauses before calling `.exec()`.
13
14use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22/// Generate the complete module for a single model.
23pub fn generate_model_module(model: &Model) -> TokenStream {
24    let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
25
26    let data_struct = gen_data_struct(model, &scalar_fields);
27    let filter_module = gen_filter_module(model, &scalar_fields);
28    let data_module = gen_data_module(model, &scalar_fields);
29    let order_module = gen_order_module(model, &scalar_fields);
30    let actions_struct = gen_actions(model);
31    let query_builders = gen_query_builders(model, &scalar_fields);
32
33    quote! {
34        #![allow(unused_imports, dead_code, clippy::all, unused_variables)]
35
36        use serde::{Deserialize, Serialize};
37        use ferriorm_runtime::prelude::*;
38
39        #data_struct
40        #filter_module
41        #data_module
42        #order_module
43        #actions_struct
44        #query_builders
45    }
46}
47
48// ─── Data Struct ──────────────────────────────────────────────
49
50fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
51    let struct_name = format_ident!("{}", model.name);
52    let table_name = &model.db_name;
53
54    let fields: Vec<TokenStream> = scalar_fields
55        .iter()
56        .map(|f| {
57            let name = format_ident!("{}", to_snake_case(&f.name));
58            let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
59            let db_name = &f.db_name;
60            if db_name != &to_snake_case(&f.name) {
61                quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
62            } else {
63                quote! { pub #name: #ty }
64            }
65        })
66        .collect();
67
68    quote! {
69        #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
70        #[sqlx(rename_all = "snake_case")]
71        pub struct #struct_name {
72            #(#fields),*
73        }
74
75        impl #struct_name {
76            pub const TABLE_NAME: &'static str = #table_name;
77        }
78    }
79}
80
81// ─── Filter Module ────────────────────────────────────────────
82
83fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
84    let where_input = format_ident!("{}WhereInput", model.name);
85    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
86
87    let where_fields: Vec<TokenStream> = scalar_fields
88        .iter()
89        .filter_map(|f| {
90            let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
91            let name = format_ident!("{}", to_snake_case(&f.name));
92            Some(quote! { pub #name: Option<#filter_ty> })
93        })
94        .collect();
95
96    let unique_variants: Vec<TokenStream> = scalar_fields
97        .iter()
98        .filter(|f| f.is_id || f.is_unique)
99        .map(|f| {
100            let variant = format_ident!("{}", to_pascal_case(&f.name));
101            let ty = rust_type_tokens(f, ModuleDepth::Nested);
102            quote! { #variant(#ty) }
103        })
104        .collect();
105
106    // Generate build_where for WhereInput
107    let db_bounds = collect_db_bounds(scalar_fields);
108    let where_arms = gen_where_arms(scalar_fields);
109    let unique_arms = gen_unique_where_arms(scalar_fields);
110
111    quote! {
112        pub mod filter {
113            use ferriorm_runtime::prelude::*;
114
115            #[derive(Debug, Clone, Default)]
116            pub struct #where_input {
117                #(#where_fields,)*
118                pub and: Option<Vec<#where_input>>,
119                pub or: Option<Vec<#where_input>>,
120                pub not: Option<Box<#where_input>>,
121            }
122
123            #[derive(Debug, Clone)]
124            pub enum #where_unique {
125                #(#unique_variants),*
126            }
127
128            impl #where_input {
129                pub(crate) fn build_where<'args, DB: sqlx::Database>(
130                    &self,
131                    qb: &mut sqlx::QueryBuilder<'args, DB>,
132                )
133                where
134                    #(#db_bounds,)*
135                {
136                    #(#where_arms)*
137
138                    if let Some(conditions) = &self.and {
139                        for c in conditions {
140                            c.build_where(qb);
141                        }
142                    }
143                    if let Some(conditions) = &self.or {
144                        if !conditions.is_empty() {
145                            qb.push(" AND (");
146                            for (i, c) in conditions.iter().enumerate() {
147                                if i > 0 { qb.push(" OR "); }
148                                qb.push("(1=1");
149                                c.build_where(qb);
150                                qb.push(")");
151                            }
152                            qb.push(")");
153                        }
154                    }
155                    if let Some(c) = &self.not {
156                        qb.push(" AND NOT (1=1");
157                        c.build_where(qb);
158                        qb.push(")");
159                    }
160                }
161            }
162
163            impl #where_unique {
164                pub(crate) fn build_where<'args, DB: sqlx::Database>(
165                    &self,
166                    qb: &mut sqlx::QueryBuilder<'args, DB>,
167                )
168                where
169                    #(#db_bounds,)*
170                {
171                    match self {
172                        #(#unique_arms)*
173                    }
174                }
175            }
176        }
177    }
178}
179
180/// Collect the sqlx type bounds needed for all scalar types used by the model.
181fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
182    let mut seen = std::collections::HashSet::new();
183    let mut bounds = Vec::new();
184
185    // Always need i64 for LIMIT/OFFSET
186    seen.insert("i64");
187    bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
188
189    for f in scalar_fields {
190        match &f.field_type {
191            FieldKind::Scalar(scalar) => {
192                let key = scalar.rust_type();
193                if seen.insert(key)
194                    && let Some(ty) = scalar_bound_tokens(scalar)
195                {
196                    bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
197                    // Also add Option<T> bound for nullable field support
198                    bounds.push(
199                        quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
200                    );
201                }
202            }
203            FieldKind::Enum(_) => {}
204            _ => {}
205        }
206    }
207
208    bounds
209}
210
211fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
212    match scalar {
213        ScalarType::String => Some(quote! { String }),
214        ScalarType::Int => Some(quote! { i32 }),
215        ScalarType::BigInt => Some(quote! { i64 }),
216        ScalarType::Float => Some(quote! { f64 }),
217        ScalarType::Boolean => Some(quote! { bool }),
218        ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
219        ScalarType::Bytes => Some(quote! { Vec<u8> }),
220        ScalarType::Json | ScalarType::Decimal => None,
221    }
222}
223
224/// Generate where-clause arms for each filterable scalar field.
225fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
226    scalar_fields
227        .iter()
228        .filter_map(|f| {
229            // Only generate filter arms for scalar types (skip enums for now)
230            if !matches!(&f.field_type, FieldKind::Scalar(_)) {
231                return None;
232            }
233            let field_ident = format_ident!("{}", to_snake_case(&f.name));
234            let db_name = &f.db_name;
235            let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
236            let is_comparable = matches!(
237                &f.field_type,
238                FieldKind::Scalar(
239                    ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
240                )
241            );
242
243            let mut arms = vec![];
244
245            arms.push(quote! {
246                if let Some(v) = &filter.equals {
247                    qb.push(concat!(" AND \"", #db_name, "\" = "));
248                    qb.push_bind(v.clone());
249                }
250                if let Some(v) = &filter.not {
251                    qb.push(concat!(" AND \"", #db_name, "\" != "));
252                    qb.push_bind(v.clone());
253                }
254            });
255
256            if is_string {
257                arms.push(quote! {
258                    if let Some(v) = &filter.contains {
259                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
260                        qb.push_bind(format!("%{}%", v));
261                    }
262                    if let Some(v) = &filter.starts_with {
263                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
264                        qb.push_bind(format!("{}%", v));
265                    }
266                    if let Some(v) = &filter.ends_with {
267                        qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
268                        qb.push_bind(format!("%{}", v));
269                    }
270                });
271            }
272
273            if is_comparable {
274                arms.push(quote! {
275                    if let Some(v) = &filter.gt {
276                        qb.push(concat!(" AND \"", #db_name, "\" > "));
277                        qb.push_bind(v.clone());
278                    }
279                    if let Some(v) = &filter.gte {
280                        qb.push(concat!(" AND \"", #db_name, "\" >= "));
281                        qb.push_bind(v.clone());
282                    }
283                    if let Some(v) = &filter.lt {
284                        qb.push(concat!(" AND \"", #db_name, "\" < "));
285                        qb.push_bind(v.clone());
286                    }
287                    if let Some(v) = &filter.lte {
288                        qb.push(concat!(" AND \"", #db_name, "\" <= "));
289                        qb.push_bind(v.clone());
290                    }
291                });
292            }
293
294            Some(quote! {
295                if let Some(filter) = &self.#field_ident {
296                    #(#arms)*
297                }
298            })
299        })
300        .collect()
301}
302
303fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
304    let _where_unique = format_ident!(
305        "{}WhereUniqueInput",
306        "" // placeholder, we use Self:: instead
307    );
308    scalar_fields
309        .iter()
310        .filter(|f| f.is_id || f.is_unique)
311        .map(|f| {
312            let variant = format_ident!("{}", to_pascal_case(&f.name));
313            let db_name = &f.db_name;
314            quote! {
315                Self::#variant(v) => {
316                    qb.push(concat!(" AND \"", #db_name, "\" = "));
317                    qb.push_bind(v.clone());
318                }
319            }
320        })
321        .collect()
322}
323
324// ─── Data Module ──────────────────────────────────────────────
325
326fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
327    let create_name = format_ident!("{}CreateInput", model.name);
328    let update_name = format_ident!("{}UpdateInput", model.name);
329
330    let required_fields: Vec<TokenStream> = scalar_fields
331        .iter()
332        .filter(|f| !f.has_default() && !f.is_updated_at)
333        .map(|f| {
334            let name = format_ident!("{}", to_snake_case(&f.name));
335            let ty = rust_type_tokens(f, ModuleDepth::Nested);
336            quote! { pub #name: #ty }
337        })
338        .collect();
339
340    let optional_fields: Vec<TokenStream> = scalar_fields
341        .iter()
342        .filter(|f| f.has_default() && !f.is_updated_at)
343        .map(|f| {
344            let name = format_ident!("{}", to_snake_case(&f.name));
345            let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
346            quote! { pub #name: Option<#base_ty> }
347        })
348        .collect();
349
350    let update_fields: Vec<TokenStream> = scalar_fields
351        .iter()
352        .filter(|f| !f.is_id && !f.is_updated_at)
353        .map(|f| {
354            let name = format_ident!("{}", to_snake_case(&f.name));
355            let ty = rust_type_tokens(f, ModuleDepth::Nested);
356            quote! { pub #name: Option<SetValue<#ty>> }
357        })
358        .collect();
359
360    quote! {
361        pub mod data {
362            use ferriorm_runtime::prelude::*;
363
364            #[derive(Debug, Clone)]
365            pub struct #create_name {
366                #(#required_fields,)*
367                #(#optional_fields,)*
368            }
369
370            #[derive(Debug, Clone, Default)]
371            pub struct #update_name {
372                #(#update_fields,)*
373            }
374        }
375    }
376}
377
378// ─── Order Module ─────────────────────────────────────────────
379
380fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
381    let order_name = format_ident!("{}OrderByInput", model.name);
382
383    let variants: Vec<TokenStream> = scalar_fields
384        .iter()
385        .map(|f| {
386            let variant = format_ident!("{}", to_pascal_case(&f.name));
387            quote! { #variant(SortOrder) }
388        })
389        .collect();
390
391    let order_arms: Vec<TokenStream> = scalar_fields
392        .iter()
393        .map(|f| {
394            let variant = format_ident!("{}", to_pascal_case(&f.name));
395            let db_name = &f.db_name;
396            quote! {
397                Self::#variant(order) => {
398                    qb.push(concat!("\"", #db_name, "\" "));
399                    qb.push(order.as_sql());
400                }
401            }
402        })
403        .collect();
404
405    quote! {
406        pub mod order {
407            use ferriorm_runtime::prelude::*;
408
409            #[derive(Debug, Clone)]
410            pub enum #order_name {
411                #(#variants),*
412            }
413
414            impl #order_name {
415                pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
416                    &self,
417                    qb: &mut sqlx::QueryBuilder<'args, DB>,
418                ) {
419                    match self {
420                        #(#order_arms)*
421                    }
422                }
423            }
424        }
425    }
426}
427
428// ─── Actions ──────────────────────────────────────────────────
429
430fn gen_actions(model: &Model) -> TokenStream {
431    let _model_ident = format_ident!("{}", model.name);
432    let actions_name = format_ident!("{}Actions", model.name);
433    let where_input = format_ident!("{}WhereInput", model.name);
434    let where_unique = format_ident!("{}WhereUniqueInput", model.name);
435    let create_input = format_ident!("{}CreateInput", model.name);
436    let update_input = format_ident!("{}UpdateInput", model.name);
437    let _order_by = format_ident!("{}OrderByInput", model.name);
438
439    quote! {
440        pub struct #actions_name<'a> {
441            client: &'a DatabaseClient,
442        }
443
444        impl<'a> #actions_name<'a> {
445            pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
446
447            pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
448                FindUniqueQuery { client: self.client, r#where }
449            }
450
451            pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
452                FindFirstQuery { client: self.client, r#where, order_by: vec![] }
453            }
454
455            pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
456                FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
457            }
458
459            pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
460                CreateQuery { client: self.client, data }
461            }
462
463            pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
464                UpdateQuery { client: self.client, r#where, data }
465            }
466
467            pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
468                DeleteQuery { client: self.client, r#where }
469            }
470
471            pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
472                CountQuery { client: self.client, r#where }
473            }
474
475            pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
476                CreateManyQuery { client: self.client, data }
477            }
478
479            pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
480                UpdateManyQuery { client: self.client, r#where, data }
481            }
482
483            pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
484                DeleteManyQuery { client: self.client, r#where }
485            }
486        }
487    }
488}
489
490// ─── Query Builders with exec() ──────────────────────────────
491
492fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
493    let model_ident = format_ident!("{}", model.name);
494    let table_name = &model.db_name;
495    let _where_input = format_ident!("{}WhereInput", model.name);
496    let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
497    let _create_input = format_ident!("{}CreateInput", model.name);
498    let _update_input = format_ident!("{}UpdateInput", model.name);
499    let order_by = format_ident!("{}OrderByInput", model.name);
500    let db_bounds = collect_db_bounds(scalar_fields);
501
502    let select_sql = format!(r#"SELECT * FROM "{}" WHERE 1=1"#, table_name);
503    let count_sql = format!(
504        r#"SELECT COUNT(*) as "count" FROM "{}" WHERE 1=1"#,
505        table_name
506    );
507    let delete_sql = format!(r#"DELETE FROM "{}" WHERE 1=1"#, table_name);
508
509    let insert_code = gen_insert_code(model, scalar_fields, table_name);
510    let update_code = gen_update_code(model, scalar_fields, table_name);
511    let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
512
513    quote! {
514        // ── Generic helper: build ORDER BY clause ──────────────
515        fn build_order_by<'args, DB: sqlx::Database>(
516            orders: &[order::#order_by],
517            qb: &mut sqlx::QueryBuilder<'args, DB>,
518        ) {
519            if !orders.is_empty() {
520                qb.push(" ORDER BY ");
521                for (i, ob) in orders.iter().enumerate() {
522                    if i > 0 { qb.push(", "); }
523                    ob.build_order_by(qb);
524                }
525            }
526        }
527
528        // ── Generic helper: build a SELECT query ───────────────
529        fn build_select_query<'args, DB: sqlx::Database>(
530            base_sql: &str,
531            where_input: &filter::#_where_input,
532            orders: &[order::#order_by],
533            take: Option<i64>,
534            skip: Option<i64>,
535        ) -> sqlx::QueryBuilder<'args, DB>
536        where
537            #(#db_bounds,)*
538        {
539            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
540            where_input.build_where(&mut qb);
541            build_order_by(orders, &mut qb);
542            if let Some(take) = take {
543                qb.push(" LIMIT ");
544                qb.push_bind(take);
545            }
546            if let Some(skip) = skip {
547                qb.push(" OFFSET ");
548                qb.push_bind(skip);
549            }
550            qb
551        }
552
553        // ── Generic helper: build a SELECT query for unique lookup ──
554        fn build_unique_select_query<'args, DB: sqlx::Database>(
555            base_sql: &str,
556            where_unique: &filter::#_where_unique,
557        ) -> sqlx::QueryBuilder<'args, DB>
558        where
559            #(#db_bounds,)*
560        {
561            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
562            where_unique.build_where(&mut qb);
563            qb.push(" LIMIT 1");
564            qb
565        }
566
567        // ── Generic helper: build a DELETE-returning query ─────
568        fn build_delete_query<'args, DB: sqlx::Database>(
569            base_sql: &str,
570            where_unique: &filter::#_where_unique,
571        ) -> sqlx::QueryBuilder<'args, DB>
572        where
573            #(#db_bounds,)*
574        {
575            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
576            where_unique.build_where(&mut qb);
577            qb.push(" RETURNING *");
578            qb
579        }
580
581        // ── Generic helper: build a COUNT query ────────────────
582        fn build_count_query<'args, DB: sqlx::Database>(
583            base_sql: &str,
584            where_input: &filter::#_where_input,
585        ) -> sqlx::QueryBuilder<'args, DB>
586        where
587            #(#db_bounds,)*
588        {
589            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
590            where_input.build_where(&mut qb);
591            qb
592        }
593
594        // ── Generic helper: build a DELETE-many query ──────────
595        fn build_delete_many_query<'args, DB: sqlx::Database>(
596            base_sql: &str,
597            where_input: &filter::#_where_input,
598        ) -> sqlx::QueryBuilder<'args, DB>
599        where
600            #(#db_bounds,)*
601        {
602            let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
603            where_input.build_where(&mut qb);
604            qb
605        }
606
607        pub struct FindUniqueQuery<'a> {
608            client: &'a DatabaseClient,
609            r#where: filter::#_where_unique,
610        }
611
612        impl<'a> FindUniqueQuery<'a> {
613            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
614                match self.client {
615                    DatabaseClient::Postgres(_) => {
616                        let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
617                        self.client.fetch_optional_pg(qb).await
618                    }
619                    DatabaseClient::Sqlite(_) => {
620                        let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
621                        self.client.fetch_optional_sqlite(qb).await
622                    }
623                }
624            }
625        }
626
627        pub struct FindFirstQuery<'a> {
628            client: &'a DatabaseClient,
629            r#where: filter::#_where_input,
630            order_by: Vec<order::#order_by>,
631        }
632
633        impl<'a> FindFirstQuery<'a> {
634            pub fn order_by(mut self, order: order::#order_by) -> Self {
635                self.order_by.push(order);
636                self
637            }
638
639            pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
640                match self.client {
641                    DatabaseClient::Postgres(_) => {
642                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
643                        self.client.fetch_optional_pg(qb).await
644                    }
645                    DatabaseClient::Sqlite(_) => {
646                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
647                        self.client.fetch_optional_sqlite(qb).await
648                    }
649                }
650            }
651        }
652
653        pub struct FindManyQuery<'a> {
654            client: &'a DatabaseClient,
655            r#where: filter::#_where_input,
656            order_by: Vec<order::#order_by>,
657            skip: Option<i64>,
658            take: Option<i64>,
659        }
660
661        impl<'a> FindManyQuery<'a> {
662            pub fn order_by(mut self, order: order::#order_by) -> Self {
663                self.order_by.push(order);
664                self
665            }
666
667            pub fn skip(mut self, n: i64) -> Self {
668                self.skip = Some(n);
669                self
670            }
671
672            pub fn take(mut self, n: i64) -> Self {
673                self.take = Some(n);
674                self
675            }
676
677            pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
678                match self.client {
679                    DatabaseClient::Postgres(_) => {
680                        let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
681                        self.client.fetch_all_pg(qb).await
682                    }
683                    DatabaseClient::Sqlite(_) => {
684                        let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
685                        self.client.fetch_all_sqlite(qb).await
686                    }
687                }
688            }
689        }
690
691        pub struct CreateQuery<'a> {
692            client: &'a DatabaseClient,
693            data: data::#_create_input,
694        }
695
696        impl<'a> CreateQuery<'a> {
697            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
698                let client = self.client;
699                #insert_code
700            }
701        }
702
703        pub struct UpdateQuery<'a> {
704            client: &'a DatabaseClient,
705            r#where: filter::#_where_unique,
706            data: data::#_update_input,
707        }
708
709        impl<'a> UpdateQuery<'a> {
710            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
711                let client = self.client;
712                #update_code
713            }
714        }
715
716        pub struct DeleteQuery<'a> {
717            client: &'a DatabaseClient,
718            r#where: filter::#_where_unique,
719        }
720
721        impl<'a> DeleteQuery<'a> {
722            pub async fn exec(self) -> Result<#model_ident, FerriormError> {
723                match self.client {
724                    DatabaseClient::Postgres(_) => {
725                        let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
726                        self.client.fetch_one_pg(qb).await
727                    }
728                    DatabaseClient::Sqlite(_) => {
729                        let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
730                        self.client.fetch_one_sqlite(qb).await
731                    }
732                }
733            }
734        }
735
736        #[derive(sqlx::FromRow)]
737        struct CountResult { count: i64 }
738
739        pub struct CountQuery<'a> {
740            client: &'a DatabaseClient,
741            r#where: filter::#_where_input,
742        }
743
744        impl<'a> CountQuery<'a> {
745            pub async fn exec(self) -> Result<i64, FerriormError> {
746                let row: CountResult = match self.client {
747                    DatabaseClient::Postgres(_) => {
748                        let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
749                        self.client.fetch_one_pg(qb).await?
750                    }
751                    DatabaseClient::Sqlite(_) => {
752                        let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
753                        self.client.fetch_one_sqlite(qb).await?
754                    }
755                };
756                Ok(row.count)
757            }
758        }
759
760        pub struct CreateManyQuery<'a> {
761            client: &'a DatabaseClient,
762            data: Vec<data::#_create_input>,
763        }
764
765        impl<'a> CreateManyQuery<'a> {
766            pub async fn exec(self) -> Result<u64, FerriormError> {
767                if self.data.is_empty() { return Ok(0); }
768                let count = self.data.len() as u64;
769                for item in self.data {
770                    CreateQuery { client: self.client, data: item }.exec().await?;
771                }
772                Ok(count)
773            }
774        }
775
776        pub struct UpdateManyQuery<'a> {
777            client: &'a DatabaseClient,
778            r#where: filter::#_where_input,
779            data: data::#_update_input,
780        }
781
782        impl<'a> UpdateManyQuery<'a> {
783            pub async fn exec(self) -> Result<u64, FerriormError> {
784                let client = self.client;
785                #update_many_code
786            }
787        }
788
789        pub struct DeleteManyQuery<'a> {
790            client: &'a DatabaseClient,
791            r#where: filter::#_where_input,
792        }
793
794        impl<'a> DeleteManyQuery<'a> {
795            pub async fn exec(self) -> Result<u64, FerriormError> {
796                match self.client {
797                    DatabaseClient::Postgres(_) => {
798                        let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
799                        self.client.execute_pg(qb).await
800                    }
801                    DatabaseClient::Sqlite(_) => {
802                        let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
803                        self.client.execute_sqlite(qb).await
804                    }
805                }
806            }
807        }
808    }
809}
810
811// ─── INSERT code generation ───────────────────────────────────
812
813fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
814    let _model_ident = format_ident!("{}", model.name);
815
816    // Required columns: scalar, no default, not @updatedAt
817    let required: Vec<&Field> = scalar_fields
818        .iter()
819        .copied()
820        .filter(|f| !f.has_default() && !f.is_updated_at)
821        .collect();
822
823    // Optional columns: have default (can be overridden), not @updatedAt
824    let optional: Vec<&Field> = scalar_fields
825        .iter()
826        .copied()
827        .filter(|f| f.has_default() && !f.is_updated_at)
828        .collect();
829
830    // @updatedAt columns: always set to now()
831    let updated_at: Vec<&Field> = scalar_fields
832        .iter()
833        .copied()
834        .filter(|f| f.is_updated_at)
835        .collect();
836
837    // Build column names and bind values
838    let mut col_pushes = vec![];
839    let mut val_pushes = vec![];
840
841    // Required fields — always included
842    for f in &required {
843        let db_name = &f.db_name;
844        let field_ident = format_ident!("{}", to_snake_case(&f.name));
845        col_pushes.push(quote! { cols.push(#db_name); });
846        val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
847    }
848
849    // Optional fields — resolve defaults in Rust
850    for f in &optional {
851        let db_name = &f.db_name;
852        let field_ident = format_ident!("{}", to_snake_case(&f.name));
853        let default_expr = gen_default_expr(f);
854
855        col_pushes.push(quote! { cols.push(#db_name); });
856        val_pushes.push(quote! {
857            let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
858            sep.push_bind(val);
859        });
860    }
861
862    // @updatedAt fields
863    for f in &updated_at {
864        let db_name = &f.db_name;
865        col_pushes.push(quote! { cols.push(#db_name); });
866        val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
867    }
868
869    let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
870
871    // The insert_body macro avoids duplicating the column/value building logic
872    // for each database backend. It captures `self` by reference.
873    quote! {
874        // Helper to build the INSERT query for any DB backend
875        macro_rules! build_insert {
876            ($qb_type:ty) => {{
877                let mut cols: Vec<&str> = Vec::new();
878                #(#col_pushes)*
879
880                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
881                qb.push(" (");
882                for (i, col) in cols.iter().enumerate() {
883                    if i > 0 { qb.push(", "); }
884                    qb.push("\"");
885                    qb.push(*col);
886                    qb.push("\"");
887                }
888                qb.push(") VALUES (");
889                {
890                    let mut sep = qb.separated(", ");
891                    #(#val_pushes)*
892                }
893                qb.push(") RETURNING *");
894                qb
895            }};
896        }
897
898        match client {
899            DatabaseClient::Postgres(_) => {
900                let qb = build_insert!(sqlx::Postgres);
901                client.fetch_one_pg(qb).await
902            }
903            DatabaseClient::Sqlite(_) => {
904                let qb = build_insert!(sqlx::Sqlite);
905                client.fetch_one_sqlite(qb).await
906            }
907        }
908    }
909}
910
911/// Generate a Rust expression for a field's @default value.
912fn gen_default_expr(field: &Field) -> TokenStream {
913    use ferriorm_core::ast::DefaultValue;
914
915    match &field.default {
916        Some(DefaultValue::Uuid) => quote! { uuid::Uuid::new_v4().to_string() },
917        Some(DefaultValue::Cuid) => quote! { uuid::Uuid::new_v4().to_string() }, // fallback
918        Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
919        Some(DefaultValue::AutoIncrement) => quote! { 0 }, // DB handles this
920        Some(DefaultValue::Literal(lit)) => {
921            use ferriorm_core::ast::LiteralValue;
922            match lit {
923                LiteralValue::String(s) => quote! { #s.to_string() },
924                LiteralValue::Int(i) => quote! { #i },
925                LiteralValue::Float(f) => quote! { #f },
926                LiteralValue::Bool(b) => quote! { #b },
927            }
928        }
929        Some(DefaultValue::EnumVariant(v)) => {
930            // Reference the enum variant — insert code runs at model module level
931            let variant = format_ident!("{}", v);
932            if let FieldKind::Enum(enum_name) = &field.field_type {
933                let enum_ident = format_ident!("{}", enum_name);
934                quote! { super::enums::#enum_ident::#variant }
935            } else {
936                quote! { Default::default() }
937            }
938        }
939        None => quote! { Default::default() },
940    }
941}
942
943// ─── UPDATE code generation ───────────────────────────────────
944
945fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
946    let _model_ident = format_ident!("{}", model.name);
947
948    // Updatable fields: non-id, non-updatedAt scalar fields
949    let updatable: Vec<&Field> = scalar_fields
950        .iter()
951        .copied()
952        .filter(|f| !f.is_id && !f.is_updated_at)
953        .collect();
954
955    let updated_at: Vec<&Field> = scalar_fields
956        .iter()
957        .copied()
958        .filter(|f| f.is_updated_at)
959        .collect();
960
961    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
962
963    // Generate SET clause arms
964    let set_arms: Vec<TokenStream> = updatable
965        .iter()
966        .map(|f| {
967            let field_ident = format_ident!("{}", to_snake_case(&f.name));
968            let db_name = &f.db_name;
969            quote! {
970                if let Some(SetValue::Set(v)) = self.data.#field_ident {
971                    if !first_set { qb.push(", "); }
972                    first_set = false;
973                    qb.push(concat!("\"", #db_name, "\" = "));
974                    qb.push_bind(v);
975                }
976            }
977        })
978        .collect();
979
980    let updated_at_arms: Vec<TokenStream> = updated_at
981        .iter()
982        .map(|f| {
983            let db_name = &f.db_name;
984            quote! {
985                if !first_set { qb.push(", "); }
986                first_set = false;
987                qb.push(concat!("\"", #db_name, "\" = "));
988                qb.push_bind(chrono::Utc::now());
989            }
990        })
991        .collect();
992
993    // The build_update macro avoids duplicating the SET clause building logic
994    // for each database backend.
995    quote! {
996        macro_rules! build_update {
997            ($qb_type:ty) => {{
998                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
999                let mut first_set = true;
1000                #(#set_arms)*
1001                #(#updated_at_arms)*
1002
1003                if first_set {
1004                    return Err(FerriormError::Query("No fields to update".into()));
1005                }
1006
1007                qb.push(" WHERE 1=1");
1008                self.r#where.build_where(&mut qb);
1009                qb.push(" RETURNING *");
1010                qb
1011            }};
1012        }
1013
1014        match client {
1015            DatabaseClient::Postgres(_) => {
1016                let qb = build_update!(sqlx::Postgres);
1017                client.fetch_one_pg(qb).await
1018            }
1019            DatabaseClient::Sqlite(_) => {
1020                let qb = build_update!(sqlx::Sqlite);
1021                client.fetch_one_sqlite(qb).await
1022            }
1023        }
1024    }
1025}
1026
1027// ─── UPDATE MANY code generation ──────────────────────────────
1028
1029fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1030    // Updatable fields: non-id, non-updatedAt scalar fields
1031    let updatable: Vec<&Field> = scalar_fields
1032        .iter()
1033        .copied()
1034        .filter(|f| !f.is_id && !f.is_updated_at)
1035        .collect();
1036
1037    let updated_at: Vec<&Field> = scalar_fields
1038        .iter()
1039        .copied()
1040        .filter(|f| f.is_updated_at)
1041        .collect();
1042
1043    let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1044
1045    // Generate SET clause arms
1046    let set_arms: Vec<TokenStream> = updatable
1047        .iter()
1048        .map(|f| {
1049            let field_ident = format_ident!("{}", to_snake_case(&f.name));
1050            let db_name = &f.db_name;
1051            quote! {
1052                if let Some(SetValue::Set(v)) = self.data.#field_ident {
1053                    if !first_set { qb.push(", "); }
1054                    first_set = false;
1055                    qb.push(concat!("\"", #db_name, "\" = "));
1056                    qb.push_bind(v);
1057                }
1058            }
1059        })
1060        .collect();
1061
1062    let updated_at_arms: Vec<TokenStream> = updated_at
1063        .iter()
1064        .map(|f| {
1065            let db_name = &f.db_name;
1066            quote! {
1067                if !first_set { qb.push(", "); }
1068                first_set = false;
1069                qb.push(concat!("\"", #db_name, "\" = "));
1070                qb.push_bind(chrono::Utc::now());
1071            }
1072        })
1073        .collect();
1074
1075    quote! {
1076        macro_rules! build_update_many {
1077            ($qb_type:ty) => {{
1078                let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1079                let mut first_set = true;
1080                #(#set_arms)*
1081                #(#updated_at_arms)*
1082
1083                if first_set {
1084                    return Ok(0);
1085                }
1086
1087                qb.push(" WHERE 1=1");
1088                self.r#where.build_where(&mut qb);
1089                qb
1090            }};
1091        }
1092
1093        match client {
1094            DatabaseClient::Postgres(_) => {
1095                let qb = build_update_many!(sqlx::Postgres);
1096                client.execute_pg(qb).await
1097            }
1098            DatabaseClient::Sqlite(_) => {
1099                let qb = build_update_many!(sqlx::Sqlite);
1100                client.execute_sqlite(qb).await
1101            }
1102        }
1103    }
1104}