1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27 let data_struct = gen_data_struct(model, &scalar_fields);
28 let filter_module = gen_filter_module(model, &scalar_fields);
29 let data_module = gen_data_module(model, &scalar_fields);
30 let order_module = gen_order_module(model, &scalar_fields);
31 let actions_struct = gen_actions(model, &scalar_fields);
32 let query_builders = gen_query_builders(model, &scalar_fields);
33 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34 let select_types = gen_select_types(model, &scalar_fields);
35
36 quote! {
37 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
38
39 use serde::{Deserialize, Serialize};
40 use ferriorm_runtime::prelude::*;
41 use ferriorm_runtime::prelude::sqlx;
42 use ferriorm_runtime::prelude::chrono;
43 use ferriorm_runtime::prelude::uuid;
44
45 #data_struct
46 #filter_module
47 #data_module
48 #order_module
49 #actions_struct
50 #query_builders
51 #aggregate_types
52 #select_types
53 }
54}
55
56fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
59 let struct_name = format_ident!("{}", model.name);
60 let table_name = &model.db_name;
61
62 let fields: Vec<TokenStream> = scalar_fields
63 .iter()
64 .map(|f| {
65 let name = format_ident!("{}", to_snake_case(&f.name));
66 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
67 let db_name = &f.db_name;
68 if db_name == &to_snake_case(&f.name) {
69 quote! { pub #name: #ty }
70 } else {
71 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
72 }
73 })
74 .collect();
75
76 quote! {
77 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
78 #[sqlx(rename_all = "snake_case")]
79 pub struct #struct_name {
80 #(#fields),*
81 }
82
83 impl #struct_name {
84 pub const TABLE_NAME: &'static str = #table_name;
85 }
86 }
87}
88
89#[allow(clippy::too_many_lines)]
92fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
93 let where_input = format_ident!("{}WhereInput", model.name);
94 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
95
96 let where_fields: Vec<TokenStream> = scalar_fields
97 .iter()
98 .filter_map(|f| {
99 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
100 let name = format_ident!("{}", to_snake_case(&f.name));
101 Some(quote! { pub #name: Option<#filter_ty> })
102 })
103 .collect();
104
105 let single_unique_variants: Vec<TokenStream> = scalar_fields
106 .iter()
107 .filter(|f| f.is_id || f.is_unique)
108 .map(|f| {
109 let variant = format_ident!("{}", to_pascal_case(&f.name));
110 let ty = rust_type_tokens(f, ModuleDepth::Nested);
111 quote! { #variant(#ty) }
112 })
113 .collect();
114
115 let compound_unique_variants: Vec<TokenStream> = model
116 .unique_constraints
117 .iter()
118 .map(|uc| {
119 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
120 let struct_fields = compound_variant_fields(model, &uc.fields);
121 quote! { #variant { #(#struct_fields),* } }
122 })
123 .collect();
124
125 let unique_variants: Vec<TokenStream> = single_unique_variants
126 .into_iter()
127 .chain(compound_unique_variants)
128 .collect();
129
130 let db_bounds = collect_db_bounds(scalar_fields);
132 let where_arms = gen_where_arms(scalar_fields);
133 let unique_arms = gen_unique_where_arms(model, scalar_fields);
134 let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
135 let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
136
137 quote! {
138 pub mod filter {
139 use ferriorm_runtime::prelude::*;
140
141 #[derive(Debug, Clone, Default)]
142 pub struct #where_input {
143 #(#where_fields,)*
144 pub and: Option<Vec<#where_input>>,
145 pub or: Option<Vec<#where_input>>,
146 pub not: Option<Box<#where_input>>,
147 }
148
149 #[derive(Debug, Clone)]
150 pub enum #where_unique {
151 #(#unique_variants),*
152 }
153
154 impl #where_input {
155 pub(crate) fn build_where<'args, DB: sqlx::Database>(
156 &self,
157 qb: &mut sqlx::QueryBuilder<'args, DB>,
158 )
159 where
160 #(#db_bounds,)*
161 {
162 #(#where_arms)*
163
164 if let Some(conditions) = &self.and {
165 for c in conditions {
166 c.build_where(qb);
167 }
168 }
169 if let Some(conditions) = &self.or {
170 if !conditions.is_empty() {
171 qb.push(" AND (");
172 for (i, c) in conditions.iter().enumerate() {
173 if i > 0 { qb.push(" OR "); }
174 qb.push("(1=1");
175 c.build_where(qb);
176 qb.push(")");
177 }
178 qb.push(")");
179 }
180 }
181 if let Some(c) = &self.not {
182 qb.push(" AND NOT (1=1");
183 c.build_where(qb);
184 qb.push(")");
185 }
186 }
187 }
188
189 impl #where_unique {
190 pub(crate) fn build_where<'args, DB: sqlx::Database>(
191 &self,
192 qb: &mut sqlx::QueryBuilder<'args, DB>,
193 )
194 where
195 #(#db_bounds,)*
196 {
197 match self {
198 #(#unique_arms)*
199 }
200 }
201 }
202
203 impl #where_unique {
204 #[allow(dead_code)]
205 pub(crate) fn conflict_target(&self) -> &'static str {
206 match self {
207 #(#conflict_target_arms)*
208 }
209 }
210
211 #[allow(dead_code)]
212 pub(crate) fn first_conflict_col(&self) -> &'static str {
213 match self {
214 #(#first_conflict_col_arms)*
215 }
216 }
217 }
218 }
219 }
220}
221
222fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
224 let mut seen = std::collections::HashSet::new();
225 let mut bounds = Vec::new();
226
227 seen.insert("i64");
229 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
230
231 for f in scalar_fields {
232 match &f.field_type {
233 FieldKind::Scalar(scalar) => {
234 let key = scalar.rust_type();
235 if seen.insert(key)
236 && let Some(ty) = scalar_bound_tokens(scalar)
237 {
238 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
239 bounds.push(
241 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
242 );
243 }
244 }
245 FieldKind::Enum(_) | FieldKind::Model(_) => {}
246 }
247 }
248
249 bounds
250}
251
252fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
253 match scalar {
254 ScalarType::String => Some(quote! { String }),
255 ScalarType::Int => Some(quote! { i32 }),
256 ScalarType::BigInt => Some(quote! { i64 }),
257 ScalarType::Float => Some(quote! { f64 }),
258 ScalarType::Boolean => Some(quote! { bool }),
259 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
260 ScalarType::Bytes => Some(quote! { Vec<u8> }),
261 ScalarType::Json | ScalarType::Decimal => None,
262 }
263}
264
265fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
267 scalar_fields
268 .iter()
269 .filter_map(|f| {
270 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
272 return None;
273 }
274 let field_ident = format_ident!("{}", to_snake_case(&f.name));
275 let db_name = &f.db_name;
276 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
277 let is_comparable = matches!(
278 &f.field_type,
279 FieldKind::Scalar(
280 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
281 )
282 );
283
284 let mut arms = vec![];
285
286 if f.is_optional {
287 arms.push(quote! {
291 if let Some(v) = &filter.equals {
292 match v {
293 None => {
294 qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
295 }
296 Some(inner) => {
297 qb.push(concat!(" AND \"", #db_name, "\" = "));
298 qb.push_bind(inner.clone());
299 }
300 }
301 }
302 if let Some(v) = &filter.not {
303 match v {
304 None => {
305 qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
306 }
307 Some(inner) => {
308 qb.push(concat!(" AND \"", #db_name, "\" != "));
309 qb.push_bind(inner.clone());
310 }
311 }
312 }
313 });
314 } else {
315 arms.push(quote! {
316 if let Some(v) = &filter.equals {
317 qb.push(concat!(" AND \"", #db_name, "\" = "));
318 qb.push_bind(v.clone());
319 }
320 if let Some(v) = &filter.not {
321 qb.push(concat!(" AND \"", #db_name, "\" != "));
322 qb.push_bind(v.clone());
323 }
324 });
325 }
326
327 if is_string {
328 arms.push(quote! {
329 if let Some(v) = &filter.contains {
330 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
331 qb.push_bind(format!("%{}%", v));
332 }
333 if let Some(v) = &filter.starts_with {
334 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
335 qb.push_bind(format!("{}%", v));
336 }
337 if let Some(v) = &filter.ends_with {
338 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
339 qb.push_bind(format!("%{}", v));
340 }
341 });
342 }
343
344 if is_comparable {
345 arms.push(quote! {
346 if let Some(v) = &filter.gt {
347 qb.push(concat!(" AND \"", #db_name, "\" > "));
348 qb.push_bind(v.clone());
349 }
350 if let Some(v) = &filter.gte {
351 qb.push(concat!(" AND \"", #db_name, "\" >= "));
352 qb.push_bind(v.clone());
353 }
354 if let Some(v) = &filter.lt {
355 qb.push(concat!(" AND \"", #db_name, "\" < "));
356 qb.push_bind(v.clone());
357 }
358 if let Some(v) = &filter.lte {
359 qb.push(concat!(" AND \"", #db_name, "\" <= "));
360 qb.push_bind(v.clone());
361 }
362 });
363 }
364
365 Some(quote! {
366 if let Some(filter) = &self.#field_ident {
367 #(#arms)*
368 }
369 })
370 })
371 .collect()
372}
373
374fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
375 let mut arms: Vec<TokenStream> = scalar_fields
376 .iter()
377 .filter(|f| f.is_id || f.is_unique)
378 .map(|f| {
379 let variant = format_ident!("{}", to_pascal_case(&f.name));
380 let db_name = &f.db_name;
381 quote! {
382 Self::#variant(v) => {
383 qb.push(concat!(" AND \"", #db_name, "\" = "));
384 qb.push_bind(v.clone());
385 }
386 }
387 })
388 .collect();
389
390 for uc in &model.unique_constraints {
391 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
392 let idents: Vec<_> = uc
393 .fields
394 .iter()
395 .map(|name| format_ident!("{}", to_snake_case(name)))
396 .collect();
397 let binds: Vec<TokenStream> = uc
398 .fields
399 .iter()
400 .map(|name| {
401 let ident = format_ident!("{}", to_snake_case(name));
402 let db_name = resolve_db_name(model, name);
403 quote! {
404 qb.push(concat!(" AND \"", #db_name, "\" = "));
405 qb.push_bind(#ident.clone());
406 }
407 })
408 .collect();
409 arms.push(quote! {
410 Self::#variant { #(#idents),* } => {
411 #(#binds)*
412 }
413 });
414 }
415
416 arms
417}
418
419fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
420 let mut arms: Vec<TokenStream> = scalar_fields
421 .iter()
422 .filter(|f| f.is_id || f.is_unique)
423 .map(|f| {
424 let variant = format_ident!("{}", to_pascal_case(&f.name));
425 let target = format!("(\"{}\")", f.db_name);
426 quote! { Self::#variant(_) => #target, }
427 })
428 .collect();
429
430 for uc in &model.unique_constraints {
431 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
432 let cols: Vec<String> = uc
433 .fields
434 .iter()
435 .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
436 .collect();
437 let target = format!("({})", cols.join(", "));
438 arms.push(quote! { Self::#variant { .. } => #target, });
439 }
440
441 arms
442}
443
444fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
445 let mut arms: Vec<TokenStream> = scalar_fields
446 .iter()
447 .filter(|f| f.is_id || f.is_unique)
448 .map(|f| {
449 let variant = format_ident!("{}", to_pascal_case(&f.name));
450 let col = format!("\"{}\"", f.db_name);
451 quote! { Self::#variant(_) => #col, }
452 })
453 .collect();
454
455 for uc in &model.unique_constraints {
456 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
457 let first = uc
458 .fields
459 .first()
460 .map_or_else(String::new, |n| resolve_db_name(model, n));
461 let col = format!("\"{first}\"");
462 arms.push(quote! { Self::#variant { .. } => #col, });
463 }
464
465 arms
466}
467
468fn compound_variant_name(fields: &[String]) -> String {
470 fields.iter().map(|f| to_pascal_case(f)).collect()
471}
472
473fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
476 fields
477 .iter()
478 .filter_map(|field_name| {
479 let field = model.fields.iter().find(|f| f.name == *field_name)?;
480 let ident = format_ident!("{}", to_snake_case(field_name));
481 let ty = rust_type_tokens(field, ModuleDepth::Nested);
482 Some(quote! { #ident: #ty })
483 })
484 .collect()
485}
486
487fn resolve_db_name(model: &Model, field_name: &str) -> String {
489 model
490 .fields
491 .iter()
492 .find(|f| f.name == field_name)
493 .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
494}
495
496fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
499 let create_name = format_ident!("{}CreateInput", model.name);
500 let update_name = format_ident!("{}UpdateInput", model.name);
501
502 let required_fields: Vec<TokenStream> = scalar_fields
503 .iter()
504 .filter(|f| !f.has_default() && !f.is_updated_at)
505 .map(|f| {
506 let name = format_ident!("{}", to_snake_case(&f.name));
507 let ty = rust_type_tokens(f, ModuleDepth::Nested);
508 quote! { pub #name: #ty }
509 })
510 .collect();
511
512 let optional_fields: Vec<TokenStream> = scalar_fields
513 .iter()
514 .filter(|f| f.has_default() && !f.is_updated_at)
515 .map(|f| {
516 let name = format_ident!("{}", to_snake_case(&f.name));
517 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
518 quote! { pub #name: Option<#base_ty> }
519 })
520 .collect();
521
522 let update_fields: Vec<TokenStream> = scalar_fields
523 .iter()
524 .filter(|f| !f.is_id && !f.is_updated_at)
525 .map(|f| {
526 let name = format_ident!("{}", to_snake_case(&f.name));
527 let ty = rust_type_tokens(f, ModuleDepth::Nested);
528 quote! { pub #name: Option<SetValue<#ty>> }
529 })
530 .collect();
531
532 quote! {
533 pub mod data {
534 use ferriorm_runtime::prelude::*;
535
536 #[derive(Debug, Clone)]
537 pub struct #create_name {
538 #(#required_fields,)*
539 #(#optional_fields,)*
540 }
541
542 #[derive(Debug, Clone, Default)]
546 pub struct #update_name {
547 #(#update_fields,)*
548 }
549 }
550 }
551}
552
553fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
556 let order_name = format_ident!("{}OrderByInput", model.name);
557
558 let variants: Vec<TokenStream> = scalar_fields
559 .iter()
560 .map(|f| {
561 let variant = format_ident!("{}", to_pascal_case(&f.name));
562 quote! { #variant(SortOrder) }
563 })
564 .collect();
565
566 let order_arms: Vec<TokenStream> = scalar_fields
567 .iter()
568 .map(|f| {
569 let variant = format_ident!("{}", to_pascal_case(&f.name));
570 let db_name = &f.db_name;
571 quote! {
572 Self::#variant(order) => {
573 qb.push(concat!("\"", #db_name, "\" "));
574 qb.push(order.as_sql());
575 }
576 }
577 })
578 .collect();
579
580 quote! {
581 pub mod order {
582 use ferriorm_runtime::prelude::*;
583
584 #[derive(Debug, Clone)]
585 pub enum #order_name {
586 #(#variants),*
587 }
588
589 impl #order_name {
590 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
591 &self,
592 qb: &mut sqlx::QueryBuilder<'args, DB>,
593 ) {
594 match self {
595 #(#order_arms)*
596 }
597 }
598 }
599 }
600 }
601}
602
603fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
606 let _model_ident = format_ident!("{}", model.name);
607 let actions_name = format_ident!("{}Actions", model.name);
608 let where_input = format_ident!("{}WhereInput", model.name);
609 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
610 let create_input = format_ident!("{}CreateInput", model.name);
611 let update_input = format_ident!("{}UpdateInput", model.name);
612 let _order_by = format_ident!("{}OrderByInput", model.name);
613
614 let has_agg_fields = scalar_fields.iter().any(|f| {
616 matches!(
617 &f.field_type,
618 FieldKind::Scalar(
619 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
620 )
621 )
622 });
623 let aggregate_method = if has_agg_fields {
624 quote! {
625 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
626 AggregateQuery { client: self.client, r#where, ops: vec![] }
627 }
628 }
629 } else {
630 quote! {}
631 };
632
633 quote! {
634 pub struct #actions_name<'a> {
635 client: &'a DatabaseClient,
636 }
637
638 impl<'a> #actions_name<'a> {
639 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
640
641 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
642 FindUniqueQuery { client: self.client, r#where }
643 }
644
645 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
646 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
647 }
648
649 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
650 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
651 }
652
653 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
654 CreateQuery { client: self.client, data }
655 }
656
657 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
658 UpdateQuery { client: self.client, r#where, data }
659 }
660
661 pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
665 UpdateFirstQuery { client: self.client, r#where, data }
666 }
667
668 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
669 DeleteQuery { client: self.client, r#where }
670 }
671
672 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
673 CountQuery { client: self.client, r#where }
674 }
675
676 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
677 CreateManyQuery { client: self.client, data }
678 }
679
680 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
681 UpdateManyQuery { client: self.client, r#where, data }
682 }
683
684 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
685 DeleteManyQuery { client: self.client, r#where }
686 }
687
688 pub fn upsert(
689 &self,
690 r#where: filter::#where_unique,
691 create: data::#create_input,
692 update: data::#update_input,
693 ) -> UpsertQuery<'a> {
694 UpsertQuery { client: self.client, r#where, create, update }
695 }
696
697 #aggregate_method
698 }
699 }
700}
701
702#[allow(clippy::too_many_lines)]
705fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
706 let model_ident = format_ident!("{}", model.name);
707 let table_name = &model.db_name;
708 let _where_input = format_ident!("{}WhereInput", model.name);
709 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
710 let _create_input = format_ident!("{}CreateInput", model.name);
711 let _update_input = format_ident!("{}UpdateInput", model.name);
712 let order_by = format_ident!("{}OrderByInput", model.name);
713 let _select_struct = format_ident!("{}Select", model.name);
714 let _partial_struct = format_ident!("{}Partial", model.name);
715 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
716 let _aggregate_field = format_ident!("{}AggregateField", model.name);
717 let db_bounds = collect_db_bounds(scalar_fields);
718
719 let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
720 let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
721 let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
722
723 let insert_code = gen_insert_code(model, scalar_fields, table_name);
724 let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
725 let update_code = gen_update_code(model, scalar_fields, table_name);
726 let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
727 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
728 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
729
730 quote! {
731 fn build_order_by<'args, DB: sqlx::Database>(
733 orders: &[order::#order_by],
734 qb: &mut sqlx::QueryBuilder<'args, DB>,
735 ) {
736 if !orders.is_empty() {
737 qb.push(" ORDER BY ");
738 for (i, ob) in orders.iter().enumerate() {
739 if i > 0 { qb.push(", "); }
740 ob.build_order_by(qb);
741 }
742 }
743 }
744
745 fn build_select_query<'args, DB: sqlx::Database>(
747 base_sql: &str,
748 where_input: &filter::#_where_input,
749 orders: &[order::#order_by],
750 take: Option<i64>,
751 skip: Option<i64>,
752 ) -> sqlx::QueryBuilder<'args, DB>
753 where
754 #(#db_bounds,)*
755 {
756 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
757 where_input.build_where(&mut qb);
758 build_order_by(orders, &mut qb);
759 if let Some(take) = take {
760 qb.push(" LIMIT ");
761 qb.push_bind(take);
762 }
763 if let Some(skip) = skip {
764 qb.push(" OFFSET ");
765 qb.push_bind(skip);
766 }
767 qb
768 }
769
770 fn build_unique_select_query<'args, DB: sqlx::Database>(
772 base_sql: &str,
773 where_unique: &filter::#_where_unique,
774 ) -> sqlx::QueryBuilder<'args, DB>
775 where
776 #(#db_bounds,)*
777 {
778 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
779 where_unique.build_where(&mut qb);
780 qb.push(" LIMIT 1");
781 qb
782 }
783
784 fn build_delete_query<'args, DB: sqlx::Database>(
786 base_sql: &str,
787 where_unique: &filter::#_where_unique,
788 ) -> sqlx::QueryBuilder<'args, DB>
789 where
790 #(#db_bounds,)*
791 {
792 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
793 where_unique.build_where(&mut qb);
794 qb.push(" RETURNING *");
795 qb
796 }
797
798 fn build_count_query<'args, DB: sqlx::Database>(
800 base_sql: &str,
801 where_input: &filter::#_where_input,
802 ) -> sqlx::QueryBuilder<'args, DB>
803 where
804 #(#db_bounds,)*
805 {
806 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
807 where_input.build_where(&mut qb);
808 qb
809 }
810
811 fn build_delete_many_query<'args, DB: sqlx::Database>(
813 base_sql: &str,
814 where_input: &filter::#_where_input,
815 ) -> sqlx::QueryBuilder<'args, DB>
816 where
817 #(#db_bounds,)*
818 {
819 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
820 where_input.build_where(&mut qb);
821 qb
822 }
823
824 pub struct FindUniqueQuery<'a> {
825 client: &'a DatabaseClient,
826 r#where: filter::#_where_unique,
827 }
828
829 impl<'a> FindUniqueQuery<'a> {
830 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
831 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
832 }
833
834 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
835 match self.client {
836 DatabaseClient::Postgres(_) => {
837 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
838 self.client.fetch_optional_pg(qb).await
839 }
840 DatabaseClient::Sqlite(_) => {
841 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
842 self.client.fetch_optional_sqlite(qb).await
843 }
844 }
845 }
846 }
847
848 pub struct FindFirstQuery<'a> {
849 client: &'a DatabaseClient,
850 r#where: filter::#_where_input,
851 order_by: Vec<order::#order_by>,
852 }
853
854 impl<'a> FindFirstQuery<'a> {
855 pub fn order_by(mut self, order: order::#order_by) -> Self {
856 self.order_by.push(order);
857 self
858 }
859
860 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
861 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
862 }
863
864 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
865 match self.client {
866 DatabaseClient::Postgres(_) => {
867 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
868 self.client.fetch_optional_pg(qb).await
869 }
870 DatabaseClient::Sqlite(_) => {
871 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
872 self.client.fetch_optional_sqlite(qb).await
873 }
874 }
875 }
876 }
877
878 pub struct FindManyQuery<'a> {
879 client: &'a DatabaseClient,
880 r#where: filter::#_where_input,
881 order_by: Vec<order::#order_by>,
882 skip: Option<i64>,
883 take: Option<i64>,
884 }
885
886 impl<'a> FindManyQuery<'a> {
887 pub fn order_by(mut self, order: order::#order_by) -> Self {
888 self.order_by.push(order);
889 self
890 }
891
892 pub fn skip(mut self, n: i64) -> Self {
893 self.skip = Some(n);
894 self
895 }
896
897 pub fn take(mut self, n: i64) -> Self {
898 self.take = Some(n);
899 self
900 }
901
902 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
903 FindManySelectQuery {
904 client: self.client,
905 r#where: self.r#where,
906 order_by: self.order_by,
907 skip: self.skip,
908 take: self.take,
909 select,
910 }
911 }
912
913 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
914 match self.client {
915 DatabaseClient::Postgres(_) => {
916 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
917 self.client.fetch_all_pg(qb).await
918 }
919 DatabaseClient::Sqlite(_) => {
920 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
921 self.client.fetch_all_sqlite(qb).await
922 }
923 }
924 }
925 }
926
927 pub struct CreateQuery<'a> {
928 client: &'a DatabaseClient,
929 data: data::#_create_input,
930 }
931
932 impl<'a> CreateQuery<'a> {
933 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
934 let client = self.client;
935 #insert_code
936 }
937
938 pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
942 CreateIgnoreQuery { client: self.client, data: self.data }
943 }
944 }
945
946 pub struct CreateIgnoreQuery<'a> {
947 client: &'a DatabaseClient,
948 data: data::#_create_input,
949 }
950
951 impl<'a> CreateIgnoreQuery<'a> {
952 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
953 let client = self.client;
954 #insert_ignore_code
955 }
956 }
957
958 pub struct UpdateQuery<'a> {
959 client: &'a DatabaseClient,
960 r#where: filter::#_where_unique,
961 data: data::#_update_input,
962 }
963
964 impl<'a> UpdateQuery<'a> {
965 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
966 let client = self.client;
967 #update_code
968 }
969 }
970
971 pub struct UpdateFirstQuery<'a> {
972 client: &'a DatabaseClient,
973 r#where: filter::#_where_input,
974 data: data::#_update_input,
975 }
976
977 impl<'a> UpdateFirstQuery<'a> {
978 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
979 let client = self.client;
980 #update_first_code
981 }
982 }
983
984 pub struct DeleteQuery<'a> {
985 client: &'a DatabaseClient,
986 r#where: filter::#_where_unique,
987 }
988
989 impl<'a> DeleteQuery<'a> {
990 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
991 match self.client {
992 DatabaseClient::Postgres(_) => {
993 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
994 self.client.fetch_one_pg(qb).await
995 }
996 DatabaseClient::Sqlite(_) => {
997 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
998 self.client.fetch_one_sqlite(qb).await
999 }
1000 }
1001 }
1002 }
1003
1004 #[derive(sqlx::FromRow)]
1005 struct CountResult { count: i64 }
1006
1007 pub struct CountQuery<'a> {
1008 client: &'a DatabaseClient,
1009 r#where: filter::#_where_input,
1010 }
1011
1012 impl<'a> CountQuery<'a> {
1013 pub async fn exec(self) -> Result<i64, FerriormError> {
1014 let row: CountResult = match self.client {
1015 DatabaseClient::Postgres(_) => {
1016 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1017 self.client.fetch_one_pg(qb).await?
1018 }
1019 DatabaseClient::Sqlite(_) => {
1020 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1021 self.client.fetch_one_sqlite(qb).await?
1022 }
1023 };
1024 Ok(row.count)
1025 }
1026 }
1027
1028 pub struct CreateManyQuery<'a> {
1029 client: &'a DatabaseClient,
1030 data: Vec<data::#_create_input>,
1031 }
1032
1033 impl<'a> CreateManyQuery<'a> {
1034 pub async fn exec(self) -> Result<u64, FerriormError> {
1035 if self.data.is_empty() { return Ok(0); }
1036 let count = self.data.len() as u64;
1037 for item in self.data {
1038 CreateQuery { client: self.client, data: item }.exec().await?;
1039 }
1040 Ok(count)
1041 }
1042 }
1043
1044 pub struct UpdateManyQuery<'a> {
1045 client: &'a DatabaseClient,
1046 r#where: filter::#_where_input,
1047 data: data::#_update_input,
1048 }
1049
1050 impl<'a> UpdateManyQuery<'a> {
1051 pub async fn exec(self) -> Result<u64, FerriormError> {
1052 let client = self.client;
1053 #update_many_code
1054 }
1055 }
1056
1057 pub struct UpsertQuery<'a> {
1058 client: &'a DatabaseClient,
1059 r#where: filter::#_where_unique,
1060 create: data::#_create_input,
1061 update: data::#_update_input,
1062 }
1063
1064 impl<'a> UpsertQuery<'a> {
1065 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1066 let client = self.client;
1067 #upsert_code
1068 }
1069 }
1070
1071 pub struct DeleteManyQuery<'a> {
1072 client: &'a DatabaseClient,
1073 r#where: filter::#_where_input,
1074 }
1075
1076 impl<'a> DeleteManyQuery<'a> {
1077 pub async fn exec(self) -> Result<u64, FerriormError> {
1078 match self.client {
1079 DatabaseClient::Postgres(_) => {
1080 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1081 self.client.execute_pg(qb).await
1082 }
1083 DatabaseClient::Sqlite(_) => {
1084 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1085 self.client.execute_sqlite(qb).await
1086 }
1087 }
1088 }
1089 }
1090 }
1091}
1092
1093fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1096 let _model_ident = format_ident!("{}", model.name);
1097
1098 let required: Vec<&Field> = scalar_fields
1100 .iter()
1101 .copied()
1102 .filter(|f| !f.has_default() && !f.is_updated_at)
1103 .collect();
1104
1105 let optional: Vec<&Field> = scalar_fields
1107 .iter()
1108 .copied()
1109 .filter(|f| f.has_default() && !f.is_updated_at)
1110 .collect();
1111
1112 let updated_at: Vec<&Field> = scalar_fields
1114 .iter()
1115 .copied()
1116 .filter(|f| f.is_updated_at)
1117 .collect();
1118
1119 let mut col_pushes = vec![];
1121 let mut val_pushes = vec![];
1122
1123 for f in &required {
1125 let db_name = &f.db_name;
1126 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1127 col_pushes.push(quote! { cols.push(#db_name); });
1128 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1129 }
1130
1131 for f in &optional {
1133 let db_name = &f.db_name;
1134 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1135 let default_expr = gen_default_expr(f, &f.field_type);
1136
1137 col_pushes.push(quote! { cols.push(#db_name); });
1138 val_pushes.push(quote! {
1139 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1140 sep.push_bind(val);
1141 });
1142 }
1143
1144 for f in &updated_at {
1146 let db_name = &f.db_name;
1147 col_pushes.push(quote! { cols.push(#db_name); });
1148 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1149 }
1150
1151 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1152
1153 quote! {
1156 macro_rules! build_insert {
1158 ($qb_type:ty) => {{
1159 let mut cols: Vec<&str> = Vec::new();
1160 #(#col_pushes)*
1161
1162 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1163 qb.push(" (");
1164 for (i, col) in cols.iter().enumerate() {
1165 if i > 0 { qb.push(", "); }
1166 qb.push("\"");
1167 qb.push(*col);
1168 qb.push("\"");
1169 }
1170 qb.push(") VALUES (");
1171 {
1172 let mut sep = qb.separated(", ");
1173 #(#val_pushes)*
1174 }
1175 qb.push(") RETURNING *");
1176 qb
1177 }};
1178 }
1179
1180 match client {
1181 DatabaseClient::Postgres(_) => {
1182 let qb = build_insert!(sqlx::Postgres);
1183 client.fetch_one_pg(qb).await
1184 }
1185 DatabaseClient::Sqlite(_) => {
1186 let qb = build_insert!(sqlx::Sqlite);
1187 client.fetch_one_sqlite(qb).await
1188 }
1189 }
1190 }
1191}
1192
1193fn gen_insert_ignore_code(
1194 _model: &Model,
1195 scalar_fields: &[&Field],
1196 table_name: &str,
1197) -> TokenStream {
1198 let required: Vec<&Field> = scalar_fields
1199 .iter()
1200 .copied()
1201 .filter(|f| !f.has_default() && !f.is_updated_at)
1202 .collect();
1203 let optional: Vec<&Field> = scalar_fields
1204 .iter()
1205 .copied()
1206 .filter(|f| f.has_default() && !f.is_updated_at)
1207 .collect();
1208 let updated_at: Vec<&Field> = scalar_fields
1209 .iter()
1210 .copied()
1211 .filter(|f| f.is_updated_at)
1212 .collect();
1213
1214 let mut col_pushes = vec![];
1215 let mut val_pushes = vec![];
1216
1217 for f in &required {
1218 let db_name = &f.db_name;
1219 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1220 col_pushes.push(quote! { cols.push(#db_name); });
1221 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1222 }
1223 for f in &optional {
1224 let db_name = &f.db_name;
1225 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1226 let default_expr = gen_default_expr(f, &f.field_type);
1227 col_pushes.push(quote! { cols.push(#db_name); });
1228 val_pushes.push(quote! {
1229 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1230 sep.push_bind(val);
1231 });
1232 }
1233 for f in &updated_at {
1234 let db_name = &f.db_name;
1235 col_pushes.push(quote! { cols.push(#db_name); });
1236 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1237 }
1238
1239 let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1240 let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1241
1242 quote! {
1243 macro_rules! build_insert_ignore {
1244 ($qb_type:ty, $head:expr, $tail:expr) => {{
1245 let mut cols: Vec<&str> = Vec::new();
1246 #(#col_pushes)*
1247
1248 let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1249 qb.push(" (");
1250 for (i, col) in cols.iter().enumerate() {
1251 if i > 0 { qb.push(", "); }
1252 qb.push("\"");
1253 qb.push(*col);
1254 qb.push("\"");
1255 }
1256 qb.push(") VALUES (");
1257 {
1258 let mut sep = qb.separated(", ");
1259 #(#val_pushes)*
1260 }
1261 qb.push(")");
1262 qb.push($tail);
1263 qb.push(" RETURNING *");
1264 qb
1265 }};
1266 }
1267
1268 match client {
1269 DatabaseClient::Postgres(_) => {
1270 let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1271 client.fetch_optional_pg(qb).await
1272 }
1273 DatabaseClient::Sqlite(_) => {
1274 let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1275 client.fetch_optional_sqlite(qb).await
1276 }
1277 }
1278 }
1279}
1280
1281fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1283 use ferriorm_core::ast::DefaultValue;
1284
1285 match &field.default {
1286 Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1287 quote! { uuid::Uuid::new_v4().to_string() }
1288 }
1289 Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1290 Some(DefaultValue::AutoIncrement) => quote! { 0i32 }, Some(DefaultValue::Literal(lit)) => {
1292 use ferriorm_core::ast::LiteralValue;
1293 match lit {
1294 LiteralValue::String(s) => quote! { #s.to_string() },
1295 LiteralValue::Int(i) => {
1296 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1298 match field_type {
1299 FieldKind::Scalar(ScalarType::Float) => {
1300 let val = *i as f64;
1301 quote! { #val }
1302 }
1303 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1304 FieldKind::Scalar(ScalarType::Int)
1306 if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1307 {
1308 quote! { #i }
1309 }
1310 _ => {
1311 let val = *i as i32;
1313 quote! { #val }
1314 }
1315 }
1316 }
1317 LiteralValue::Float(f) => quote! { #f },
1318 LiteralValue::Bool(b) => quote! { #b },
1319 }
1320 }
1321 Some(DefaultValue::EnumVariant(v)) => {
1322 let variant = format_ident!("{}", v);
1324 if let FieldKind::Enum(enum_name) = &field.field_type {
1325 let enum_ident = format_ident!("{}", enum_name);
1326 quote! { super::enums::#enum_ident::#variant }
1327 } else {
1328 quote! { Default::default() }
1329 }
1330 }
1331 None => quote! { Default::default() },
1332 }
1333}
1334
1335fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1338 let _model_ident = format_ident!("{}", model.name);
1339
1340 let updatable: Vec<&Field> = scalar_fields
1342 .iter()
1343 .copied()
1344 .filter(|f| !f.is_id && !f.is_updated_at)
1345 .collect();
1346
1347 let updated_at: Vec<&Field> = scalar_fields
1348 .iter()
1349 .copied()
1350 .filter(|f| f.is_updated_at)
1351 .collect();
1352
1353 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1354
1355 let set_arms: Vec<TokenStream> = updatable
1357 .iter()
1358 .map(|f| {
1359 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1360 let db_name = &f.db_name;
1361 quote! {
1362 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1363 if !first_set { qb.push(", "); }
1364 first_set = false;
1365 qb.push(concat!("\"", #db_name, "\" = "));
1366 qb.push_bind(v);
1367 }
1368 }
1369 })
1370 .collect();
1371
1372 let updated_at_arms: Vec<TokenStream> = updated_at
1373 .iter()
1374 .map(|f| {
1375 let db_name = &f.db_name;
1376 quote! {
1377 if !first_set { qb.push(", "); }
1378 first_set = false;
1379 qb.push(concat!("\"", #db_name, "\" = "));
1380 qb.push_bind(chrono::Utc::now());
1381 }
1382 })
1383 .collect();
1384
1385 quote! {
1388 macro_rules! build_update {
1389 ($qb_type:ty) => {{
1390 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1391 let mut first_set = true;
1392 #(#set_arms)*
1393 #(#updated_at_arms)*
1394
1395 if first_set {
1396 return Err(FerriormError::Query("No fields to update".into()));
1397 }
1398
1399 qb.push(" WHERE 1=1");
1400 self.r#where.build_where(&mut qb);
1401 qb.push(" RETURNING *");
1402 qb
1403 }};
1404 }
1405
1406 match client {
1407 DatabaseClient::Postgres(_) => {
1408 let qb = build_update!(sqlx::Postgres);
1409 client.fetch_one_pg(qb).await
1410 }
1411 DatabaseClient::Sqlite(_) => {
1412 let qb = build_update!(sqlx::Sqlite);
1413 client.fetch_one_sqlite(qb).await
1414 }
1415 }
1416 }
1417}
1418
1419fn gen_update_first_code(
1422 _model: &Model,
1423 scalar_fields: &[&Field],
1424 table_name: &str,
1425) -> TokenStream {
1426 let updatable: Vec<&Field> = scalar_fields
1427 .iter()
1428 .copied()
1429 .filter(|f| !f.is_id && !f.is_updated_at)
1430 .collect();
1431
1432 let updated_at: Vec<&Field> = scalar_fields
1433 .iter()
1434 .copied()
1435 .filter(|f| f.is_updated_at)
1436 .collect();
1437
1438 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1439
1440 let set_arms: Vec<TokenStream> = updatable
1441 .iter()
1442 .map(|f| {
1443 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1444 let db_name = &f.db_name;
1445 quote! {
1446 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1447 if !first_set { qb.push(", "); }
1448 first_set = false;
1449 qb.push(concat!("\"", #db_name, "\" = "));
1450 qb.push_bind(v);
1451 }
1452 }
1453 })
1454 .collect();
1455
1456 let updated_at_arms: Vec<TokenStream> = updated_at
1457 .iter()
1458 .map(|f| {
1459 let db_name = &f.db_name;
1460 quote! {
1461 if !first_set { qb.push(", "); }
1462 first_set = false;
1463 qb.push(concat!("\"", #db_name, "\" = "));
1464 qb.push_bind(chrono::Utc::now());
1465 }
1466 })
1467 .collect();
1468
1469 quote! {
1470 macro_rules! build_update_first {
1471 ($qb_type:ty) => {{
1472 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1473 let mut first_set = true;
1474 #(#set_arms)*
1475 #(#updated_at_arms)*
1476
1477 if first_set {
1478 return Err(FerriormError::Query("No fields to update".into()));
1479 }
1480
1481 qb.push(" WHERE 1=1");
1482 self.r#where.build_where(&mut qb);
1483 qb.push(" RETURNING *");
1484 qb
1485 }};
1486 }
1487
1488 match client {
1489 DatabaseClient::Postgres(_) => {
1490 let qb = build_update_first!(sqlx::Postgres);
1491 client.fetch_optional_pg(qb).await
1492 }
1493 DatabaseClient::Sqlite(_) => {
1494 let qb = build_update_first!(sqlx::Sqlite);
1495 client.fetch_optional_sqlite(qb).await
1496 }
1497 }
1498 }
1499}
1500
1501fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1504 let updatable: Vec<&Field> = scalar_fields
1506 .iter()
1507 .copied()
1508 .filter(|f| !f.is_id && !f.is_updated_at)
1509 .collect();
1510
1511 let updated_at: Vec<&Field> = scalar_fields
1512 .iter()
1513 .copied()
1514 .filter(|f| f.is_updated_at)
1515 .collect();
1516
1517 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1518
1519 let set_arms: Vec<TokenStream> = updatable
1521 .iter()
1522 .map(|f| {
1523 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1524 let db_name = &f.db_name;
1525 quote! {
1526 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1527 if !first_set { qb.push(", "); }
1528 first_set = false;
1529 qb.push(concat!("\"", #db_name, "\" = "));
1530 qb.push_bind(v);
1531 }
1532 }
1533 })
1534 .collect();
1535
1536 let updated_at_arms: Vec<TokenStream> = updated_at
1537 .iter()
1538 .map(|f| {
1539 let db_name = &f.db_name;
1540 quote! {
1541 if !first_set { qb.push(", "); }
1542 first_set = false;
1543 qb.push(concat!("\"", #db_name, "\" = "));
1544 qb.push_bind(chrono::Utc::now());
1545 }
1546 })
1547 .collect();
1548
1549 quote! {
1550 macro_rules! build_update_many {
1551 ($qb_type:ty) => {{
1552 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1553 let mut first_set = true;
1554 #(#set_arms)*
1555 #(#updated_at_arms)*
1556
1557 if first_set {
1558 return Ok(0);
1559 }
1560
1561 qb.push(" WHERE 1=1");
1562 self.r#where.build_where(&mut qb);
1563 qb
1564 }};
1565 }
1566
1567 match client {
1568 DatabaseClient::Postgres(_) => {
1569 let qb = build_update_many!(sqlx::Postgres);
1570 client.execute_pg(qb).await
1571 }
1572 DatabaseClient::Sqlite(_) => {
1573 let qb = build_update_many!(sqlx::Sqlite);
1574 client.execute_sqlite(qb).await
1575 }
1576 }
1577 }
1578}
1579
1580enum AggregateKind {
1584 Numeric,
1586 DateTime,
1588}
1589
1590#[allow(clippy::too_many_lines)]
1593fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1594 let required: Vec<&Field> = scalar_fields
1596 .iter()
1597 .copied()
1598 .filter(|f| !f.has_default() && !f.is_updated_at)
1599 .collect();
1600 let optional: Vec<&Field> = scalar_fields
1601 .iter()
1602 .copied()
1603 .filter(|f| f.has_default() && !f.is_updated_at)
1604 .collect();
1605 let updated_at: Vec<&Field> = scalar_fields
1606 .iter()
1607 .copied()
1608 .filter(|f| f.is_updated_at)
1609 .collect();
1610
1611 let mut col_pushes = vec![];
1612 let mut val_pushes = vec![];
1613
1614 for f in &required {
1615 let db_name = &f.db_name;
1616 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1617 col_pushes.push(quote! { cols.push(#db_name); });
1618 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1619 }
1620 for f in &optional {
1621 let db_name = &f.db_name;
1622 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1623 let default_expr = gen_default_expr(f, &f.field_type);
1624 col_pushes.push(quote! { cols.push(#db_name); });
1625 val_pushes.push(quote! {
1626 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1627 sep.push_bind(val);
1628 });
1629 }
1630 for f in &updated_at {
1631 let db_name = &f.db_name;
1632 col_pushes.push(quote! { cols.push(#db_name); });
1633 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1634 }
1635
1636 let updatable: Vec<&Field> = scalar_fields
1638 .iter()
1639 .copied()
1640 .filter(|f| !f.is_id && !f.is_updated_at)
1641 .collect();
1642
1643 let set_arms: Vec<TokenStream> = updatable
1644 .iter()
1645 .map(|f| {
1646 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1647 let db_name = &f.db_name;
1648 quote! {
1649 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1650 if !first_set { qb.push(", "); }
1651 first_set = false;
1652 qb.push(concat!("\"", #db_name, "\" = "));
1653 qb.push_bind(v);
1654 }
1655 }
1656 })
1657 .collect();
1658
1659 let updated_at_set: Vec<TokenStream> = updated_at
1660 .iter()
1661 .map(|f| {
1662 let db_name = &f.db_name;
1663 quote! {
1664 if !first_set { qb.push(", "); }
1665 first_set = false;
1666 qb.push(concat!("\"", #db_name, "\" = "));
1667 qb.push_bind(chrono::Utc::now());
1668 }
1669 })
1670 .collect();
1671
1672 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1673
1674 quote! {
1675 let conflict_target = self.r#where.conflict_target();
1676 let first_conflict_col = self.r#where.first_conflict_col();
1677
1678 macro_rules! build_upsert {
1679 ($qb_type:ty) => {{
1680 let mut cols: Vec<&str> = Vec::new();
1681 #(#col_pushes)*
1682
1683 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1684 qb.push(" (");
1685 for (i, col) in cols.iter().enumerate() {
1686 if i > 0 { qb.push(", "); }
1687 qb.push("\"");
1688 qb.push(*col);
1689 qb.push("\"");
1690 }
1691 qb.push(") VALUES (");
1692 {
1693 let mut sep = qb.separated(", ");
1694 #(#val_pushes)*
1695 }
1696 qb.push(")");
1697 qb.push(" ON CONFLICT ");
1698 qb.push(conflict_target);
1699 qb.push(" DO UPDATE SET ");
1700
1701 let mut first_set = true;
1702 #(#set_arms)*
1703 #(#updated_at_set)*
1704
1705 if first_set {
1706 qb.push(first_conflict_col);
1709 qb.push(" = ");
1710 qb.push(first_conflict_col);
1711 }
1712
1713 qb.push(" RETURNING *");
1714 qb
1715 }};
1716 }
1717
1718 match client {
1719 DatabaseClient::Postgres(_) => {
1720 let qb = build_upsert!(sqlx::Postgres);
1721 client.fetch_one_pg(qb).await
1722 }
1723 DatabaseClient::Sqlite(_) => {
1724 let qb = build_upsert!(sqlx::Sqlite);
1725 client.fetch_one_sqlite(qb).await
1726 }
1727 }
1728 }
1729}
1730
1731#[allow(clippy::too_many_lines)]
1732fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1733 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1734 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1735 let _where_input = format_ident!("{}WhereInput", model.name);
1736 let table_name = &model.db_name;
1737
1738 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1740 .iter()
1741 .filter_map(|f| match &f.field_type {
1742 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1743 Some((*f, AggregateKind::Numeric))
1744 }
1745 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1746 _ => None,
1747 })
1748 .collect();
1749
1750 if agg_fields.is_empty() {
1751 return quote! {};
1752 }
1753
1754 let enum_variants: Vec<TokenStream> = agg_fields
1756 .iter()
1757 .map(|(f, _)| {
1758 let variant = format_ident!("{}", to_pascal_case(&f.name));
1759 quote! { #variant }
1760 })
1761 .collect();
1762
1763 let db_name_arms: Vec<TokenStream> = agg_fields
1765 .iter()
1766 .map(|(f, _)| {
1767 let variant = format_ident!("{}", to_pascal_case(&f.name));
1768 let db_name = &f.db_name;
1769 quote! { Self::#variant => #db_name }
1770 })
1771 .collect();
1772
1773 let mut result_fields = Vec::new();
1775 for (f, kind) in &agg_fields {
1776 let snake = to_snake_case(&f.name);
1777 let orig_ty = rust_type_tokens(
1778 &Field {
1779 is_optional: false,
1780 ..(*f).clone()
1781 },
1782 ModuleDepth::TopLevel,
1783 );
1784
1785 match kind {
1786 AggregateKind::Numeric => {
1787 let avg_name = format_ident!("avg_{}", snake);
1788 let sum_name = format_ident!("sum_{}", snake);
1789 let min_name = format_ident!("min_{}", snake);
1790 let max_name = format_ident!("max_{}", snake);
1791 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1792 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1793 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1794 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1795 }
1796 AggregateKind::DateTime => {
1797 let min_name = format_ident!("min_{}", snake);
1798 let max_name = format_ident!("max_{}", snake);
1799 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1800 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1801 }
1802 }
1803 }
1804
1805 let numeric_arms: Vec<TokenStream> = agg_fields
1807 .iter()
1808 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1809 .map(|(f, _)| {
1810 let variant = format_ident!("{}", to_pascal_case(&f.name));
1811 quote! { Self::#variant => true }
1812 })
1813 .collect();
1814
1815 let has_numeric = !numeric_arms.is_empty();
1816 let is_numeric_method = if has_numeric {
1817 quote! {
1818 fn is_numeric(&self) -> bool {
1819 match self {
1820 #(#numeric_arms,)*
1821 #[allow(unreachable_patterns)]
1822 _ => false,
1823 }
1824 }
1825 }
1826 } else {
1827 quote! {
1828 fn is_numeric(&self) -> bool { false }
1829 }
1830 };
1831
1832 let mut alias_arms = Vec::new();
1834 for (f, kind) in &agg_fields {
1835 let variant = format_ident!("{}", to_pascal_case(&f.name));
1836 let snake = to_snake_case(&f.name);
1837 let prefixes = match kind {
1838 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1839 AggregateKind::DateTime => vec!["min", "max"],
1840 };
1841 for prefix in prefixes {
1842 let alias_str = format!("{prefix}_{snake}");
1843 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1844 }
1845 }
1846
1847 let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1848
1849 quote! {
1850 #[derive(Debug, Clone, Copy)]
1851 pub enum #aggregate_field_name {
1852 #(#enum_variants),*
1853 }
1854
1855 impl #aggregate_field_name {
1856 pub fn db_name(&self) -> &'static str {
1857 match self {
1858 #(#db_name_arms,)*
1859 }
1860 }
1861
1862 fn alias(&self, prefix: &'static str) -> &'static str {
1863 match (prefix, self) {
1864 #(#alias_arms,)*
1865 _ => unreachable!(),
1866 }
1867 }
1868
1869 #is_numeric_method
1870 }
1871
1872 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1873 pub struct #aggregate_result_name {
1874 #(#result_fields,)*
1875 }
1876
1877 pub struct AggregateQuery<'a> {
1878 client: &'a DatabaseClient,
1879 r#where: filter::#_where_input,
1880 ops: Vec<(&'static str, &'static str, &'static str)>,
1881 }
1882
1883 impl<'a> AggregateQuery<'a> {
1884 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1885 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1886 let db_name = field.db_name();
1887 let alias = field.alias("avg");
1888 self.ops.push(("AVG", db_name, alias));
1889 self
1890 }
1891
1892 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1893 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1894 let db_name = field.db_name();
1895 let alias = field.alias("sum");
1896 self.ops.push(("SUM", db_name, alias));
1897 self
1898 }
1899
1900 pub fn min(mut self, field: #aggregate_field_name) -> Self {
1901 let db_name = field.db_name();
1902 let alias = field.alias("min");
1903 self.ops.push(("MIN", db_name, alias));
1904 self
1905 }
1906
1907 pub fn max(mut self, field: #aggregate_field_name) -> Self {
1908 let db_name = field.db_name();
1909 let alias = field.alias("max");
1910 self.ops.push(("MAX", db_name, alias));
1911 self
1912 }
1913
1914 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1915 if self.ops.is_empty() {
1916 return Err(FerriormError::Query("No aggregate operations specified".into()));
1917 }
1918
1919 let selections: Vec<String> = self.ops.iter()
1920 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1921 .collect();
1922 let select_clause = selections.join(", ");
1923 let base_sql = format!(#agg_select_base, select_clause);
1924
1925 match self.client {
1926 DatabaseClient::Postgres(_) => {
1927 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1928 self.r#where.build_where(&mut qb);
1929 self.client.fetch_one_pg(qb).await
1930 }
1931 DatabaseClient::Sqlite(_) => {
1932 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1933 self.r#where.build_where(&mut qb);
1934 self.client.fetch_one_sqlite(qb).await
1935 }
1936 }
1937 }
1938 }
1939 }
1940}
1941
1942#[allow(clippy::too_many_lines)]
1945fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1946 let select_name = format_ident!("{}Select", model.name);
1947 let partial_name = format_ident!("{}Partial", model.name);
1948 let _where_input = format_ident!("{}WhereInput", model.name);
1949 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1950 let order_by_name = format_ident!("{}OrderByInput", model.name);
1951 let table_name = &model.db_name;
1952
1953 let select_fields: Vec<TokenStream> = scalar_fields
1955 .iter()
1956 .map(|f| {
1957 let name = format_ident!("{}", to_snake_case(&f.name));
1958 quote! { pub #name: bool }
1959 })
1960 .collect();
1961
1962 let partial_fields: Vec<TokenStream> = scalar_fields
1965 .iter()
1966 .map(|f| {
1967 let name = format_ident!("{}", to_snake_case(&f.name));
1968 let db_name = &f.db_name;
1969 let base_ty = rust_type_tokens(
1971 &Field {
1972 is_optional: false,
1973 ..(*f).clone()
1974 },
1975 ModuleDepth::TopLevel,
1976 );
1977 let rename = if db_name == &to_snake_case(&f.name) {
1978 quote! {}
1979 } else {
1980 quote! { #[sqlx(rename = #db_name)] }
1981 };
1982 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
1984 })
1985 .collect();
1986
1987 let select_col_arms: Vec<TokenStream> = scalar_fields
1989 .iter()
1990 .map(|f| {
1991 let name = format_ident!("{}", to_snake_case(&f.name));
1992 let db_name = &f.db_name;
1993 let col_expr = format!(r#""{db_name}""#);
1994 quote! {
1995 if select.#name { cols.push(#col_expr); }
1996 }
1997 })
1998 .collect();
1999
2000 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2001
2002 quote! {
2003 #[derive(Debug, Clone, Default)]
2004 pub struct #select_name {
2005 #(#select_fields,)*
2006 }
2007
2008 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2009 #[sqlx(rename_all = "snake_case")]
2010 pub struct #partial_name {
2011 #(#partial_fields,)*
2012 }
2013
2014 fn build_select_columns(select: &#select_name) -> String {
2015 let mut cols = Vec::new();
2016 #(#select_col_arms)*
2017 if cols.is_empty() {
2018 "*".to_string()
2019 } else {
2020 cols.join(", ")
2021 }
2022 }
2023
2024 pub struct FindManySelectQuery<'a> {
2027 client: &'a DatabaseClient,
2028 r#where: filter::#_where_input,
2029 order_by: Vec<order::#order_by_name>,
2030 skip: Option<i64>,
2031 take: Option<i64>,
2032 select: #select_name,
2033 }
2034
2035 impl<'a> FindManySelectQuery<'a> {
2036 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2037 self.order_by.push(order);
2038 self
2039 }
2040
2041 pub fn skip(mut self, n: i64) -> Self {
2042 self.skip = Some(n);
2043 self
2044 }
2045
2046 pub fn take(mut self, n: i64) -> Self {
2047 self.take = Some(n);
2048 self
2049 }
2050
2051 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2052 let cols = build_select_columns(&self.select);
2053 let base_sql = format!(#select_sql_prefix, cols);
2054
2055 match self.client {
2056 DatabaseClient::Postgres(_) => {
2057 let qb = build_select_query::<sqlx::Postgres>(
2058 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2059 );
2060 self.client.fetch_all_pg(qb).await
2061 }
2062 DatabaseClient::Sqlite(_) => {
2063 let qb = build_select_query::<sqlx::Sqlite>(
2064 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2065 );
2066 self.client.fetch_all_sqlite(qb).await
2067 }
2068 }
2069 }
2070 }
2071
2072 pub struct FindUniqueSelectQuery<'a> {
2075 client: &'a DatabaseClient,
2076 r#where: filter::#_where_unique,
2077 select: #select_name,
2078 }
2079
2080 impl<'a> FindUniqueSelectQuery<'a> {
2081 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2082 let cols = build_select_columns(&self.select);
2083 let base_sql = format!(#select_sql_prefix, cols);
2084
2085 match self.client {
2086 DatabaseClient::Postgres(_) => {
2087 let qb = build_unique_select_query::<sqlx::Postgres>(
2088 &base_sql, &self.r#where,
2089 );
2090 self.client.fetch_optional_pg(qb).await
2091 }
2092 DatabaseClient::Sqlite(_) => {
2093 let qb = build_unique_select_query::<sqlx::Sqlite>(
2094 &base_sql, &self.r#where,
2095 );
2096 self.client.fetch_optional_sqlite(qb).await
2097 }
2098 }
2099 }
2100 }
2101
2102 pub struct FindFirstSelectQuery<'a> {
2105 client: &'a DatabaseClient,
2106 r#where: filter::#_where_input,
2107 order_by: Vec<order::#order_by_name>,
2108 select: #select_name,
2109 }
2110
2111 impl<'a> FindFirstSelectQuery<'a> {
2112 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2113 self.order_by.push(order);
2114 self
2115 }
2116
2117 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2118 let cols = build_select_columns(&self.select);
2119 let base_sql = format!(#select_sql_prefix, cols);
2120
2121 match self.client {
2122 DatabaseClient::Postgres(_) => {
2123 let qb = build_select_query::<sqlx::Postgres>(
2124 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2125 );
2126 self.client.fetch_optional_pg(qb).await
2127 }
2128 DatabaseClient::Sqlite(_) => {
2129 let qb = build_select_query::<sqlx::Sqlite>(
2130 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2131 );
2132 self.client.fetch_optional_sqlite(qb).await
2133 }
2134 }
2135 }
2136 }
2137 }
2138}