1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27 let data_struct = gen_data_struct(model, &scalar_fields);
28 let filter_module = gen_filter_module(model, &scalar_fields);
29 let data_module = gen_data_module(model, &scalar_fields);
30 let order_module = gen_order_module(model, &scalar_fields);
31 let actions_struct = gen_actions(model, &scalar_fields);
32 let query_builders = gen_query_builders(model, &scalar_fields);
33 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34 let select_types = gen_select_types(model, &scalar_fields);
35
36 quote! {
37 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
38
39 use serde::{Deserialize, Serialize};
40 use ferriorm_runtime::prelude::*;
41 use ferriorm_runtime::prelude::sqlx;
42 use ferriorm_runtime::prelude::chrono;
43 use ferriorm_runtime::prelude::uuid;
44
45 #data_struct
46 #filter_module
47 #data_module
48 #order_module
49 #actions_struct
50 #query_builders
51 #aggregate_types
52 #select_types
53 }
54}
55
56fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
59 let struct_name = format_ident!("{}", model.name);
60 let table_name = &model.db_name;
61
62 let fields: Vec<TokenStream> = scalar_fields
63 .iter()
64 .map(|f| {
65 let name = format_ident!("{}", to_snake_case(&f.name));
66 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
67 let db_name = &f.db_name;
68 if db_name == &to_snake_case(&f.name) {
69 quote! { pub #name: #ty }
70 } else {
71 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
72 }
73 })
74 .collect();
75
76 quote! {
77 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
78 #[sqlx(rename_all = "snake_case")]
79 pub struct #struct_name {
80 #(#fields),*
81 }
82
83 impl #struct_name {
84 pub const TABLE_NAME: &'static str = #table_name;
85 }
86 }
87}
88
89fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
92 let where_input = format_ident!("{}WhereInput", model.name);
93 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
94
95 let where_fields: Vec<TokenStream> = scalar_fields
96 .iter()
97 .filter_map(|f| {
98 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
99 let name = format_ident!("{}", to_snake_case(&f.name));
100 Some(quote! { pub #name: Option<#filter_ty> })
101 })
102 .collect();
103
104 let unique_variants: Vec<TokenStream> = scalar_fields
105 .iter()
106 .filter(|f| f.is_id || f.is_unique)
107 .map(|f| {
108 let variant = format_ident!("{}", to_pascal_case(&f.name));
109 let ty = rust_type_tokens(f, ModuleDepth::Nested);
110 quote! { #variant(#ty) }
111 })
112 .collect();
113
114 let db_bounds = collect_db_bounds(scalar_fields);
116 let where_arms = gen_where_arms(scalar_fields);
117 let unique_arms = gen_unique_where_arms(scalar_fields);
118
119 quote! {
120 pub mod filter {
121 use ferriorm_runtime::prelude::*;
122
123 #[derive(Debug, Clone, Default)]
124 pub struct #where_input {
125 #(#where_fields,)*
126 pub and: Option<Vec<#where_input>>,
127 pub or: Option<Vec<#where_input>>,
128 pub not: Option<Box<#where_input>>,
129 }
130
131 #[derive(Debug, Clone)]
132 pub enum #where_unique {
133 #(#unique_variants),*
134 }
135
136 impl #where_input {
137 pub(crate) fn build_where<'args, DB: sqlx::Database>(
138 &self,
139 qb: &mut sqlx::QueryBuilder<'args, DB>,
140 )
141 where
142 #(#db_bounds,)*
143 {
144 #(#where_arms)*
145
146 if let Some(conditions) = &self.and {
147 for c in conditions {
148 c.build_where(qb);
149 }
150 }
151 if let Some(conditions) = &self.or {
152 if !conditions.is_empty() {
153 qb.push(" AND (");
154 for (i, c) in conditions.iter().enumerate() {
155 if i > 0 { qb.push(" OR "); }
156 qb.push("(1=1");
157 c.build_where(qb);
158 qb.push(")");
159 }
160 qb.push(")");
161 }
162 }
163 if let Some(c) = &self.not {
164 qb.push(" AND NOT (1=1");
165 c.build_where(qb);
166 qb.push(")");
167 }
168 }
169 }
170
171 impl #where_unique {
172 pub(crate) fn build_where<'args, DB: sqlx::Database>(
173 &self,
174 qb: &mut sqlx::QueryBuilder<'args, DB>,
175 )
176 where
177 #(#db_bounds,)*
178 {
179 match self {
180 #(#unique_arms)*
181 }
182 }
183 }
184 }
185 }
186}
187
188fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
190 let mut seen = std::collections::HashSet::new();
191 let mut bounds = Vec::new();
192
193 seen.insert("i64");
195 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
196
197 for f in scalar_fields {
198 match &f.field_type {
199 FieldKind::Scalar(scalar) => {
200 let key = scalar.rust_type();
201 if seen.insert(key)
202 && let Some(ty) = scalar_bound_tokens(scalar)
203 {
204 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
205 bounds.push(
207 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
208 );
209 }
210 }
211 FieldKind::Enum(_) | FieldKind::Model(_) => {}
212 }
213 }
214
215 bounds
216}
217
218fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
219 match scalar {
220 ScalarType::String => Some(quote! { String }),
221 ScalarType::Int => Some(quote! { i32 }),
222 ScalarType::BigInt => Some(quote! { i64 }),
223 ScalarType::Float => Some(quote! { f64 }),
224 ScalarType::Boolean => Some(quote! { bool }),
225 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
226 ScalarType::Bytes => Some(quote! { Vec<u8> }),
227 ScalarType::Json | ScalarType::Decimal => None,
228 }
229}
230
231fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
233 scalar_fields
234 .iter()
235 .filter_map(|f| {
236 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
238 return None;
239 }
240 let field_ident = format_ident!("{}", to_snake_case(&f.name));
241 let db_name = &f.db_name;
242 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
243 let is_comparable = matches!(
244 &f.field_type,
245 FieldKind::Scalar(
246 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
247 )
248 );
249
250 let mut arms = vec![];
251
252 arms.push(quote! {
253 if let Some(v) = &filter.equals {
254 qb.push(concat!(" AND \"", #db_name, "\" = "));
255 qb.push_bind(v.clone());
256 }
257 if let Some(v) = &filter.not {
258 qb.push(concat!(" AND \"", #db_name, "\" != "));
259 qb.push_bind(v.clone());
260 }
261 });
262
263 if is_string {
264 arms.push(quote! {
265 if let Some(v) = &filter.contains {
266 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
267 qb.push_bind(format!("%{}%", v));
268 }
269 if let Some(v) = &filter.starts_with {
270 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
271 qb.push_bind(format!("{}%", v));
272 }
273 if let Some(v) = &filter.ends_with {
274 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
275 qb.push_bind(format!("%{}", v));
276 }
277 });
278 }
279
280 if is_comparable {
281 arms.push(quote! {
282 if let Some(v) = &filter.gt {
283 qb.push(concat!(" AND \"", #db_name, "\" > "));
284 qb.push_bind(v.clone());
285 }
286 if let Some(v) = &filter.gte {
287 qb.push(concat!(" AND \"", #db_name, "\" >= "));
288 qb.push_bind(v.clone());
289 }
290 if let Some(v) = &filter.lt {
291 qb.push(concat!(" AND \"", #db_name, "\" < "));
292 qb.push_bind(v.clone());
293 }
294 if let Some(v) = &filter.lte {
295 qb.push(concat!(" AND \"", #db_name, "\" <= "));
296 qb.push_bind(v.clone());
297 }
298 });
299 }
300
301 Some(quote! {
302 if let Some(filter) = &self.#field_ident {
303 #(#arms)*
304 }
305 })
306 })
307 .collect()
308}
309
310fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
311 let _where_unique = format_ident!(
312 "{}WhereUniqueInput",
313 "" );
315 scalar_fields
316 .iter()
317 .filter(|f| f.is_id || f.is_unique)
318 .map(|f| {
319 let variant = format_ident!("{}", to_pascal_case(&f.name));
320 let db_name = &f.db_name;
321 quote! {
322 Self::#variant(v) => {
323 qb.push(concat!(" AND \"", #db_name, "\" = "));
324 qb.push_bind(v.clone());
325 }
326 }
327 })
328 .collect()
329}
330
331fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
334 let create_name = format_ident!("{}CreateInput", model.name);
335 let update_name = format_ident!("{}UpdateInput", model.name);
336
337 let required_fields: Vec<TokenStream> = scalar_fields
338 .iter()
339 .filter(|f| !f.has_default() && !f.is_updated_at)
340 .map(|f| {
341 let name = format_ident!("{}", to_snake_case(&f.name));
342 let ty = rust_type_tokens(f, ModuleDepth::Nested);
343 quote! { pub #name: #ty }
344 })
345 .collect();
346
347 let optional_fields: Vec<TokenStream> = scalar_fields
348 .iter()
349 .filter(|f| f.has_default() && !f.is_updated_at)
350 .map(|f| {
351 let name = format_ident!("{}", to_snake_case(&f.name));
352 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
353 quote! { pub #name: Option<#base_ty> }
354 })
355 .collect();
356
357 let update_fields: Vec<TokenStream> = scalar_fields
358 .iter()
359 .filter(|f| !f.is_id && !f.is_updated_at)
360 .map(|f| {
361 let name = format_ident!("{}", to_snake_case(&f.name));
362 let ty = rust_type_tokens(f, ModuleDepth::Nested);
363 quote! { pub #name: Option<SetValue<#ty>> }
364 })
365 .collect();
366
367 quote! {
368 pub mod data {
369 use ferriorm_runtime::prelude::*;
370
371 #[derive(Debug, Clone)]
372 pub struct #create_name {
373 #(#required_fields,)*
374 #(#optional_fields,)*
375 }
376
377 #[derive(Debug, Clone, Default)]
378 pub struct #update_name {
379 #(#update_fields,)*
380 }
381 }
382 }
383}
384
385fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
388 let order_name = format_ident!("{}OrderByInput", model.name);
389
390 let variants: Vec<TokenStream> = scalar_fields
391 .iter()
392 .map(|f| {
393 let variant = format_ident!("{}", to_pascal_case(&f.name));
394 quote! { #variant(SortOrder) }
395 })
396 .collect();
397
398 let order_arms: Vec<TokenStream> = scalar_fields
399 .iter()
400 .map(|f| {
401 let variant = format_ident!("{}", to_pascal_case(&f.name));
402 let db_name = &f.db_name;
403 quote! {
404 Self::#variant(order) => {
405 qb.push(concat!("\"", #db_name, "\" "));
406 qb.push(order.as_sql());
407 }
408 }
409 })
410 .collect();
411
412 quote! {
413 pub mod order {
414 use ferriorm_runtime::prelude::*;
415
416 #[derive(Debug, Clone)]
417 pub enum #order_name {
418 #(#variants),*
419 }
420
421 impl #order_name {
422 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
423 &self,
424 qb: &mut sqlx::QueryBuilder<'args, DB>,
425 ) {
426 match self {
427 #(#order_arms)*
428 }
429 }
430 }
431 }
432 }
433}
434
435fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
438 let _model_ident = format_ident!("{}", model.name);
439 let actions_name = format_ident!("{}Actions", model.name);
440 let where_input = format_ident!("{}WhereInput", model.name);
441 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
442 let create_input = format_ident!("{}CreateInput", model.name);
443 let update_input = format_ident!("{}UpdateInput", model.name);
444 let _order_by = format_ident!("{}OrderByInput", model.name);
445
446 let has_agg_fields = scalar_fields.iter().any(|f| {
448 matches!(
449 &f.field_type,
450 FieldKind::Scalar(
451 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
452 )
453 )
454 });
455 let aggregate_method = if has_agg_fields {
456 quote! {
457 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
458 AggregateQuery { client: self.client, r#where, ops: vec![] }
459 }
460 }
461 } else {
462 quote! {}
463 };
464
465 quote! {
466 pub struct #actions_name<'a> {
467 client: &'a DatabaseClient,
468 }
469
470 impl<'a> #actions_name<'a> {
471 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
472
473 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
474 FindUniqueQuery { client: self.client, r#where }
475 }
476
477 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
478 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
479 }
480
481 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
482 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
483 }
484
485 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
486 CreateQuery { client: self.client, data }
487 }
488
489 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
490 UpdateQuery { client: self.client, r#where, data }
491 }
492
493 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
494 DeleteQuery { client: self.client, r#where }
495 }
496
497 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
498 CountQuery { client: self.client, r#where }
499 }
500
501 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
502 CreateManyQuery { client: self.client, data }
503 }
504
505 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
506 UpdateManyQuery { client: self.client, r#where, data }
507 }
508
509 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
510 DeleteManyQuery { client: self.client, r#where }
511 }
512
513 pub fn upsert(
514 &self,
515 r#where: filter::#where_unique,
516 create: data::#create_input,
517 update: data::#update_input,
518 ) -> UpsertQuery<'a> {
519 UpsertQuery { client: self.client, r#where, create, update }
520 }
521
522 #aggregate_method
523 }
524 }
525}
526
527#[allow(clippy::too_many_lines)]
530fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
531 let model_ident = format_ident!("{}", model.name);
532 let table_name = &model.db_name;
533 let _where_input = format_ident!("{}WhereInput", model.name);
534 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
535 let _create_input = format_ident!("{}CreateInput", model.name);
536 let _update_input = format_ident!("{}UpdateInput", model.name);
537 let order_by = format_ident!("{}OrderByInput", model.name);
538 let _select_struct = format_ident!("{}Select", model.name);
539 let _partial_struct = format_ident!("{}Partial", model.name);
540 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
541 let _aggregate_field = format_ident!("{}AggregateField", model.name);
542 let db_bounds = collect_db_bounds(scalar_fields);
543
544 let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
545 let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
546 let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
547
548 let insert_code = gen_insert_code(model, scalar_fields, table_name);
549 let update_code = gen_update_code(model, scalar_fields, table_name);
550 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
551 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
552
553 quote! {
554 fn build_order_by<'args, DB: sqlx::Database>(
556 orders: &[order::#order_by],
557 qb: &mut sqlx::QueryBuilder<'args, DB>,
558 ) {
559 if !orders.is_empty() {
560 qb.push(" ORDER BY ");
561 for (i, ob) in orders.iter().enumerate() {
562 if i > 0 { qb.push(", "); }
563 ob.build_order_by(qb);
564 }
565 }
566 }
567
568 fn build_select_query<'args, DB: sqlx::Database>(
570 base_sql: &str,
571 where_input: &filter::#_where_input,
572 orders: &[order::#order_by],
573 take: Option<i64>,
574 skip: Option<i64>,
575 ) -> sqlx::QueryBuilder<'args, DB>
576 where
577 #(#db_bounds,)*
578 {
579 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
580 where_input.build_where(&mut qb);
581 build_order_by(orders, &mut qb);
582 if let Some(take) = take {
583 qb.push(" LIMIT ");
584 qb.push_bind(take);
585 }
586 if let Some(skip) = skip {
587 qb.push(" OFFSET ");
588 qb.push_bind(skip);
589 }
590 qb
591 }
592
593 fn build_unique_select_query<'args, DB: sqlx::Database>(
595 base_sql: &str,
596 where_unique: &filter::#_where_unique,
597 ) -> sqlx::QueryBuilder<'args, DB>
598 where
599 #(#db_bounds,)*
600 {
601 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
602 where_unique.build_where(&mut qb);
603 qb.push(" LIMIT 1");
604 qb
605 }
606
607 fn build_delete_query<'args, DB: sqlx::Database>(
609 base_sql: &str,
610 where_unique: &filter::#_where_unique,
611 ) -> sqlx::QueryBuilder<'args, DB>
612 where
613 #(#db_bounds,)*
614 {
615 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
616 where_unique.build_where(&mut qb);
617 qb.push(" RETURNING *");
618 qb
619 }
620
621 fn build_count_query<'args, DB: sqlx::Database>(
623 base_sql: &str,
624 where_input: &filter::#_where_input,
625 ) -> sqlx::QueryBuilder<'args, DB>
626 where
627 #(#db_bounds,)*
628 {
629 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
630 where_input.build_where(&mut qb);
631 qb
632 }
633
634 fn build_delete_many_query<'args, DB: sqlx::Database>(
636 base_sql: &str,
637 where_input: &filter::#_where_input,
638 ) -> sqlx::QueryBuilder<'args, DB>
639 where
640 #(#db_bounds,)*
641 {
642 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
643 where_input.build_where(&mut qb);
644 qb
645 }
646
647 pub struct FindUniqueQuery<'a> {
648 client: &'a DatabaseClient,
649 r#where: filter::#_where_unique,
650 }
651
652 impl<'a> FindUniqueQuery<'a> {
653 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
654 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
655 }
656
657 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
658 match self.client {
659 DatabaseClient::Postgres(_) => {
660 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
661 self.client.fetch_optional_pg(qb).await
662 }
663 DatabaseClient::Sqlite(_) => {
664 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
665 self.client.fetch_optional_sqlite(qb).await
666 }
667 }
668 }
669 }
670
671 pub struct FindFirstQuery<'a> {
672 client: &'a DatabaseClient,
673 r#where: filter::#_where_input,
674 order_by: Vec<order::#order_by>,
675 }
676
677 impl<'a> FindFirstQuery<'a> {
678 pub fn order_by(mut self, order: order::#order_by) -> Self {
679 self.order_by.push(order);
680 self
681 }
682
683 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
684 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
685 }
686
687 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
688 match self.client {
689 DatabaseClient::Postgres(_) => {
690 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
691 self.client.fetch_optional_pg(qb).await
692 }
693 DatabaseClient::Sqlite(_) => {
694 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
695 self.client.fetch_optional_sqlite(qb).await
696 }
697 }
698 }
699 }
700
701 pub struct FindManyQuery<'a> {
702 client: &'a DatabaseClient,
703 r#where: filter::#_where_input,
704 order_by: Vec<order::#order_by>,
705 skip: Option<i64>,
706 take: Option<i64>,
707 }
708
709 impl<'a> FindManyQuery<'a> {
710 pub fn order_by(mut self, order: order::#order_by) -> Self {
711 self.order_by.push(order);
712 self
713 }
714
715 pub fn skip(mut self, n: i64) -> Self {
716 self.skip = Some(n);
717 self
718 }
719
720 pub fn take(mut self, n: i64) -> Self {
721 self.take = Some(n);
722 self
723 }
724
725 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
726 FindManySelectQuery {
727 client: self.client,
728 r#where: self.r#where,
729 order_by: self.order_by,
730 skip: self.skip,
731 take: self.take,
732 select,
733 }
734 }
735
736 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
737 match self.client {
738 DatabaseClient::Postgres(_) => {
739 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
740 self.client.fetch_all_pg(qb).await
741 }
742 DatabaseClient::Sqlite(_) => {
743 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
744 self.client.fetch_all_sqlite(qb).await
745 }
746 }
747 }
748 }
749
750 pub struct CreateQuery<'a> {
751 client: &'a DatabaseClient,
752 data: data::#_create_input,
753 }
754
755 impl<'a> CreateQuery<'a> {
756 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
757 let client = self.client;
758 #insert_code
759 }
760 }
761
762 pub struct UpdateQuery<'a> {
763 client: &'a DatabaseClient,
764 r#where: filter::#_where_unique,
765 data: data::#_update_input,
766 }
767
768 impl<'a> UpdateQuery<'a> {
769 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
770 let client = self.client;
771 #update_code
772 }
773 }
774
775 pub struct DeleteQuery<'a> {
776 client: &'a DatabaseClient,
777 r#where: filter::#_where_unique,
778 }
779
780 impl<'a> DeleteQuery<'a> {
781 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
782 match self.client {
783 DatabaseClient::Postgres(_) => {
784 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
785 self.client.fetch_one_pg(qb).await
786 }
787 DatabaseClient::Sqlite(_) => {
788 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
789 self.client.fetch_one_sqlite(qb).await
790 }
791 }
792 }
793 }
794
795 #[derive(sqlx::FromRow)]
796 struct CountResult { count: i64 }
797
798 pub struct CountQuery<'a> {
799 client: &'a DatabaseClient,
800 r#where: filter::#_where_input,
801 }
802
803 impl<'a> CountQuery<'a> {
804 pub async fn exec(self) -> Result<i64, FerriormError> {
805 let row: CountResult = match self.client {
806 DatabaseClient::Postgres(_) => {
807 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
808 self.client.fetch_one_pg(qb).await?
809 }
810 DatabaseClient::Sqlite(_) => {
811 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
812 self.client.fetch_one_sqlite(qb).await?
813 }
814 };
815 Ok(row.count)
816 }
817 }
818
819 pub struct CreateManyQuery<'a> {
820 client: &'a DatabaseClient,
821 data: Vec<data::#_create_input>,
822 }
823
824 impl<'a> CreateManyQuery<'a> {
825 pub async fn exec(self) -> Result<u64, FerriormError> {
826 if self.data.is_empty() { return Ok(0); }
827 let count = self.data.len() as u64;
828 for item in self.data {
829 CreateQuery { client: self.client, data: item }.exec().await?;
830 }
831 Ok(count)
832 }
833 }
834
835 pub struct UpdateManyQuery<'a> {
836 client: &'a DatabaseClient,
837 r#where: filter::#_where_input,
838 data: data::#_update_input,
839 }
840
841 impl<'a> UpdateManyQuery<'a> {
842 pub async fn exec(self) -> Result<u64, FerriormError> {
843 let client = self.client;
844 #update_many_code
845 }
846 }
847
848 pub struct UpsertQuery<'a> {
849 client: &'a DatabaseClient,
850 r#where: filter::#_where_unique,
851 create: data::#_create_input,
852 update: data::#_update_input,
853 }
854
855 impl<'a> UpsertQuery<'a> {
856 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
857 let client = self.client;
858 #upsert_code
859 }
860 }
861
862 pub struct DeleteManyQuery<'a> {
863 client: &'a DatabaseClient,
864 r#where: filter::#_where_input,
865 }
866
867 impl<'a> DeleteManyQuery<'a> {
868 pub async fn exec(self) -> Result<u64, FerriormError> {
869 match self.client {
870 DatabaseClient::Postgres(_) => {
871 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
872 self.client.execute_pg(qb).await
873 }
874 DatabaseClient::Sqlite(_) => {
875 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
876 self.client.execute_sqlite(qb).await
877 }
878 }
879 }
880 }
881 }
882}
883
884fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
887 let _model_ident = format_ident!("{}", model.name);
888
889 let required: Vec<&Field> = scalar_fields
891 .iter()
892 .copied()
893 .filter(|f| !f.has_default() && !f.is_updated_at)
894 .collect();
895
896 let optional: Vec<&Field> = scalar_fields
898 .iter()
899 .copied()
900 .filter(|f| f.has_default() && !f.is_updated_at)
901 .collect();
902
903 let updated_at: Vec<&Field> = scalar_fields
905 .iter()
906 .copied()
907 .filter(|f| f.is_updated_at)
908 .collect();
909
910 let mut col_pushes = vec![];
912 let mut val_pushes = vec![];
913
914 for f in &required {
916 let db_name = &f.db_name;
917 let field_ident = format_ident!("{}", to_snake_case(&f.name));
918 col_pushes.push(quote! { cols.push(#db_name); });
919 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
920 }
921
922 for f in &optional {
924 let db_name = &f.db_name;
925 let field_ident = format_ident!("{}", to_snake_case(&f.name));
926 let default_expr = gen_default_expr(f, &f.field_type);
927
928 col_pushes.push(quote! { cols.push(#db_name); });
929 val_pushes.push(quote! {
930 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
931 sep.push_bind(val);
932 });
933 }
934
935 for f in &updated_at {
937 let db_name = &f.db_name;
938 col_pushes.push(quote! { cols.push(#db_name); });
939 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
940 }
941
942 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
943
944 quote! {
947 macro_rules! build_insert {
949 ($qb_type:ty) => {{
950 let mut cols: Vec<&str> = Vec::new();
951 #(#col_pushes)*
952
953 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
954 qb.push(" (");
955 for (i, col) in cols.iter().enumerate() {
956 if i > 0 { qb.push(", "); }
957 qb.push("\"");
958 qb.push(*col);
959 qb.push("\"");
960 }
961 qb.push(") VALUES (");
962 {
963 let mut sep = qb.separated(", ");
964 #(#val_pushes)*
965 }
966 qb.push(") RETURNING *");
967 qb
968 }};
969 }
970
971 match client {
972 DatabaseClient::Postgres(_) => {
973 let qb = build_insert!(sqlx::Postgres);
974 client.fetch_one_pg(qb).await
975 }
976 DatabaseClient::Sqlite(_) => {
977 let qb = build_insert!(sqlx::Sqlite);
978 client.fetch_one_sqlite(qb).await
979 }
980 }
981 }
982}
983
984fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
986 use ferriorm_core::ast::DefaultValue;
987
988 match &field.default {
989 Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
990 quote! { uuid::Uuid::new_v4().to_string() }
991 }
992 Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
993 Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, Some(DefaultValue::Literal(lit)) => {
995 use ferriorm_core::ast::LiteralValue;
996 match lit {
997 LiteralValue::String(s) => quote! { #s.to_string() },
998 LiteralValue::Int(i) => {
999 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1001 match field_type {
1002 FieldKind::Scalar(ScalarType::Float) => {
1003 let val = *i as f64;
1004 quote! { #val }
1005 }
1006 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1007 _ => {
1008 let val = *i as i32;
1010 quote! { #val }
1011 }
1012 }
1013 }
1014 LiteralValue::Float(f) => quote! { #f },
1015 LiteralValue::Bool(b) => quote! { #b },
1016 }
1017 }
1018 Some(DefaultValue::EnumVariant(v)) => {
1019 let variant = format_ident!("{}", v);
1021 if let FieldKind::Enum(enum_name) = &field.field_type {
1022 let enum_ident = format_ident!("{}", enum_name);
1023 quote! { super::enums::#enum_ident::#variant }
1024 } else {
1025 quote! { Default::default() }
1026 }
1027 }
1028 None => quote! { Default::default() },
1029 }
1030}
1031
1032fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1035 let _model_ident = format_ident!("{}", model.name);
1036
1037 let updatable: Vec<&Field> = scalar_fields
1039 .iter()
1040 .copied()
1041 .filter(|f| !f.is_id && !f.is_updated_at)
1042 .collect();
1043
1044 let updated_at: Vec<&Field> = scalar_fields
1045 .iter()
1046 .copied()
1047 .filter(|f| f.is_updated_at)
1048 .collect();
1049
1050 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1051
1052 let set_arms: Vec<TokenStream> = updatable
1054 .iter()
1055 .map(|f| {
1056 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1057 let db_name = &f.db_name;
1058 quote! {
1059 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1060 if !first_set { qb.push(", "); }
1061 first_set = false;
1062 qb.push(concat!("\"", #db_name, "\" = "));
1063 qb.push_bind(v);
1064 }
1065 }
1066 })
1067 .collect();
1068
1069 let updated_at_arms: Vec<TokenStream> = updated_at
1070 .iter()
1071 .map(|f| {
1072 let db_name = &f.db_name;
1073 quote! {
1074 if !first_set { qb.push(", "); }
1075 first_set = false;
1076 qb.push(concat!("\"", #db_name, "\" = "));
1077 qb.push_bind(chrono::Utc::now());
1078 }
1079 })
1080 .collect();
1081
1082 quote! {
1085 macro_rules! build_update {
1086 ($qb_type:ty) => {{
1087 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1088 let mut first_set = true;
1089 #(#set_arms)*
1090 #(#updated_at_arms)*
1091
1092 if first_set {
1093 return Err(FerriormError::Query("No fields to update".into()));
1094 }
1095
1096 qb.push(" WHERE 1=1");
1097 self.r#where.build_where(&mut qb);
1098 qb.push(" RETURNING *");
1099 qb
1100 }};
1101 }
1102
1103 match client {
1104 DatabaseClient::Postgres(_) => {
1105 let qb = build_update!(sqlx::Postgres);
1106 client.fetch_one_pg(qb).await
1107 }
1108 DatabaseClient::Sqlite(_) => {
1109 let qb = build_update!(sqlx::Sqlite);
1110 client.fetch_one_sqlite(qb).await
1111 }
1112 }
1113 }
1114}
1115
1116fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1119 let updatable: Vec<&Field> = scalar_fields
1121 .iter()
1122 .copied()
1123 .filter(|f| !f.is_id && !f.is_updated_at)
1124 .collect();
1125
1126 let updated_at: Vec<&Field> = scalar_fields
1127 .iter()
1128 .copied()
1129 .filter(|f| f.is_updated_at)
1130 .collect();
1131
1132 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1133
1134 let set_arms: Vec<TokenStream> = updatable
1136 .iter()
1137 .map(|f| {
1138 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1139 let db_name = &f.db_name;
1140 quote! {
1141 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1142 if !first_set { qb.push(", "); }
1143 first_set = false;
1144 qb.push(concat!("\"", #db_name, "\" = "));
1145 qb.push_bind(v);
1146 }
1147 }
1148 })
1149 .collect();
1150
1151 let updated_at_arms: Vec<TokenStream> = updated_at
1152 .iter()
1153 .map(|f| {
1154 let db_name = &f.db_name;
1155 quote! {
1156 if !first_set { qb.push(", "); }
1157 first_set = false;
1158 qb.push(concat!("\"", #db_name, "\" = "));
1159 qb.push_bind(chrono::Utc::now());
1160 }
1161 })
1162 .collect();
1163
1164 quote! {
1165 macro_rules! build_update_many {
1166 ($qb_type:ty) => {{
1167 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1168 let mut first_set = true;
1169 #(#set_arms)*
1170 #(#updated_at_arms)*
1171
1172 if first_set {
1173 return Ok(0);
1174 }
1175
1176 qb.push(" WHERE 1=1");
1177 self.r#where.build_where(&mut qb);
1178 qb
1179 }};
1180 }
1181
1182 match client {
1183 DatabaseClient::Postgres(_) => {
1184 let qb = build_update_many!(sqlx::Postgres);
1185 client.execute_pg(qb).await
1186 }
1187 DatabaseClient::Sqlite(_) => {
1188 let qb = build_update_many!(sqlx::Sqlite);
1189 client.execute_sqlite(qb).await
1190 }
1191 }
1192 }
1193}
1194
1195enum AggregateKind {
1199 Numeric,
1201 DateTime,
1203}
1204
1205#[allow(clippy::too_many_lines)]
1208fn gen_upsert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1209 let pk_db_names: Vec<String> = model
1211 .primary_key
1212 .fields
1213 .iter()
1214 .filter_map(|pk| {
1215 model
1216 .fields
1217 .iter()
1218 .find(|f| f.name == *pk || to_snake_case(&f.name) == *pk)
1219 .map(|f| f.db_name.clone())
1220 })
1221 .collect();
1222 let pk_conflict_cols = pk_db_names
1223 .iter()
1224 .map(|c| format!("\"{c}\""))
1225 .collect::<Vec<_>>()
1226 .join(", ");
1227
1228 let required: Vec<&Field> = scalar_fields
1230 .iter()
1231 .copied()
1232 .filter(|f| !f.has_default() && !f.is_updated_at)
1233 .collect();
1234 let optional: Vec<&Field> = scalar_fields
1235 .iter()
1236 .copied()
1237 .filter(|f| f.has_default() && !f.is_updated_at)
1238 .collect();
1239 let updated_at: Vec<&Field> = scalar_fields
1240 .iter()
1241 .copied()
1242 .filter(|f| f.is_updated_at)
1243 .collect();
1244
1245 let mut col_pushes = vec![];
1246 let mut val_pushes = vec![];
1247
1248 for f in &required {
1249 let db_name = &f.db_name;
1250 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1251 col_pushes.push(quote! { cols.push(#db_name); });
1252 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1253 }
1254 for f in &optional {
1255 let db_name = &f.db_name;
1256 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1257 let default_expr = gen_default_expr(f, &f.field_type);
1258 col_pushes.push(quote! { cols.push(#db_name); });
1259 val_pushes.push(quote! {
1260 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1261 sep.push_bind(val);
1262 });
1263 }
1264 for f in &updated_at {
1265 let db_name = &f.db_name;
1266 col_pushes.push(quote! { cols.push(#db_name); });
1267 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1268 }
1269
1270 let updatable: Vec<&Field> = scalar_fields
1272 .iter()
1273 .copied()
1274 .filter(|f| !f.is_id && !f.is_updated_at)
1275 .collect();
1276
1277 let set_arms: Vec<TokenStream> = updatable
1278 .iter()
1279 .map(|f| {
1280 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1281 let db_name = &f.db_name;
1282 quote! {
1283 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1284 if !first_set { qb.push(", "); }
1285 first_set = false;
1286 qb.push(concat!("\"", #db_name, "\" = "));
1287 qb.push_bind(v);
1288 }
1289 }
1290 })
1291 .collect();
1292
1293 let updated_at_set: Vec<TokenStream> = updated_at
1294 .iter()
1295 .map(|f| {
1296 let db_name = &f.db_name;
1297 quote! {
1298 if !first_set { qb.push(", "); }
1299 first_set = false;
1300 qb.push(concat!("\"", #db_name, "\" = "));
1301 qb.push_bind(chrono::Utc::now());
1302 }
1303 })
1304 .collect();
1305
1306 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1307 let conflict_clause = format!(" ON CONFLICT ({pk_conflict_cols}) DO UPDATE SET ");
1308 let noop_set = format!(
1309 r#""{}" = "{}""#,
1310 pk_db_names.first().unwrap_or(&"id".to_string()),
1311 pk_db_names.first().unwrap_or(&"id".to_string()),
1312 );
1313
1314 quote! {
1315 macro_rules! build_upsert {
1316 ($qb_type:ty) => {{
1317 let mut cols: Vec<&str> = Vec::new();
1318 #(#col_pushes)*
1319
1320 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1321 qb.push(" (");
1322 for (i, col) in cols.iter().enumerate() {
1323 if i > 0 { qb.push(", "); }
1324 qb.push("\"");
1325 qb.push(*col);
1326 qb.push("\"");
1327 }
1328 qb.push(") VALUES (");
1329 {
1330 let mut sep = qb.separated(", ");
1331 #(#val_pushes)*
1332 }
1333 qb.push(")");
1334 qb.push(#conflict_clause);
1335
1336 let mut first_set = true;
1337 #(#set_arms)*
1338 #(#updated_at_set)*
1339
1340 if first_set {
1341 qb.push(#noop_set);
1343 }
1344
1345 qb.push(" RETURNING *");
1346 qb
1347 }};
1348 }
1349
1350 match client {
1351 DatabaseClient::Postgres(_) => {
1352 let qb = build_upsert!(sqlx::Postgres);
1353 client.fetch_one_pg(qb).await
1354 }
1355 DatabaseClient::Sqlite(_) => {
1356 let qb = build_upsert!(sqlx::Sqlite);
1357 client.fetch_one_sqlite(qb).await
1358 }
1359 }
1360 }
1361}
1362
1363#[allow(clippy::too_many_lines)]
1364fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1365 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1366 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1367 let _where_input = format_ident!("{}WhereInput", model.name);
1368 let table_name = &model.db_name;
1369
1370 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1372 .iter()
1373 .filter_map(|f| match &f.field_type {
1374 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1375 Some((*f, AggregateKind::Numeric))
1376 }
1377 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1378 _ => None,
1379 })
1380 .collect();
1381
1382 if agg_fields.is_empty() {
1383 return quote! {};
1384 }
1385
1386 let enum_variants: Vec<TokenStream> = agg_fields
1388 .iter()
1389 .map(|(f, _)| {
1390 let variant = format_ident!("{}", to_pascal_case(&f.name));
1391 quote! { #variant }
1392 })
1393 .collect();
1394
1395 let db_name_arms: Vec<TokenStream> = agg_fields
1397 .iter()
1398 .map(|(f, _)| {
1399 let variant = format_ident!("{}", to_pascal_case(&f.name));
1400 let db_name = &f.db_name;
1401 quote! { Self::#variant => #db_name }
1402 })
1403 .collect();
1404
1405 let mut result_fields = Vec::new();
1407 for (f, kind) in &agg_fields {
1408 let snake = to_snake_case(&f.name);
1409 let orig_ty = rust_type_tokens(
1410 &Field {
1411 is_optional: false,
1412 ..(*f).clone()
1413 },
1414 ModuleDepth::TopLevel,
1415 );
1416
1417 match kind {
1418 AggregateKind::Numeric => {
1419 let avg_name = format_ident!("avg_{}", snake);
1420 let sum_name = format_ident!("sum_{}", snake);
1421 let min_name = format_ident!("min_{}", snake);
1422 let max_name = format_ident!("max_{}", snake);
1423 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1424 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1425 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1426 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1427 }
1428 AggregateKind::DateTime => {
1429 let min_name = format_ident!("min_{}", snake);
1430 let max_name = format_ident!("max_{}", snake);
1431 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1432 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1433 }
1434 }
1435 }
1436
1437 let numeric_arms: Vec<TokenStream> = agg_fields
1439 .iter()
1440 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1441 .map(|(f, _)| {
1442 let variant = format_ident!("{}", to_pascal_case(&f.name));
1443 quote! { Self::#variant => true }
1444 })
1445 .collect();
1446
1447 let has_numeric = !numeric_arms.is_empty();
1448 let is_numeric_method = if has_numeric {
1449 quote! {
1450 fn is_numeric(&self) -> bool {
1451 match self {
1452 #(#numeric_arms,)*
1453 #[allow(unreachable_patterns)]
1454 _ => false,
1455 }
1456 }
1457 }
1458 } else {
1459 quote! {
1460 fn is_numeric(&self) -> bool { false }
1461 }
1462 };
1463
1464 let mut alias_arms = Vec::new();
1466 for (f, kind) in &agg_fields {
1467 let variant = format_ident!("{}", to_pascal_case(&f.name));
1468 let snake = to_snake_case(&f.name);
1469 let prefixes = match kind {
1470 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1471 AggregateKind::DateTime => vec!["min", "max"],
1472 };
1473 for prefix in prefixes {
1474 let alias_str = format!("{prefix}_{snake}");
1475 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1476 }
1477 }
1478
1479 let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1480
1481 quote! {
1482 #[derive(Debug, Clone, Copy)]
1483 pub enum #aggregate_field_name {
1484 #(#enum_variants),*
1485 }
1486
1487 impl #aggregate_field_name {
1488 pub fn db_name(&self) -> &'static str {
1489 match self {
1490 #(#db_name_arms,)*
1491 }
1492 }
1493
1494 fn alias(&self, prefix: &'static str) -> &'static str {
1495 match (prefix, self) {
1496 #(#alias_arms,)*
1497 _ => unreachable!(),
1498 }
1499 }
1500
1501 #is_numeric_method
1502 }
1503
1504 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1505 pub struct #aggregate_result_name {
1506 #(#result_fields,)*
1507 }
1508
1509 pub struct AggregateQuery<'a> {
1510 client: &'a DatabaseClient,
1511 r#where: filter::#_where_input,
1512 ops: Vec<(&'static str, &'static str, &'static str)>,
1513 }
1514
1515 impl<'a> AggregateQuery<'a> {
1516 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1517 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1518 let db_name = field.db_name();
1519 let alias = field.alias("avg");
1520 self.ops.push(("AVG", db_name, alias));
1521 self
1522 }
1523
1524 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1525 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1526 let db_name = field.db_name();
1527 let alias = field.alias("sum");
1528 self.ops.push(("SUM", db_name, alias));
1529 self
1530 }
1531
1532 pub fn min(mut self, field: #aggregate_field_name) -> Self {
1533 let db_name = field.db_name();
1534 let alias = field.alias("min");
1535 self.ops.push(("MIN", db_name, alias));
1536 self
1537 }
1538
1539 pub fn max(mut self, field: #aggregate_field_name) -> Self {
1540 let db_name = field.db_name();
1541 let alias = field.alias("max");
1542 self.ops.push(("MAX", db_name, alias));
1543 self
1544 }
1545
1546 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1547 if self.ops.is_empty() {
1548 return Err(FerriormError::Query("No aggregate operations specified".into()));
1549 }
1550
1551 let selections: Vec<String> = self.ops.iter()
1552 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1553 .collect();
1554 let select_clause = selections.join(", ");
1555 let base_sql = format!(#agg_select_base, select_clause);
1556
1557 match self.client {
1558 DatabaseClient::Postgres(_) => {
1559 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1560 self.r#where.build_where(&mut qb);
1561 self.client.fetch_one_pg(qb).await
1562 }
1563 DatabaseClient::Sqlite(_) => {
1564 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1565 self.r#where.build_where(&mut qb);
1566 self.client.fetch_one_sqlite(qb).await
1567 }
1568 }
1569 }
1570 }
1571 }
1572}
1573
1574#[allow(clippy::too_many_lines)]
1577fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1578 let select_name = format_ident!("{}Select", model.name);
1579 let partial_name = format_ident!("{}Partial", model.name);
1580 let _where_input = format_ident!("{}WhereInput", model.name);
1581 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1582 let order_by_name = format_ident!("{}OrderByInput", model.name);
1583 let table_name = &model.db_name;
1584
1585 let select_fields: Vec<TokenStream> = scalar_fields
1587 .iter()
1588 .map(|f| {
1589 let name = format_ident!("{}", to_snake_case(&f.name));
1590 quote! { pub #name: bool }
1591 })
1592 .collect();
1593
1594 let partial_fields: Vec<TokenStream> = scalar_fields
1597 .iter()
1598 .map(|f| {
1599 let name = format_ident!("{}", to_snake_case(&f.name));
1600 let db_name = &f.db_name;
1601 let base_ty = rust_type_tokens(
1603 &Field {
1604 is_optional: false,
1605 ..(*f).clone()
1606 },
1607 ModuleDepth::TopLevel,
1608 );
1609 let rename = if db_name == &to_snake_case(&f.name) {
1610 quote! {}
1611 } else {
1612 quote! { #[sqlx(rename = #db_name)] }
1613 };
1614 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1616 })
1617 .collect();
1618
1619 let select_col_arms: Vec<TokenStream> = scalar_fields
1621 .iter()
1622 .map(|f| {
1623 let name = format_ident!("{}", to_snake_case(&f.name));
1624 let db_name = &f.db_name;
1625 let col_expr = format!(r#""{db_name}""#);
1626 quote! {
1627 if select.#name { cols.push(#col_expr); }
1628 }
1629 })
1630 .collect();
1631
1632 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1633
1634 quote! {
1635 #[derive(Debug, Clone, Default)]
1636 pub struct #select_name {
1637 #(#select_fields,)*
1638 }
1639
1640 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
1641 #[sqlx(rename_all = "snake_case")]
1642 pub struct #partial_name {
1643 #(#partial_fields,)*
1644 }
1645
1646 fn build_select_columns(select: &#select_name) -> String {
1647 let mut cols = Vec::new();
1648 #(#select_col_arms)*
1649 if cols.is_empty() {
1650 "*".to_string()
1651 } else {
1652 cols.join(", ")
1653 }
1654 }
1655
1656 pub struct FindManySelectQuery<'a> {
1659 client: &'a DatabaseClient,
1660 r#where: filter::#_where_input,
1661 order_by: Vec<order::#order_by_name>,
1662 skip: Option<i64>,
1663 take: Option<i64>,
1664 select: #select_name,
1665 }
1666
1667 impl<'a> FindManySelectQuery<'a> {
1668 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1669 self.order_by.push(order);
1670 self
1671 }
1672
1673 pub fn skip(mut self, n: i64) -> Self {
1674 self.skip = Some(n);
1675 self
1676 }
1677
1678 pub fn take(mut self, n: i64) -> Self {
1679 self.take = Some(n);
1680 self
1681 }
1682
1683 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
1684 let cols = build_select_columns(&self.select);
1685 let base_sql = format!(#select_sql_prefix, cols);
1686
1687 match self.client {
1688 DatabaseClient::Postgres(_) => {
1689 let qb = build_select_query::<sqlx::Postgres>(
1690 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1691 );
1692 self.client.fetch_all_pg(qb).await
1693 }
1694 DatabaseClient::Sqlite(_) => {
1695 let qb = build_select_query::<sqlx::Sqlite>(
1696 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1697 );
1698 self.client.fetch_all_sqlite(qb).await
1699 }
1700 }
1701 }
1702 }
1703
1704 pub struct FindUniqueSelectQuery<'a> {
1707 client: &'a DatabaseClient,
1708 r#where: filter::#_where_unique,
1709 select: #select_name,
1710 }
1711
1712 impl<'a> FindUniqueSelectQuery<'a> {
1713 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1714 let cols = build_select_columns(&self.select);
1715 let base_sql = format!(#select_sql_prefix, cols);
1716
1717 match self.client {
1718 DatabaseClient::Postgres(_) => {
1719 let qb = build_unique_select_query::<sqlx::Postgres>(
1720 &base_sql, &self.r#where,
1721 );
1722 self.client.fetch_optional_pg(qb).await
1723 }
1724 DatabaseClient::Sqlite(_) => {
1725 let qb = build_unique_select_query::<sqlx::Sqlite>(
1726 &base_sql, &self.r#where,
1727 );
1728 self.client.fetch_optional_sqlite(qb).await
1729 }
1730 }
1731 }
1732 }
1733
1734 pub struct FindFirstSelectQuery<'a> {
1737 client: &'a DatabaseClient,
1738 r#where: filter::#_where_input,
1739 order_by: Vec<order::#order_by_name>,
1740 select: #select_name,
1741 }
1742
1743 impl<'a> FindFirstSelectQuery<'a> {
1744 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1745 self.order_by.push(order);
1746 self
1747 }
1748
1749 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1750 let cols = build_select_columns(&self.select);
1751 let base_sql = format!(#select_sql_prefix, cols);
1752
1753 match self.client {
1754 DatabaseClient::Postgres(_) => {
1755 let qb = build_select_query::<sqlx::Postgres>(
1756 &base_sql, &self.r#where, &self.order_by, Some(1), None,
1757 );
1758 self.client.fetch_optional_pg(qb).await
1759 }
1760 DatabaseClient::Sqlite(_) => {
1761 let qb = build_select_query::<sqlx::Sqlite>(
1762 &base_sql, &self.r#where, &self.order_by, Some(1), None,
1763 );
1764 self.client.fetch_optional_sqlite(qb).await
1765 }
1766 }
1767 }
1768 }
1769 }
1770}