1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22pub fn generate_model_module(model: &Model) -> TokenStream {
24 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
25
26 let data_struct = gen_data_struct(model, &scalar_fields);
27 let filter_module = gen_filter_module(model, &scalar_fields);
28 let data_module = gen_data_module(model, &scalar_fields);
29 let order_module = gen_order_module(model, &scalar_fields);
30 let actions_struct = gen_actions(model, &scalar_fields);
31 let query_builders = gen_query_builders(model, &scalar_fields);
32 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
33 let select_types = gen_select_types(model, &scalar_fields);
34
35 quote! {
36 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
37
38 use serde::{Deserialize, Serialize};
39 use ferriorm_runtime::prelude::*;
40 use ferriorm_runtime::prelude::sqlx;
41 use ferriorm_runtime::prelude::chrono;
42 use ferriorm_runtime::prelude::uuid;
43
44 #data_struct
45 #filter_module
46 #data_module
47 #order_module
48 #actions_struct
49 #query_builders
50 #aggregate_types
51 #select_types
52 }
53}
54
55fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
58 let struct_name = format_ident!("{}", model.name);
59 let table_name = &model.db_name;
60
61 let fields: Vec<TokenStream> = scalar_fields
62 .iter()
63 .map(|f| {
64 let name = format_ident!("{}", to_snake_case(&f.name));
65 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
66 let db_name = &f.db_name;
67 if db_name != &to_snake_case(&f.name) {
68 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
69 } else {
70 quote! { pub #name: #ty }
71 }
72 })
73 .collect();
74
75 quote! {
76 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
77 #[sqlx(rename_all = "snake_case")]
78 pub struct #struct_name {
79 #(#fields),*
80 }
81
82 impl #struct_name {
83 pub const TABLE_NAME: &'static str = #table_name;
84 }
85 }
86}
87
88fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
91 let where_input = format_ident!("{}WhereInput", model.name);
92 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
93
94 let where_fields: Vec<TokenStream> = scalar_fields
95 .iter()
96 .filter_map(|f| {
97 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
98 let name = format_ident!("{}", to_snake_case(&f.name));
99 Some(quote! { pub #name: Option<#filter_ty> })
100 })
101 .collect();
102
103 let unique_variants: Vec<TokenStream> = scalar_fields
104 .iter()
105 .filter(|f| f.is_id || f.is_unique)
106 .map(|f| {
107 let variant = format_ident!("{}", to_pascal_case(&f.name));
108 let ty = rust_type_tokens(f, ModuleDepth::Nested);
109 quote! { #variant(#ty) }
110 })
111 .collect();
112
113 let db_bounds = collect_db_bounds(scalar_fields);
115 let where_arms = gen_where_arms(scalar_fields);
116 let unique_arms = gen_unique_where_arms(scalar_fields);
117
118 quote! {
119 pub mod filter {
120 use ferriorm_runtime::prelude::*;
121
122 #[derive(Debug, Clone, Default)]
123 pub struct #where_input {
124 #(#where_fields,)*
125 pub and: Option<Vec<#where_input>>,
126 pub or: Option<Vec<#where_input>>,
127 pub not: Option<Box<#where_input>>,
128 }
129
130 #[derive(Debug, Clone)]
131 pub enum #where_unique {
132 #(#unique_variants),*
133 }
134
135 impl #where_input {
136 pub(crate) fn build_where<'args, DB: sqlx::Database>(
137 &self,
138 qb: &mut sqlx::QueryBuilder<'args, DB>,
139 )
140 where
141 #(#db_bounds,)*
142 {
143 #(#where_arms)*
144
145 if let Some(conditions) = &self.and {
146 for c in conditions {
147 c.build_where(qb);
148 }
149 }
150 if let Some(conditions) = &self.or {
151 if !conditions.is_empty() {
152 qb.push(" AND (");
153 for (i, c) in conditions.iter().enumerate() {
154 if i > 0 { qb.push(" OR "); }
155 qb.push("(1=1");
156 c.build_where(qb);
157 qb.push(")");
158 }
159 qb.push(")");
160 }
161 }
162 if let Some(c) = &self.not {
163 qb.push(" AND NOT (1=1");
164 c.build_where(qb);
165 qb.push(")");
166 }
167 }
168 }
169
170 impl #where_unique {
171 pub(crate) fn build_where<'args, DB: sqlx::Database>(
172 &self,
173 qb: &mut sqlx::QueryBuilder<'args, DB>,
174 )
175 where
176 #(#db_bounds,)*
177 {
178 match self {
179 #(#unique_arms)*
180 }
181 }
182 }
183 }
184 }
185}
186
187fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
189 let mut seen = std::collections::HashSet::new();
190 let mut bounds = Vec::new();
191
192 seen.insert("i64");
194 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
195
196 for f in scalar_fields {
197 match &f.field_type {
198 FieldKind::Scalar(scalar) => {
199 let key = scalar.rust_type();
200 if seen.insert(key)
201 && let Some(ty) = scalar_bound_tokens(scalar)
202 {
203 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
204 bounds.push(
206 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
207 );
208 }
209 }
210 FieldKind::Enum(_) => {}
211 _ => {}
212 }
213 }
214
215 bounds
216}
217
218fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
219 match scalar {
220 ScalarType::String => Some(quote! { String }),
221 ScalarType::Int => Some(quote! { i32 }),
222 ScalarType::BigInt => Some(quote! { i64 }),
223 ScalarType::Float => Some(quote! { f64 }),
224 ScalarType::Boolean => Some(quote! { bool }),
225 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
226 ScalarType::Bytes => Some(quote! { Vec<u8> }),
227 ScalarType::Json | ScalarType::Decimal => None,
228 }
229}
230
231fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
233 scalar_fields
234 .iter()
235 .filter_map(|f| {
236 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
238 return None;
239 }
240 let field_ident = format_ident!("{}", to_snake_case(&f.name));
241 let db_name = &f.db_name;
242 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
243 let is_comparable = matches!(
244 &f.field_type,
245 FieldKind::Scalar(
246 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
247 )
248 );
249
250 let mut arms = vec![];
251
252 arms.push(quote! {
253 if let Some(v) = &filter.equals {
254 qb.push(concat!(" AND \"", #db_name, "\" = "));
255 qb.push_bind(v.clone());
256 }
257 if let Some(v) = &filter.not {
258 qb.push(concat!(" AND \"", #db_name, "\" != "));
259 qb.push_bind(v.clone());
260 }
261 });
262
263 if is_string {
264 arms.push(quote! {
265 if let Some(v) = &filter.contains {
266 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
267 qb.push_bind(format!("%{}%", v));
268 }
269 if let Some(v) = &filter.starts_with {
270 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
271 qb.push_bind(format!("{}%", v));
272 }
273 if let Some(v) = &filter.ends_with {
274 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
275 qb.push_bind(format!("%{}", v));
276 }
277 });
278 }
279
280 if is_comparable {
281 arms.push(quote! {
282 if let Some(v) = &filter.gt {
283 qb.push(concat!(" AND \"", #db_name, "\" > "));
284 qb.push_bind(v.clone());
285 }
286 if let Some(v) = &filter.gte {
287 qb.push(concat!(" AND \"", #db_name, "\" >= "));
288 qb.push_bind(v.clone());
289 }
290 if let Some(v) = &filter.lt {
291 qb.push(concat!(" AND \"", #db_name, "\" < "));
292 qb.push_bind(v.clone());
293 }
294 if let Some(v) = &filter.lte {
295 qb.push(concat!(" AND \"", #db_name, "\" <= "));
296 qb.push_bind(v.clone());
297 }
298 });
299 }
300
301 Some(quote! {
302 if let Some(filter) = &self.#field_ident {
303 #(#arms)*
304 }
305 })
306 })
307 .collect()
308}
309
310fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
311 let _where_unique = format_ident!(
312 "{}WhereUniqueInput",
313 "" );
315 scalar_fields
316 .iter()
317 .filter(|f| f.is_id || f.is_unique)
318 .map(|f| {
319 let variant = format_ident!("{}", to_pascal_case(&f.name));
320 let db_name = &f.db_name;
321 quote! {
322 Self::#variant(v) => {
323 qb.push(concat!(" AND \"", #db_name, "\" = "));
324 qb.push_bind(v.clone());
325 }
326 }
327 })
328 .collect()
329}
330
331fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
334 let create_name = format_ident!("{}CreateInput", model.name);
335 let update_name = format_ident!("{}UpdateInput", model.name);
336
337 let required_fields: Vec<TokenStream> = scalar_fields
338 .iter()
339 .filter(|f| !f.has_default() && !f.is_updated_at)
340 .map(|f| {
341 let name = format_ident!("{}", to_snake_case(&f.name));
342 let ty = rust_type_tokens(f, ModuleDepth::Nested);
343 quote! { pub #name: #ty }
344 })
345 .collect();
346
347 let optional_fields: Vec<TokenStream> = scalar_fields
348 .iter()
349 .filter(|f| f.has_default() && !f.is_updated_at)
350 .map(|f| {
351 let name = format_ident!("{}", to_snake_case(&f.name));
352 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
353 quote! { pub #name: Option<#base_ty> }
354 })
355 .collect();
356
357 let update_fields: Vec<TokenStream> = scalar_fields
358 .iter()
359 .filter(|f| !f.is_id && !f.is_updated_at)
360 .map(|f| {
361 let name = format_ident!("{}", to_snake_case(&f.name));
362 let ty = rust_type_tokens(f, ModuleDepth::Nested);
363 quote! { pub #name: Option<SetValue<#ty>> }
364 })
365 .collect();
366
367 quote! {
368 pub mod data {
369 use ferriorm_runtime::prelude::*;
370
371 #[derive(Debug, Clone)]
372 pub struct #create_name {
373 #(#required_fields,)*
374 #(#optional_fields,)*
375 }
376
377 #[derive(Debug, Clone, Default)]
378 pub struct #update_name {
379 #(#update_fields,)*
380 }
381 }
382 }
383}
384
385fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
388 let order_name = format_ident!("{}OrderByInput", model.name);
389
390 let variants: Vec<TokenStream> = scalar_fields
391 .iter()
392 .map(|f| {
393 let variant = format_ident!("{}", to_pascal_case(&f.name));
394 quote! { #variant(SortOrder) }
395 })
396 .collect();
397
398 let order_arms: Vec<TokenStream> = scalar_fields
399 .iter()
400 .map(|f| {
401 let variant = format_ident!("{}", to_pascal_case(&f.name));
402 let db_name = &f.db_name;
403 quote! {
404 Self::#variant(order) => {
405 qb.push(concat!("\"", #db_name, "\" "));
406 qb.push(order.as_sql());
407 }
408 }
409 })
410 .collect();
411
412 quote! {
413 pub mod order {
414 use ferriorm_runtime::prelude::*;
415
416 #[derive(Debug, Clone)]
417 pub enum #order_name {
418 #(#variants),*
419 }
420
421 impl #order_name {
422 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
423 &self,
424 qb: &mut sqlx::QueryBuilder<'args, DB>,
425 ) {
426 match self {
427 #(#order_arms)*
428 }
429 }
430 }
431 }
432 }
433}
434
435fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
438 let _model_ident = format_ident!("{}", model.name);
439 let actions_name = format_ident!("{}Actions", model.name);
440 let where_input = format_ident!("{}WhereInput", model.name);
441 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
442 let create_input = format_ident!("{}CreateInput", model.name);
443 let update_input = format_ident!("{}UpdateInput", model.name);
444 let _order_by = format_ident!("{}OrderByInput", model.name);
445
446 let has_agg_fields = scalar_fields.iter().any(|f| {
448 matches!(
449 &f.field_type,
450 FieldKind::Scalar(
451 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
452 )
453 )
454 });
455 let aggregate_method = if has_agg_fields {
456 quote! {
457 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
458 AggregateQuery { client: self.client, r#where, ops: vec![] }
459 }
460 }
461 } else {
462 quote! {}
463 };
464
465 quote! {
466 pub struct #actions_name<'a> {
467 client: &'a DatabaseClient,
468 }
469
470 impl<'a> #actions_name<'a> {
471 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
472
473 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
474 FindUniqueQuery { client: self.client, r#where }
475 }
476
477 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
478 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
479 }
480
481 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
482 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
483 }
484
485 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
486 CreateQuery { client: self.client, data }
487 }
488
489 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
490 UpdateQuery { client: self.client, r#where, data }
491 }
492
493 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
494 DeleteQuery { client: self.client, r#where }
495 }
496
497 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
498 CountQuery { client: self.client, r#where }
499 }
500
501 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
502 CreateManyQuery { client: self.client, data }
503 }
504
505 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
506 UpdateManyQuery { client: self.client, r#where, data }
507 }
508
509 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
510 DeleteManyQuery { client: self.client, r#where }
511 }
512
513 pub fn upsert(
514 &self,
515 r#where: filter::#where_unique,
516 create: data::#create_input,
517 update: data::#update_input,
518 ) -> UpsertQuery<'a> {
519 UpsertQuery { client: self.client, r#where, create, update }
520 }
521
522 #aggregate_method
523 }
524 }
525}
526
527fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
530 let model_ident = format_ident!("{}", model.name);
531 let table_name = &model.db_name;
532 let _where_input = format_ident!("{}WhereInput", model.name);
533 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
534 let _create_input = format_ident!("{}CreateInput", model.name);
535 let _update_input = format_ident!("{}UpdateInput", model.name);
536 let order_by = format_ident!("{}OrderByInput", model.name);
537 let _select_struct = format_ident!("{}Select", model.name);
538 let _partial_struct = format_ident!("{}Partial", model.name);
539 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
540 let _aggregate_field = format_ident!("{}AggregateField", model.name);
541 let db_bounds = collect_db_bounds(scalar_fields);
542
543 let select_sql = format!(r#"SELECT * FROM "{}" WHERE 1=1"#, table_name);
544 let count_sql = format!(
545 r#"SELECT COUNT(*) as "count" FROM "{}" WHERE 1=1"#,
546 table_name
547 );
548 let delete_sql = format!(r#"DELETE FROM "{}" WHERE 1=1"#, table_name);
549
550 let insert_code = gen_insert_code(model, scalar_fields, table_name);
551 let update_code = gen_update_code(model, scalar_fields, table_name);
552 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
553 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
554
555 quote! {
556 fn build_order_by<'args, DB: sqlx::Database>(
558 orders: &[order::#order_by],
559 qb: &mut sqlx::QueryBuilder<'args, DB>,
560 ) {
561 if !orders.is_empty() {
562 qb.push(" ORDER BY ");
563 for (i, ob) in orders.iter().enumerate() {
564 if i > 0 { qb.push(", "); }
565 ob.build_order_by(qb);
566 }
567 }
568 }
569
570 fn build_select_query<'args, DB: sqlx::Database>(
572 base_sql: &str,
573 where_input: &filter::#_where_input,
574 orders: &[order::#order_by],
575 take: Option<i64>,
576 skip: Option<i64>,
577 ) -> sqlx::QueryBuilder<'args, DB>
578 where
579 #(#db_bounds,)*
580 {
581 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
582 where_input.build_where(&mut qb);
583 build_order_by(orders, &mut qb);
584 if let Some(take) = take {
585 qb.push(" LIMIT ");
586 qb.push_bind(take);
587 }
588 if let Some(skip) = skip {
589 qb.push(" OFFSET ");
590 qb.push_bind(skip);
591 }
592 qb
593 }
594
595 fn build_unique_select_query<'args, DB: sqlx::Database>(
597 base_sql: &str,
598 where_unique: &filter::#_where_unique,
599 ) -> sqlx::QueryBuilder<'args, DB>
600 where
601 #(#db_bounds,)*
602 {
603 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
604 where_unique.build_where(&mut qb);
605 qb.push(" LIMIT 1");
606 qb
607 }
608
609 fn build_delete_query<'args, DB: sqlx::Database>(
611 base_sql: &str,
612 where_unique: &filter::#_where_unique,
613 ) -> sqlx::QueryBuilder<'args, DB>
614 where
615 #(#db_bounds,)*
616 {
617 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
618 where_unique.build_where(&mut qb);
619 qb.push(" RETURNING *");
620 qb
621 }
622
623 fn build_count_query<'args, DB: sqlx::Database>(
625 base_sql: &str,
626 where_input: &filter::#_where_input,
627 ) -> sqlx::QueryBuilder<'args, DB>
628 where
629 #(#db_bounds,)*
630 {
631 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
632 where_input.build_where(&mut qb);
633 qb
634 }
635
636 fn build_delete_many_query<'args, DB: sqlx::Database>(
638 base_sql: &str,
639 where_input: &filter::#_where_input,
640 ) -> sqlx::QueryBuilder<'args, DB>
641 where
642 #(#db_bounds,)*
643 {
644 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
645 where_input.build_where(&mut qb);
646 qb
647 }
648
649 pub struct FindUniqueQuery<'a> {
650 client: &'a DatabaseClient,
651 r#where: filter::#_where_unique,
652 }
653
654 impl<'a> FindUniqueQuery<'a> {
655 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
656 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
657 }
658
659 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
660 match self.client {
661 DatabaseClient::Postgres(_) => {
662 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
663 self.client.fetch_optional_pg(qb).await
664 }
665 DatabaseClient::Sqlite(_) => {
666 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
667 self.client.fetch_optional_sqlite(qb).await
668 }
669 }
670 }
671 }
672
673 pub struct FindFirstQuery<'a> {
674 client: &'a DatabaseClient,
675 r#where: filter::#_where_input,
676 order_by: Vec<order::#order_by>,
677 }
678
679 impl<'a> FindFirstQuery<'a> {
680 pub fn order_by(mut self, order: order::#order_by) -> Self {
681 self.order_by.push(order);
682 self
683 }
684
685 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
686 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
687 }
688
689 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
690 match self.client {
691 DatabaseClient::Postgres(_) => {
692 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
693 self.client.fetch_optional_pg(qb).await
694 }
695 DatabaseClient::Sqlite(_) => {
696 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
697 self.client.fetch_optional_sqlite(qb).await
698 }
699 }
700 }
701 }
702
703 pub struct FindManyQuery<'a> {
704 client: &'a DatabaseClient,
705 r#where: filter::#_where_input,
706 order_by: Vec<order::#order_by>,
707 skip: Option<i64>,
708 take: Option<i64>,
709 }
710
711 impl<'a> FindManyQuery<'a> {
712 pub fn order_by(mut self, order: order::#order_by) -> Self {
713 self.order_by.push(order);
714 self
715 }
716
717 pub fn skip(mut self, n: i64) -> Self {
718 self.skip = Some(n);
719 self
720 }
721
722 pub fn take(mut self, n: i64) -> Self {
723 self.take = Some(n);
724 self
725 }
726
727 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
728 FindManySelectQuery {
729 client: self.client,
730 r#where: self.r#where,
731 order_by: self.order_by,
732 skip: self.skip,
733 take: self.take,
734 select,
735 }
736 }
737
738 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
739 match self.client {
740 DatabaseClient::Postgres(_) => {
741 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
742 self.client.fetch_all_pg(qb).await
743 }
744 DatabaseClient::Sqlite(_) => {
745 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
746 self.client.fetch_all_sqlite(qb).await
747 }
748 }
749 }
750 }
751
752 pub struct CreateQuery<'a> {
753 client: &'a DatabaseClient,
754 data: data::#_create_input,
755 }
756
757 impl<'a> CreateQuery<'a> {
758 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
759 let client = self.client;
760 #insert_code
761 }
762 }
763
764 pub struct UpdateQuery<'a> {
765 client: &'a DatabaseClient,
766 r#where: filter::#_where_unique,
767 data: data::#_update_input,
768 }
769
770 impl<'a> UpdateQuery<'a> {
771 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
772 let client = self.client;
773 #update_code
774 }
775 }
776
777 pub struct DeleteQuery<'a> {
778 client: &'a DatabaseClient,
779 r#where: filter::#_where_unique,
780 }
781
782 impl<'a> DeleteQuery<'a> {
783 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
784 match self.client {
785 DatabaseClient::Postgres(_) => {
786 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
787 self.client.fetch_one_pg(qb).await
788 }
789 DatabaseClient::Sqlite(_) => {
790 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
791 self.client.fetch_one_sqlite(qb).await
792 }
793 }
794 }
795 }
796
797 #[derive(sqlx::FromRow)]
798 struct CountResult { count: i64 }
799
800 pub struct CountQuery<'a> {
801 client: &'a DatabaseClient,
802 r#where: filter::#_where_input,
803 }
804
805 impl<'a> CountQuery<'a> {
806 pub async fn exec(self) -> Result<i64, FerriormError> {
807 let row: CountResult = match self.client {
808 DatabaseClient::Postgres(_) => {
809 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
810 self.client.fetch_one_pg(qb).await?
811 }
812 DatabaseClient::Sqlite(_) => {
813 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
814 self.client.fetch_one_sqlite(qb).await?
815 }
816 };
817 Ok(row.count)
818 }
819 }
820
821 pub struct CreateManyQuery<'a> {
822 client: &'a DatabaseClient,
823 data: Vec<data::#_create_input>,
824 }
825
826 impl<'a> CreateManyQuery<'a> {
827 pub async fn exec(self) -> Result<u64, FerriormError> {
828 if self.data.is_empty() { return Ok(0); }
829 let count = self.data.len() as u64;
830 for item in self.data {
831 CreateQuery { client: self.client, data: item }.exec().await?;
832 }
833 Ok(count)
834 }
835 }
836
837 pub struct UpdateManyQuery<'a> {
838 client: &'a DatabaseClient,
839 r#where: filter::#_where_input,
840 data: data::#_update_input,
841 }
842
843 impl<'a> UpdateManyQuery<'a> {
844 pub async fn exec(self) -> Result<u64, FerriormError> {
845 let client = self.client;
846 #update_many_code
847 }
848 }
849
850 pub struct UpsertQuery<'a> {
851 client: &'a DatabaseClient,
852 r#where: filter::#_where_unique,
853 create: data::#_create_input,
854 update: data::#_update_input,
855 }
856
857 impl<'a> UpsertQuery<'a> {
858 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
859 let client = self.client;
860 #upsert_code
861 }
862 }
863
864 pub struct DeleteManyQuery<'a> {
865 client: &'a DatabaseClient,
866 r#where: filter::#_where_input,
867 }
868
869 impl<'a> DeleteManyQuery<'a> {
870 pub async fn exec(self) -> Result<u64, FerriormError> {
871 match self.client {
872 DatabaseClient::Postgres(_) => {
873 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
874 self.client.execute_pg(qb).await
875 }
876 DatabaseClient::Sqlite(_) => {
877 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
878 self.client.execute_sqlite(qb).await
879 }
880 }
881 }
882 }
883 }
884}
885
886fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
889 let _model_ident = format_ident!("{}", model.name);
890
891 let required: Vec<&Field> = scalar_fields
893 .iter()
894 .copied()
895 .filter(|f| !f.has_default() && !f.is_updated_at)
896 .collect();
897
898 let optional: Vec<&Field> = scalar_fields
900 .iter()
901 .copied()
902 .filter(|f| f.has_default() && !f.is_updated_at)
903 .collect();
904
905 let updated_at: Vec<&Field> = scalar_fields
907 .iter()
908 .copied()
909 .filter(|f| f.is_updated_at)
910 .collect();
911
912 let mut col_pushes = vec![];
914 let mut val_pushes = vec![];
915
916 for f in &required {
918 let db_name = &f.db_name;
919 let field_ident = format_ident!("{}", to_snake_case(&f.name));
920 col_pushes.push(quote! { cols.push(#db_name); });
921 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
922 }
923
924 for f in &optional {
926 let db_name = &f.db_name;
927 let field_ident = format_ident!("{}", to_snake_case(&f.name));
928 let default_expr = gen_default_expr(f, &f.field_type);
929
930 col_pushes.push(quote! { cols.push(#db_name); });
931 val_pushes.push(quote! {
932 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
933 sep.push_bind(val);
934 });
935 }
936
937 for f in &updated_at {
939 let db_name = &f.db_name;
940 col_pushes.push(quote! { cols.push(#db_name); });
941 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
942 }
943
944 let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
945
946 quote! {
949 macro_rules! build_insert {
951 ($qb_type:ty) => {{
952 let mut cols: Vec<&str> = Vec::new();
953 #(#col_pushes)*
954
955 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
956 qb.push(" (");
957 for (i, col) in cols.iter().enumerate() {
958 if i > 0 { qb.push(", "); }
959 qb.push("\"");
960 qb.push(*col);
961 qb.push("\"");
962 }
963 qb.push(") VALUES (");
964 {
965 let mut sep = qb.separated(", ");
966 #(#val_pushes)*
967 }
968 qb.push(") RETURNING *");
969 qb
970 }};
971 }
972
973 match client {
974 DatabaseClient::Postgres(_) => {
975 let qb = build_insert!(sqlx::Postgres);
976 client.fetch_one_pg(qb).await
977 }
978 DatabaseClient::Sqlite(_) => {
979 let qb = build_insert!(sqlx::Sqlite);
980 client.fetch_one_sqlite(qb).await
981 }
982 }
983 }
984}
985
986fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
988 use ferriorm_core::ast::DefaultValue;
989
990 match &field.default {
991 Some(DefaultValue::Uuid) => quote! { uuid::Uuid::new_v4().to_string() },
992 Some(DefaultValue::Cuid) => quote! { uuid::Uuid::new_v4().to_string() }, Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
994 Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, Some(DefaultValue::Literal(lit)) => {
996 use ferriorm_core::ast::LiteralValue;
997 match lit {
998 LiteralValue::String(s) => quote! { #s.to_string() },
999 LiteralValue::Int(i) => {
1000 match field_type {
1002 FieldKind::Scalar(ScalarType::Float) => {
1003 let val = *i as f64;
1004 quote! { #val }
1005 }
1006 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1007 _ => {
1008 let val = *i as i32;
1010 quote! { #val }
1011 }
1012 }
1013 }
1014 LiteralValue::Float(f) => quote! { #f },
1015 LiteralValue::Bool(b) => quote! { #b },
1016 }
1017 }
1018 Some(DefaultValue::EnumVariant(v)) => {
1019 let variant = format_ident!("{}", v);
1021 if let FieldKind::Enum(enum_name) = &field.field_type {
1022 let enum_ident = format_ident!("{}", enum_name);
1023 quote! { super::enums::#enum_ident::#variant }
1024 } else {
1025 quote! { Default::default() }
1026 }
1027 }
1028 None => quote! { Default::default() },
1029 }
1030}
1031
1032fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1035 let _model_ident = format_ident!("{}", model.name);
1036
1037 let updatable: Vec<&Field> = scalar_fields
1039 .iter()
1040 .copied()
1041 .filter(|f| !f.is_id && !f.is_updated_at)
1042 .collect();
1043
1044 let updated_at: Vec<&Field> = scalar_fields
1045 .iter()
1046 .copied()
1047 .filter(|f| f.is_updated_at)
1048 .collect();
1049
1050 let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1051
1052 let set_arms: Vec<TokenStream> = updatable
1054 .iter()
1055 .map(|f| {
1056 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1057 let db_name = &f.db_name;
1058 quote! {
1059 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1060 if !first_set { qb.push(", "); }
1061 first_set = false;
1062 qb.push(concat!("\"", #db_name, "\" = "));
1063 qb.push_bind(v);
1064 }
1065 }
1066 })
1067 .collect();
1068
1069 let updated_at_arms: Vec<TokenStream> = updated_at
1070 .iter()
1071 .map(|f| {
1072 let db_name = &f.db_name;
1073 quote! {
1074 if !first_set { qb.push(", "); }
1075 first_set = false;
1076 qb.push(concat!("\"", #db_name, "\" = "));
1077 qb.push_bind(chrono::Utc::now());
1078 }
1079 })
1080 .collect();
1081
1082 quote! {
1085 macro_rules! build_update {
1086 ($qb_type:ty) => {{
1087 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1088 let mut first_set = true;
1089 #(#set_arms)*
1090 #(#updated_at_arms)*
1091
1092 if first_set {
1093 return Err(FerriormError::Query("No fields to update".into()));
1094 }
1095
1096 qb.push(" WHERE 1=1");
1097 self.r#where.build_where(&mut qb);
1098 qb.push(" RETURNING *");
1099 qb
1100 }};
1101 }
1102
1103 match client {
1104 DatabaseClient::Postgres(_) => {
1105 let qb = build_update!(sqlx::Postgres);
1106 client.fetch_one_pg(qb).await
1107 }
1108 DatabaseClient::Sqlite(_) => {
1109 let qb = build_update!(sqlx::Sqlite);
1110 client.fetch_one_sqlite(qb).await
1111 }
1112 }
1113 }
1114}
1115
1116fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1119 let updatable: Vec<&Field> = scalar_fields
1121 .iter()
1122 .copied()
1123 .filter(|f| !f.is_id && !f.is_updated_at)
1124 .collect();
1125
1126 let updated_at: Vec<&Field> = scalar_fields
1127 .iter()
1128 .copied()
1129 .filter(|f| f.is_updated_at)
1130 .collect();
1131
1132 let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1133
1134 let set_arms: Vec<TokenStream> = updatable
1136 .iter()
1137 .map(|f| {
1138 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1139 let db_name = &f.db_name;
1140 quote! {
1141 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1142 if !first_set { qb.push(", "); }
1143 first_set = false;
1144 qb.push(concat!("\"", #db_name, "\" = "));
1145 qb.push_bind(v);
1146 }
1147 }
1148 })
1149 .collect();
1150
1151 let updated_at_arms: Vec<TokenStream> = updated_at
1152 .iter()
1153 .map(|f| {
1154 let db_name = &f.db_name;
1155 quote! {
1156 if !first_set { qb.push(", "); }
1157 first_set = false;
1158 qb.push(concat!("\"", #db_name, "\" = "));
1159 qb.push_bind(chrono::Utc::now());
1160 }
1161 })
1162 .collect();
1163
1164 quote! {
1165 macro_rules! build_update_many {
1166 ($qb_type:ty) => {{
1167 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1168 let mut first_set = true;
1169 #(#set_arms)*
1170 #(#updated_at_arms)*
1171
1172 if first_set {
1173 return Ok(0);
1174 }
1175
1176 qb.push(" WHERE 1=1");
1177 self.r#where.build_where(&mut qb);
1178 qb
1179 }};
1180 }
1181
1182 match client {
1183 DatabaseClient::Postgres(_) => {
1184 let qb = build_update_many!(sqlx::Postgres);
1185 client.execute_pg(qb).await
1186 }
1187 DatabaseClient::Sqlite(_) => {
1188 let qb = build_update_many!(sqlx::Sqlite);
1189 client.execute_sqlite(qb).await
1190 }
1191 }
1192 }
1193}
1194
1195enum AggregateKind {
1199 Numeric,
1201 DateTime,
1203}
1204
1205fn gen_upsert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1208 let pk_db_names: Vec<String> = model
1210 .primary_key
1211 .fields
1212 .iter()
1213 .filter_map(|pk| {
1214 model
1215 .fields
1216 .iter()
1217 .find(|f| f.name == *pk || to_snake_case(&f.name) == *pk)
1218 .map(|f| f.db_name.clone())
1219 })
1220 .collect();
1221 let pk_conflict_cols = pk_db_names
1222 .iter()
1223 .map(|c| format!("\"{}\"", c))
1224 .collect::<Vec<_>>()
1225 .join(", ");
1226
1227 let required: Vec<&Field> = scalar_fields
1229 .iter()
1230 .copied()
1231 .filter(|f| !f.has_default() && !f.is_updated_at)
1232 .collect();
1233 let optional: Vec<&Field> = scalar_fields
1234 .iter()
1235 .copied()
1236 .filter(|f| f.has_default() && !f.is_updated_at)
1237 .collect();
1238 let updated_at: Vec<&Field> = scalar_fields
1239 .iter()
1240 .copied()
1241 .filter(|f| f.is_updated_at)
1242 .collect();
1243
1244 let mut col_pushes = vec![];
1245 let mut val_pushes = vec![];
1246
1247 for f in &required {
1248 let db_name = &f.db_name;
1249 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1250 col_pushes.push(quote! { cols.push(#db_name); });
1251 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1252 }
1253 for f in &optional {
1254 let db_name = &f.db_name;
1255 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1256 let default_expr = gen_default_expr(f, &f.field_type);
1257 col_pushes.push(quote! { cols.push(#db_name); });
1258 val_pushes.push(quote! {
1259 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1260 sep.push_bind(val);
1261 });
1262 }
1263 for f in &updated_at {
1264 let db_name = &f.db_name;
1265 col_pushes.push(quote! { cols.push(#db_name); });
1266 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1267 }
1268
1269 let updatable: Vec<&Field> = scalar_fields
1271 .iter()
1272 .copied()
1273 .filter(|f| !f.is_id && !f.is_updated_at)
1274 .collect();
1275
1276 let set_arms: Vec<TokenStream> = updatable
1277 .iter()
1278 .map(|f| {
1279 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1280 let db_name = &f.db_name;
1281 quote! {
1282 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1283 if !first_set { qb.push(", "); }
1284 first_set = false;
1285 qb.push(concat!("\"", #db_name, "\" = "));
1286 qb.push_bind(v);
1287 }
1288 }
1289 })
1290 .collect();
1291
1292 let updated_at_set: Vec<TokenStream> = updated_at
1293 .iter()
1294 .map(|f| {
1295 let db_name = &f.db_name;
1296 quote! {
1297 if !first_set { qb.push(", "); }
1298 first_set = false;
1299 qb.push(concat!("\"", #db_name, "\" = "));
1300 qb.push_bind(chrono::Utc::now());
1301 }
1302 })
1303 .collect();
1304
1305 let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
1306 let conflict_clause = format!(" ON CONFLICT ({}) DO UPDATE SET ", pk_conflict_cols);
1307 let noop_set = format!(
1308 r#""{}" = "{}""#,
1309 pk_db_names.first().unwrap_or(&"id".to_string()),
1310 pk_db_names.first().unwrap_or(&"id".to_string()),
1311 );
1312
1313 quote! {
1314 macro_rules! build_upsert {
1315 ($qb_type:ty) => {{
1316 let mut cols: Vec<&str> = Vec::new();
1317 #(#col_pushes)*
1318
1319 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1320 qb.push(" (");
1321 for (i, col) in cols.iter().enumerate() {
1322 if i > 0 { qb.push(", "); }
1323 qb.push("\"");
1324 qb.push(*col);
1325 qb.push("\"");
1326 }
1327 qb.push(") VALUES (");
1328 {
1329 let mut sep = qb.separated(", ");
1330 #(#val_pushes)*
1331 }
1332 qb.push(")");
1333 qb.push(#conflict_clause);
1334
1335 let mut first_set = true;
1336 #(#set_arms)*
1337 #(#updated_at_set)*
1338
1339 if first_set {
1340 qb.push(#noop_set);
1342 }
1343
1344 qb.push(" RETURNING *");
1345 qb
1346 }};
1347 }
1348
1349 match client {
1350 DatabaseClient::Postgres(_) => {
1351 let qb = build_upsert!(sqlx::Postgres);
1352 client.fetch_one_pg(qb).await
1353 }
1354 DatabaseClient::Sqlite(_) => {
1355 let qb = build_upsert!(sqlx::Sqlite);
1356 client.fetch_one_sqlite(qb).await
1357 }
1358 }
1359 }
1360}
1361
1362fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1363 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1364 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1365 let _where_input = format_ident!("{}WhereInput", model.name);
1366 let table_name = &model.db_name;
1367
1368 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1370 .iter()
1371 .filter_map(|f| match &f.field_type {
1372 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1373 Some((*f, AggregateKind::Numeric))
1374 }
1375 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1376 _ => None,
1377 })
1378 .collect();
1379
1380 if agg_fields.is_empty() {
1381 return quote! {};
1382 }
1383
1384 let enum_variants: Vec<TokenStream> = agg_fields
1386 .iter()
1387 .map(|(f, _)| {
1388 let variant = format_ident!("{}", to_pascal_case(&f.name));
1389 quote! { #variant }
1390 })
1391 .collect();
1392
1393 let db_name_arms: Vec<TokenStream> = agg_fields
1395 .iter()
1396 .map(|(f, _)| {
1397 let variant = format_ident!("{}", to_pascal_case(&f.name));
1398 let db_name = &f.db_name;
1399 quote! { Self::#variant => #db_name }
1400 })
1401 .collect();
1402
1403 let mut result_fields = Vec::new();
1405 for (f, kind) in &agg_fields {
1406 let snake = to_snake_case(&f.name);
1407 let orig_ty = rust_type_tokens(
1408 &Field {
1409 is_optional: false,
1410 ..(*f).clone()
1411 },
1412 ModuleDepth::TopLevel,
1413 );
1414
1415 match kind {
1416 AggregateKind::Numeric => {
1417 let avg_name = format_ident!("avg_{}", snake);
1418 let sum_name = format_ident!("sum_{}", snake);
1419 let min_name = format_ident!("min_{}", snake);
1420 let max_name = format_ident!("max_{}", snake);
1421 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1422 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1423 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1424 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1425 }
1426 AggregateKind::DateTime => {
1427 let min_name = format_ident!("min_{}", snake);
1428 let max_name = format_ident!("max_{}", snake);
1429 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1430 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1431 }
1432 }
1433 }
1434
1435 let numeric_arms: Vec<TokenStream> = agg_fields
1437 .iter()
1438 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1439 .map(|(f, _)| {
1440 let variant = format_ident!("{}", to_pascal_case(&f.name));
1441 quote! { Self::#variant => true }
1442 })
1443 .collect();
1444
1445 let has_numeric = !numeric_arms.is_empty();
1446 let is_numeric_method = if has_numeric {
1447 quote! {
1448 fn is_numeric(&self) -> bool {
1449 match self {
1450 #(#numeric_arms,)*
1451 #[allow(unreachable_patterns)]
1452 _ => false,
1453 }
1454 }
1455 }
1456 } else {
1457 quote! {
1458 fn is_numeric(&self) -> bool { false }
1459 }
1460 };
1461
1462 let mut alias_arms = Vec::new();
1464 for (f, kind) in &agg_fields {
1465 let variant = format_ident!("{}", to_pascal_case(&f.name));
1466 let snake = to_snake_case(&f.name);
1467 let prefixes = match kind {
1468 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1469 AggregateKind::DateTime => vec!["min", "max"],
1470 };
1471 for prefix in prefixes {
1472 let alias_str = format!("{}_{}", prefix, snake);
1473 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1474 }
1475 }
1476
1477 let agg_select_base = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1478
1479 quote! {
1480 #[derive(Debug, Clone, Copy)]
1481 pub enum #aggregate_field_name {
1482 #(#enum_variants),*
1483 }
1484
1485 impl #aggregate_field_name {
1486 pub fn db_name(&self) -> &'static str {
1487 match self {
1488 #(#db_name_arms,)*
1489 }
1490 }
1491
1492 fn alias(&self, prefix: &'static str) -> &'static str {
1493 match (prefix, self) {
1494 #(#alias_arms,)*
1495 _ => unreachable!(),
1496 }
1497 }
1498
1499 #is_numeric_method
1500 }
1501
1502 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1503 pub struct #aggregate_result_name {
1504 #(#result_fields,)*
1505 }
1506
1507 pub struct AggregateQuery<'a> {
1508 client: &'a DatabaseClient,
1509 r#where: filter::#_where_input,
1510 ops: Vec<(&'static str, &'static str, &'static str)>,
1511 }
1512
1513 impl<'a> AggregateQuery<'a> {
1514 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1515 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1516 let db_name = field.db_name();
1517 let alias = field.alias("avg");
1518 self.ops.push(("AVG", db_name, alias));
1519 self
1520 }
1521
1522 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1523 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1524 let db_name = field.db_name();
1525 let alias = field.alias("sum");
1526 self.ops.push(("SUM", db_name, alias));
1527 self
1528 }
1529
1530 pub fn min(mut self, field: #aggregate_field_name) -> Self {
1531 let db_name = field.db_name();
1532 let alias = field.alias("min");
1533 self.ops.push(("MIN", db_name, alias));
1534 self
1535 }
1536
1537 pub fn max(mut self, field: #aggregate_field_name) -> Self {
1538 let db_name = field.db_name();
1539 let alias = field.alias("max");
1540 self.ops.push(("MAX", db_name, alias));
1541 self
1542 }
1543
1544 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1545 if self.ops.is_empty() {
1546 return Err(FerriormError::Query("No aggregate operations specified".into()));
1547 }
1548
1549 let selections: Vec<String> = self.ops.iter()
1550 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1551 .collect();
1552 let select_clause = selections.join(", ");
1553 let base_sql = format!(#agg_select_base, select_clause);
1554
1555 match self.client {
1556 DatabaseClient::Postgres(_) => {
1557 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1558 self.r#where.build_where(&mut qb);
1559 self.client.fetch_one_pg(qb).await
1560 }
1561 DatabaseClient::Sqlite(_) => {
1562 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1563 self.r#where.build_where(&mut qb);
1564 self.client.fetch_one_sqlite(qb).await
1565 }
1566 }
1567 }
1568 }
1569 }
1570}
1571
1572fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1575 let select_name = format_ident!("{}Select", model.name);
1576 let partial_name = format_ident!("{}Partial", model.name);
1577 let _where_input = format_ident!("{}WhereInput", model.name);
1578 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1579 let order_by_name = format_ident!("{}OrderByInput", model.name);
1580 let table_name = &model.db_name;
1581
1582 let select_fields: Vec<TokenStream> = scalar_fields
1584 .iter()
1585 .map(|f| {
1586 let name = format_ident!("{}", to_snake_case(&f.name));
1587 quote! { pub #name: bool }
1588 })
1589 .collect();
1590
1591 let partial_fields: Vec<TokenStream> = scalar_fields
1594 .iter()
1595 .map(|f| {
1596 let name = format_ident!("{}", to_snake_case(&f.name));
1597 let db_name = &f.db_name;
1598 let base_ty = rust_type_tokens(
1600 &Field {
1601 is_optional: false,
1602 ..(*f).clone()
1603 },
1604 ModuleDepth::TopLevel,
1605 );
1606 let rename = if db_name != &to_snake_case(&f.name) {
1607 quote! { #[sqlx(rename = #db_name)] }
1608 } else {
1609 quote! {}
1610 };
1611 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1613 })
1614 .collect();
1615
1616 let select_col_arms: Vec<TokenStream> = scalar_fields
1618 .iter()
1619 .map(|f| {
1620 let name = format_ident!("{}", to_snake_case(&f.name));
1621 let db_name = &f.db_name;
1622 let col_expr = format!(r#""{}""#, db_name);
1623 quote! {
1624 if select.#name { cols.push(#col_expr); }
1625 }
1626 })
1627 .collect();
1628
1629 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{}" WHERE 1=1"#, table_name);
1630
1631 quote! {
1632 #[derive(Debug, Clone, Default)]
1633 pub struct #select_name {
1634 #(#select_fields,)*
1635 }
1636
1637 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
1638 #[sqlx(rename_all = "snake_case")]
1639 pub struct #partial_name {
1640 #(#partial_fields,)*
1641 }
1642
1643 fn build_select_columns(select: &#select_name) -> String {
1644 let mut cols = Vec::new();
1645 #(#select_col_arms)*
1646 if cols.is_empty() {
1647 "*".to_string()
1648 } else {
1649 cols.join(", ")
1650 }
1651 }
1652
1653 pub struct FindManySelectQuery<'a> {
1656 client: &'a DatabaseClient,
1657 r#where: filter::#_where_input,
1658 order_by: Vec<order::#order_by_name>,
1659 skip: Option<i64>,
1660 take: Option<i64>,
1661 select: #select_name,
1662 }
1663
1664 impl<'a> FindManySelectQuery<'a> {
1665 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1666 self.order_by.push(order);
1667 self
1668 }
1669
1670 pub fn skip(mut self, n: i64) -> Self {
1671 self.skip = Some(n);
1672 self
1673 }
1674
1675 pub fn take(mut self, n: i64) -> Self {
1676 self.take = Some(n);
1677 self
1678 }
1679
1680 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
1681 let cols = build_select_columns(&self.select);
1682 let base_sql = format!(#select_sql_prefix, cols);
1683
1684 match self.client {
1685 DatabaseClient::Postgres(_) => {
1686 let qb = build_select_query::<sqlx::Postgres>(
1687 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1688 );
1689 self.client.fetch_all_pg(qb).await
1690 }
1691 DatabaseClient::Sqlite(_) => {
1692 let qb = build_select_query::<sqlx::Sqlite>(
1693 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
1694 );
1695 self.client.fetch_all_sqlite(qb).await
1696 }
1697 }
1698 }
1699 }
1700
1701 pub struct FindUniqueSelectQuery<'a> {
1704 client: &'a DatabaseClient,
1705 r#where: filter::#_where_unique,
1706 select: #select_name,
1707 }
1708
1709 impl<'a> FindUniqueSelectQuery<'a> {
1710 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1711 let cols = build_select_columns(&self.select);
1712 let base_sql = format!(#select_sql_prefix, cols);
1713
1714 match self.client {
1715 DatabaseClient::Postgres(_) => {
1716 let qb = build_unique_select_query::<sqlx::Postgres>(
1717 &base_sql, &self.r#where,
1718 );
1719 self.client.fetch_optional_pg(qb).await
1720 }
1721 DatabaseClient::Sqlite(_) => {
1722 let qb = build_unique_select_query::<sqlx::Sqlite>(
1723 &base_sql, &self.r#where,
1724 );
1725 self.client.fetch_optional_sqlite(qb).await
1726 }
1727 }
1728 }
1729 }
1730
1731 pub struct FindFirstSelectQuery<'a> {
1734 client: &'a DatabaseClient,
1735 r#where: filter::#_where_input,
1736 order_by: Vec<order::#order_by_name>,
1737 select: #select_name,
1738 }
1739
1740 impl<'a> FindFirstSelectQuery<'a> {
1741 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
1742 self.order_by.push(order);
1743 self
1744 }
1745
1746 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
1747 let cols = build_select_columns(&self.select);
1748 let base_sql = format!(#select_sql_prefix, cols);
1749
1750 match self.client {
1751 DatabaseClient::Postgres(_) => {
1752 let qb = build_select_query::<sqlx::Postgres>(
1753 &base_sql, &self.r#where, &self.order_by, Some(1), None,
1754 );
1755 self.client.fetch_optional_pg(qb).await
1756 }
1757 DatabaseClient::Sqlite(_) => {
1758 let qb = build_select_query::<sqlx::Sqlite>(
1759 &base_sql, &self.r#where, &self.order_by, Some(1), None,
1760 );
1761 self.client.fetch_optional_sqlite(qb).await
1762 }
1763 }
1764 }
1765 }
1766 }
1767}