1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27 let data_struct = gen_data_struct(model, &scalar_fields);
28 let filter_module = gen_filter_module(model, &scalar_fields);
29 let data_module = gen_data_module(model, &scalar_fields);
30 let order_module = gen_order_module(model, &scalar_fields);
31 let actions_struct = gen_actions(model, &scalar_fields);
32 let query_builders = gen_query_builders(model, &scalar_fields);
33 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34 let groupby_types = gen_groupby_types(model, &scalar_fields);
35 let select_types = gen_select_types(model, &scalar_fields);
36
37 quote! {
38 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
39
40 use serde::{Deserialize, Serialize};
41 use ferriorm_runtime::prelude::*;
42 use ferriorm_runtime::prelude::sqlx;
43 use ferriorm_runtime::prelude::chrono;
44 use ferriorm_runtime::prelude::uuid;
45
46 #data_struct
47 #filter_module
48 #data_module
49 #order_module
50 #actions_struct
51 #query_builders
52 #aggregate_types
53 #groupby_types
54 #select_types
55 }
56}
57
58fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
61 let struct_name = format_ident!("{}", model.name);
62 let table_name = &model.db_name;
63
64 let fields: Vec<TokenStream> = scalar_fields
65 .iter()
66 .map(|f| {
67 let name = format_ident!("{}", to_snake_case(&f.name));
68 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
69 let db_name = &f.db_name;
70 if db_name == &to_snake_case(&f.name) {
71 quote! { pub #name: #ty }
72 } else {
73 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
74 }
75 })
76 .collect();
77
78 quote! {
79 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
80 #[sqlx(rename_all = "snake_case")]
81 pub struct #struct_name {
82 #(#fields),*
83 }
84
85 impl #struct_name {
86 pub const TABLE_NAME: &'static str = #table_name;
87 }
88 }
89}
90
91#[allow(clippy::too_many_lines)]
94fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
95 let where_input = format_ident!("{}WhereInput", model.name);
96 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
97
98 let where_fields: Vec<TokenStream> = scalar_fields
99 .iter()
100 .filter_map(|f| {
101 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
102 let name = format_ident!("{}", to_snake_case(&f.name));
103 Some(quote! { pub #name: Option<#filter_ty> })
104 })
105 .collect();
106
107 let single_unique_variants: Vec<TokenStream> = scalar_fields
108 .iter()
109 .filter(|f| f.is_id || f.is_unique)
110 .map(|f| {
111 let variant = format_ident!("{}", to_pascal_case(&f.name));
112 let ty = rust_type_tokens(f, ModuleDepth::Nested);
113 quote! { #variant(#ty) }
114 })
115 .collect();
116
117 let compound_unique_variants: Vec<TokenStream> = model
118 .unique_constraints
119 .iter()
120 .map(|uc| {
121 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
122 let struct_fields = compound_variant_fields(model, &uc.fields);
123 quote! { #variant { #(#struct_fields),* } }
124 })
125 .collect();
126
127 let unique_variants: Vec<TokenStream> = single_unique_variants
128 .into_iter()
129 .chain(compound_unique_variants)
130 .collect();
131
132 let db_bounds = collect_db_bounds(scalar_fields);
134 let where_arms = gen_where_arms(scalar_fields);
135 let unique_arms = gen_unique_where_arms(model, scalar_fields);
136 let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
137 let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
138
139 quote! {
140 pub mod filter {
141 use ferriorm_runtime::prelude::*;
142
143 #[derive(Debug, Clone, Default)]
144 pub struct #where_input {
145 #(#where_fields,)*
146 pub and: Option<Vec<#where_input>>,
147 pub or: Option<Vec<#where_input>>,
148 pub not: Option<Box<#where_input>>,
149 }
150
151 #[derive(Debug, Clone)]
152 pub enum #where_unique {
153 #(#unique_variants),*
154 }
155
156 impl #where_input {
157 pub(crate) fn build_where<'args, DB: sqlx::Database>(
158 &self,
159 qb: &mut sqlx::QueryBuilder<'args, DB>,
160 )
161 where
162 #(#db_bounds,)*
163 {
164 #(#where_arms)*
165
166 if let Some(conditions) = &self.and {
167 for c in conditions {
168 c.build_where(qb);
169 }
170 }
171 if let Some(conditions) = &self.or {
172 if !conditions.is_empty() {
173 qb.push(" AND (");
174 for (i, c) in conditions.iter().enumerate() {
175 if i > 0 { qb.push(" OR "); }
176 qb.push("(1=1");
177 c.build_where(qb);
178 qb.push(")");
179 }
180 qb.push(")");
181 }
182 }
183 if let Some(c) = &self.not {
184 qb.push(" AND NOT (1=1");
185 c.build_where(qb);
186 qb.push(")");
187 }
188 }
189 }
190
191 impl #where_unique {
192 pub(crate) fn build_where<'args, DB: sqlx::Database>(
193 &self,
194 qb: &mut sqlx::QueryBuilder<'args, DB>,
195 )
196 where
197 #(#db_bounds,)*
198 {
199 match self {
200 #(#unique_arms)*
201 }
202 }
203 }
204
205 impl #where_unique {
206 #[allow(dead_code)]
207 pub(crate) fn conflict_target(&self) -> &'static str {
208 match self {
209 #(#conflict_target_arms)*
210 }
211 }
212
213 #[allow(dead_code)]
214 pub(crate) fn first_conflict_col(&self) -> &'static str {
215 match self {
216 #(#first_conflict_col_arms)*
217 }
218 }
219 }
220 }
221 }
222}
223
224fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
226 let mut seen = std::collections::HashSet::new();
227 let mut bounds = Vec::new();
228
229 seen.insert("i64");
231 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
232
233 for f in scalar_fields {
234 match &f.field_type {
235 FieldKind::Scalar(scalar) => {
236 let key = scalar.rust_type();
237 if seen.insert(key)
238 && let Some(ty) = scalar_bound_tokens(scalar)
239 {
240 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
241 bounds.push(
243 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
244 );
245 }
246 }
247 FieldKind::Enum(_) | FieldKind::Model(_) => {}
248 }
249 }
250
251 bounds
252}
253
254fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
255 match scalar {
256 ScalarType::String => Some(quote! { String }),
257 ScalarType::Int => Some(quote! { i32 }),
258 ScalarType::BigInt => Some(quote! { i64 }),
259 ScalarType::Float => Some(quote! { f64 }),
260 ScalarType::Boolean => Some(quote! { bool }),
261 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
262 ScalarType::Bytes => Some(quote! { Vec<u8> }),
263 ScalarType::Json | ScalarType::Decimal => None,
264 }
265}
266
267fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
269 scalar_fields
270 .iter()
271 .filter_map(|f| {
272 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
274 return None;
275 }
276 let field_ident = format_ident!("{}", to_snake_case(&f.name));
277 let db_name = &f.db_name;
278 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
279 let is_comparable = matches!(
280 &f.field_type,
281 FieldKind::Scalar(
282 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
283 )
284 );
285
286 let mut arms = vec![];
287
288 if f.is_optional {
289 arms.push(quote! {
293 if let Some(v) = &filter.equals {
294 match v {
295 None => {
296 qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
297 }
298 Some(inner) => {
299 qb.push(concat!(" AND \"", #db_name, "\" = "));
300 qb.push_bind(inner.clone());
301 }
302 }
303 }
304 if let Some(v) = &filter.not {
305 match v {
306 None => {
307 qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
308 }
309 Some(inner) => {
310 qb.push(concat!(" AND \"", #db_name, "\" != "));
311 qb.push_bind(inner.clone());
312 }
313 }
314 }
315 });
316 } else {
317 arms.push(quote! {
318 if let Some(v) = &filter.equals {
319 qb.push(concat!(" AND \"", #db_name, "\" = "));
320 qb.push_bind(v.clone());
321 }
322 if let Some(v) = &filter.not {
323 qb.push(concat!(" AND \"", #db_name, "\" != "));
324 qb.push_bind(v.clone());
325 }
326 });
327 }
328
329 if is_string {
330 arms.push(quote! {
336 if let Some(v) = &filter.contains {
337 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
338 qb.push_bind(format!("%{}%", ferriorm_runtime::filter::like_escape(v)));
339 qb.push(" ESCAPE '\\'");
340 }
341 if let Some(v) = &filter.starts_with {
342 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
343 qb.push_bind(format!("{}%", ferriorm_runtime::filter::like_escape(v)));
344 qb.push(" ESCAPE '\\'");
345 }
346 if let Some(v) = &filter.ends_with {
347 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
348 qb.push_bind(format!("%{}", ferriorm_runtime::filter::like_escape(v)));
349 qb.push(" ESCAPE '\\'");
350 }
351 });
352 }
353
354 if is_comparable {
355 arms.push(quote! {
356 if let Some(v) = &filter.gt {
357 qb.push(concat!(" AND \"", #db_name, "\" > "));
358 qb.push_bind(v.clone());
359 }
360 if let Some(v) = &filter.gte {
361 qb.push(concat!(" AND \"", #db_name, "\" >= "));
362 qb.push_bind(v.clone());
363 }
364 if let Some(v) = &filter.lt {
365 qb.push(concat!(" AND \"", #db_name, "\" < "));
366 qb.push_bind(v.clone());
367 }
368 if let Some(v) = &filter.lte {
369 qb.push(concat!(" AND \"", #db_name, "\" <= "));
370 qb.push_bind(v.clone());
371 }
372 });
373 }
374
375 Some(quote! {
376 if let Some(filter) = &self.#field_ident {
377 #(#arms)*
378 }
379 })
380 })
381 .collect()
382}
383
384fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
385 let mut arms: Vec<TokenStream> = scalar_fields
386 .iter()
387 .filter(|f| f.is_id || f.is_unique)
388 .map(|f| {
389 let variant = format_ident!("{}", to_pascal_case(&f.name));
390 let db_name = &f.db_name;
391 quote! {
392 Self::#variant(v) => {
393 qb.push(concat!(" AND \"", #db_name, "\" = "));
394 qb.push_bind(v.clone());
395 }
396 }
397 })
398 .collect();
399
400 for uc in &model.unique_constraints {
401 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
402 let idents: Vec<_> = uc
403 .fields
404 .iter()
405 .map(|name| format_ident!("{}", to_snake_case(name)))
406 .collect();
407 let binds: Vec<TokenStream> = uc
408 .fields
409 .iter()
410 .map(|name| {
411 let ident = format_ident!("{}", to_snake_case(name));
412 let db_name = resolve_db_name(model, name);
413 quote! {
414 qb.push(concat!(" AND \"", #db_name, "\" = "));
415 qb.push_bind(#ident.clone());
416 }
417 })
418 .collect();
419 arms.push(quote! {
420 Self::#variant { #(#idents),* } => {
421 #(#binds)*
422 }
423 });
424 }
425
426 arms
427}
428
429fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
430 let mut arms: Vec<TokenStream> = scalar_fields
431 .iter()
432 .filter(|f| f.is_id || f.is_unique)
433 .map(|f| {
434 let variant = format_ident!("{}", to_pascal_case(&f.name));
435 let target = format!("(\"{}\")", f.db_name);
436 quote! { Self::#variant(_) => #target, }
437 })
438 .collect();
439
440 for uc in &model.unique_constraints {
441 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
442 let cols: Vec<String> = uc
443 .fields
444 .iter()
445 .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
446 .collect();
447 let target = format!("({})", cols.join(", "));
448 arms.push(quote! { Self::#variant { .. } => #target, });
449 }
450
451 arms
452}
453
454fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
455 let mut arms: Vec<TokenStream> = scalar_fields
456 .iter()
457 .filter(|f| f.is_id || f.is_unique)
458 .map(|f| {
459 let variant = format_ident!("{}", to_pascal_case(&f.name));
460 let col = format!("\"{}\"", f.db_name);
461 quote! { Self::#variant(_) => #col, }
462 })
463 .collect();
464
465 for uc in &model.unique_constraints {
466 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
467 let first = uc
468 .fields
469 .first()
470 .map_or_else(String::new, |n| resolve_db_name(model, n));
471 let col = format!("\"{first}\"");
472 arms.push(quote! { Self::#variant { .. } => #col, });
473 }
474
475 arms
476}
477
478fn compound_variant_name(fields: &[String]) -> String {
480 fields.iter().map(|f| to_pascal_case(f)).collect()
481}
482
483fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
486 fields
487 .iter()
488 .filter_map(|field_name| {
489 let field = model.fields.iter().find(|f| f.name == *field_name)?;
490 let ident = format_ident!("{}", to_snake_case(field_name));
491 let ty = rust_type_tokens(field, ModuleDepth::Nested);
492 Some(quote! { #ident: #ty })
493 })
494 .collect()
495}
496
497fn resolve_db_name(model: &Model, field_name: &str) -> String {
499 model
500 .fields
501 .iter()
502 .find(|f| f.name == field_name)
503 .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
504}
505
506fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
509 let create_name = format_ident!("{}CreateInput", model.name);
510 let update_name = format_ident!("{}UpdateInput", model.name);
511
512 let required_fields: Vec<TokenStream> = scalar_fields
513 .iter()
514 .filter(|f| !f.has_default() && !f.is_updated_at)
515 .map(|f| {
516 let name = format_ident!("{}", to_snake_case(&f.name));
517 let ty = rust_type_tokens(f, ModuleDepth::Nested);
518 quote! { pub #name: #ty }
519 })
520 .collect();
521
522 let optional_fields: Vec<TokenStream> = scalar_fields
523 .iter()
524 .filter(|f| f.has_default() && !f.is_updated_at)
525 .map(|f| {
526 let name = format_ident!("{}", to_snake_case(&f.name));
527 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
528 quote! { pub #name: Option<#base_ty> }
529 })
530 .collect();
531
532 let update_fields: Vec<TokenStream> = scalar_fields
533 .iter()
534 .filter(|f| !f.is_id && !f.is_updated_at)
535 .map(|f| {
536 let name = format_ident!("{}", to_snake_case(&f.name));
537 let ty = rust_type_tokens(f, ModuleDepth::Nested);
538 quote! { pub #name: Option<SetValue<#ty>> }
539 })
540 .collect();
541
542 quote! {
543 pub mod data {
544 use ferriorm_runtime::prelude::*;
545
546 #[derive(Debug, Clone)]
547 pub struct #create_name {
548 #(#required_fields,)*
549 #(#optional_fields,)*
550 }
551
552 #[derive(Debug, Clone, Default)]
556 pub struct #update_name {
557 #(#update_fields,)*
558 }
559 }
560 }
561}
562
563fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
566 let order_name = format_ident!("{}OrderByInput", model.name);
567
568 let variants: Vec<TokenStream> = scalar_fields
569 .iter()
570 .map(|f| {
571 let variant = format_ident!("{}", to_pascal_case(&f.name));
572 quote! { #variant(SortOrder) }
573 })
574 .collect();
575
576 let order_arms: Vec<TokenStream> = scalar_fields
577 .iter()
578 .map(|f| {
579 let variant = format_ident!("{}", to_pascal_case(&f.name));
580 let db_name = &f.db_name;
581 quote! {
582 Self::#variant(order) => {
583 qb.push(concat!("\"", #db_name, "\" "));
584 qb.push(order.as_sql());
585 }
586 }
587 })
588 .collect();
589
590 quote! {
591 pub mod order {
592 use ferriorm_runtime::prelude::*;
593
594 #[derive(Debug, Clone)]
595 pub enum #order_name {
596 #(#variants),*
597 }
598
599 impl #order_name {
600 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
601 &self,
602 qb: &mut sqlx::QueryBuilder<'args, DB>,
603 ) {
604 match self {
605 #(#order_arms)*
606 }
607 }
608 }
609 }
610 }
611}
612
613fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
616 let _model_ident = format_ident!("{}", model.name);
617 let actions_name = format_ident!("{}Actions", model.name);
618 let where_input = format_ident!("{}WhereInput", model.name);
619 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
620 let create_input = format_ident!("{}CreateInput", model.name);
621 let update_input = format_ident!("{}UpdateInput", model.name);
622 let _order_by = format_ident!("{}OrderByInput", model.name);
623
624 let has_agg_fields = scalar_fields.iter().any(|f| {
626 matches!(
627 &f.field_type,
628 FieldKind::Scalar(
629 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
630 )
631 )
632 });
633 let aggregate_method = if has_agg_fields {
634 quote! {
635 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
636 AggregateQuery { client: self.client, r#where, ops: vec![] }
637 }
638 }
639 } else {
640 quote! {}
641 };
642
643 let has_group_fields = scalar_fields.iter().any(|f| is_groupable(f));
645 let groupby_field_name = format_ident!("{}GroupByField", model.name);
646 let group_by_method = if has_group_fields {
647 quote! {
648 pub fn group_by(&self, keys: Vec<#groupby_field_name>) -> GroupByQuery<'a> {
649 GroupByQuery {
650 client: self.client,
651 r#where: filter::#where_input::default(),
652 group_keys: keys,
653 agg_ops: vec![],
654 count: false,
655 having: None,
656 }
657 }
658 }
659 } else {
660 quote! {}
661 };
662
663 quote! {
664 pub struct #actions_name<'a> {
665 client: &'a DatabaseClient,
666 }
667
668 impl<'a> #actions_name<'a> {
669 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
670
671 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
672 FindUniqueQuery { client: self.client, r#where }
673 }
674
675 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
676 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
677 }
678
679 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
680 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
681 }
682
683 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
684 CreateQuery { client: self.client, data }
685 }
686
687 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
688 UpdateQuery { client: self.client, r#where, data }
689 }
690
691 pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
695 UpdateFirstQuery { client: self.client, r#where, data }
696 }
697
698 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
699 DeleteQuery { client: self.client, r#where }
700 }
701
702 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
703 CountQuery { client: self.client, r#where }
704 }
705
706 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
707 CreateManyQuery { client: self.client, data }
708 }
709
710 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
711 UpdateManyQuery { client: self.client, r#where, data }
712 }
713
714 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
715 DeleteManyQuery { client: self.client, r#where }
716 }
717
718 pub fn upsert(
719 &self,
720 r#where: filter::#where_unique,
721 create: data::#create_input,
722 update: data::#update_input,
723 ) -> UpsertQuery<'a> {
724 UpsertQuery { client: self.client, r#where, create, update }
725 }
726
727 #aggregate_method
728
729 #group_by_method
730 }
731 }
732}
733
734#[allow(clippy::too_many_lines)]
737fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
738 let model_ident = format_ident!("{}", model.name);
739 let table_name = &model.db_name;
740 let _where_input = format_ident!("{}WhereInput", model.name);
741 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
742 let _create_input = format_ident!("{}CreateInput", model.name);
743 let _update_input = format_ident!("{}UpdateInput", model.name);
744 let order_by = format_ident!("{}OrderByInput", model.name);
745 let _select_struct = format_ident!("{}Select", model.name);
746 let _partial_struct = format_ident!("{}Partial", model.name);
747 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
748 let _aggregate_field = format_ident!("{}AggregateField", model.name);
749 let db_bounds = collect_db_bounds(scalar_fields);
750
751 let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
752 let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
753 let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
754
755 let insert_code = gen_insert_code(model, scalar_fields, table_name);
756 let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
757 let update_code = gen_update_code(model, scalar_fields, table_name);
758 let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
759 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
760 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
761
762 quote! {
763 fn build_order_by<'args, DB: sqlx::Database>(
765 orders: &[order::#order_by],
766 qb: &mut sqlx::QueryBuilder<'args, DB>,
767 ) {
768 if !orders.is_empty() {
769 qb.push(" ORDER BY ");
770 for (i, ob) in orders.iter().enumerate() {
771 if i > 0 { qb.push(", "); }
772 ob.build_order_by(qb);
773 }
774 }
775 }
776
777 fn build_select_query<'args, DB: sqlx::Database>(
779 base_sql: &str,
780 where_input: &filter::#_where_input,
781 orders: &[order::#order_by],
782 take: Option<i64>,
783 skip: Option<i64>,
784 ) -> sqlx::QueryBuilder<'args, DB>
785 where
786 #(#db_bounds,)*
787 {
788 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
789 where_input.build_where(&mut qb);
790 build_order_by(orders, &mut qb);
791 if let Some(take) = take {
792 qb.push(" LIMIT ");
793 qb.push_bind(take);
794 }
795 if let Some(skip) = skip {
796 qb.push(" OFFSET ");
797 qb.push_bind(skip);
798 }
799 qb
800 }
801
802 fn build_unique_select_query<'args, DB: sqlx::Database>(
804 base_sql: &str,
805 where_unique: &filter::#_where_unique,
806 ) -> sqlx::QueryBuilder<'args, DB>
807 where
808 #(#db_bounds,)*
809 {
810 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
811 where_unique.build_where(&mut qb);
812 qb.push(" LIMIT 1");
813 qb
814 }
815
816 fn build_delete_query<'args, DB: sqlx::Database>(
818 base_sql: &str,
819 where_unique: &filter::#_where_unique,
820 ) -> sqlx::QueryBuilder<'args, DB>
821 where
822 #(#db_bounds,)*
823 {
824 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
825 where_unique.build_where(&mut qb);
826 qb.push(" RETURNING *");
827 qb
828 }
829
830 fn build_count_query<'args, DB: sqlx::Database>(
832 base_sql: &str,
833 where_input: &filter::#_where_input,
834 ) -> sqlx::QueryBuilder<'args, DB>
835 where
836 #(#db_bounds,)*
837 {
838 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
839 where_input.build_where(&mut qb);
840 qb
841 }
842
843 fn build_delete_many_query<'args, DB: sqlx::Database>(
845 base_sql: &str,
846 where_input: &filter::#_where_input,
847 ) -> sqlx::QueryBuilder<'args, DB>
848 where
849 #(#db_bounds,)*
850 {
851 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
852 where_input.build_where(&mut qb);
853 qb
854 }
855
856 pub struct FindUniqueQuery<'a> {
857 client: &'a DatabaseClient,
858 r#where: filter::#_where_unique,
859 }
860
861 impl<'a> FindUniqueQuery<'a> {
862 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
863 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
864 }
865
866 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
867 match self.client {
868 DatabaseClient::Postgres(_) => {
869 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
870 self.client.fetch_optional_pg(qb).await
871 }
872 DatabaseClient::Sqlite(_) => {
873 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
874 self.client.fetch_optional_sqlite(qb).await
875 }
876 }
877 }
878 }
879
880 pub struct FindFirstQuery<'a> {
881 client: &'a DatabaseClient,
882 r#where: filter::#_where_input,
883 order_by: Vec<order::#order_by>,
884 }
885
886 impl<'a> FindFirstQuery<'a> {
887 pub fn order_by(mut self, order: order::#order_by) -> Self {
888 self.order_by.push(order);
889 self
890 }
891
892 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
893 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
894 }
895
896 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
897 match self.client {
898 DatabaseClient::Postgres(_) => {
899 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
900 self.client.fetch_optional_pg(qb).await
901 }
902 DatabaseClient::Sqlite(_) => {
903 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
904 self.client.fetch_optional_sqlite(qb).await
905 }
906 }
907 }
908 }
909
910 pub struct FindManyQuery<'a> {
911 client: &'a DatabaseClient,
912 r#where: filter::#_where_input,
913 order_by: Vec<order::#order_by>,
914 skip: Option<i64>,
915 take: Option<i64>,
916 }
917
918 impl<'a> FindManyQuery<'a> {
919 pub fn order_by(mut self, order: order::#order_by) -> Self {
920 self.order_by.push(order);
921 self
922 }
923
924 pub fn skip(mut self, n: i64) -> Self {
925 self.skip = Some(n);
926 self
927 }
928
929 pub fn take(mut self, n: i64) -> Self {
930 self.take = Some(n);
931 self
932 }
933
934 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
935 FindManySelectQuery {
936 client: self.client,
937 r#where: self.r#where,
938 order_by: self.order_by,
939 skip: self.skip,
940 take: self.take,
941 select,
942 }
943 }
944
945 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
946 match self.client {
947 DatabaseClient::Postgres(_) => {
948 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
949 self.client.fetch_all_pg(qb).await
950 }
951 DatabaseClient::Sqlite(_) => {
952 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
953 self.client.fetch_all_sqlite(qb).await
954 }
955 }
956 }
957 }
958
959 pub struct CreateQuery<'a> {
960 client: &'a DatabaseClient,
961 data: data::#_create_input,
962 }
963
964 impl<'a> CreateQuery<'a> {
965 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
966 let client = self.client;
967 #insert_code
968 }
969
970 pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
974 CreateIgnoreQuery { client: self.client, data: self.data }
975 }
976 }
977
978 pub struct CreateIgnoreQuery<'a> {
979 client: &'a DatabaseClient,
980 data: data::#_create_input,
981 }
982
983 impl<'a> CreateIgnoreQuery<'a> {
984 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
985 let client = self.client;
986 #insert_ignore_code
987 }
988 }
989
990 pub struct UpdateQuery<'a> {
991 client: &'a DatabaseClient,
992 r#where: filter::#_where_unique,
993 data: data::#_update_input,
994 }
995
996 impl<'a> UpdateQuery<'a> {
997 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
998 let client = self.client;
999 #update_code
1000 }
1001 }
1002
1003 pub struct UpdateFirstQuery<'a> {
1004 client: &'a DatabaseClient,
1005 r#where: filter::#_where_input,
1006 data: data::#_update_input,
1007 }
1008
1009 impl<'a> UpdateFirstQuery<'a> {
1010 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
1011 let client = self.client;
1012 #update_first_code
1013 }
1014 }
1015
1016 pub struct DeleteQuery<'a> {
1017 client: &'a DatabaseClient,
1018 r#where: filter::#_where_unique,
1019 }
1020
1021 impl<'a> DeleteQuery<'a> {
1022 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1023 match self.client {
1024 DatabaseClient::Postgres(_) => {
1025 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1026 self.client.fetch_one_pg(qb).await
1027 }
1028 DatabaseClient::Sqlite(_) => {
1029 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1030 self.client.fetch_one_sqlite(qb).await
1031 }
1032 }
1033 }
1034 }
1035
1036 #[derive(sqlx::FromRow)]
1037 struct CountResult { count: i64 }
1038
1039 pub struct CountQuery<'a> {
1040 client: &'a DatabaseClient,
1041 r#where: filter::#_where_input,
1042 }
1043
1044 impl<'a> CountQuery<'a> {
1045 pub async fn exec(self) -> Result<i64, FerriormError> {
1046 let row: CountResult = match self.client {
1047 DatabaseClient::Postgres(_) => {
1048 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1049 self.client.fetch_one_pg(qb).await?
1050 }
1051 DatabaseClient::Sqlite(_) => {
1052 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1053 self.client.fetch_one_sqlite(qb).await?
1054 }
1055 };
1056 Ok(row.count)
1057 }
1058 }
1059
1060 pub struct CreateManyQuery<'a> {
1061 client: &'a DatabaseClient,
1062 data: Vec<data::#_create_input>,
1063 }
1064
1065 impl<'a> CreateManyQuery<'a> {
1066 pub async fn exec(self) -> Result<u64, FerriormError> {
1067 if self.data.is_empty() { return Ok(0); }
1068 let count = self.data.len() as u64;
1069 for item in self.data {
1070 CreateQuery { client: self.client, data: item }.exec().await?;
1071 }
1072 Ok(count)
1073 }
1074 }
1075
1076 pub struct UpdateManyQuery<'a> {
1077 client: &'a DatabaseClient,
1078 r#where: filter::#_where_input,
1079 data: data::#_update_input,
1080 }
1081
1082 impl<'a> UpdateManyQuery<'a> {
1083 pub async fn exec(self) -> Result<u64, FerriormError> {
1084 let client = self.client;
1085 #update_many_code
1086 }
1087 }
1088
1089 pub struct UpsertQuery<'a> {
1090 client: &'a DatabaseClient,
1091 r#where: filter::#_where_unique,
1092 create: data::#_create_input,
1093 update: data::#_update_input,
1094 }
1095
1096 impl<'a> UpsertQuery<'a> {
1097 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1098 let client = self.client;
1099 #upsert_code
1100 }
1101 }
1102
1103 pub struct DeleteManyQuery<'a> {
1104 client: &'a DatabaseClient,
1105 r#where: filter::#_where_input,
1106 }
1107
1108 impl<'a> DeleteManyQuery<'a> {
1109 pub async fn exec(self) -> Result<u64, FerriormError> {
1110 match self.client {
1111 DatabaseClient::Postgres(_) => {
1112 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1113 self.client.execute_pg(qb).await
1114 }
1115 DatabaseClient::Sqlite(_) => {
1116 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1117 self.client.execute_sqlite(qb).await
1118 }
1119 }
1120 }
1121 }
1122 }
1123}
1124
1125fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1128 let _model_ident = format_ident!("{}", model.name);
1129
1130 let required: Vec<&Field> = scalar_fields
1132 .iter()
1133 .copied()
1134 .filter(|f| !f.has_default() && !f.is_updated_at)
1135 .collect();
1136
1137 let optional: Vec<&Field> = scalar_fields
1139 .iter()
1140 .copied()
1141 .filter(|f| f.has_default() && !f.is_updated_at)
1142 .collect();
1143
1144 let updated_at: Vec<&Field> = scalar_fields
1146 .iter()
1147 .copied()
1148 .filter(|f| f.is_updated_at)
1149 .collect();
1150
1151 let mut col_pushes = vec![];
1153 let mut val_pushes = vec![];
1154
1155 for f in &required {
1157 let db_name = &f.db_name;
1158 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1159 col_pushes.push(quote! { cols.push(#db_name); });
1160 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1161 }
1162
1163 for f in &optional {
1165 let db_name = &f.db_name;
1166 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1167 if is_autoincrement(f) {
1168 col_pushes.push(quote! {
1172 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1173 });
1174 val_pushes.push(quote! {
1175 if let Some(val) = self.data.#field_ident {
1176 sep.push_bind(val);
1177 }
1178 });
1179 } else {
1180 let default_expr = gen_default_expr(f, &f.field_type);
1181 col_pushes.push(quote! { cols.push(#db_name); });
1182 val_pushes.push(quote! {
1183 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1184 sep.push_bind(val);
1185 });
1186 }
1187 }
1188
1189 for f in &updated_at {
1191 let db_name = &f.db_name;
1192 col_pushes.push(quote! { cols.push(#db_name); });
1193 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1194 }
1195
1196 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1197
1198 quote! {
1201 macro_rules! build_insert {
1203 ($qb_type:ty) => {{
1204 let mut cols: Vec<&str> = Vec::new();
1205 #(#col_pushes)*
1206
1207 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1208 qb.push(" (");
1209 for (i, col) in cols.iter().enumerate() {
1210 if i > 0 { qb.push(", "); }
1211 qb.push("\"");
1212 qb.push(*col);
1213 qb.push("\"");
1214 }
1215 qb.push(") VALUES (");
1216 {
1217 let mut sep = qb.separated(", ");
1218 #(#val_pushes)*
1219 }
1220 qb.push(") RETURNING *");
1221 qb
1222 }};
1223 }
1224
1225 match client {
1226 DatabaseClient::Postgres(_) => {
1227 let qb = build_insert!(sqlx::Postgres);
1228 client.fetch_one_pg(qb).await
1229 }
1230 DatabaseClient::Sqlite(_) => {
1231 let qb = build_insert!(sqlx::Sqlite);
1232 client.fetch_one_sqlite(qb).await
1233 }
1234 }
1235 }
1236}
1237
1238fn gen_insert_ignore_code(
1239 _model: &Model,
1240 scalar_fields: &[&Field],
1241 table_name: &str,
1242) -> TokenStream {
1243 let required: Vec<&Field> = scalar_fields
1244 .iter()
1245 .copied()
1246 .filter(|f| !f.has_default() && !f.is_updated_at)
1247 .collect();
1248 let optional: Vec<&Field> = scalar_fields
1249 .iter()
1250 .copied()
1251 .filter(|f| f.has_default() && !f.is_updated_at)
1252 .collect();
1253 let updated_at: Vec<&Field> = scalar_fields
1254 .iter()
1255 .copied()
1256 .filter(|f| f.is_updated_at)
1257 .collect();
1258
1259 let mut col_pushes = vec![];
1260 let mut val_pushes = vec![];
1261
1262 for f in &required {
1263 let db_name = &f.db_name;
1264 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1265 col_pushes.push(quote! { cols.push(#db_name); });
1266 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1267 }
1268 for f in &optional {
1269 let db_name = &f.db_name;
1270 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1271 if is_autoincrement(f) {
1272 col_pushes.push(quote! {
1273 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1274 });
1275 val_pushes.push(quote! {
1276 if let Some(val) = self.data.#field_ident {
1277 sep.push_bind(val);
1278 }
1279 });
1280 } else {
1281 let default_expr = gen_default_expr(f, &f.field_type);
1282 col_pushes.push(quote! { cols.push(#db_name); });
1283 val_pushes.push(quote! {
1284 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1285 sep.push_bind(val);
1286 });
1287 }
1288 }
1289 for f in &updated_at {
1290 let db_name = &f.db_name;
1291 col_pushes.push(quote! { cols.push(#db_name); });
1292 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1293 }
1294
1295 let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1296 let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1297
1298 quote! {
1299 macro_rules! build_insert_ignore {
1300 ($qb_type:ty, $head:expr, $tail:expr) => {{
1301 let mut cols: Vec<&str> = Vec::new();
1302 #(#col_pushes)*
1303
1304 let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1305 qb.push(" (");
1306 for (i, col) in cols.iter().enumerate() {
1307 if i > 0 { qb.push(", "); }
1308 qb.push("\"");
1309 qb.push(*col);
1310 qb.push("\"");
1311 }
1312 qb.push(") VALUES (");
1313 {
1314 let mut sep = qb.separated(", ");
1315 #(#val_pushes)*
1316 }
1317 qb.push(")");
1318 qb.push($tail);
1319 qb.push(" RETURNING *");
1320 qb
1321 }};
1322 }
1323
1324 match client {
1325 DatabaseClient::Postgres(_) => {
1326 let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1327 client.fetch_optional_pg(qb).await
1328 }
1329 DatabaseClient::Sqlite(_) => {
1330 let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1331 client.fetch_optional_sqlite(qb).await
1332 }
1333 }
1334 }
1335}
1336
1337fn is_autoincrement(field: &Field) -> bool {
1341 matches!(
1342 field.default,
1343 Some(ferriorm_core::ast::DefaultValue::AutoIncrement)
1344 )
1345}
1346
1347fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1349 use ferriorm_core::ast::DefaultValue;
1350
1351 match &field.default {
1352 Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1353 quote! { uuid::Uuid::new_v4().to_string() }
1354 }
1355 Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1356 Some(DefaultValue::AutoIncrement) => quote! { 0i32 },
1361 Some(DefaultValue::Literal(lit)) => {
1362 use ferriorm_core::ast::LiteralValue;
1363 match lit {
1364 LiteralValue::String(s) => quote! { #s.to_string() },
1365 LiteralValue::Int(i) => {
1366 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1368 match field_type {
1369 FieldKind::Scalar(ScalarType::Float) => {
1370 let val = *i as f64;
1371 quote! { #val }
1372 }
1373 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1374 FieldKind::Scalar(ScalarType::Int)
1376 if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1377 {
1378 quote! { #i }
1379 }
1380 _ => {
1381 let val = *i as i32;
1383 quote! { #val }
1384 }
1385 }
1386 }
1387 LiteralValue::Float(f) => quote! { #f },
1388 LiteralValue::Bool(b) => quote! { #b },
1389 }
1390 }
1391 Some(DefaultValue::EnumVariant(v)) => {
1392 let variant = format_ident!("{}", v);
1394 if let FieldKind::Enum(enum_name) = &field.field_type {
1395 let enum_ident = format_ident!("{}", enum_name);
1396 quote! { super::enums::#enum_ident::#variant }
1397 } else {
1398 quote! { Default::default() }
1399 }
1400 }
1401 None => quote! { Default::default() },
1402 }
1403}
1404
1405fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1408 let _model_ident = format_ident!("{}", model.name);
1409
1410 let updatable: Vec<&Field> = scalar_fields
1412 .iter()
1413 .copied()
1414 .filter(|f| !f.is_id && !f.is_updated_at)
1415 .collect();
1416
1417 let updated_at: Vec<&Field> = scalar_fields
1418 .iter()
1419 .copied()
1420 .filter(|f| f.is_updated_at)
1421 .collect();
1422
1423 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1424
1425 let set_arms: Vec<TokenStream> = updatable
1427 .iter()
1428 .map(|f| {
1429 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1430 let db_name = &f.db_name;
1431 quote! {
1432 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1433 if !first_set { qb.push(", "); }
1434 first_set = false;
1435 qb.push(concat!("\"", #db_name, "\" = "));
1436 qb.push_bind(v);
1437 }
1438 }
1439 })
1440 .collect();
1441
1442 let updated_at_arms: Vec<TokenStream> = updated_at
1443 .iter()
1444 .map(|f| {
1445 let db_name = &f.db_name;
1446 quote! {
1447 if !first_set { qb.push(", "); }
1448 first_set = false;
1449 qb.push(concat!("\"", #db_name, "\" = "));
1450 qb.push_bind(chrono::Utc::now());
1451 }
1452 })
1453 .collect();
1454
1455 quote! {
1458 macro_rules! build_update {
1459 ($qb_type:ty) => {{
1460 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1461 let mut first_set = true;
1462 #(#set_arms)*
1463 #(#updated_at_arms)*
1464
1465 if first_set {
1466 return Err(FerriormError::Query("No fields to update".into()));
1467 }
1468
1469 qb.push(" WHERE 1=1");
1470 self.r#where.build_where(&mut qb);
1471 qb.push(" RETURNING *");
1472 qb
1473 }};
1474 }
1475
1476 match client {
1477 DatabaseClient::Postgres(_) => {
1478 let qb = build_update!(sqlx::Postgres);
1479 client.fetch_one_pg(qb).await
1480 }
1481 DatabaseClient::Sqlite(_) => {
1482 let qb = build_update!(sqlx::Sqlite);
1483 client.fetch_one_sqlite(qb).await
1484 }
1485 }
1486 }
1487}
1488
1489fn gen_update_first_code(
1492 _model: &Model,
1493 scalar_fields: &[&Field],
1494 table_name: &str,
1495) -> TokenStream {
1496 let updatable: Vec<&Field> = scalar_fields
1497 .iter()
1498 .copied()
1499 .filter(|f| !f.is_id && !f.is_updated_at)
1500 .collect();
1501
1502 let updated_at: Vec<&Field> = scalar_fields
1503 .iter()
1504 .copied()
1505 .filter(|f| f.is_updated_at)
1506 .collect();
1507
1508 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1509
1510 let set_arms: Vec<TokenStream> = updatable
1511 .iter()
1512 .map(|f| {
1513 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1514 let db_name = &f.db_name;
1515 quote! {
1516 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1517 if !first_set { qb.push(", "); }
1518 first_set = false;
1519 qb.push(concat!("\"", #db_name, "\" = "));
1520 qb.push_bind(v);
1521 }
1522 }
1523 })
1524 .collect();
1525
1526 let updated_at_arms: Vec<TokenStream> = updated_at
1527 .iter()
1528 .map(|f| {
1529 let db_name = &f.db_name;
1530 quote! {
1531 if !first_set { qb.push(", "); }
1532 first_set = false;
1533 qb.push(concat!("\"", #db_name, "\" = "));
1534 qb.push_bind(chrono::Utc::now());
1535 }
1536 })
1537 .collect();
1538
1539 quote! {
1540 macro_rules! build_update_first {
1541 ($qb_type:ty) => {{
1542 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1543 let mut first_set = true;
1544 #(#set_arms)*
1545 #(#updated_at_arms)*
1546
1547 if first_set {
1548 return Err(FerriormError::Query("No fields to update".into()));
1549 }
1550
1551 qb.push(" WHERE 1=1");
1552 self.r#where.build_where(&mut qb);
1553 qb.push(" RETURNING *");
1554 qb
1555 }};
1556 }
1557
1558 match client {
1559 DatabaseClient::Postgres(_) => {
1560 let qb = build_update_first!(sqlx::Postgres);
1561 client.fetch_optional_pg(qb).await
1562 }
1563 DatabaseClient::Sqlite(_) => {
1564 let qb = build_update_first!(sqlx::Sqlite);
1565 client.fetch_optional_sqlite(qb).await
1566 }
1567 }
1568 }
1569}
1570
1571fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1574 let updatable: Vec<&Field> = scalar_fields
1576 .iter()
1577 .copied()
1578 .filter(|f| !f.is_id && !f.is_updated_at)
1579 .collect();
1580
1581 let updated_at: Vec<&Field> = scalar_fields
1582 .iter()
1583 .copied()
1584 .filter(|f| f.is_updated_at)
1585 .collect();
1586
1587 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1588
1589 let set_arms: Vec<TokenStream> = updatable
1591 .iter()
1592 .map(|f| {
1593 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1594 let db_name = &f.db_name;
1595 quote! {
1596 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1597 if !first_set { qb.push(", "); }
1598 first_set = false;
1599 qb.push(concat!("\"", #db_name, "\" = "));
1600 qb.push_bind(v);
1601 }
1602 }
1603 })
1604 .collect();
1605
1606 let updated_at_arms: Vec<TokenStream> = updated_at
1607 .iter()
1608 .map(|f| {
1609 let db_name = &f.db_name;
1610 quote! {
1611 if !first_set { qb.push(", "); }
1612 first_set = false;
1613 qb.push(concat!("\"", #db_name, "\" = "));
1614 qb.push_bind(chrono::Utc::now());
1615 }
1616 })
1617 .collect();
1618
1619 quote! {
1620 macro_rules! build_update_many {
1621 ($qb_type:ty) => {{
1622 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1623 let mut first_set = true;
1624 #(#set_arms)*
1625 #(#updated_at_arms)*
1626
1627 if first_set {
1628 return Ok(0);
1629 }
1630
1631 qb.push(" WHERE 1=1");
1632 self.r#where.build_where(&mut qb);
1633 qb
1634 }};
1635 }
1636
1637 match client {
1638 DatabaseClient::Postgres(_) => {
1639 let qb = build_update_many!(sqlx::Postgres);
1640 client.execute_pg(qb).await
1641 }
1642 DatabaseClient::Sqlite(_) => {
1643 let qb = build_update_many!(sqlx::Sqlite);
1644 client.execute_sqlite(qb).await
1645 }
1646 }
1647 }
1648}
1649
1650enum AggregateKind {
1654 Numeric,
1656 DateTime,
1658}
1659
1660#[allow(clippy::too_many_lines)]
1663fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1664 let required: Vec<&Field> = scalar_fields
1666 .iter()
1667 .copied()
1668 .filter(|f| !f.has_default() && !f.is_updated_at)
1669 .collect();
1670 let optional: Vec<&Field> = scalar_fields
1671 .iter()
1672 .copied()
1673 .filter(|f| f.has_default() && !f.is_updated_at)
1674 .collect();
1675 let updated_at: Vec<&Field> = scalar_fields
1676 .iter()
1677 .copied()
1678 .filter(|f| f.is_updated_at)
1679 .collect();
1680
1681 let mut col_pushes = vec![];
1682 let mut val_pushes = vec![];
1683
1684 for f in &required {
1685 let db_name = &f.db_name;
1686 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1687 col_pushes.push(quote! { cols.push(#db_name); });
1688 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1689 }
1690 for f in &optional {
1691 let db_name = &f.db_name;
1692 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1693 if is_autoincrement(f) {
1694 col_pushes.push(quote! {
1695 if self.create.#field_ident.is_some() { cols.push(#db_name); }
1696 });
1697 val_pushes.push(quote! {
1698 if let Some(val) = self.create.#field_ident {
1699 sep.push_bind(val);
1700 }
1701 });
1702 } else {
1703 let default_expr = gen_default_expr(f, &f.field_type);
1704 col_pushes.push(quote! { cols.push(#db_name); });
1705 val_pushes.push(quote! {
1706 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1707 sep.push_bind(val);
1708 });
1709 }
1710 }
1711 for f in &updated_at {
1712 let db_name = &f.db_name;
1713 col_pushes.push(quote! { cols.push(#db_name); });
1714 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1715 }
1716
1717 let updatable: Vec<&Field> = scalar_fields
1719 .iter()
1720 .copied()
1721 .filter(|f| !f.is_id && !f.is_updated_at)
1722 .collect();
1723
1724 let set_arms: Vec<TokenStream> = updatable
1725 .iter()
1726 .map(|f| {
1727 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1728 let db_name = &f.db_name;
1729 quote! {
1730 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1731 if !first_set { qb.push(", "); }
1732 first_set = false;
1733 qb.push(concat!("\"", #db_name, "\" = "));
1734 qb.push_bind(v);
1735 }
1736 }
1737 })
1738 .collect();
1739
1740 let updated_at_set: Vec<TokenStream> = updated_at
1741 .iter()
1742 .map(|f| {
1743 let db_name = &f.db_name;
1744 quote! {
1745 if !first_set { qb.push(", "); }
1746 first_set = false;
1747 qb.push(concat!("\"", #db_name, "\" = "));
1748 qb.push_bind(chrono::Utc::now());
1749 }
1750 })
1751 .collect();
1752
1753 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1754
1755 quote! {
1756 let conflict_target = self.r#where.conflict_target();
1757 let first_conflict_col = self.r#where.first_conflict_col();
1758
1759 macro_rules! build_upsert {
1760 ($qb_type:ty) => {{
1761 let mut cols: Vec<&str> = Vec::new();
1762 #(#col_pushes)*
1763
1764 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1765 qb.push(" (");
1766 for (i, col) in cols.iter().enumerate() {
1767 if i > 0 { qb.push(", "); }
1768 qb.push("\"");
1769 qb.push(*col);
1770 qb.push("\"");
1771 }
1772 qb.push(") VALUES (");
1773 {
1774 let mut sep = qb.separated(", ");
1775 #(#val_pushes)*
1776 }
1777 qb.push(")");
1778 qb.push(" ON CONFLICT ");
1779 qb.push(conflict_target);
1780 qb.push(" DO UPDATE SET ");
1781
1782 let mut first_set = true;
1783 #(#set_arms)*
1784 #(#updated_at_set)*
1785
1786 if first_set {
1787 qb.push(first_conflict_col);
1790 qb.push(" = ");
1791 qb.push(first_conflict_col);
1792 }
1793
1794 qb.push(" RETURNING *");
1795 qb
1796 }};
1797 }
1798
1799 match client {
1800 DatabaseClient::Postgres(_) => {
1801 let qb = build_upsert!(sqlx::Postgres);
1802 client.fetch_one_pg(qb).await
1803 }
1804 DatabaseClient::Sqlite(_) => {
1805 let qb = build_upsert!(sqlx::Sqlite);
1806 client.fetch_one_sqlite(qb).await
1807 }
1808 }
1809 }
1810}
1811
1812#[allow(clippy::too_many_lines)]
1813fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1814 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1815 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1816 let _where_input = format_ident!("{}WhereInput", model.name);
1817 let table_name = &model.db_name;
1818
1819 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1821 .iter()
1822 .filter_map(|f| match &f.field_type {
1823 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1824 Some((*f, AggregateKind::Numeric))
1825 }
1826 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1827 _ => None,
1828 })
1829 .collect();
1830
1831 if agg_fields.is_empty() {
1832 return quote! {};
1833 }
1834
1835 let enum_variants: Vec<TokenStream> = agg_fields
1837 .iter()
1838 .map(|(f, _)| {
1839 let variant = format_ident!("{}", to_pascal_case(&f.name));
1840 quote! { #variant }
1841 })
1842 .collect();
1843
1844 let db_name_arms: Vec<TokenStream> = agg_fields
1846 .iter()
1847 .map(|(f, _)| {
1848 let variant = format_ident!("{}", to_pascal_case(&f.name));
1849 let db_name = &f.db_name;
1850 quote! { Self::#variant => #db_name }
1851 })
1852 .collect();
1853
1854 let mut result_fields = Vec::new();
1856 for (f, kind) in &agg_fields {
1857 let snake = to_snake_case(&f.name);
1858 let orig_ty = rust_type_tokens(
1859 &Field {
1860 is_optional: false,
1861 ..(*f).clone()
1862 },
1863 ModuleDepth::TopLevel,
1864 );
1865
1866 match kind {
1867 AggregateKind::Numeric => {
1868 let avg_name = format_ident!("avg_{}", snake);
1869 let sum_name = format_ident!("sum_{}", snake);
1870 let min_name = format_ident!("min_{}", snake);
1871 let max_name = format_ident!("max_{}", snake);
1872 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1873 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1874 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1875 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1876 }
1877 AggregateKind::DateTime => {
1878 let min_name = format_ident!("min_{}", snake);
1879 let max_name = format_ident!("max_{}", snake);
1880 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1881 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1882 }
1883 }
1884 }
1885
1886 let numeric_arms: Vec<TokenStream> = agg_fields
1888 .iter()
1889 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1890 .map(|(f, _)| {
1891 let variant = format_ident!("{}", to_pascal_case(&f.name));
1892 quote! { Self::#variant => true }
1893 })
1894 .collect();
1895
1896 let has_numeric = !numeric_arms.is_empty();
1897 let is_numeric_method = if has_numeric {
1898 quote! {
1899 fn is_numeric(&self) -> bool {
1900 match self {
1901 #(#numeric_arms,)*
1902 #[allow(unreachable_patterns)]
1903 _ => false,
1904 }
1905 }
1906 }
1907 } else {
1908 quote! {
1909 fn is_numeric(&self) -> bool { false }
1910 }
1911 };
1912
1913 let mut alias_arms = Vec::new();
1915 for (f, kind) in &agg_fields {
1916 let variant = format_ident!("{}", to_pascal_case(&f.name));
1917 let snake = to_snake_case(&f.name);
1918 let prefixes = match kind {
1919 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1920 AggregateKind::DateTime => vec!["min", "max"],
1921 };
1922 for prefix in prefixes {
1923 let alias_str = format!("{prefix}_{snake}");
1924 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1925 }
1926 }
1927
1928 let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1929
1930 quote! {
1931 #[derive(Debug, Clone, Copy)]
1932 pub enum #aggregate_field_name {
1933 #(#enum_variants),*
1934 }
1935
1936 impl #aggregate_field_name {
1937 pub fn db_name(&self) -> &'static str {
1938 match self {
1939 #(#db_name_arms,)*
1940 }
1941 }
1942
1943 fn alias(&self, prefix: &'static str) -> &'static str {
1944 match (prefix, self) {
1945 #(#alias_arms,)*
1946 _ => unreachable!(),
1947 }
1948 }
1949
1950 #is_numeric_method
1951 }
1952
1953 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1954 pub struct #aggregate_result_name {
1955 #(#result_fields,)*
1956 }
1957
1958 pub struct AggregateQuery<'a> {
1959 client: &'a DatabaseClient,
1960 r#where: filter::#_where_input,
1961 ops: Vec<(&'static str, &'static str, &'static str)>,
1962 }
1963
1964 impl<'a> AggregateQuery<'a> {
1965 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1966 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1967 let db_name = field.db_name();
1968 let alias = field.alias("avg");
1969 self.ops.push(("AVG", db_name, alias));
1970 self
1971 }
1972
1973 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1974 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1975 let db_name = field.db_name();
1976 let alias = field.alias("sum");
1977 self.ops.push(("SUM", db_name, alias));
1978 self
1979 }
1980
1981 pub fn min(mut self, field: #aggregate_field_name) -> Self {
1982 let db_name = field.db_name();
1983 let alias = field.alias("min");
1984 self.ops.push(("MIN", db_name, alias));
1985 self
1986 }
1987
1988 pub fn max(mut self, field: #aggregate_field_name) -> Self {
1989 let db_name = field.db_name();
1990 let alias = field.alias("max");
1991 self.ops.push(("MAX", db_name, alias));
1992 self
1993 }
1994
1995 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1996 if self.ops.is_empty() {
1997 return Err(FerriormError::Query("No aggregate operations specified".into()));
1998 }
1999
2000 let selections: Vec<String> = self.ops.iter()
2001 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
2002 .collect();
2003 let select_clause = selections.join(", ");
2004 let base_sql = format!(#agg_select_base, select_clause);
2005
2006 match self.client {
2007 DatabaseClient::Postgres(_) => {
2008 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2009 self.r#where.build_where(&mut qb);
2010 self.client.fetch_one_pg(qb).await
2011 }
2012 DatabaseClient::Sqlite(_) => {
2013 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2014 self.r#where.build_where(&mut qb);
2015 self.client.fetch_one_sqlite(qb).await
2016 }
2017 }
2018 }
2019 }
2020 }
2021}
2022
2023fn gen_having_comparable_arms(field_ident: &proc_macro2::Ident, lhs: &str) -> TokenStream {
2029 let eq = format!(" AND {lhs} = ");
2030 let ne = format!(" AND {lhs} != ");
2031 let gt = format!(" AND {lhs} > ");
2032 let gte = format!(" AND {lhs} >= ");
2033 let lt = format!(" AND {lhs} < ");
2034 let lte = format!(" AND {lhs} <= ");
2035 quote! {
2036 if let Some(filter) = &self.#field_ident {
2037 if let Some(v) = &filter.equals { qb.push(#eq); qb.push_bind(v.clone()); }
2038 if let Some(v) = &filter.not { qb.push(#ne); qb.push_bind(v.clone()); }
2039 if let Some(v) = &filter.gt { qb.push(#gt); qb.push_bind(v.clone()); }
2040 if let Some(v) = &filter.gte { qb.push(#gte); qb.push_bind(v.clone()); }
2041 if let Some(v) = &filter.lt { qb.push(#lt); qb.push_bind(v.clone()); }
2042 if let Some(v) = &filter.lte { qb.push(#lte); qb.push_bind(v.clone()); }
2043 }
2044 }
2045}
2046
2047fn is_groupable(field: &Field) -> bool {
2052 match &field.field_type {
2053 FieldKind::Scalar(
2054 ScalarType::String
2055 | ScalarType::Int
2056 | ScalarType::BigInt
2057 | ScalarType::Float
2058 | ScalarType::Boolean
2059 | ScalarType::DateTime,
2060 )
2061 | FieldKind::Enum(_) => true,
2062 FieldKind::Scalar(ScalarType::Json | ScalarType::Bytes | ScalarType::Decimal)
2063 | FieldKind::Model(_) => false,
2064 }
2065}
2066
2067#[allow(clippy::too_many_lines)]
2068fn gen_groupby_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2069 let groupby_field_name = format_ident!("{}GroupByField", model.name);
2070 let groupby_result_name = format_ident!("{}GroupByResult", model.name);
2071 let having_input_name = format_ident!("{}HavingInput", model.name);
2072 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
2073 let where_input = format_ident!("{}WhereInput", model.name);
2074 let table_name = &model.db_name;
2075
2076 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
2079 .iter()
2080 .filter_map(|f| match &f.field_type {
2081 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
2082 Some((*f, AggregateKind::Numeric))
2083 }
2084 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
2085 _ => None,
2086 })
2087 .collect();
2088
2089 let group_fields: Vec<&Field> = scalar_fields
2090 .iter()
2091 .filter(|f| is_groupable(f))
2092 .copied()
2093 .collect();
2094
2095 if group_fields.is_empty() {
2096 return quote! {};
2097 }
2098
2099 let groupby_variants: Vec<TokenStream> = group_fields
2101 .iter()
2102 .map(|f| {
2103 let variant = format_ident!("{}", to_pascal_case(&f.name));
2104 quote! { #variant }
2105 })
2106 .collect();
2107
2108 let groupby_db_arms: Vec<TokenStream> = group_fields
2109 .iter()
2110 .map(|f| {
2111 let variant = format_ident!("{}", to_pascal_case(&f.name));
2112 let db_name = &f.db_name;
2113 quote! { Self::#variant => #db_name }
2114 })
2115 .collect();
2116
2117 let groupby_alias_arms: Vec<TokenStream> = group_fields
2118 .iter()
2119 .map(|f| {
2120 let variant = format_ident!("{}", to_pascal_case(&f.name));
2121 let alias = to_snake_case(&f.name);
2122 quote! { Self::#variant => #alias }
2123 })
2124 .collect();
2125
2126 let mut result_fields: Vec<TokenStream> = Vec::new();
2131 for f in &group_fields {
2132 let snake = to_snake_case(&f.name);
2133 let name = format_ident!("{}", snake);
2134 let base_ty = rust_type_tokens(
2136 &Field {
2137 is_optional: false,
2138 ..(*f).clone()
2139 },
2140 ModuleDepth::TopLevel,
2141 );
2142 result_fields.push(quote! { #[sqlx(default)] pub #name: Option<#base_ty> });
2143 }
2144 result_fields.push(quote! { #[sqlx(default)] pub count: Option<i64> });
2145 for (f, kind) in &agg_fields {
2146 let snake = to_snake_case(&f.name);
2147 let orig_ty = rust_type_tokens(
2148 &Field {
2149 is_optional: false,
2150 ..(*f).clone()
2151 },
2152 ModuleDepth::TopLevel,
2153 );
2154 match kind {
2155 AggregateKind::Numeric => {
2156 let avg_name = format_ident!("avg_{}", snake);
2157 let sum_name = format_ident!("sum_{}", snake);
2158 let min_name = format_ident!("min_{}", snake);
2159 let max_name = format_ident!("max_{}", snake);
2160 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
2161 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
2162 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2163 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2164 }
2165 AggregateKind::DateTime => {
2166 let min_name = format_ident!("min_{}", snake);
2167 let max_name = format_ident!("max_{}", snake);
2168 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
2169 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
2170 }
2171 }
2172 }
2173
2174 let mut having_fields: Vec<TokenStream> = Vec::new();
2179 having_fields.push(quote! { pub count: Option<ferriorm_runtime::filter::BigIntFilter> });
2181 for (f, kind) in &agg_fields {
2182 let snake = to_snake_case(&f.name);
2183 let avg_name = format_ident!("avg_{}", snake);
2184 let sum_name = format_ident!("sum_{}", snake);
2185 let min_name = format_ident!("min_{}", snake);
2186 let max_name = format_ident!("max_{}", snake);
2187 let column_filter = filter_type_tokens(
2188 &Field {
2189 is_optional: false,
2190 ..(*f).clone()
2191 },
2192 ModuleDepth::TopLevel,
2193 )
2194 .unwrap_or_else(|| quote! { ferriorm_runtime::filter::BigIntFilter });
2195 match kind {
2196 AggregateKind::Numeric => {
2197 having_fields
2198 .push(quote! { pub #avg_name: Option<ferriorm_runtime::filter::FloatFilter> });
2199 having_fields
2200 .push(quote! { pub #sum_name: Option<ferriorm_runtime::filter::FloatFilter> });
2201 having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2202 having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2203 }
2204 AggregateKind::DateTime => {
2205 having_fields.push(quote! { pub #min_name: Option<#column_filter> });
2206 having_fields.push(quote! { pub #max_name: Option<#column_filter> });
2207 }
2208 }
2209 }
2210
2211 let mut having_arms: Vec<TokenStream> = Vec::new();
2217 having_arms.push(quote! {
2219 if let Some(filter) = &self.count {
2220 if let Some(v) = &filter.equals { qb.push(" AND COUNT(*) = "); qb.push_bind(*v); }
2221 if let Some(v) = &filter.not { qb.push(" AND COUNT(*) != "); qb.push_bind(*v); }
2222 if let Some(v) = &filter.gt { qb.push(" AND COUNT(*) > "); qb.push_bind(*v); }
2223 if let Some(v) = &filter.gte { qb.push(" AND COUNT(*) >= "); qb.push_bind(*v); }
2224 if let Some(v) = &filter.lt { qb.push(" AND COUNT(*) < "); qb.push_bind(*v); }
2225 if let Some(v) = &filter.lte { qb.push(" AND COUNT(*) <= "); qb.push_bind(*v); }
2226 }
2227 });
2228
2229 for (f, kind) in &agg_fields {
2230 let snake = to_snake_case(&f.name);
2231 let db_name = &f.db_name;
2232 let avg_ident = format_ident!("avg_{}", snake);
2233 let sum_ident = format_ident!("sum_{}", snake);
2234 let min_ident = format_ident!("min_{}", snake);
2235 let max_ident = format_ident!("max_{}", snake);
2236 match kind {
2237 AggregateKind::Numeric => {
2238 let avg_lhs = format!(r#"AVG("{db_name}")"#);
2239 let sum_lhs = format!(r#"SUM("{db_name}")"#);
2240 let min_lhs = format!(r#"MIN("{db_name}")"#);
2241 let max_lhs = format!(r#"MAX("{db_name}")"#);
2242 having_arms.push(gen_having_comparable_arms(&avg_ident, &avg_lhs));
2243 having_arms.push(gen_having_comparable_arms(&sum_ident, &sum_lhs));
2244 having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2245 having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2246 }
2247 AggregateKind::DateTime => {
2248 let min_lhs = format!(r#"MIN("{db_name}")"#);
2249 let max_lhs = format!(r#"MAX("{db_name}")"#);
2250 having_arms.push(gen_having_comparable_arms(&min_ident, &min_lhs));
2251 having_arms.push(gen_having_comparable_arms(&max_ident, &max_lhs));
2252 }
2253 }
2254 }
2255
2256 let mut db_bounds = collect_db_bounds(scalar_fields);
2261 if !scalar_fields
2262 .iter()
2263 .any(|f| matches!(&f.field_type, FieldKind::Scalar(ScalarType::Float)))
2264 {
2265 db_bounds.push(quote! { f64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
2266 }
2267
2268 let has_agg_fields = !agg_fields.is_empty();
2273
2274 let agg_methods = if has_agg_fields {
2275 quote! {
2276 pub fn count(mut self) -> Self {
2277 self.count = true;
2278 self
2279 }
2280
2281 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
2282 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
2283 let db_name = field.db_name();
2284 let alias = field.alias("avg");
2285 self.agg_ops.push(("AVG", db_name, alias));
2286 self
2287 }
2288
2289 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
2290 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
2291 let db_name = field.db_name();
2292 let alias = field.alias("sum");
2293 self.agg_ops.push(("SUM", db_name, alias));
2294 self
2295 }
2296
2297 pub fn min(mut self, field: #aggregate_field_name) -> Self {
2298 let db_name = field.db_name();
2299 let alias = field.alias("min");
2300 self.agg_ops.push(("MIN", db_name, alias));
2301 self
2302 }
2303
2304 pub fn max(mut self, field: #aggregate_field_name) -> Self {
2305 let db_name = field.db_name();
2306 let alias = field.alias("max");
2307 self.agg_ops.push(("MAX", db_name, alias));
2308 self
2309 }
2310 }
2311 } else {
2312 quote! {
2313 pub fn count(mut self) -> Self {
2314 self.count = true;
2315 self
2316 }
2317 }
2318 };
2319
2320 quote! {
2321 #[derive(Debug, Clone, Copy)]
2322 pub enum #groupby_field_name {
2323 #(#groupby_variants),*
2324 }
2325
2326 impl #groupby_field_name {
2327 pub fn db_name(&self) -> &'static str {
2328 match self {
2329 #(#groupby_db_arms,)*
2330 }
2331 }
2332
2333 fn alias(&self) -> &'static str {
2334 match self {
2335 #(#groupby_alias_arms,)*
2336 }
2337 }
2338 }
2339
2340 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
2341 pub struct #groupby_result_name {
2342 #(#result_fields,)*
2343 }
2344
2345 #[derive(Debug, Clone, Default)]
2346 pub struct #having_input_name {
2347 #(#having_fields,)*
2348 pub and: Option<Vec<#having_input_name>>,
2349 pub or: Option<Vec<#having_input_name>>,
2350 pub not: Option<Box<#having_input_name>>,
2351 }
2352
2353 impl #having_input_name {
2354 pub(crate) fn build_having<'args, DB: sqlx::Database>(
2355 &self,
2356 qb: &mut sqlx::QueryBuilder<'args, DB>,
2357 )
2358 where
2359 #(#db_bounds,)*
2360 {
2361 #(#having_arms)*
2362
2363 if let Some(conditions) = &self.and {
2364 for c in conditions {
2365 c.build_having(qb);
2366 }
2367 }
2368 if let Some(conditions) = &self.or {
2369 if !conditions.is_empty() {
2370 qb.push(" AND (");
2371 for (i, c) in conditions.iter().enumerate() {
2372 if i > 0 { qb.push(" OR "); }
2373 qb.push("(1=1");
2374 c.build_having(qb);
2375 qb.push(")");
2376 }
2377 qb.push(")");
2378 }
2379 }
2380 if let Some(c) = &self.not {
2381 qb.push(" AND NOT (1=1");
2382 c.build_having(qb);
2383 qb.push(")");
2384 }
2385 }
2386 }
2387
2388 pub struct GroupByQuery<'a> {
2389 client: &'a DatabaseClient,
2390 r#where: filter::#where_input,
2391 group_keys: Vec<#groupby_field_name>,
2392 agg_ops: Vec<(&'static str, &'static str, &'static str)>,
2393 count: bool,
2394 having: Option<#having_input_name>,
2395 }
2396
2397 impl<'a> GroupByQuery<'a> {
2398 pub fn r#where(mut self, r#where: filter::#where_input) -> Self {
2399 self.r#where = r#where;
2400 self
2401 }
2402
2403 #agg_methods
2404
2405 pub fn having(mut self, having: #having_input_name) -> Self {
2406 self.having = Some(having);
2407 self
2408 }
2409
2410 pub async fn exec(self) -> Result<Vec<#groupby_result_name>, FerriormError> {
2411 if self.group_keys.is_empty() {
2412 return Err(FerriormError::Query(
2413 "group_by() requires at least one group key".into(),
2414 ));
2415 }
2416
2417 let mut selections: Vec<String> = self.group_keys
2418 .iter()
2419 .map(|k| format!(r#""{}" as "{}""#, k.db_name(), k.alias()))
2420 .collect();
2421 if self.count {
2422 selections.push(r#"COUNT(*) as "count""#.to_string());
2423 }
2424 for (func, col, alias) in &self.agg_ops {
2425 selections.push(format!(r#"{}("{}") as "{}""#, func, col, alias));
2426 }
2427
2428 let group_by_clause: Vec<String> = self.group_keys
2429 .iter()
2430 .map(|k| format!(r#""{}""#, k.db_name()))
2431 .collect();
2432
2433 let base_sql = format!(
2434 r#"SELECT {} FROM "{}" WHERE 1=1"#,
2435 selections.join(", "),
2436 #table_name,
2437 );
2438
2439 match self.client {
2440 DatabaseClient::Postgres(_) => {
2441 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
2442 self.r#where.build_where(&mut qb);
2443 qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2444 if let Some(h) = &self.having {
2445 qb.push(" HAVING 1=1");
2446 h.build_having(&mut qb);
2447 }
2448 self.client.fetch_all_pg(qb).await
2449 }
2450 DatabaseClient::Sqlite(_) => {
2451 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
2452 self.r#where.build_where(&mut qb);
2453 qb.push(format!(" GROUP BY {}", group_by_clause.join(", ")));
2454 if let Some(h) = &self.having {
2455 qb.push(" HAVING 1=1");
2456 h.build_having(&mut qb);
2457 }
2458 self.client.fetch_all_sqlite(qb).await
2459 }
2460 }
2461 }
2462 }
2463 }
2464}
2465
2466#[allow(clippy::too_many_lines)]
2469fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
2470 let select_name = format_ident!("{}Select", model.name);
2471 let partial_name = format_ident!("{}Partial", model.name);
2472 let _where_input = format_ident!("{}WhereInput", model.name);
2473 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
2474 let order_by_name = format_ident!("{}OrderByInput", model.name);
2475 let table_name = &model.db_name;
2476
2477 let select_fields: Vec<TokenStream> = scalar_fields
2479 .iter()
2480 .map(|f| {
2481 let name = format_ident!("{}", to_snake_case(&f.name));
2482 quote! { pub #name: bool }
2483 })
2484 .collect();
2485
2486 let partial_fields: Vec<TokenStream> = scalar_fields
2489 .iter()
2490 .map(|f| {
2491 let name = format_ident!("{}", to_snake_case(&f.name));
2492 let db_name = &f.db_name;
2493 let base_ty = rust_type_tokens(
2495 &Field {
2496 is_optional: false,
2497 ..(*f).clone()
2498 },
2499 ModuleDepth::TopLevel,
2500 );
2501 let rename = if db_name == &to_snake_case(&f.name) {
2502 quote! {}
2503 } else {
2504 quote! { #[sqlx(rename = #db_name)] }
2505 };
2506 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
2508 })
2509 .collect();
2510
2511 let select_col_arms: Vec<TokenStream> = scalar_fields
2513 .iter()
2514 .map(|f| {
2515 let name = format_ident!("{}", to_snake_case(&f.name));
2516 let db_name = &f.db_name;
2517 let col_expr = format!(r#""{db_name}""#);
2518 quote! {
2519 if select.#name { cols.push(#col_expr); }
2520 }
2521 })
2522 .collect();
2523
2524 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2525
2526 quote! {
2527 #[derive(Debug, Clone, Default)]
2528 pub struct #select_name {
2529 #(#select_fields,)*
2530 }
2531
2532 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2533 #[sqlx(rename_all = "snake_case")]
2534 pub struct #partial_name {
2535 #(#partial_fields,)*
2536 }
2537
2538 fn build_select_columns(select: &#select_name) -> String {
2539 let mut cols = Vec::new();
2540 #(#select_col_arms)*
2541 if cols.is_empty() {
2542 "*".to_string()
2543 } else {
2544 cols.join(", ")
2545 }
2546 }
2547
2548 pub struct FindManySelectQuery<'a> {
2551 client: &'a DatabaseClient,
2552 r#where: filter::#_where_input,
2553 order_by: Vec<order::#order_by_name>,
2554 skip: Option<i64>,
2555 take: Option<i64>,
2556 select: #select_name,
2557 }
2558
2559 impl<'a> FindManySelectQuery<'a> {
2560 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2561 self.order_by.push(order);
2562 self
2563 }
2564
2565 pub fn skip(mut self, n: i64) -> Self {
2566 self.skip = Some(n);
2567 self
2568 }
2569
2570 pub fn take(mut self, n: i64) -> Self {
2571 self.take = Some(n);
2572 self
2573 }
2574
2575 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2576 let cols = build_select_columns(&self.select);
2577 let base_sql = format!(#select_sql_prefix, cols);
2578
2579 match self.client {
2580 DatabaseClient::Postgres(_) => {
2581 let qb = build_select_query::<sqlx::Postgres>(
2582 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2583 );
2584 self.client.fetch_all_pg(qb).await
2585 }
2586 DatabaseClient::Sqlite(_) => {
2587 let qb = build_select_query::<sqlx::Sqlite>(
2588 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2589 );
2590 self.client.fetch_all_sqlite(qb).await
2591 }
2592 }
2593 }
2594 }
2595
2596 pub struct FindUniqueSelectQuery<'a> {
2599 client: &'a DatabaseClient,
2600 r#where: filter::#_where_unique,
2601 select: #select_name,
2602 }
2603
2604 impl<'a> FindUniqueSelectQuery<'a> {
2605 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2606 let cols = build_select_columns(&self.select);
2607 let base_sql = format!(#select_sql_prefix, cols);
2608
2609 match self.client {
2610 DatabaseClient::Postgres(_) => {
2611 let qb = build_unique_select_query::<sqlx::Postgres>(
2612 &base_sql, &self.r#where,
2613 );
2614 self.client.fetch_optional_pg(qb).await
2615 }
2616 DatabaseClient::Sqlite(_) => {
2617 let qb = build_unique_select_query::<sqlx::Sqlite>(
2618 &base_sql, &self.r#where,
2619 );
2620 self.client.fetch_optional_sqlite(qb).await
2621 }
2622 }
2623 }
2624 }
2625
2626 pub struct FindFirstSelectQuery<'a> {
2629 client: &'a DatabaseClient,
2630 r#where: filter::#_where_input,
2631 order_by: Vec<order::#order_by_name>,
2632 select: #select_name,
2633 }
2634
2635 impl<'a> FindFirstSelectQuery<'a> {
2636 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2637 self.order_by.push(order);
2638 self
2639 }
2640
2641 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2642 let cols = build_select_columns(&self.select);
2643 let base_sql = format!(#select_sql_prefix, cols);
2644
2645 match self.client {
2646 DatabaseClient::Postgres(_) => {
2647 let qb = build_select_query::<sqlx::Postgres>(
2648 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2649 );
2650 self.client.fetch_optional_pg(qb).await
2651 }
2652 DatabaseClient::Sqlite(_) => {
2653 let qb = build_select_query::<sqlx::Sqlite>(
2654 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2655 );
2656 self.client.fetch_optional_sqlite(qb).await
2657 }
2658 }
2659 }
2660 }
2661 }
2662}