1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, enum_path, filter_type_tokens, rust_type_tokens};
21
22#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27 let data_struct = gen_data_struct(model, &scalar_fields);
28 let filter_module = gen_filter_module(model, &scalar_fields);
29 let data_module = gen_data_module(model, &scalar_fields);
30 let order_module = gen_order_module(model, &scalar_fields);
31 let actions_struct = gen_actions(model, &scalar_fields);
32 let query_builders = gen_query_builders(model, &scalar_fields);
33 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34 let groupby_types = gen_groupby_types(model, &scalar_fields);
35 let select_types = gen_select_types(model, &scalar_fields);
36
37 quote! {
38 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
39
40 use serde::{Deserialize, Serialize};
41 use ferriorm_runtime::prelude::*;
42 use ferriorm_runtime::prelude::sqlx;
43 use ferriorm_runtime::prelude::chrono;
44 use ferriorm_runtime::prelude::uuid;
45
46 #data_struct
47 #filter_module
48 #data_module
49 #order_module
50 #actions_struct
51 #query_builders
52 #aggregate_types
53 #groupby_types
54 #select_types
55 }
56}
57
58fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
61 let struct_name = format_ident!("{}", model.name);
62 let table_name = &model.db_name;
63
64 let fields: Vec<TokenStream> = scalar_fields
65 .iter()
66 .map(|f| {
67 let name = format_ident!("{}", to_snake_case(&f.name));
68 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
69 let db_name = &f.db_name;
70 if db_name == &to_snake_case(&f.name) {
71 quote! { pub #name: #ty }
72 } else {
73 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
74 }
75 })
76 .collect();
77
78 quote! {
79 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
80 #[sqlx(rename_all = "snake_case")]
81 pub struct #struct_name {
82 #(#fields),*
83 }
84
85 impl #struct_name {
86 pub const TABLE_NAME: &'static str = #table_name;
87 }
88 }
89}
90
91#[allow(clippy::too_many_lines)]
94fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
95 let where_input = format_ident!("{}WhereInput", model.name);
96 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
97
98 let where_fields: Vec<TokenStream> = scalar_fields
99 .iter()
100 .filter_map(|f| {
101 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
102 let name = format_ident!("{}", to_snake_case(&f.name));
103 Some(quote! { pub #name: Option<#filter_ty> })
104 })
105 .collect();
106
107 let single_unique_variants: Vec<TokenStream> = scalar_fields
108 .iter()
109 .filter(|f| f.is_id || f.is_unique)
110 .map(|f| {
111 let variant = format_ident!("{}", to_pascal_case(&f.name));
112 let ty = rust_type_tokens(f, ModuleDepth::Nested);
113 quote! { #variant(#ty) }
114 })
115 .collect();
116
117 let compound_unique_variants: Vec<TokenStream> = model
118 .unique_constraints
119 .iter()
120 .map(|uc| {
121 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
122 let struct_fields = compound_variant_fields(model, &uc.fields);
123 quote! { #variant { #(#struct_fields),* } }
124 })
125 .collect();
126
127 let unique_variants: Vec<TokenStream> = single_unique_variants
128 .into_iter()
129 .chain(compound_unique_variants)
130 .collect();
131
132 let db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::Nested);
134 let where_arms = gen_where_arms(scalar_fields);
135 let unique_arms = gen_unique_where_arms(model, scalar_fields);
136 let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
137 let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
138
139 quote! {
140 pub mod filter {
141 use ferriorm_runtime::prelude::*;
142
143 #[derive(Debug, Clone, Default)]
144 pub struct #where_input {
145 #(#where_fields,)*
146 pub and: Option<Vec<#where_input>>,
147 pub or: Option<Vec<#where_input>>,
148 pub not: Option<Box<#where_input>>,
149 }
150
151 #[derive(Debug, Clone)]
152 pub enum #where_unique {
153 #(#unique_variants),*
154 }
155
156 impl #where_input {
157 pub(crate) fn build_where<'args, DB: sqlx::Database>(
158 &self,
159 qb: &mut sqlx::QueryBuilder<'args, DB>,
160 )
161 where
162 #(#db_bounds,)*
163 {
164 #(#where_arms)*
165
166 if let Some(conditions) = &self.and {
167 for c in conditions {
168 c.build_where(qb);
169 }
170 }
171 if let Some(conditions) = &self.or {
172 if !conditions.is_empty() {
173 qb.push(" AND (");
174 for (i, c) in conditions.iter().enumerate() {
175 if i > 0 { qb.push(" OR "); }
176 qb.push("(1=1");
177 c.build_where(qb);
178 qb.push(")");
179 }
180 qb.push(")");
181 }
182 }
183 if let Some(c) = &self.not {
184 qb.push(" AND NOT (1=1");
185 c.build_where(qb);
186 qb.push(")");
187 }
188 }
189 }
190
191 impl #where_unique {
192 pub(crate) fn build_where<'args, DB: sqlx::Database>(
193 &self,
194 qb: &mut sqlx::QueryBuilder<'args, DB>,
195 )
196 where
197 #(#db_bounds,)*
198 {
199 match self {
200 #(#unique_arms)*
201 }
202 }
203 }
204
205 impl #where_unique {
206 #[allow(dead_code)]
207 pub(crate) fn conflict_target(&self) -> &'static str {
208 match self {
209 #(#conflict_target_arms)*
210 }
211 }
212
213 #[allow(dead_code)]
214 pub(crate) fn first_conflict_col(&self) -> &'static str {
215 match self {
216 #(#first_conflict_col_arms)*
217 }
218 }
219 }
220 }
221 }
222}
223
224fn collect_db_bounds(scalar_fields: &[&Field], depth: ModuleDepth) -> Vec<TokenStream> {
228 let mut seen = std::collections::HashSet::new();
229 let mut seen_enums = std::collections::HashSet::new();
230 let mut bounds = Vec::new();
231
232 seen.insert("i64");
234 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
235
236 for f in scalar_fields {
237 match &f.field_type {
238 FieldKind::Scalar(scalar) => {
239 let key = scalar.rust_type();
240 if seen.insert(key)
241 && let Some(ty) = scalar_bound_tokens(scalar)
242 {
243 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
244 bounds.push(
246 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
247 );
248 }
249 }
250 FieldKind::Enum(name) => {
251 if seen_enums.insert(name.clone()) {
254 let enum_ty = enum_path(name, depth);
255 bounds.push(quote! { #enum_ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
256 }
257 }
258 FieldKind::Model(_) => {}
259 }
260 }
261
262 bounds
263}
264
265fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
266 match scalar {
267 ScalarType::String => Some(quote! { String }),
268 ScalarType::Int => Some(quote! { i32 }),
269 ScalarType::BigInt => Some(quote! { i64 }),
270 ScalarType::Float => Some(quote! { f64 }),
271 ScalarType::Boolean => Some(quote! { bool }),
272 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
273 ScalarType::Bytes => Some(quote! { Vec<u8> }),
274 ScalarType::Json | ScalarType::Decimal => None,
275 }
276}
277
278fn gen_in_arms_lhs(lhs: &str) -> TokenStream {
285 let in_prefix = format!(" AND {lhs} IN (");
286 let not_in_prefix = format!(" AND {lhs} NOT IN (");
287 quote! {
288 if let Some(values) = &filter.r#in {
289 if values.is_empty() {
290 qb.push(" AND 1 = 0");
291 } else {
292 qb.push(#in_prefix);
293 {
294 let mut sep = qb.separated(", ");
295 for v in values {
296 sep.push_bind(v.clone());
297 }
298 }
299 qb.push(")");
300 }
301 }
302 if let Some(values) = &filter.not_in {
303 if !values.is_empty() {
304 qb.push(#not_in_prefix);
305 {
306 let mut sep = qb.separated(", ");
307 for v in values {
308 sep.push_bind(v.clone());
309 }
310 }
311 qb.push(")");
312 }
313 }
314 }
315}
316
317#[allow(clippy::too_many_lines)]
319fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
320 scalar_fields
321 .iter()
322 .filter_map(|f| {
323 let field_ident = format_ident!("{}", to_snake_case(&f.name));
324 let db_name = &f.db_name;
325 let column_lhs = format!("\"{db_name}\"");
326
327 match &f.field_type {
328 FieldKind::Scalar(scalar) => {
329 if matches!(
332 scalar,
333 ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal
334 ) {
335 return None;
336 }
337 let is_string = matches!(scalar, ScalarType::String);
338 let is_comparable = matches!(
339 scalar,
340 ScalarType::Int
341 | ScalarType::BigInt
342 | ScalarType::Float
343 | ScalarType::DateTime
344 );
345 let supports_in = !matches!(scalar, ScalarType::Boolean);
348
349 let mut arms: Vec<TokenStream> = Vec::new();
350
351 if f.is_optional {
352 arms.push(quote! {
356 if let Some(v) = &filter.equals {
357 match v {
358 None => {
359 qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
360 }
361 Some(inner) => {
362 qb.push(concat!(" AND \"", #db_name, "\" = "));
363 qb.push_bind(inner.clone());
364 }
365 }
366 }
367 if let Some(v) = &filter.not {
368 match v {
369 None => {
370 qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
371 }
372 Some(inner) => {
373 qb.push(concat!(" AND \"", #db_name, "\" != "));
374 qb.push_bind(inner.clone());
375 }
376 }
377 }
378 });
379 } else {
380 arms.push(quote! {
381 if let Some(v) = &filter.equals {
382 qb.push(concat!(" AND \"", #db_name, "\" = "));
383 qb.push_bind(v.clone());
384 }
385 if let Some(v) = &filter.not {
386 qb.push(concat!(" AND \"", #db_name, "\" != "));
387 qb.push_bind(v.clone());
388 }
389 });
390 }
391
392 if is_string {
393 arms.push(quote! {
399 if let Some(v) = &filter.contains {
400 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
401 qb.push_bind(format!("%{}%", ferriorm_runtime::filter::like_escape(v)));
402 qb.push(" ESCAPE '\\'");
403 }
404 if let Some(v) = &filter.starts_with {
405 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
406 qb.push_bind(format!("{}%", ferriorm_runtime::filter::like_escape(v)));
407 qb.push(" ESCAPE '\\'");
408 }
409 if let Some(v) = &filter.ends_with {
410 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
411 qb.push_bind(format!("%{}", ferriorm_runtime::filter::like_escape(v)));
412 qb.push(" ESCAPE '\\'");
413 }
414 });
415 }
416
417 if is_comparable {
418 arms.push(quote! {
419 if let Some(v) = &filter.gt {
420 qb.push(concat!(" AND \"", #db_name, "\" > "));
421 qb.push_bind(v.clone());
422 }
423 if let Some(v) = &filter.gte {
424 qb.push(concat!(" AND \"", #db_name, "\" >= "));
425 qb.push_bind(v.clone());
426 }
427 if let Some(v) = &filter.lt {
428 qb.push(concat!(" AND \"", #db_name, "\" < "));
429 qb.push_bind(v.clone());
430 }
431 if let Some(v) = &filter.lte {
432 qb.push(concat!(" AND \"", #db_name, "\" <= "));
433 qb.push_bind(v.clone());
434 }
435 });
436 }
437
438 if supports_in {
439 arms.push(gen_in_arms_lhs(&column_lhs));
440 }
441
442 Some(quote! {
443 if let Some(filter) = &self.#field_ident {
444 #(#arms)*
445 }
446 })
447 }
448 FieldKind::Enum(_) => {
449 let in_arms = gen_in_arms_lhs(&column_lhs);
453 Some(quote! {
454 if let Some(filter) = &self.#field_ident {
455 if let Some(v) = &filter.equals {
456 qb.push(concat!(" AND \"", #db_name, "\" = "));
457 qb.push_bind(v.clone());
458 }
459 if let Some(v) = &filter.not {
460 qb.push(concat!(" AND \"", #db_name, "\" != "));
461 qb.push_bind(v.clone());
462 }
463 #in_arms
464 }
465 })
466 }
467 FieldKind::Model(_) => None,
468 }
469 })
470 .collect()
471}
472
473fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
474 let mut arms: Vec<TokenStream> = scalar_fields
475 .iter()
476 .filter(|f| f.is_id || f.is_unique)
477 .map(|f| {
478 let variant = format_ident!("{}", to_pascal_case(&f.name));
479 let db_name = &f.db_name;
480 quote! {
481 Self::#variant(v) => {
482 qb.push(concat!(" AND \"", #db_name, "\" = "));
483 qb.push_bind(v.clone());
484 }
485 }
486 })
487 .collect();
488
489 for uc in &model.unique_constraints {
490 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
491 let idents: Vec<_> = uc
492 .fields
493 .iter()
494 .map(|name| format_ident!("{}", to_snake_case(name)))
495 .collect();
496 let binds: Vec<TokenStream> = uc
497 .fields
498 .iter()
499 .map(|name| {
500 let ident = format_ident!("{}", to_snake_case(name));
501 let db_name = resolve_db_name(model, name);
502 quote! {
503 qb.push(concat!(" AND \"", #db_name, "\" = "));
504 qb.push_bind(#ident.clone());
505 }
506 })
507 .collect();
508 arms.push(quote! {
509 Self::#variant { #(#idents),* } => {
510 #(#binds)*
511 }
512 });
513 }
514
515 arms
516}
517
518fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
519 let mut arms: Vec<TokenStream> = scalar_fields
520 .iter()
521 .filter(|f| f.is_id || f.is_unique)
522 .map(|f| {
523 let variant = format_ident!("{}", to_pascal_case(&f.name));
524 let target = format!("(\"{}\")", f.db_name);
525 quote! { Self::#variant(_) => #target, }
526 })
527 .collect();
528
529 for uc in &model.unique_constraints {
530 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
531 let cols: Vec<String> = uc
532 .fields
533 .iter()
534 .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
535 .collect();
536 let target = format!("({})", cols.join(", "));
537 arms.push(quote! { Self::#variant { .. } => #target, });
538 }
539
540 arms
541}
542
543fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
544 let mut arms: Vec<TokenStream> = scalar_fields
545 .iter()
546 .filter(|f| f.is_id || f.is_unique)
547 .map(|f| {
548 let variant = format_ident!("{}", to_pascal_case(&f.name));
549 let col = format!("\"{}\"", f.db_name);
550 quote! { Self::#variant(_) => #col, }
551 })
552 .collect();
553
554 for uc in &model.unique_constraints {
555 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
556 let first = uc
557 .fields
558 .first()
559 .map_or_else(String::new, |n| resolve_db_name(model, n));
560 let col = format!("\"{first}\"");
561 arms.push(quote! { Self::#variant { .. } => #col, });
562 }
563
564 arms
565}
566
567fn compound_variant_name(fields: &[String]) -> String {
569 fields.iter().map(|f| to_pascal_case(f)).collect()
570}
571
572fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
575 fields
576 .iter()
577 .filter_map(|field_name| {
578 let field = model.fields.iter().find(|f| f.name == *field_name)?;
579 let ident = format_ident!("{}", to_snake_case(field_name));
580 let ty = rust_type_tokens(field, ModuleDepth::Nested);
581 Some(quote! { #ident: #ty })
582 })
583 .collect()
584}
585
586fn resolve_db_name(model: &Model, field_name: &str) -> String {
588 model
589 .fields
590 .iter()
591 .find(|f| f.name == field_name)
592 .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
593}
594
595fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
598 let create_name = format_ident!("{}CreateInput", model.name);
599 let update_name = format_ident!("{}UpdateInput", model.name);
600
601 let required_fields: Vec<TokenStream> = scalar_fields
602 .iter()
603 .filter(|f| !f.has_default() && !f.is_updated_at)
604 .map(|f| {
605 let name = format_ident!("{}", to_snake_case(&f.name));
606 let ty = rust_type_tokens(f, ModuleDepth::Nested);
607 quote! { pub #name: #ty }
608 })
609 .collect();
610
611 let optional_fields: Vec<TokenStream> = scalar_fields
612 .iter()
613 .filter(|f| f.has_default() && !f.is_updated_at)
614 .map(|f| {
615 let name = format_ident!("{}", to_snake_case(&f.name));
616 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
617 quote! { pub #name: Option<#base_ty> }
618 })
619 .collect();
620
621 let update_fields: Vec<TokenStream> = scalar_fields
622 .iter()
623 .filter(|f| !f.is_id && !f.is_updated_at)
624 .map(|f| {
625 let name = format_ident!("{}", to_snake_case(&f.name));
626 let ty = rust_type_tokens(f, ModuleDepth::Nested);
627 quote! { pub #name: Option<SetValue<#ty>> }
628 })
629 .collect();
630
631 quote! {
632 pub mod data {
633 use ferriorm_runtime::prelude::*;
634
635 #[derive(Debug, Clone)]
636 pub struct #create_name {
637 #(#required_fields,)*
638 #(#optional_fields,)*
639 }
640
641 #[derive(Debug, Clone, Default)]
645 pub struct #update_name {
646 #(#update_fields,)*
647 }
648 }
649 }
650}
651
652fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
655 let order_name = format_ident!("{}OrderByInput", model.name);
656
657 let variants: Vec<TokenStream> = scalar_fields
658 .iter()
659 .map(|f| {
660 let variant = format_ident!("{}", to_pascal_case(&f.name));
661 quote! { #variant(SortOrder) }
662 })
663 .collect();
664
665 let order_arms: Vec<TokenStream> = scalar_fields
666 .iter()
667 .map(|f| {
668 let variant = format_ident!("{}", to_pascal_case(&f.name));
669 let db_name = &f.db_name;
670 quote! {
671 Self::#variant(order) => {
672 qb.push(concat!("\"", #db_name, "\" "));
673 qb.push(order.as_sql());
674 }
675 }
676 })
677 .collect();
678
679 quote! {
680 pub mod order {
681 use ferriorm_runtime::prelude::*;
682
683 #[derive(Debug, Clone)]
684 pub enum #order_name {
685 #(#variants),*
686 }
687
688 impl #order_name {
689 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
690 &self,
691 qb: &mut sqlx::QueryBuilder<'args, DB>,
692 ) {
693 match self {
694 #(#order_arms)*
695 }
696 }
697 }
698 }
699 }
700}
701
702fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
705 let _model_ident = format_ident!("{}", model.name);
706 let actions_name = format_ident!("{}Actions", model.name);
707 let where_input = format_ident!("{}WhereInput", model.name);
708 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
709 let create_input = format_ident!("{}CreateInput", model.name);
710 let update_input = format_ident!("{}UpdateInput", model.name);
711 let _order_by = format_ident!("{}OrderByInput", model.name);
712
713 let has_agg_fields = scalar_fields.iter().any(|f| {
715 matches!(
716 &f.field_type,
717 FieldKind::Scalar(
718 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
719 )
720 )
721 });
722 let aggregate_method = if has_agg_fields {
723 quote! {
724 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
725 AggregateQuery { client: self.client, r#where, ops: vec![] }
726 }
727 }
728 } else {
729 quote! {}
730 };
731
732 let has_group_fields = scalar_fields.iter().any(|f| is_groupable(f));
734 let groupby_field_name = format_ident!("{}GroupByField", model.name);
735 let group_by_method = if has_group_fields {
736 quote! {
737 pub fn group_by(&self, keys: Vec<#groupby_field_name>) -> GroupByQuery<'a> {
738 GroupByQuery {
739 client: self.client,
740 r#where: filter::#where_input::default(),
741 group_keys: keys,
742 agg_ops: vec![],
743 count: false,
744 having: None,
745 }
746 }
747 }
748 } else {
749 quote! {}
750 };
751
752 quote! {
753 pub struct #actions_name<'a> {
754 client: &'a DatabaseClient,
755 }
756
757 impl<'a> #actions_name<'a> {
758 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
759
760 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
761 FindUniqueQuery { client: self.client, r#where }
762 }
763
764 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
765 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
766 }
767
768 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
769 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
770 }
771
772 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
773 CreateQuery { client: self.client, data }
774 }
775
776 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
777 UpdateQuery { client: self.client, r#where, data }
778 }
779
780 pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
784 UpdateFirstQuery { client: self.client, r#where, data }
785 }
786
787 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
788 DeleteQuery { client: self.client, r#where }
789 }
790
791 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
792 CountQuery { client: self.client, r#where }
793 }
794
795 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
796 CreateManyQuery { client: self.client, data }
797 }
798
799 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
800 UpdateManyQuery { client: self.client, r#where, data }
801 }
802
803 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
804 DeleteManyQuery { client: self.client, r#where }
805 }
806
807 pub fn upsert(
808 &self,
809 r#where: filter::#where_unique,
810 create: data::#create_input,
811 update: data::#update_input,
812 ) -> UpsertQuery<'a> {
813 UpsertQuery { client: self.client, r#where, create, update }
814 }
815
816 #aggregate_method
817
818 #group_by_method
819 }
820 }
821}
822
823#[allow(clippy::too_many_lines)]
826fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
827 let model_ident = format_ident!("{}", model.name);
828 let table_name = &model.db_name;
829 let _where_input = format_ident!("{}WhereInput", model.name);
830 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
831 let _create_input = format_ident!("{}CreateInput", model.name);
832 let _update_input = format_ident!("{}UpdateInput", model.name);
833 let order_by = format_ident!("{}OrderByInput", model.name);
834 let _select_struct = format_ident!("{}Select", model.name);
835 let _partial_struct = format_ident!("{}Partial", model.name);
836 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
837 let _aggregate_field = format_ident!("{}AggregateField", model.name);
838 let db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::TopLevel);
839
840 let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
841 let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
842 let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
843
844 let insert_code = gen_insert_code(model, scalar_fields, table_name);
845 let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
846 let update_code = gen_update_code(model, scalar_fields, table_name);
847 let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
848 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
849 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
850
851 quote! {
852 fn build_order_by<'args, DB: sqlx::Database>(
854 orders: &[order::#order_by],
855 qb: &mut sqlx::QueryBuilder<'args, DB>,
856 ) {
857 if !orders.is_empty() {
858 qb.push(" ORDER BY ");
859 for (i, ob) in orders.iter().enumerate() {
860 if i > 0 { qb.push(", "); }
861 ob.build_order_by(qb);
862 }
863 }
864 }
865
866 fn build_select_query<'args, DB: sqlx::Database>(
868 base_sql: &str,
869 where_input: &filter::#_where_input,
870 orders: &[order::#order_by],
871 take: Option<i64>,
872 skip: Option<i64>,
873 ) -> sqlx::QueryBuilder<'args, DB>
874 where
875 #(#db_bounds,)*
876 {
877 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
878 where_input.build_where(&mut qb);
879 build_order_by(orders, &mut qb);
880 if let Some(take) = take {
881 qb.push(" LIMIT ");
882 qb.push_bind(take);
883 }
884 if let Some(skip) = skip {
885 qb.push(" OFFSET ");
886 qb.push_bind(skip);
887 }
888 qb
889 }
890
891 fn build_unique_select_query<'args, DB: sqlx::Database>(
893 base_sql: &str,
894 where_unique: &filter::#_where_unique,
895 ) -> sqlx::QueryBuilder<'args, DB>
896 where
897 #(#db_bounds,)*
898 {
899 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
900 where_unique.build_where(&mut qb);
901 qb.push(" LIMIT 1");
902 qb
903 }
904
905 fn build_delete_query<'args, DB: sqlx::Database>(
907 base_sql: &str,
908 where_unique: &filter::#_where_unique,
909 ) -> sqlx::QueryBuilder<'args, DB>
910 where
911 #(#db_bounds,)*
912 {
913 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
914 where_unique.build_where(&mut qb);
915 qb.push(" RETURNING *");
916 qb
917 }
918
919 fn build_count_query<'args, DB: sqlx::Database>(
921 base_sql: &str,
922 where_input: &filter::#_where_input,
923 ) -> sqlx::QueryBuilder<'args, DB>
924 where
925 #(#db_bounds,)*
926 {
927 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
928 where_input.build_where(&mut qb);
929 qb
930 }
931
932 fn build_delete_many_query<'args, DB: sqlx::Database>(
934 base_sql: &str,
935 where_input: &filter::#_where_input,
936 ) -> sqlx::QueryBuilder<'args, DB>
937 where
938 #(#db_bounds,)*
939 {
940 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
941 where_input.build_where(&mut qb);
942 qb
943 }
944
945 pub struct FindUniqueQuery<'a> {
946 client: &'a DatabaseClient,
947 r#where: filter::#_where_unique,
948 }
949
950 impl<'a> FindUniqueQuery<'a> {
951 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
952 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
953 }
954
955 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
956 match self.client {
957 DatabaseClient::Postgres(_) => {
958 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
959 self.client.fetch_optional_pg(qb).await
960 }
961 DatabaseClient::Sqlite(_) => {
962 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
963 self.client.fetch_optional_sqlite(qb).await
964 }
965 }
966 }
967 }
968
969 pub struct FindFirstQuery<'a> {
970 client: &'a DatabaseClient,
971 r#where: filter::#_where_input,
972 order_by: Vec<order::#order_by>,
973 }
974
975 impl<'a> FindFirstQuery<'a> {
976 pub fn order_by(mut self, order: order::#order_by) -> Self {
977 self.order_by.push(order);
978 self
979 }
980
981 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
982 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
983 }
984
985 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
986 match self.client {
987 DatabaseClient::Postgres(_) => {
988 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
989 self.client.fetch_optional_pg(qb).await
990 }
991 DatabaseClient::Sqlite(_) => {
992 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
993 self.client.fetch_optional_sqlite(qb).await
994 }
995 }
996 }
997 }
998
999 pub struct FindManyQuery<'a> {
1000 client: &'a DatabaseClient,
1001 r#where: filter::#_where_input,
1002 order_by: Vec<order::#order_by>,
1003 skip: Option<i64>,
1004 take: Option<i64>,
1005 }
1006
1007 impl<'a> FindManyQuery<'a> {
1008 pub fn order_by(mut self, order: order::#order_by) -> Self {
1009 self.order_by.push(order);
1010 self
1011 }
1012
1013 pub fn skip(mut self, n: i64) -> Self {
1014 self.skip = Some(n);
1015 self
1016 }
1017
1018 pub fn take(mut self, n: i64) -> Self {
1019 self.take = Some(n);
1020 self
1021 }
1022
1023 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
1024 FindManySelectQuery {
1025 client: self.client,
1026 r#where: self.r#where,
1027 order_by: self.order_by,
1028 skip: self.skip,
1029 take: self.take,
1030 select,
1031 }
1032 }
1033
1034 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
1035 match self.client {
1036 DatabaseClient::Postgres(_) => {
1037 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
1038 self.client.fetch_all_pg(qb).await
1039 }
1040 DatabaseClient::Sqlite(_) => {
1041 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
1042 self.client.fetch_all_sqlite(qb).await
1043 }
1044 }
1045 }
1046 }
1047
1048 pub struct CreateQuery<'a> {
1049 client: &'a DatabaseClient,
1050 data: data::#_create_input,
1051 }
1052
1053 impl<'a> CreateQuery<'a> {
1054 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1055 let client = self.client;
1056 #insert_code
1057 }
1058
1059 pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
1063 CreateIgnoreQuery { client: self.client, data: self.data }
1064 }
1065 }
1066
1067 pub struct CreateIgnoreQuery<'a> {
1068 client: &'a DatabaseClient,
1069 data: data::#_create_input,
1070 }
1071
1072 impl<'a> CreateIgnoreQuery<'a> {
1073 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1074 let client = self.client;
1075 #insert_ignore_code
1076 }
1077 }
1078
1079 pub struct UpdateQuery<'a> {
1080 client: &'a DatabaseClient,
1081 r#where: filter::#_where_unique,
1082 data: data::#_update_input,
1083 }
1084
1085 impl<'a> UpdateQuery<'a> {
1086 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1087 let client = self.client;
1088 #update_code
1089 }
1090 }
1091
1092 pub struct UpdateFirstQuery<'a> {
1093 client: &'a DatabaseClient,
1094 r#where: filter::#_where_input,
1095 data: data::#_update_input,
1096 }
1097
1098 impl<'a> UpdateFirstQuery<'a> {
1099 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1100 let client = self.client;
1101 #update_first_code
1102 }
1103 }
1104
1105 pub struct DeleteQuery<'a> {
1106 client: &'a DatabaseClient,
1107 r#where: filter::#_where_unique,
1108 }
1109
1110 impl<'a> DeleteQuery<'a> {
1111 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1112 match self.client {
1113 DatabaseClient::Postgres(_) => {
1114 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1115 self.client.fetch_one_pg(qb).await
1116 }
1117 DatabaseClient::Sqlite(_) => {
1118 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1119 self.client.fetch_one_sqlite(qb).await
1120 }
1121 }
1122 }
1123 }
1124
1125 #[derive(sqlx::FromRow)]
1126 struct CountResult { count: i64 }
1127
1128 pub struct CountQuery<'a> {
1129 client: &'a DatabaseClient,
1130 r#where: filter::#_where_input,
1131 }
1132
1133 impl<'a> CountQuery<'a> {
1134 pub async fn exec(self) -> Result<i64, FerriormError> {
1135 let row: CountResult = match self.client {
1136 DatabaseClient::Postgres(_) => {
1137 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1138 self.client.fetch_one_pg(qb).await?
1139 }
1140 DatabaseClient::Sqlite(_) => {
1141 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1142 self.client.fetch_one_sqlite(qb).await?
1143 }
1144 };
1145 Ok(row.count)
1146 }
1147 }
1148
1149 pub struct CreateManyQuery<'a> {
1150 client: &'a DatabaseClient,
1151 data: Vec<data::#_create_input>,
1152 }
1153
1154 impl<'a> CreateManyQuery<'a> {
1155 pub async fn exec(self) -> Result<u64, FerriormError> {
1156 if self.data.is_empty() { return Ok(0); }
1157 let count = self.data.len() as u64;
1158 for item in self.data {
1159 CreateQuery { client: self.client, data: item }.exec().await?;
1160 }
1161 Ok(count)
1162 }
1163 }
1164
1165 pub struct UpdateManyQuery<'a> {
1166 client: &'a DatabaseClient,
1167 r#where: filter::#_where_input,
1168 data: data::#_update_input,
1169 }
1170
1171 impl<'a> UpdateManyQuery<'a> {
1172 pub async fn exec(self) -> Result<u64, FerriormError> {
1173 let client = self.client;
1174 #update_many_code
1175 }
1176 }
1177
1178 pub struct UpsertQuery<'a> {
1179 client: &'a DatabaseClient,
1180 r#where: filter::#_where_unique,
1181 create: data::#_create_input,
1182 update: data::#_update_input,
1183 }
1184
1185 impl<'a> UpsertQuery<'a> {
1186 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1187 let client = self.client;
1188 #upsert_code
1189 }
1190 }
1191
1192 pub struct DeleteManyQuery<'a> {
1193 client: &'a DatabaseClient,
1194 r#where: filter::#_where_input,
1195 }
1196
1197 impl<'a> DeleteManyQuery<'a> {
1198 pub async fn exec(self) -> Result<u64, FerriormError> {
1199 match self.client {
1200 DatabaseClient::Postgres(_) => {
1201 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1202 self.client.execute_pg(qb).await
1203 }
1204 DatabaseClient::Sqlite(_) => {
1205 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1206 self.client.execute_sqlite(qb).await
1207 }
1208 }
1209 }
1210 }
1211 }
1212}
1213
1214fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1217 let _model_ident = format_ident!("{}", model.name);
1218
1219 let required: Vec<&Field> = scalar_fields
1221 .iter()
1222 .copied()
1223 .filter(|f| !f.has_default() && !f.is_updated_at)
1224 .collect();
1225
1226 let optional: Vec<&Field> = scalar_fields
1228 .iter()
1229 .copied()
1230 .filter(|f| f.has_default() && !f.is_updated_at)
1231 .collect();
1232
1233 let updated_at: Vec<&Field> = scalar_fields
1235 .iter()
1236 .copied()
1237 .filter(|f| f.is_updated_at)
1238 .collect();
1239
1240 let mut col_pushes = vec![];
1242 let mut val_pushes = vec![];
1243
1244 for f in &required {
1246 let db_name = &f.db_name;
1247 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1248 col_pushes.push(quote! { cols.push(#db_name); });
1249 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1250 }
1251
1252 for f in &optional {
1254 let db_name = &f.db_name;
1255 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1256 if is_autoincrement(f) {
1257 col_pushes.push(quote! {
1261 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1262 });
1263 val_pushes.push(quote! {
1264 if let Some(val) = self.data.#field_ident {
1265 sep.push_bind(val);
1266 }
1267 });
1268 } else {
1269 let default_expr = gen_default_expr(f, &f.field_type);
1270 col_pushes.push(quote! { cols.push(#db_name); });
1271 val_pushes.push(quote! {
1272 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1273 sep.push_bind(val);
1274 });
1275 }
1276 }
1277
1278 for f in &updated_at {
1280 let db_name = &f.db_name;
1281 col_pushes.push(quote! { cols.push(#db_name); });
1282 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1283 }
1284
1285 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1286
1287 quote! {
1290 macro_rules! build_insert {
1292 ($qb_type:ty) => {{
1293 let mut cols: Vec<&str> = Vec::new();
1294 #(#col_pushes)*
1295
1296 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1297 qb.push(" (");
1298 for (i, col) in cols.iter().enumerate() {
1299 if i > 0 { qb.push(", "); }
1300 qb.push("\"");
1301 qb.push(*col);
1302 qb.push("\"");
1303 }
1304 qb.push(") VALUES (");
1305 {
1306 let mut sep = qb.separated(", ");
1307 #(#val_pushes)*
1308 }
1309 qb.push(") RETURNING *");
1310 qb
1311 }};
1312 }
1313
1314 match client {
1315 DatabaseClient::Postgres(_) => {
1316 let qb = build_insert!(sqlx::Postgres);
1317 client.fetch_one_pg(qb).await
1318 }
1319 DatabaseClient::Sqlite(_) => {
1320 let qb = build_insert!(sqlx::Sqlite);
1321 client.fetch_one_sqlite(qb).await
1322 }
1323 }
1324 }
1325}
1326
1327fn gen_insert_ignore_code(
1328 _model: &Model,
1329 scalar_fields: &[&Field],
1330 table_name: &str,
1331) -> TokenStream {
1332 let required: Vec<&Field> = scalar_fields
1333 .iter()
1334 .copied()
1335 .filter(|f| !f.has_default() && !f.is_updated_at)
1336 .collect();
1337 let optional: Vec<&Field> = scalar_fields
1338 .iter()
1339 .copied()
1340 .filter(|f| f.has_default() && !f.is_updated_at)
1341 .collect();
1342 let updated_at: Vec<&Field> = scalar_fields
1343 .iter()
1344 .copied()
1345 .filter(|f| f.is_updated_at)
1346 .collect();
1347
1348 let mut col_pushes = vec![];
1349 let mut val_pushes = vec![];
1350
1351 for f in &required {
1352 let db_name = &f.db_name;
1353 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1354 col_pushes.push(quote! { cols.push(#db_name); });
1355 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1356 }
1357 for f in &optional {
1358 let db_name = &f.db_name;
1359 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1360 if is_autoincrement(f) {
1361 col_pushes.push(quote! {
1362 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1363 });
1364 val_pushes.push(quote! {
1365 if let Some(val) = self.data.#field_ident {
1366 sep.push_bind(val);
1367 }
1368 });
1369 } else {
1370 let default_expr = gen_default_expr(f, &f.field_type);
1371 col_pushes.push(quote! { cols.push(#db_name); });
1372 val_pushes.push(quote! {
1373 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1374 sep.push_bind(val);
1375 });
1376 }
1377 }
1378 for f in &updated_at {
1379 let db_name = &f.db_name;
1380 col_pushes.push(quote! { cols.push(#db_name); });
1381 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1382 }
1383
1384 let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1385 let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1386
1387 quote! {
1388 macro_rules! build_insert_ignore {
1389 ($qb_type:ty, $head:expr, $tail:expr) => {{
1390 let mut cols: Vec<&str> = Vec::new();
1391 #(#col_pushes)*
1392
1393 let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1394 qb.push(" (");
1395 for (i, col) in cols.iter().enumerate() {
1396 if i > 0 { qb.push(", "); }
1397 qb.push("\"");
1398 qb.push(*col);
1399 qb.push("\"");
1400 }
1401 qb.push(") VALUES (");
1402 {
1403 let mut sep = qb.separated(", ");
1404 #(#val_pushes)*
1405 }
1406 qb.push(")");
1407 qb.push($tail);
1408 qb.push(" RETURNING *");
1409 qb
1410 }};
1411 }
1412
1413 match client {
1414 DatabaseClient::Postgres(_) => {
1415 let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1416 client.fetch_optional_pg(qb).await
1417 }
1418 DatabaseClient::Sqlite(_) => {
1419 let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1420 client.fetch_optional_sqlite(qb).await
1421 }
1422 }
1423 }
1424}
1425
1426fn is_autoincrement(field: &Field) -> bool {
1430 matches!(
1431 field.default,
1432 Some(ferriorm_core::ast::DefaultValue::AutoIncrement)
1433 )
1434}
1435
1436fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1438 use ferriorm_core::ast::DefaultValue;
1439
1440 match &field.default {
1441 Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1442 quote! { uuid::Uuid::new_v4().to_string() }
1443 }
1444 Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1445 Some(DefaultValue::AutoIncrement) => quote! { 0i32 },
1450 Some(DefaultValue::Literal(lit)) => {
1451 use ferriorm_core::ast::LiteralValue;
1452 match lit {
1453 LiteralValue::String(s) => quote! { #s.to_string() },
1454 LiteralValue::Int(i) => {
1455 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1457 match field_type {
1458 FieldKind::Scalar(ScalarType::Float) => {
1459 let val = *i as f64;
1460 quote! { #val }
1461 }
1462 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1463 FieldKind::Scalar(ScalarType::Int)
1465 if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1466 {
1467 quote! { #i }
1468 }
1469 _ => {
1470 let val = *i as i32;
1472 quote! { #val }
1473 }
1474 }
1475 }
1476 LiteralValue::Float(f) => quote! { #f },
1477 LiteralValue::Bool(b) => quote! { #b },
1478 }
1479 }
1480 Some(DefaultValue::EnumVariant(v)) => {
1481 let variant = format_ident!("{}", v);
1483 if let FieldKind::Enum(enum_name) = &field.field_type {
1484 let enum_ident = format_ident!("{}", enum_name);
1485 quote! { super::enums::#enum_ident::#variant }
1486 } else {
1487 quote! { Default::default() }
1488 }
1489 }
1490 None => quote! { Default::default() },
1491 }
1492}
1493
1494fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1497 let _model_ident = format_ident!("{}", model.name);
1498
1499 let updatable: Vec<&Field> = scalar_fields
1501 .iter()
1502 .copied()
1503 .filter(|f| !f.is_id && !f.is_updated_at)
1504 .collect();
1505
1506 let updated_at: Vec<&Field> = scalar_fields
1507 .iter()
1508 .copied()
1509 .filter(|f| f.is_updated_at)
1510 .collect();
1511
1512 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1513
1514 let set_arms: Vec<TokenStream> = updatable
1516 .iter()
1517 .map(|f| {
1518 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1519 let db_name = &f.db_name;
1520 quote! {
1521 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1522 if !first_set { qb.push(", "); }
1523 first_set = false;
1524 qb.push(concat!("\"", #db_name, "\" = "));
1525 qb.push_bind(v);
1526 }
1527 }
1528 })
1529 .collect();
1530
1531 let updated_at_arms: Vec<TokenStream> = updated_at
1532 .iter()
1533 .map(|f| {
1534 let db_name = &f.db_name;
1535 quote! {
1536 if !first_set { qb.push(", "); }
1537 first_set = false;
1538 qb.push(concat!("\"", #db_name, "\" = "));
1539 qb.push_bind(chrono::Utc::now());
1540 }
1541 })
1542 .collect();
1543
1544 quote! {
1547 macro_rules! build_update {
1548 ($qb_type:ty) => {{
1549 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1550 let mut first_set = true;
1551 #(#set_arms)*
1552 #(#updated_at_arms)*
1553
1554 if first_set {
1555 return Err(FerriormError::Query("No fields to update".into()));
1556 }
1557
1558 qb.push(" WHERE 1=1");
1559 self.r#where.build_where(&mut qb);
1560 qb.push(" RETURNING *");
1561 qb
1562 }};
1563 }
1564
1565 match client {
1566 DatabaseClient::Postgres(_) => {
1567 let qb = build_update!(sqlx::Postgres);
1568 client.fetch_one_pg(qb).await
1569 }
1570 DatabaseClient::Sqlite(_) => {
1571 let qb = build_update!(sqlx::Sqlite);
1572 client.fetch_one_sqlite(qb).await
1573 }
1574 }
1575 }
1576}
1577
1578fn gen_update_first_code(
1581 _model: &Model,
1582 scalar_fields: &[&Field],
1583 table_name: &str,
1584) -> TokenStream {
1585 let updatable: Vec<&Field> = scalar_fields
1586 .iter()
1587 .copied()
1588 .filter(|f| !f.is_id && !f.is_updated_at)
1589 .collect();
1590
1591 let updated_at: Vec<&Field> = scalar_fields
1592 .iter()
1593 .copied()
1594 .filter(|f| f.is_updated_at)
1595 .collect();
1596
1597 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1598
1599 let set_arms: Vec<TokenStream> = updatable
1600 .iter()
1601 .map(|f| {
1602 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1603 let db_name = &f.db_name;
1604 quote! {
1605 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1606 if !first_set { qb.push(", "); }
1607 first_set = false;
1608 qb.push(concat!("\"", #db_name, "\" = "));
1609 qb.push_bind(v);
1610 }
1611 }
1612 })
1613 .collect();
1614
1615 let updated_at_arms: Vec<TokenStream> = updated_at
1616 .iter()
1617 .map(|f| {
1618 let db_name = &f.db_name;
1619 quote! {
1620 if !first_set { qb.push(", "); }
1621 first_set = false;
1622 qb.push(concat!("\"", #db_name, "\" = "));
1623 qb.push_bind(chrono::Utc::now());
1624 }
1625 })
1626 .collect();
1627
1628 quote! {
1629 macro_rules! build_update_first {
1630 ($qb_type:ty) => {{
1631 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1632 let mut first_set = true;
1633 #(#set_arms)*
1634 #(#updated_at_arms)*
1635
1636 if first_set {
1637 return Err(FerriormError::Query("No fields to update".into()));
1638 }
1639
1640 qb.push(" WHERE 1=1");
1641 self.r#where.build_where(&mut qb);
1642 qb.push(" RETURNING *");
1643 qb
1644 }};
1645 }
1646
1647 match client {
1648 DatabaseClient::Postgres(_) => {
1649 let qb = build_update_first!(sqlx::Postgres);
1650 client.fetch_optional_pg(qb).await
1651 }
1652 DatabaseClient::Sqlite(_) => {
1653 let qb = build_update_first!(sqlx::Sqlite);
1654 client.fetch_optional_sqlite(qb).await
1655 }
1656 }
1657 }
1658}
1659
1660fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1663 let updatable: Vec<&Field> = scalar_fields
1665 .iter()
1666 .copied()
1667 .filter(|f| !f.is_id && !f.is_updated_at)
1668 .collect();
1669
1670 let updated_at: Vec<&Field> = scalar_fields
1671 .iter()
1672 .copied()
1673 .filter(|f| f.is_updated_at)
1674 .collect();
1675
1676 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1677
1678 let set_arms: Vec<TokenStream> = updatable
1680 .iter()
1681 .map(|f| {
1682 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1683 let db_name = &f.db_name;
1684 quote! {
1685 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1686 if !first_set { qb.push(", "); }
1687 first_set = false;
1688 qb.push(concat!("\"", #db_name, "\" = "));
1689 qb.push_bind(v);
1690 }
1691 }
1692 })
1693 .collect();
1694
1695 let updated_at_arms: Vec<TokenStream> = updated_at
1696 .iter()
1697 .map(|f| {
1698 let db_name = &f.db_name;
1699 quote! {
1700 if !first_set { qb.push(", "); }
1701 first_set = false;
1702 qb.push(concat!("\"", #db_name, "\" = "));
1703 qb.push_bind(chrono::Utc::now());
1704 }
1705 })
1706 .collect();
1707
1708 quote! {
1709 macro_rules! build_update_many {
1710 ($qb_type:ty) => {{
1711 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1712 let mut first_set = true;
1713 #(#set_arms)*
1714 #(#updated_at_arms)*
1715
1716 if first_set {
1717 return Ok(0);
1718 }
1719
1720 qb.push(" WHERE 1=1");
1721 self.r#where.build_where(&mut qb);
1722 qb
1723 }};
1724 }
1725
1726 match client {
1727 DatabaseClient::Postgres(_) => {
1728 let qb = build_update_many!(sqlx::Postgres);
1729 client.execute_pg(qb).await
1730 }
1731 DatabaseClient::Sqlite(_) => {
1732 let qb = build_update_many!(sqlx::Sqlite);
1733 client.execute_sqlite(qb).await
1734 }
1735 }
1736 }
1737}
1738
1739enum AggregateKind {
1743 Numeric,
1745 DateTime,
1747}
1748
1749#[allow(clippy::too_many_lines)]
1752fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1753 let required: Vec<&Field> = scalar_fields
1755 .iter()
1756 .copied()
1757 .filter(|f| !f.has_default() && !f.is_updated_at)
1758 .collect();
1759 let optional: Vec<&Field> = scalar_fields
1760 .iter()
1761 .copied()
1762 .filter(|f| f.has_default() && !f.is_updated_at)
1763 .collect();
1764 let updated_at: Vec<&Field> = scalar_fields
1765 .iter()
1766 .copied()
1767 .filter(|f| f.is_updated_at)
1768 .collect();
1769
1770 let mut col_pushes = vec![];
1771 let mut val_pushes = vec![];
1772
1773 for f in &required {
1774 let db_name = &f.db_name;
1775 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1776 col_pushes.push(quote! { cols.push(#db_name); });
1777 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1778 }
1779 for f in &optional {
1780 let db_name = &f.db_name;
1781 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1782 if is_autoincrement(f) {
1783 col_pushes.push(quote! {
1784 if self.create.#field_ident.is_some() { cols.push(#db_name); }
1785 });
1786 val_pushes.push(quote! {
1787 if let Some(val) = self.create.#field_ident {
1788 sep.push_bind(val);
1789 }
1790 });
1791 } else {
1792 let default_expr = gen_default_expr(f, &f.field_type);
1793 col_pushes.push(quote! { cols.push(#db_name); });
1794 val_pushes.push(quote! {
1795 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1796 sep.push_bind(val);
1797 });
1798 }
1799 }
1800 for f in &updated_at {
1801 let db_name = &f.db_name;
1802 col_pushes.push(quote! { cols.push(#db_name); });
1803 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1804 }
1805
1806 let updatable: Vec<&Field> = scalar_fields
1808 .iter()
1809 .copied()
1810 .filter(|f| !f.is_id && !f.is_updated_at)
1811 .collect();
1812
1813 let set_arms: Vec<TokenStream> = updatable
1814 .iter()
1815 .map(|f| {
1816 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1817 let db_name = &f.db_name;
1818 quote! {
1819 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1820 if !first_set { qb.push(", "); }
1821 first_set = false;
1822 qb.push(concat!("\"", #db_name, "\" = "));
1823 qb.push_bind(v);
1824 }
1825 }
1826 })
1827 .collect();
1828
1829 let updated_at_set: Vec<TokenStream> = updated_at
1830 .iter()
1831 .map(|f| {
1832 let db_name = &f.db_name;
1833 quote! {
1834 if !first_set { qb.push(", "); }
1835 first_set = false;
1836 qb.push(concat!("\"", #db_name, "\" = "));
1837 qb.push_bind(chrono::Utc::now());
1838 }
1839 })
1840 .collect();
1841
1842 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1843
1844 quote! {
1845 let conflict_target = self.r#where.conflict_target();
1846 let first_conflict_col = self.r#where.first_conflict_col();
1847
1848 macro_rules! build_upsert {
1849 ($qb_type:ty) => {{
1850 let mut cols: Vec<&str> = Vec::new();
1851 #(#col_pushes)*
1852
1853 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1854 qb.push(" (");
1855 for (i, col) in cols.iter().enumerate() {
1856 if i > 0 { qb.push(", "); }
1857 qb.push("\"");
1858 qb.push(*col);
1859 qb.push("\"");
1860 }
1861 qb.push(") VALUES (");
1862 {
1863 let mut sep = qb.separated(", ");
1864 #(#val_pushes)*
1865 }
1866 qb.push(")");
1867 qb.push(" ON CONFLICT ");
1868 qb.push(conflict_target);
1869 qb.push(" DO UPDATE SET ");
1870
1871 let mut first_set = true;
1872 #(#set_arms)*
1873 #(#updated_at_set)*
1874
1875 if first_set {
1876 qb.push(first_conflict_col);
1879 qb.push(" = ");
1880 qb.push(first_conflict_col);
1881 }
1882
1883 qb.push(" RETURNING *");
1884 qb
1885 }};
1886 }
1887
1888 match client {
1889 DatabaseClient::Postgres(_) => {
1890 let qb = build_upsert!(sqlx::Postgres);
1891 client.fetch_one_pg(qb).await
1892 }
1893 DatabaseClient::Sqlite(_) => {
1894 let qb = build_upsert!(sqlx::Sqlite);
1895 client.fetch_one_sqlite(qb).await
1896 }
1897 }
1898 }
1899}
1900
1901#[allow(clippy::too_many_lines)]
1902fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1903 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1904 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1905 let _where_input = format_ident!("{}WhereInput", model.name);
1906 let table_name = &model.db_name;
1907
1908 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1910 .iter()
1911 .filter_map(|f| match &f.field_type {
1912 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1913 Some((*f, AggregateKind::Numeric))
1914 }
1915 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1916 _ => None,
1917 })
1918 .collect();
1919
1920 if agg_fields.is_empty() {
1921 return quote! {};
1922 }
1923
1924 let enum_variants: Vec<TokenStream> = agg_fields
1926 .iter()
1927 .map(|(f, _)| {
1928 let variant = format_ident!("{}", to_pascal_case(&f.name));
1929 quote! { #variant }
1930 })
1931 .collect();
1932
1933 let db_name_arms: Vec<TokenStream> = agg_fields
1935 .iter()
1936 .map(|(f, _)| {
1937 let variant = format_ident!("{}", to_pascal_case(&f.name));
1938 let db_name = &f.db_name;
1939 quote! { Self::#variant => #db_name }
1940 })
1941 .collect();
1942
1943 let mut result_fields = Vec::new();
1945 for (f, kind) in &agg_fields {
1946 let snake = to_snake_case(&f.name);
1947 let orig_ty = rust_type_tokens(
1948 &Field {
1949 is_optional: false,
1950 ..(*f).clone()
1951 },
1952 ModuleDepth::TopLevel,
1953 );
1954
1955 match kind {
1956 AggregateKind::Numeric => {
1957 let avg_name = format_ident!("avg_{}", snake);
1958 let sum_name = format_ident!("sum_{}", snake);
1959 let min_name = format_ident!("min_{}", snake);
1960 let max_name = format_ident!("max_{}", snake);
1961 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1962 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1963 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1964 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1965 }
1966 AggregateKind::DateTime => {
1967 let min_name = format_ident!("min_{}", snake);
1968 let max_name = format_ident!("max_{}", snake);
1969 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1970 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1971 }
1972 }
1973 }
1974
1975 let numeric_arms: Vec<TokenStream> = agg_fields
1977 .iter()
1978 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1979 .map(|(f, _)| {
1980 let variant = format_ident!("{}", to_pascal_case(&f.name));
1981 quote! { Self::#variant => true }
1982 })
1983 .collect();
1984
1985 let has_numeric = !numeric_arms.is_empty();
1986 let is_numeric_method = if has_numeric {
1987 quote! {
1988 fn is_numeric(&self) -> bool {
1989 match self {
1990 #(#numeric_arms,)*
1991 #[allow(unreachable_patterns)]
1992 _ => false,
1993 }
1994 }
1995 }
1996 } else {
1997 quote! {
1998 fn is_numeric(&self) -> bool { false }
1999 }
2000 };
2001
2002 let mut alias_arms = Vec::new();
2004 for (f, kind) in &agg_fields {
2005 let variant = format_ident!("{}", to_pascal_case(&f.name));
2006 let snake = to_snake_case(&f.name);
2007 let prefixes = match kind {
2008 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
2009 AggregateKind::DateTime => vec!["min", "max"],
2010 };
2011 for prefix in prefixes {
2012 let alias_str = format!("{prefix}_{snake}");
2013 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
2014 }
2015 }
2016
2017 let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2018
2019 quote! {
2020 #[derive(Debug, Clone, Copy)]
2021 pub enum #aggregate_field_name {
2022 #(#enum_variants),*
2023 }
2024
2025 impl #aggregate_field_name {
2026 pub fn db_name(&self) -> &'static str {
2027 match self {
2028 #(#db_name_arms,)*
2029 }
2030 }
2031
2032 fn alias(&self, prefix: &'static str) -> &'static str {
2033 match (prefix, self) {
2034 #(#alias_arms,)*
2035 _ => unreachable!(),
2036 }
2037 }
2038
2039 #is_numeric_method
2040 }
2041
2042 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2043 pub struct #aggregate_result_name {
2044 #(#result_fields,)*
2045 }
2046
2047 pub struct AggregateQuery<'a> {
2048 client: &'a DatabaseClient,
2049 r#where: filter::#_where_input,
2050 ops: Vec<(&'static str, &'static str, &'static str)>,
2051 }
2052
2053 impl<'a> AggregateQuery<'a> {
2054 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2055 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2056 let db_name = field.db_name();
2057 let alias = field.alias("avg");
2058 self.ops.push(("AVG", db_name, alias));
2059 self
2060 }
2061
2062 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2063 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2064 let db_name = field.db_name();
2065 let alias = field.alias("sum");
2066 self.ops.push(("SUM", db_name, alias));
2067 self
2068 }
2069
2070 pub fn min(mut self, field: #aggregate_field_name) -> Self {
2071 let db_name = field.db_name();
2072 let alias = field.alias("min");
2073 self.ops.push(("MIN", db_name, alias));
2074 self
2075 }
2076
2077 pub fn max(mut self, field: #aggregate_field_name) -> Self {
2078 let db_name = field.db_name();
2079 let alias = field.alias("max");
2080 self.ops.push(("MAX", db_name, alias));
2081 self
2082 }
2083
2084 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
2085 if self.ops.is_empty() {
2086 return Err(FerriormError::Query("No aggregate operations specified".into()));
2087 }
2088
2089 let selections: Vec<String> = self.ops.iter()
2090 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
2091 .collect();
2092 let select_clause = selections.join(", ");
2093 let base_sql = format!(#agg_select_base, select_clause);
2094
2095 match self.client {
2096 DatabaseClient::Postgres(_) => {
2097 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2098 self.r#where.build_where(&mut qb);
2099 self.client.fetch_one_pg(qb).await
2100 }
2101 DatabaseClient::Sqlite(_) => {
2102 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2103 self.r#where.build_where(&mut qb);
2104 self.client.fetch_one_sqlite(qb).await
2105 }
2106 }
2107 }
2108 }
2109 }
2110}
2111
2112fn gen_having_comparable_arms(field_ident: &proc_macro2::Ident, lhs: &str) -> TokenStream {
2119 let eq = format!(" AND {lhs} = ");
2120 let ne = format!(" AND {lhs} != ");
2121 let gt = format!(" AND {lhs} > ");
2122 let gte = format!(" AND {lhs} >= ");
2123 let lt = format!(" AND {lhs} < ");
2124 let lte = format!(" AND {lhs} <= ");
2125 let in_arms = gen_in_arms_lhs(lhs);
2126 quote! {
2127 if let Some(filter) = &self.#field_ident {
2128 if let Some(v) = &filter.equals { qb.push(#eq); qb.push_bind(v.clone()); }
2129 if let Some(v) = &filter.not { qb.push(#ne); qb.push_bind(v.clone()); }
2130 if let Some(v) = &filter.gt { qb.push(#gt); qb.push_bind(v.clone()); }
2131 if let Some(v) = &filter.gte { qb.push(#gte); qb.push_bind(v.clone()); }
2132 if let Some(v) = &filter.lt { qb.push(#lt); qb.push_bind(v.clone()); }
2133 if let Some(v) = &filter.lte { qb.push(#lte); qb.push_bind(v.clone()); }
2134 #in_arms
2135 }
2136 }
2137}
2138
2139fn is_groupable(field: &Field) -> bool {
2144 match &field.field_type {
2145 FieldKind::Scalar(
2146 ScalarType::String
2147 | ScalarType::Int
2148 | ScalarType::BigInt
2149 | ScalarType::Float
2150 | ScalarType::Boolean
2151 | ScalarType::DateTime,
2152 )
2153 | FieldKind::Enum(_) => true,
2154 FieldKind::Scalar(ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal)
2155 | FieldKind::Model(_) => false,
2156 }
2157}
2158
2159#[allow(clippy::too_many_lines)]
2160fn gen_groupby_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2161 let groupby_field_name = format_ident!("{}GroupByField", model.name);
2162 let groupby_result_name = format_ident!("{}GroupByResult", model.name);
2163 let having_input_name = format_ident!("{}HavingInput", model.name);
2164 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
2165 let where_input = format_ident!("{}WhereInput", model.name);
2166 let table_name = &model.db_name;
2167
2168 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
2171 .iter()
2172 .filter_map(|f| match &f.field_type {
2173 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
2174 Some((*f, AggregateKind::Numeric))
2175 }
2176 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
2177 _ => None,
2178 })
2179 .collect();
2180
2181 let group_fields: Vec<&Field> = scalar_fields
2182 .iter()
2183 .filter(|f| is_groupable(f))
2184 .copied()
2185 .collect();
2186
2187 if group_fields.is_empty() {
2188 return quote! {};
2189 }
2190
2191 let groupby_variants: Vec<TokenStream> = group_fields
2193 .iter()
2194 .map(|f| {
2195 let variant = format_ident!("{}", to_pascal_case(&f.name));
2196 quote! { #variant }
2197 })
2198 .collect();
2199
2200 let groupby_db_arms: Vec<TokenStream> = group_fields
2201 .iter()
2202 .map(|f| {
2203 let variant = format_ident!("{}", to_pascal_case(&f.name));
2204 let db_name = &f.db_name;
2205 quote! { Self::#variant => #db_name }
2206 })
2207 .collect();
2208
2209 let groupby_alias_arms: Vec<TokenStream> = group_fields
2210 .iter()
2211 .map(|f| {
2212 let variant = format_ident!("{}", to_pascal_case(&f.name));
2213 let alias = to_snake_case(&f.name);
2214 quote! { Self::#variant => #alias }
2215 })
2216 .collect();
2217
2218 let mut result_fields: Vec<TokenStream> = Vec::new();
2223 for f in &group_fields {
2224 let snake = to_snake_case(&f.name);
2225 let name = format_ident!("{}", snake);
2226 let base_ty = rust_type_tokens(
2228 &Field {
2229 is_optional: false,
2230 ..(*f).clone()
2231 },
2232 ModuleDepth::TopLevel,
2233 );
2234 result_fields.push(quote! { #[sqlx(default)] pub #name: Option<#base_ty> });
2235 }
2236 result_fields.push(quote! { #[sqlx(default)] pub count: Option<i64> });
2237 for (f, kind) in &agg_fields {
2238 let snake = to_snake_case(&f.name);
2239 let orig_ty = rust_type_tokens(
2240 &Field {
2241 is_optional: false,
2242 ..(*f).clone()
2243 },
2244 ModuleDepth::TopLevel,
2245 );
2246 match kind {
2247 AggregateKind::Numeric => {
2248 let avg_name = format_ident!("avg_{}", snake);
2249 let sum_name = format_ident!("sum_{}", snake);
2250 let min_name = format_ident!("min_{}", snake);
2251 let max_name = format_ident!("max_{}", snake);
2252 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
2253 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
2254 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2255 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2256 }
2257 AggregateKind::DateTime => {
2258 let min_name = format_ident!("min_{}", snake);
2259 let max_name = format_ident!("max_{}", snake);
2260 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2261 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2262 }
2263 }
2264 }
2265
2266 let mut having_fields: Vec<TokenStream> = Vec::new();
2271 having_fields.push(quote! { pub count: Option<ferriorm_runtime::filter::BigIntFilter> });
2273 for (f, kind) in &agg_fields {
2274 let snake = to_snake_case(&f.name);
2275 let avg_name = format_ident!("avg_{}", snake);
2276 let sum_name = format_ident!("sum_{}", snake);
2277 let min_name = format_ident!("min_{}", snake);
2278 let max_name = format_ident!("max_{}", snake);
2279 let column_filter = filter_type_tokens(
2280 &Field {
2281 is_optional: false,
2282 ..(*f).clone()
2283 },
2284 ModuleDepth::TopLevel,
2285 )
2286 .unwrap_or_else(|| quote! { ferriorm_runtime::filter::BigIntFilter });
2287 match kind {
2288 AggregateKind::Numeric => {
2289 having_fields
2290 .push(quote! { pub #avg_name: Option<ferriorm_runtime::filter::FloatFilter> });
2291 having_fields
2292 .push(quote! { pub #sum_name: Option<ferriorm_runtime::filter::FloatFilter> });
2293 having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2294 having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2295 }
2296 AggregateKind::DateTime => {
2297 having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2298 having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2299 }
2300 }
2301 }
2302
2303 let mut having_arms: Vec<TokenStream> = Vec::new();
2309 let count_in_arms = gen_in_arms_lhs("COUNT(*)");
2312 having_arms.push(quote! {
2313 if let Some(filter) = &self.count {
2314 if let Some(v) = &filter.equals { qb.push(" AND COUNT(*) = "); qb.push_bind(*v); }
2315 if let Some(v) = &filter.not { qb.push(" AND COUNT(*) != "); qb.push_bind(*v); }
2316 if let Some(v) = &filter.gt { qb.push(" AND COUNT(*) > "); qb.push_bind(*v); }
2317 if let Some(v) = &filter.gte { qb.push(" AND COUNT(*) >= "); qb.push_bind(*v); }
2318 if let Some(v) = &filter.lt { qb.push(" AND COUNT(*) < "); qb.push_bind(*v); }
2319 if let Some(v) = &filter.lte { qb.push(" AND COUNT(*) <= "); qb.push_bind(*v); }
2320 #count_in_arms
2321 }
2322 });
2323
2324 for (f, kind) in &agg_fields {
2325 let snake = to_snake_case(&f.name);
2326 let db_name = &f.db_name;
2327 let avg_ident = format_ident!("avg_{}", snake);
2328 let sum_ident = format_ident!("sum_{}", snake);
2329 let min_ident = format_ident!("min_{}", snake);
2330 let max_ident = format_ident!("max_{}", snake);
2331 match kind {
2332 AggregateKind::Numeric => {
2333 let avg_lhs = format!(r#"AVG("{db_name}")"#);
2334 let sum_lhs = format!(r#"SUM("{db_name}")"#);
2335 let min_lhs = format!(r#"MIN("{db_name}")"#);
2336 let max_lhs = format!(r#"MAX("{db_name}")"#);
2337 having_arms.push(gen_having_comparable_arms(&avg_ident, &avg_lhs));
2338 having_arms.push(gen_having_comparable_arms(&sum_ident, &sum_lhs));
2339 having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2340 having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2341 }
2342 AggregateKind::DateTime => {
2343 let min_lhs = format!(r#"MIN("{db_name}")"#);
2344 let max_lhs = format!(r#"MAX("{db_name}")"#);
2345 having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2346 having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2347 }
2348 }
2349 }
2350
2351 let mut db_bounds = collect_db_bounds(scalar_fields, ModuleDepth::TopLevel);
2356 if !scalar_fields
2357 .iter()
2358 .any(|f| matches!(&f.field_type, FieldKind::Scalar(ScalarType::Float)))
2359 {
2360 db_bounds.push(quote! { f64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
2361 }
2362
2363 let has_agg_fields = !agg_fields.is_empty();
2368
2369 let agg_methods = if has_agg_fields {
2370 quote! {
2371 pub fn count(mut self) -> Self {
2372 self.count = true;
2373 self
2374 }
2375
2376 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2377 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2378 let db_name = field.db_name();
2379 let alias = field.alias("avg");
2380 self.agg_ops.push(("AVG", db_name, alias));
2381 self
2382 }
2383
2384 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2385 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2386 let db_name = field.db_name();
2387 let alias = field.alias("sum");
2388 self.agg_ops.push(("SUM", db_name, alias));
2389 self
2390 }
2391
2392 pub fn min(mut self, field: #aggregate_field_name) -> Self {
2393 let db_name = field.db_name();
2394 let alias = field.alias("min");
2395 self.agg_ops.push(("MIN", db_name, alias));
2396 self
2397 }
2398
2399 pub fn max(mut self, field: #aggregate_field_name) -> Self {
2400 let db_name = field.db_name();
2401 let alias = field.alias("max");
2402 self.agg_ops.push(("MAX", db_name, alias));
2403 self
2404 }
2405 }
2406 } else {
2407 quote! {
2408 pub fn count(mut self) -> Self {
2409 self.count = true;
2410 self
2411 }
2412 }
2413 };
2414
2415 quote! {
2416 #[derive(Debug, Clone, Copy)]
2417 pub enum #groupby_field_name {
2418 #(#groupby_variants),*
2419 }
2420
2421 impl #groupby_field_name {
2422 pub fn db_name(&self) -> &'static str {
2423 match self {
2424 #(#groupby_db_arms,)*
2425 }
2426 }
2427
2428 fn alias(&self) -> &'static str {
2429 match self {
2430 #(#groupby_alias_arms,)*
2431 }
2432 }
2433 }
2434
2435 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2436 pub struct #groupby_result_name {
2437 #(#result_fields,)*
2438 }
2439
2440 #[derive(Debug, Clone, Default)]
2441 pub struct #having_input_name {
2442 #(#having_fields,)*
2443 pub and: Option<Vec<#having_input_name>>,
2444 pub or: Option<Vec<#having_input_name>>,
2445 pub not: Option<Box<#having_input_name>>,
2446 }
2447
2448 impl #having_input_name {
2449 pub(crate) fn build_having<'args, DB: sqlx::Database>(
2450 &self,
2451 qb: &mut sqlx::QueryBuilder<'args, DB>,
2452 )
2453 where
2454 #(#db_bounds,)*
2455 {
2456 #(#having_arms)*
2457
2458 if let Some(conditions) = &self.and {
2459 for c in conditions {
2460 c.build_having(qb);
2461 }
2462 }
2463 if let Some(conditions) = &self.or {
2464 if !conditions.is_empty() {
2465 qb.push(" AND (");
2466 for (i, c) in conditions.iter().enumerate() {
2467 if i > 0 { qb.push(" OR "); }
2468 qb.push("(1=1");
2469 c.build_having(qb);
2470 qb.push(")");
2471 }
2472 qb.push(")");
2473 }
2474 }
2475 if let Some(c) = &self.not {
2476 qb.push(" AND NOT (1=1");
2477 c.build_having(qb);
2478 qb.push(")");
2479 }
2480 }
2481 }
2482
2483 pub struct GroupByQuery<'a> {
2484 client: &'a DatabaseClient,
2485 r#where: filter::#where_input,
2486 group_keys: Vec<#groupby_field_name>,
2487 agg_ops: Vec<(&'static str, &'static str, &'static str)>,
2488 count: bool,
2489 having: Option<#having_input_name>,
2490 }
2491
2492 impl<'a> GroupByQuery<'a> {
2493 pub fn r#where(mut self, r#where: filter::#where_input) -> Self {
2494 self.r#where = r#where;
2495 self
2496 }
2497
2498 #agg_methods
2499
2500 pub fn having(mut self, having: #having_input_name) -> Self {
2501 self.having = Some(having);
2502 self
2503 }
2504
2505 pub async fn exec(self) -> Result<Vec<#groupby_result_name>, FerriormError> {
2506 if self.group_keys.is_empty() {
2507 return Err(FerriormError::Query(
2508 "group_by() requires at least one group key".into(),
2509 ));
2510 }
2511
2512 let mut selections: Vec<String> = self.group_keys
2513 .iter()
2514 .map(|k| format!(r#""{}" as "{}""#, k.db_name(), k.alias()))
2515 .collect();
2516 if self.count {
2517 selections.push(r#"COUNT(*) as "count""#.to_string());
2518 }
2519 for (func, col, alias) in &self.agg_ops {
2520 selections.push(format!(r#"{}("{}") as "{}""#, func, col, alias));
2521 }
2522
2523 let group_by_clause: Vec<String> = self.group_keys
2524 .iter()
2525 .map(|k| format!(r#""{}""#, k.db_name()))
2526 .collect();
2527
2528 let base_sql = format!(
2529 r#"SELECT {} FROM "{}" WHERE 1=1"#,
2530 selections.join(", "),
2531 #table_name,
2532 );
2533
2534 match self.client {
2535 DatabaseClient::Postgres(_) => {
2536 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2537 self.r#where.build_where(&mut qb);
2538 qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2539 if let Some(h) = &self.having {
2540 qb.push(" HAVING 1=1");
2541 h.build_having(&mut qb);
2542 }
2543 self.client.fetch_all_pg(qb).await
2544 }
2545 DatabaseClient::Sqlite(_) => {
2546 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2547 self.r#where.build_where(&mut qb);
2548 qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2549 if let Some(h) = &self.having {
2550 qb.push(" HAVING 1=1");
2551 h.build_having(&mut qb);
2552 }
2553 self.client.fetch_all_sqlite(qb).await
2554 }
2555 }
2556 }
2557 }
2558 }
2559}
2560
2561#[allow(clippy::too_many_lines)]
2564fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2565 let select_name = format_ident!("{}Select", model.name);
2566 let partial_name = format_ident!("{}Partial", model.name);
2567 let _where_input = format_ident!("{}WhereInput", model.name);
2568 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
2569 let order_by_name = format_ident!("{}OrderByInput", model.name);
2570 let table_name = &model.db_name;
2571
2572 let select_fields: Vec<TokenStream> = scalar_fields
2574 .iter()
2575 .map(|f| {
2576 let name = format_ident!("{}", to_snake_case(&f.name));
2577 quote! { pub #name: bool }
2578 })
2579 .collect();
2580
2581 let partial_fields: Vec<TokenStream> = scalar_fields
2584 .iter()
2585 .map(|f| {
2586 let name = format_ident!("{}", to_snake_case(&f.name));
2587 let db_name = &f.db_name;
2588 let base_ty = rust_type_tokens(
2590 &Field {
2591 is_optional: false,
2592 ..(*f).clone()
2593 },
2594 ModuleDepth::TopLevel,
2595 );
2596 let rename = if db_name == &to_snake_case(&f.name) {
2597 quote! {}
2598 } else {
2599 quote! { #[sqlx(rename = #db_name)] }
2600 };
2601 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
2603 })
2604 .collect();
2605
2606 let select_col_arms: Vec<TokenStream> = scalar_fields
2608 .iter()
2609 .map(|f| {
2610 let name = format_ident!("{}", to_snake_case(&f.name));
2611 let db_name = &f.db_name;
2612 let col_expr = format!(r#""{db_name}""#);
2613 quote! {
2614 if select.#name { cols.push(#col_expr); }
2615 }
2616 })
2617 .collect();
2618
2619 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2620
2621 quote! {
2622 #[derive(Debug, Clone, Default)]
2623 pub struct #select_name {
2624 #(#select_fields,)*
2625 }
2626
2627 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2628 #[sqlx(rename_all = "snake_case")]
2629 pub struct #partial_name {
2630 #(#partial_fields,)*
2631 }
2632
2633 fn build_select_columns(select: &#select_name) -> String {
2634 let mut cols = Vec::new();
2635 #(#select_col_arms)*
2636 if cols.is_empty() {
2637 "*".to_string()
2638 } else {
2639 cols.join(", ")
2640 }
2641 }
2642
2643 pub struct FindManySelectQuery<'a> {
2646 client: &'a DatabaseClient,
2647 r#where: filter::#_where_input,
2648 order_by: Vec<order::#order_by_name>,
2649 skip: Option<i64>,
2650 take: Option<i64>,
2651 select: #select_name,
2652 }
2653
2654 impl<'a> FindManySelectQuery<'a> {
2655 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2656 self.order_by.push(order);
2657 self
2658 }
2659
2660 pub fn skip(mut self, n: i64) -> Self {
2661 self.skip = Some(n);
2662 self
2663 }
2664
2665 pub fn take(mut self, n: i64) -> Self {
2666 self.take = Some(n);
2667 self
2668 }
2669
2670 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2671 let cols = build_select_columns(&self.select);
2672 let base_sql = format!(#select_sql_prefix, cols);
2673
2674 match self.client {
2675 DatabaseClient::Postgres(_) => {
2676 let qb = build_select_query::<sqlx::Postgres>(
2677 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2678 );
2679 self.client.fetch_all_pg(qb).await
2680 }
2681 DatabaseClient::Sqlite(_) => {
2682 let qb = build_select_query::<sqlx::Sqlite>(
2683 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2684 );
2685 self.client.fetch_all_sqlite(qb).await
2686 }
2687 }
2688 }
2689 }
2690
2691 pub struct FindUniqueSelectQuery<'a> {
2694 client: &'a DatabaseClient,
2695 r#where: filter::#_where_unique,
2696 select: #select_name,
2697 }
2698
2699 impl<'a> FindUniqueSelectQuery<'a> {
2700 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2701 let cols = build_select_columns(&self.select);
2702 let base_sql = format!(#select_sql_prefix, cols);
2703
2704 match self.client {
2705 DatabaseClient::Postgres(_) => {
2706 let qb = build_unique_select_query::<sqlx::Postgres>(
2707 &base_sql, &self.r#where,
2708 );
2709 self.client.fetch_optional_pg(qb).await
2710 }
2711 DatabaseClient::Sqlite(_) => {
2712 let qb = build_unique_select_query::<sqlx::Sqlite>(
2713 &base_sql, &self.r#where,
2714 );
2715 self.client.fetch_optional_sqlite(qb).await
2716 }
2717 }
2718 }
2719 }
2720
2721 pub struct FindFirstSelectQuery<'a> {
2724 client: &'a DatabaseClient,
2725 r#where: filter::#_where_input,
2726 order_by: Vec<order::#order_by_name>,
2727 select: #select_name,
2728 }
2729
2730 impl<'a> FindFirstSelectQuery<'a> {
2731 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2732 self.order_by.push(order);
2733 self
2734 }
2735
2736 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2737 let cols = build_select_columns(&self.select);
2738 let base_sql = format!(#select_sql_prefix, cols);
2739
2740 match self.client {
2741 DatabaseClient::Postgres(_) => {
2742 let qb = build_select_query::<sqlx::Postgres>(
2743 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2744 );
2745 self.client.fetch_optional_pg(qb).await
2746 }
2747 DatabaseClient::Sqlite(_) => {
2748 let qb = build_select_query::<sqlx::Sqlite>(
2749 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2750 );
2751 self.client.fetch_optional_sqlite(qb).await
2752 }
2753 }
2754 }
2755 }
2756 }
2757}