1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22#[must_use]
24pub fn generate_model_module(model: &Model) -> TokenStream {
25 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
26
27 let data_struct = gen_data_struct(model, &scalar_fields);
28 let filter_module = gen_filter_module(model, &scalar_fields);
29 let data_module = gen_data_module(model, &scalar_fields);
30 let order_module = gen_order_module(model, &scalar_fields);
31 let actions_struct = gen_actions(model, &scalar_fields);
32 let query_builders = gen_query_builders(model, &scalar_fields);
33 let aggregate_types = gen_aggregate_types(model, &scalar_fields);
34 let select_types = gen_select_types(model, &scalar_fields);
35
36 quote! {
37 #![allow(unused_imports, dead_code, unused_variables, clippy::all, clippy::pedantic, clippy::nursery)]
38
39 use serde::{Deserialize, Serialize};
40 use ferriorm_runtime::prelude::*;
41 use ferriorm_runtime::prelude::sqlx;
42 use ferriorm_runtime::prelude::chrono;
43 use ferriorm_runtime::prelude::uuid;
44
45 #data_struct
46 #filter_module
47 #data_module
48 #order_module
49 #actions_struct
50 #query_builders
51 #aggregate_types
52 #select_types
53 }
54}
55
56fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
59 let struct_name = format_ident!("{}", model.name);
60 let table_name = &model.db_name;
61
62 let fields: Vec<TokenStream> = scalar_fields
63 .iter()
64 .map(|f| {
65 let name = format_ident!("{}", to_snake_case(&f.name));
66 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
67 let db_name = &f.db_name;
68 if db_name == &to_snake_case(&f.name) {
69 quote! { pub #name: #ty }
70 } else {
71 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
72 }
73 })
74 .collect();
75
76 quote! {
77 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
78 #[sqlx(rename_all = "snake_case")]
79 pub struct #struct_name {
80 #(#fields),*
81 }
82
83 impl #struct_name {
84 pub const TABLE_NAME: &'static str = #table_name;
85 }
86 }
87}
88
89#[allow(clippy::too_many_lines)]
92fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
93 let where_input = format_ident!("{}WhereInput", model.name);
94 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
95
96 let where_fields: Vec<TokenStream> = scalar_fields
97 .iter()
98 .filter_map(|f| {
99 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
100 let name = format_ident!("{}", to_snake_case(&f.name));
101 Some(quote! { pub #name: Option<#filter_ty> })
102 })
103 .collect();
104
105 let single_unique_variants: Vec<TokenStream> = scalar_fields
106 .iter()
107 .filter(|f| f.is_id || f.is_unique)
108 .map(|f| {
109 let variant = format_ident!("{}", to_pascal_case(&f.name));
110 let ty = rust_type_tokens(f, ModuleDepth::Nested);
111 quote! { #variant(#ty) }
112 })
113 .collect();
114
115 let compound_unique_variants: Vec<TokenStream> = model
116 .unique_constraints
117 .iter()
118 .map(|uc| {
119 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
120 let struct_fields = compound_variant_fields(model, &uc.fields);
121 quote! { #variant { #(#struct_fields),* } }
122 })
123 .collect();
124
125 let unique_variants: Vec<TokenStream> = single_unique_variants
126 .into_iter()
127 .chain(compound_unique_variants)
128 .collect();
129
130 let db_bounds = collect_db_bounds(scalar_fields);
132 let where_arms = gen_where_arms(scalar_fields);
133 let unique_arms = gen_unique_where_arms(model, scalar_fields);
134 let conflict_target_arms = gen_conflict_target_arms(model, scalar_fields);
135 let first_conflict_col_arms = gen_first_conflict_col_arms(model, scalar_fields);
136
137 quote! {
138 pub mod filter {
139 use ferriorm_runtime::prelude::*;
140
141 #[derive(Debug, Clone, Default)]
142 pub struct #where_input {
143 #(#where_fields,)*
144 pub and: Option<Vec<#where_input>>,
145 pub or: Option<Vec<#where_input>>,
146 pub not: Option<Box<#where_input>>,
147 }
148
149 #[derive(Debug, Clone)]
150 pub enum #where_unique {
151 #(#unique_variants),*
152 }
153
154 impl #where_input {
155 pub(crate) fn build_where<'args, DB: sqlx::Database>(
156 &self,
157 qb: &mut sqlx::QueryBuilder<'args, DB>,
158 )
159 where
160 #(#db_bounds,)*
161 {
162 #(#where_arms)*
163
164 if let Some(conditions) = &self.and {
165 for c in conditions {
166 c.build_where(qb);
167 }
168 }
169 if let Some(conditions) = &self.or {
170 if !conditions.is_empty() {
171 qb.push(" AND (");
172 for (i, c) in conditions.iter().enumerate() {
173 if i > 0 { qb.push(" OR "); }
174 qb.push("(1=1");
175 c.build_where(qb);
176 qb.push(")");
177 }
178 qb.push(")");
179 }
180 }
181 if let Some(c) = &self.not {
182 qb.push(" AND NOT (1=1");
183 c.build_where(qb);
184 qb.push(")");
185 }
186 }
187 }
188
189 impl #where_unique {
190 pub(crate) fn build_where<'args, DB: sqlx::Database>(
191 &self,
192 qb: &mut sqlx::QueryBuilder<'args, DB>,
193 )
194 where
195 #(#db_bounds,)*
196 {
197 match self {
198 #(#unique_arms)*
199 }
200 }
201 }
202
203 impl #where_unique {
204 #[allow(dead_code)]
205 pub(crate) fn conflict_target(&self) -> &'static str {
206 match self {
207 #(#conflict_target_arms)*
208 }
209 }
210
211 #[allow(dead_code)]
212 pub(crate) fn first_conflict_col(&self) -> &'static str {
213 match self {
214 #(#first_conflict_col_arms)*
215 }
216 }
217 }
218 }
219 }
220}
221
222fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
224 let mut seen = std::collections::HashSet::new();
225 let mut bounds = Vec::new();
226
227 seen.insert("i64");
229 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
230
231 for f in scalar_fields {
232 match &f.field_type {
233 FieldKind::Scalar(scalar) => {
234 let key = scalar.rust_type();
235 if seen.insert(key)
236 && let Some(ty) = scalar_bound_tokens(scalar)
237 {
238 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
239 bounds.push(
241 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
242 );
243 }
244 }
245 FieldKind::Enum(_) | FieldKind::Model(_) => {}
246 }
247 }
248
249 bounds
250}
251
252fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
253 match scalar {
254 ScalarType::String => Some(quote! { String }),
255 ScalarType::Int => Some(quote! { i32 }),
256 ScalarType::BigInt => Some(quote! { i64 }),
257 ScalarType::Float => Some(quote! { f64 }),
258 ScalarType::Boolean => Some(quote! { bool }),
259 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
260 ScalarType::Bytes => Some(quote! { Vec<u8> }),
261 ScalarType::Json | ScalarType::Decimal => None,
262 }
263}
264
265fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
267 scalar_fields
268 .iter()
269 .filter_map(|f| {
270 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
272 return None;
273 }
274 let field_ident = format_ident!("{}", to_snake_case(&f.name));
275 let db_name = &f.db_name;
276 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
277 let is_comparable = matches!(
278 &f.field_type,
279 FieldKind::Scalar(
280 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
281 )
282 );
283
284 let mut arms = vec![];
285
286 if f.is_optional {
287 arms.push(quote! {
291 if let Some(v) = &filter.equals {
292 match v {
293 None => {
294 qb.push(concat!(" AND \"", #db_name, "\" IS NULL"));
295 }
296 Some(inner) => {
297 qb.push(concat!(" AND \"", #db_name, "\" = "));
298 qb.push_bind(inner.clone());
299 }
300 }
301 }
302 if let Some(v) = &filter.not {
303 match v {
304 None => {
305 qb.push(concat!(" AND \"", #db_name, "\" IS NOT NULL"));
306 }
307 Some(inner) => {
308 qb.push(concat!(" AND \"", #db_name, "\" != "));
309 qb.push_bind(inner.clone());
310 }
311 }
312 }
313 });
314 } else {
315 arms.push(quote! {
316 if let Some(v) = &filter.equals {
317 qb.push(concat!(" AND \"", #db_name, "\" = "));
318 qb.push_bind(v.clone());
319 }
320 if let Some(v) = &filter.not {
321 qb.push(concat!(" AND \"", #db_name, "\" != "));
322 qb.push_bind(v.clone());
323 }
324 });
325 }
326
327 if is_string {
328 arms.push(quote! {
329 if let Some(v) = &filter.contains {
330 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
331 qb.push_bind(format!("%{}%", v));
332 }
333 if let Some(v) = &filter.starts_with {
334 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
335 qb.push_bind(format!("{}%", v));
336 }
337 if let Some(v) = &filter.ends_with {
338 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
339 qb.push_bind(format!("%{}", v));
340 }
341 });
342 }
343
344 if is_comparable {
345 arms.push(quote! {
346 if let Some(v) = &filter.gt {
347 qb.push(concat!(" AND \"", #db_name, "\" > "));
348 qb.push_bind(v.clone());
349 }
350 if let Some(v) = &filter.gte {
351 qb.push(concat!(" AND \"", #db_name, "\" >= "));
352 qb.push_bind(v.clone());
353 }
354 if let Some(v) = &filter.lt {
355 qb.push(concat!(" AND \"", #db_name, "\" < "));
356 qb.push_bind(v.clone());
357 }
358 if let Some(v) = &filter.lte {
359 qb.push(concat!(" AND \"", #db_name, "\" <= "));
360 qb.push_bind(v.clone());
361 }
362 });
363 }
364
365 Some(quote! {
366 if let Some(filter) = &self.#field_ident {
367 #(#arms)*
368 }
369 })
370 })
371 .collect()
372}
373
374fn gen_unique_where_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
375 let mut arms: Vec<TokenStream> = scalar_fields
376 .iter()
377 .filter(|f| f.is_id || f.is_unique)
378 .map(|f| {
379 let variant = format_ident!("{}", to_pascal_case(&f.name));
380 let db_name = &f.db_name;
381 quote! {
382 Self::#variant(v) => {
383 qb.push(concat!(" AND \"", #db_name, "\" = "));
384 qb.push_bind(v.clone());
385 }
386 }
387 })
388 .collect();
389
390 for uc in &model.unique_constraints {
391 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
392 let idents: Vec<_> = uc
393 .fields
394 .iter()
395 .map(|name| format_ident!("{}", to_snake_case(name)))
396 .collect();
397 let binds: Vec<TokenStream> = uc
398 .fields
399 .iter()
400 .map(|name| {
401 let ident = format_ident!("{}", to_snake_case(name));
402 let db_name = resolve_db_name(model, name);
403 quote! {
404 qb.push(concat!(" AND \"", #db_name, "\" = "));
405 qb.push_bind(#ident.clone());
406 }
407 })
408 .collect();
409 arms.push(quote! {
410 Self::#variant { #(#idents),* } => {
411 #(#binds)*
412 }
413 });
414 }
415
416 arms
417}
418
419fn gen_conflict_target_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
420 let mut arms: Vec<TokenStream> = scalar_fields
421 .iter()
422 .filter(|f| f.is_id || f.is_unique)
423 .map(|f| {
424 let variant = format_ident!("{}", to_pascal_case(&f.name));
425 let target = format!("(\"{}\")", f.db_name);
426 quote! { Self::#variant(_) => #target, }
427 })
428 .collect();
429
430 for uc in &model.unique_constraints {
431 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
432 let cols: Vec<String> = uc
433 .fields
434 .iter()
435 .map(|n| format!("\"{}\"", resolve_db_name(model, n)))
436 .collect();
437 let target = format!("({})", cols.join(", "));
438 arms.push(quote! { Self::#variant { .. } => #target, });
439 }
440
441 arms
442}
443
444fn gen_first_conflict_col_arms(model: &Model, scalar_fields: &[&Field]) -> Vec<TokenStream> {
445 let mut arms: Vec<TokenStream> = scalar_fields
446 .iter()
447 .filter(|f| f.is_id || f.is_unique)
448 .map(|f| {
449 let variant = format_ident!("{}", to_pascal_case(&f.name));
450 let col = format!("\"{}\"", f.db_name);
451 quote! { Self::#variant(_) => #col, }
452 })
453 .collect();
454
455 for uc in &model.unique_constraints {
456 let variant = format_ident!("{}", compound_variant_name(&uc.fields));
457 let first = uc
458 .fields
459 .first()
460 .map_or_else(String::new, |n| resolve_db_name(model, n));
461 let col = format!("\"{first}\"");
462 arms.push(quote! { Self::#variant { .. } => #col, });
463 }
464
465 arms
466}
467
468fn compound_variant_name(fields: &[String]) -> String {
470 fields.iter().map(|f| to_pascal_case(f)).collect()
471}
472
473fn compound_variant_fields(model: &Model, fields: &[String]) -> Vec<TokenStream> {
476 fields
477 .iter()
478 .filter_map(|field_name| {
479 let field = model.fields.iter().find(|f| f.name == *field_name)?;
480 let ident = format_ident!("{}", to_snake_case(field_name));
481 let ty = rust_type_tokens(field, ModuleDepth::Nested);
482 Some(quote! { #ident: #ty })
483 })
484 .collect()
485}
486
487fn resolve_db_name(model: &Model, field_name: &str) -> String {
489 model
490 .fields
491 .iter()
492 .find(|f| f.name == field_name)
493 .map_or_else(|| to_snake_case(field_name), |f| f.db_name.clone())
494}
495
496fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
499 let create_name = format_ident!("{}CreateInput", model.name);
500 let update_name = format_ident!("{}UpdateInput", model.name);
501
502 let required_fields: Vec<TokenStream> = scalar_fields
503 .iter()
504 .filter(|f| !f.has_default() && !f.is_updated_at)
505 .map(|f| {
506 let name = format_ident!("{}", to_snake_case(&f.name));
507 let ty = rust_type_tokens(f, ModuleDepth::Nested);
508 quote! { pub #name: #ty }
509 })
510 .collect();
511
512 let optional_fields: Vec<TokenStream> = scalar_fields
513 .iter()
514 .filter(|f| f.has_default() && !f.is_updated_at)
515 .map(|f| {
516 let name = format_ident!("{}", to_snake_case(&f.name));
517 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
518 quote! { pub #name: Option<#base_ty> }
519 })
520 .collect();
521
522 let update_fields: Vec<TokenStream> = scalar_fields
523 .iter()
524 .filter(|f| !f.is_id && !f.is_updated_at)
525 .map(|f| {
526 let name = format_ident!("{}", to_snake_case(&f.name));
527 let ty = rust_type_tokens(f, ModuleDepth::Nested);
528 quote! { pub #name: Option<SetValue<#ty>> }
529 })
530 .collect();
531
532 quote! {
533 pub mod data {
534 use ferriorm_runtime::prelude::*;
535
536 #[derive(Debug, Clone)]
537 pub struct #create_name {
538 #(#required_fields,)*
539 #(#optional_fields,)*
540 }
541
542 #[derive(Debug, Clone, Default)]
546 pub struct #update_name {
547 #(#update_fields,)*
548 }
549 }
550 }
551}
552
553fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
556 let order_name = format_ident!("{}OrderByInput", model.name);
557
558 let variants: Vec<TokenStream> = scalar_fields
559 .iter()
560 .map(|f| {
561 let variant = format_ident!("{}", to_pascal_case(&f.name));
562 quote! { #variant(SortOrder) }
563 })
564 .collect();
565
566 let order_arms: Vec<TokenStream> = scalar_fields
567 .iter()
568 .map(|f| {
569 let variant = format_ident!("{}", to_pascal_case(&f.name));
570 let db_name = &f.db_name;
571 quote! {
572 Self::#variant(order) => {
573 qb.push(concat!("\"", #db_name, "\" "));
574 qb.push(order.as_sql());
575 }
576 }
577 })
578 .collect();
579
580 quote! {
581 pub mod order {
582 use ferriorm_runtime::prelude::*;
583
584 #[derive(Debug, Clone)]
585 pub enum #order_name {
586 #(#variants),*
587 }
588
589 impl #order_name {
590 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
591 &self,
592 qb: &mut sqlx::QueryBuilder<'args, DB>,
593 ) {
594 match self {
595 #(#order_arms)*
596 }
597 }
598 }
599 }
600 }
601}
602
603fn gen_actions(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
606 let _model_ident = format_ident!("{}", model.name);
607 let actions_name = format_ident!("{}Actions", model.name);
608 let where_input = format_ident!("{}WhereInput", model.name);
609 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
610 let create_input = format_ident!("{}CreateInput", model.name);
611 let update_input = format_ident!("{}UpdateInput", model.name);
612 let _order_by = format_ident!("{}OrderByInput", model.name);
613
614 let has_agg_fields = scalar_fields.iter().any(|f| {
616 matches!(
617 &f.field_type,
618 FieldKind::Scalar(
619 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
620 )
621 )
622 });
623 let aggregate_method = if has_agg_fields {
624 quote! {
625 pub fn aggregate(&self, r#where: filter::#where_input) -> AggregateQuery<'a> {
626 AggregateQuery { client: self.client, r#where, ops: vec![] }
627 }
628 }
629 } else {
630 quote! {}
631 };
632
633 quote! {
634 pub struct #actions_name<'a> {
635 client: &'a DatabaseClient,
636 }
637
638 impl<'a> #actions_name<'a> {
639 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
640
641 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
642 FindUniqueQuery { client: self.client, r#where }
643 }
644
645 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
646 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
647 }
648
649 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
650 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
651 }
652
653 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
654 CreateQuery { client: self.client, data }
655 }
656
657 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
658 UpdateQuery { client: self.client, r#where, data }
659 }
660
661 pub fn update_first(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateFirstQuery<'a> {
665 UpdateFirstQuery { client: self.client, r#where, data }
666 }
667
668 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
669 DeleteQuery { client: self.client, r#where }
670 }
671
672 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
673 CountQuery { client: self.client, r#where }
674 }
675
676 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
677 CreateManyQuery { client: self.client, data }
678 }
679
680 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
681 UpdateManyQuery { client: self.client, r#where, data }
682 }
683
684 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
685 DeleteManyQuery { client: self.client, r#where }
686 }
687
688 pub fn upsert(
689 &self,
690 r#where: filter::#where_unique,
691 create: data::#create_input,
692 update: data::#update_input,
693 ) -> UpsertQuery<'a> {
694 UpsertQuery { client: self.client, r#where, create, update }
695 }
696
697 #aggregate_method
698 }
699 }
700}
701
702#[allow(clippy::too_many_lines)]
705fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
706 let model_ident = format_ident!("{}", model.name);
707 let table_name = &model.db_name;
708 let _where_input = format_ident!("{}WhereInput", model.name);
709 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
710 let _create_input = format_ident!("{}CreateInput", model.name);
711 let _update_input = format_ident!("{}UpdateInput", model.name);
712 let order_by = format_ident!("{}OrderByInput", model.name);
713 let _select_struct = format_ident!("{}Select", model.name);
714 let _partial_struct = format_ident!("{}Partial", model.name);
715 let _aggregate_result = format_ident!("{}AggregateResult", model.name);
716 let _aggregate_field = format_ident!("{}AggregateField", model.name);
717 let db_bounds = collect_db_bounds(scalar_fields);
718
719 let select_sql = format!(r#"SELECT * FROM "{table_name}" WHERE 1=1"#);
720 let count_sql = format!(r#"SELECT COUNT(*) as "count" FROM "{table_name}" WHERE 1=1"#);
721 let delete_sql = format!(r#"DELETE FROM "{table_name}" WHERE 1=1"#);
722
723 let insert_code = gen_insert_code(model, scalar_fields, table_name);
724 let insert_ignore_code = gen_insert_ignore_code(model, scalar_fields, table_name);
725 let update_code = gen_update_code(model, scalar_fields, table_name);
726 let update_first_code = gen_update_first_code(model, scalar_fields, table_name);
727 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
728 let upsert_code = gen_upsert_code(model, scalar_fields, table_name);
729
730 quote! {
731 fn build_order_by<'args, DB: sqlx::Database>(
733 orders: &[order::#order_by],
734 qb: &mut sqlx::QueryBuilder<'args, DB>,
735 ) {
736 if !orders.is_empty() {
737 qb.push(" ORDER BY ");
738 for (i, ob) in orders.iter().enumerate() {
739 if i > 0 { qb.push(", "); }
740 ob.build_order_by(qb);
741 }
742 }
743 }
744
745 fn build_select_query<'args, DB: sqlx::Database>(
747 base_sql: &str,
748 where_input: &filter::#_where_input,
749 orders: &[order::#order_by],
750 take: Option<i64>,
751 skip: Option<i64>,
752 ) -> sqlx::QueryBuilder<'args, DB>
753 where
754 #(#db_bounds,)*
755 {
756 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
757 where_input.build_where(&mut qb);
758 build_order_by(orders, &mut qb);
759 if let Some(take) = take {
760 qb.push(" LIMIT ");
761 qb.push_bind(take);
762 }
763 if let Some(skip) = skip {
764 qb.push(" OFFSET ");
765 qb.push_bind(skip);
766 }
767 qb
768 }
769
770 fn build_unique_select_query<'args, DB: sqlx::Database>(
772 base_sql: &str,
773 where_unique: &filter::#_where_unique,
774 ) -> sqlx::QueryBuilder<'args, DB>
775 where
776 #(#db_bounds,)*
777 {
778 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
779 where_unique.build_where(&mut qb);
780 qb.push(" LIMIT 1");
781 qb
782 }
783
784 fn build_delete_query<'args, DB: sqlx::Database>(
786 base_sql: &str,
787 where_unique: &filter::#_where_unique,
788 ) -> sqlx::QueryBuilder<'args, DB>
789 where
790 #(#db_bounds,)*
791 {
792 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
793 where_unique.build_where(&mut qb);
794 qb.push(" RETURNING *");
795 qb
796 }
797
798 fn build_count_query<'args, DB: sqlx::Database>(
800 base_sql: &str,
801 where_input: &filter::#_where_input,
802 ) -> sqlx::QueryBuilder<'args, DB>
803 where
804 #(#db_bounds,)*
805 {
806 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
807 where_input.build_where(&mut qb);
808 qb
809 }
810
811 fn build_delete_many_query<'args, DB: sqlx::Database>(
813 base_sql: &str,
814 where_input: &filter::#_where_input,
815 ) -> sqlx::QueryBuilder<'args, DB>
816 where
817 #(#db_bounds,)*
818 {
819 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
820 where_input.build_where(&mut qb);
821 qb
822 }
823
824 pub struct FindUniqueQuery<'a> {
825 client: &'a DatabaseClient,
826 r#where: filter::#_where_unique,
827 }
828
829 impl<'a> FindUniqueQuery<'a> {
830 pub fn select(self, select: #_select_struct) -> FindUniqueSelectQuery<'a> {
831 FindUniqueSelectQuery { client: self.client, r#where: self.r#where, select }
832 }
833
834 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
835 match self.client {
836 DatabaseClient::Postgres(_) => {
837 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
838 self.client.fetch_optional_pg(qb).await
839 }
840 DatabaseClient::Sqlite(_) => {
841 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
842 self.client.fetch_optional_sqlite(qb).await
843 }
844 }
845 }
846 }
847
848 pub struct FindFirstQuery<'a> {
849 client: &'a DatabaseClient,
850 r#where: filter::#_where_input,
851 order_by: Vec<order::#order_by>,
852 }
853
854 impl<'a> FindFirstQuery<'a> {
855 pub fn order_by(mut self, order: order::#order_by) -> Self {
856 self.order_by.push(order);
857 self
858 }
859
860 pub fn select(self, select: #_select_struct) -> FindFirstSelectQuery<'a> {
861 FindFirstSelectQuery { client: self.client, r#where: self.r#where, order_by: self.order_by, select }
862 }
863
864 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
865 match self.client {
866 DatabaseClient::Postgres(_) => {
867 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
868 self.client.fetch_optional_pg(qb).await
869 }
870 DatabaseClient::Sqlite(_) => {
871 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
872 self.client.fetch_optional_sqlite(qb).await
873 }
874 }
875 }
876 }
877
878 pub struct FindManyQuery<'a> {
879 client: &'a DatabaseClient,
880 r#where: filter::#_where_input,
881 order_by: Vec<order::#order_by>,
882 skip: Option<i64>,
883 take: Option<i64>,
884 }
885
886 impl<'a> FindManyQuery<'a> {
887 pub fn order_by(mut self, order: order::#order_by) -> Self {
888 self.order_by.push(order);
889 self
890 }
891
892 pub fn skip(mut self, n: i64) -> Self {
893 self.skip = Some(n);
894 self
895 }
896
897 pub fn take(mut self, n: i64) -> Self {
898 self.take = Some(n);
899 self
900 }
901
902 pub fn select(self, select: #_select_struct) -> FindManySelectQuery<'a> {
903 FindManySelectQuery {
904 client: self.client,
905 r#where: self.r#where,
906 order_by: self.order_by,
907 skip: self.skip,
908 take: self.take,
909 select,
910 }
911 }
912
913 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
914 match self.client {
915 DatabaseClient::Postgres(_) => {
916 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
917 self.client.fetch_all_pg(qb).await
918 }
919 DatabaseClient::Sqlite(_) => {
920 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
921 self.client.fetch_all_sqlite(qb).await
922 }
923 }
924 }
925 }
926
927 pub struct CreateQuery<'a> {
928 client: &'a DatabaseClient,
929 data: data::#_create_input,
930 }
931
932 impl<'a> CreateQuery<'a> {
933 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
934 let client = self.client;
935 #insert_code
936 }
937
938 pub fn on_conflict_ignore(self) -> CreateIgnoreQuery<'a> {
942 CreateIgnoreQuery { client: self.client, data: self.data }
943 }
944 }
945
946 pub struct CreateIgnoreQuery<'a> {
947 client: &'a DatabaseClient,
948 data: data::#_create_input,
949 }
950
951 impl<'a> CreateIgnoreQuery<'a> {
952 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
953 let client = self.client;
954 #insert_ignore_code
955 }
956 }
957
958 pub struct UpdateQuery<'a> {
959 client: &'a DatabaseClient,
960 r#where: filter::#_where_unique,
961 data: data::#_update_input,
962 }
963
964 impl<'a> UpdateQuery<'a> {
965 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
966 let client = self.client;
967 #update_code
968 }
969 }
970
971 pub struct UpdateFirstQuery<'a> {
972 client: &'a DatabaseClient,
973 r#where: filter::#_where_input,
974 data: data::#_update_input,
975 }
976
977 impl<'a> UpdateFirstQuery<'a> {
978 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
979 let client = self.client;
980 #update_first_code
981 }
982 }
983
984 pub struct DeleteQuery<'a> {
985 client: &'a DatabaseClient,
986 r#where: filter::#_where_unique,
987 }
988
989 impl<'a> DeleteQuery<'a> {
990 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
991 match self.client {
992 DatabaseClient::Postgres(_) => {
993 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
994 self.client.fetch_one_pg(qb).await
995 }
996 DatabaseClient::Sqlite(_) => {
997 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
998 self.client.fetch_one_sqlite(qb).await
999 }
1000 }
1001 }
1002 }
1003
1004 #[derive(sqlx::FromRow)]
1005 struct CountResult { count: i64 }
1006
1007 pub struct CountQuery<'a> {
1008 client: &'a DatabaseClient,
1009 r#where: filter::#_where_input,
1010 }
1011
1012 impl<'a> CountQuery<'a> {
1013 pub async fn exec(self) -> Result<i64, FerriormError> {
1014 let row: CountResult = match self.client {
1015 DatabaseClient::Postgres(_) => {
1016 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
1017 self.client.fetch_one_pg(qb).await?
1018 }
1019 DatabaseClient::Sqlite(_) => {
1020 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
1021 self.client.fetch_one_sqlite(qb).await?
1022 }
1023 };
1024 Ok(row.count)
1025 }
1026 }
1027
1028 pub struct CreateManyQuery<'a> {
1029 client: &'a DatabaseClient,
1030 data: Vec<data::#_create_input>,
1031 }
1032
1033 impl<'a> CreateManyQuery<'a> {
1034 pub async fn exec(self) -> Result<u64, FerriormError> {
1035 if self.data.is_empty() { return Ok(0); }
1036 let count = self.data.len() as u64;
1037 for item in self.data {
1038 CreateQuery { client: self.client, data: item }.exec().await?;
1039 }
1040 Ok(count)
1041 }
1042 }
1043
1044 pub struct UpdateManyQuery<'a> {
1045 client: &'a DatabaseClient,
1046 r#where: filter::#_where_input,
1047 data: data::#_update_input,
1048 }
1049
1050 impl<'a> UpdateManyQuery<'a> {
1051 pub async fn exec(self) -> Result<u64, FerriormError> {
1052 let client = self.client;
1053 #update_many_code
1054 }
1055 }
1056
1057 pub struct UpsertQuery<'a> {
1058 client: &'a DatabaseClient,
1059 r#where: filter::#_where_unique,
1060 create: data::#_create_input,
1061 update: data::#_update_input,
1062 }
1063
1064 impl<'a> UpsertQuery<'a> {
1065 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
1066 let client = self.client;
1067 #upsert_code
1068 }
1069 }
1070
1071 pub struct DeleteManyQuery<'a> {
1072 client: &'a DatabaseClient,
1073 r#where: filter::#_where_input,
1074 }
1075
1076 impl<'a> DeleteManyQuery<'a> {
1077 pub async fn exec(self) -> Result<u64, FerriormError> {
1078 match self.client {
1079 DatabaseClient::Postgres(_) => {
1080 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
1081 self.client.execute_pg(qb).await
1082 }
1083 DatabaseClient::Sqlite(_) => {
1084 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
1085 self.client.execute_sqlite(qb).await
1086 }
1087 }
1088 }
1089 }
1090 }
1091}
1092
1093fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1096 let _model_ident = format_ident!("{}", model.name);
1097
1098 let required: Vec<&Field> = scalar_fields
1100 .iter()
1101 .copied()
1102 .filter(|f| !f.has_default() && !f.is_updated_at)
1103 .collect();
1104
1105 let optional: Vec<&Field> = scalar_fields
1107 .iter()
1108 .copied()
1109 .filter(|f| f.has_default() && !f.is_updated_at)
1110 .collect();
1111
1112 let updated_at: Vec<&Field> = scalar_fields
1114 .iter()
1115 .copied()
1116 .filter(|f| f.is_updated_at)
1117 .collect();
1118
1119 let mut col_pushes = vec![];
1121 let mut val_pushes = vec![];
1122
1123 for f in &required {
1125 let db_name = &f.db_name;
1126 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1127 col_pushes.push(quote! { cols.push(#db_name); });
1128 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1129 }
1130
1131 for f in &optional {
1133 let db_name = &f.db_name;
1134 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1135 if is_autoincrement(f) {
1136 col_pushes.push(quote! {
1140 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1141 });
1142 val_pushes.push(quote! {
1143 if let Some(val) = self.data.#field_ident {
1144 sep.push_bind(val);
1145 }
1146 });
1147 } else {
1148 let default_expr = gen_default_expr(f, &f.field_type);
1149 col_pushes.push(quote! { cols.push(#db_name); });
1150 val_pushes.push(quote! {
1151 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1152 sep.push_bind(val);
1153 });
1154 }
1155 }
1156
1157 for f in &updated_at {
1159 let db_name = &f.db_name;
1160 col_pushes.push(quote! { cols.push(#db_name); });
1161 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1162 }
1163
1164 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1165
1166 quote! {
1169 macro_rules! build_insert {
1171 ($qb_type:ty) => {{
1172 let mut cols: Vec<&str> = Vec::new();
1173 #(#col_pushes)*
1174
1175 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1176 qb.push(" (");
1177 for (i, col) in cols.iter().enumerate() {
1178 if i > 0 { qb.push(", "); }
1179 qb.push("\"");
1180 qb.push(*col);
1181 qb.push("\"");
1182 }
1183 qb.push(") VALUES (");
1184 {
1185 let mut sep = qb.separated(", ");
1186 #(#val_pushes)*
1187 }
1188 qb.push(") RETURNING *");
1189 qb
1190 }};
1191 }
1192
1193 match client {
1194 DatabaseClient::Postgres(_) => {
1195 let qb = build_insert!(sqlx::Postgres);
1196 client.fetch_one_pg(qb).await
1197 }
1198 DatabaseClient::Sqlite(_) => {
1199 let qb = build_insert!(sqlx::Sqlite);
1200 client.fetch_one_sqlite(qb).await
1201 }
1202 }
1203 }
1204}
1205
1206fn gen_insert_ignore_code(
1207 _model: &Model,
1208 scalar_fields: &[&Field],
1209 table_name: &str,
1210) -> TokenStream {
1211 let required: Vec<&Field> = scalar_fields
1212 .iter()
1213 .copied()
1214 .filter(|f| !f.has_default() && !f.is_updated_at)
1215 .collect();
1216 let optional: Vec<&Field> = scalar_fields
1217 .iter()
1218 .copied()
1219 .filter(|f| f.has_default() && !f.is_updated_at)
1220 .collect();
1221 let updated_at: Vec<&Field> = scalar_fields
1222 .iter()
1223 .copied()
1224 .filter(|f| f.is_updated_at)
1225 .collect();
1226
1227 let mut col_pushes = vec![];
1228 let mut val_pushes = vec![];
1229
1230 for f in &required {
1231 let db_name = &f.db_name;
1232 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1233 col_pushes.push(quote! { cols.push(#db_name); });
1234 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
1235 }
1236 for f in &optional {
1237 let db_name = &f.db_name;
1238 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1239 if is_autoincrement(f) {
1240 col_pushes.push(quote! {
1241 if self.data.#field_ident.is_some() { cols.push(#db_name); }
1242 });
1243 val_pushes.push(quote! {
1244 if let Some(val) = self.data.#field_ident {
1245 sep.push_bind(val);
1246 }
1247 });
1248 } else {
1249 let default_expr = gen_default_expr(f, &f.field_type);
1250 col_pushes.push(quote! { cols.push(#db_name); });
1251 val_pushes.push(quote! {
1252 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
1253 sep.push_bind(val);
1254 });
1255 }
1256 }
1257 for f in &updated_at {
1258 let db_name = &f.db_name;
1259 col_pushes.push(quote! { cols.push(#db_name); });
1260 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1261 }
1262
1263 let pg_insert_start = format!(r#"INSERT INTO "{table_name}""#);
1264 let sqlite_insert_start = format!(r#"INSERT OR IGNORE INTO "{table_name}""#);
1265
1266 quote! {
1267 macro_rules! build_insert_ignore {
1268 ($qb_type:ty, $head:expr, $tail:expr) => {{
1269 let mut cols: Vec<&str> = Vec::new();
1270 #(#col_pushes)*
1271
1272 let mut qb = sqlx::QueryBuilder::<$qb_type>::new($head);
1273 qb.push(" (");
1274 for (i, col) in cols.iter().enumerate() {
1275 if i > 0 { qb.push(", "); }
1276 qb.push("\"");
1277 qb.push(*col);
1278 qb.push("\"");
1279 }
1280 qb.push(") VALUES (");
1281 {
1282 let mut sep = qb.separated(", ");
1283 #(#val_pushes)*
1284 }
1285 qb.push(")");
1286 qb.push($tail);
1287 qb.push(" RETURNING *");
1288 qb
1289 }};
1290 }
1291
1292 match client {
1293 DatabaseClient::Postgres(_) => {
1294 let qb = build_insert_ignore!(sqlx::Postgres, #pg_insert_start, " ON CONFLICT DO NOTHING");
1295 client.fetch_optional_pg(qb).await
1296 }
1297 DatabaseClient::Sqlite(_) => {
1298 let qb = build_insert_ignore!(sqlx::Sqlite, #sqlite_insert_start, "");
1299 client.fetch_optional_sqlite(qb).await
1300 }
1301 }
1302 }
1303}
1304
1305fn is_autoincrement(field: &Field) -> bool {
1309 matches!(
1310 field.default,
1311 Some(ferriorm_core::ast::DefaultValue::AutoIncrement)
1312 )
1313}
1314
1315fn gen_default_expr(field: &Field, field_type: &FieldKind) -> TokenStream {
1317 use ferriorm_core::ast::DefaultValue;
1318
1319 match &field.default {
1320 Some(DefaultValue::Uuid | DefaultValue::Cuid) => {
1321 quote! { uuid::Uuid::new_v4().to_string() }
1322 }
1323 Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
1324 Some(DefaultValue::AutoIncrement) => quote! { 0i32 },
1329 Some(DefaultValue::Literal(lit)) => {
1330 use ferriorm_core::ast::LiteralValue;
1331 match lit {
1332 LiteralValue::String(s) => quote! { #s.to_string() },
1333 LiteralValue::Int(i) => {
1334 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
1336 match field_type {
1337 FieldKind::Scalar(ScalarType::Float) => {
1338 let val = *i as f64;
1339 quote! { #val }
1340 }
1341 FieldKind::Scalar(ScalarType::BigInt) => quote! { #i },
1342 FieldKind::Scalar(ScalarType::Int)
1344 if field.db_type.as_ref().is_some_and(|(ty, _)| ty == "BigInt") =>
1345 {
1346 quote! { #i }
1347 }
1348 _ => {
1349 let val = *i as i32;
1351 quote! { #val }
1352 }
1353 }
1354 }
1355 LiteralValue::Float(f) => quote! { #f },
1356 LiteralValue::Bool(b) => quote! { #b },
1357 }
1358 }
1359 Some(DefaultValue::EnumVariant(v)) => {
1360 let variant = format_ident!("{}", v);
1362 if let FieldKind::Enum(enum_name) = &field.field_type {
1363 let enum_ident = format_ident!("{}", enum_name);
1364 quote! { super::enums::#enum_ident::#variant }
1365 } else {
1366 quote! { Default::default() }
1367 }
1368 }
1369 None => quote! { Default::default() },
1370 }
1371}
1372
1373fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1376 let _model_ident = format_ident!("{}", model.name);
1377
1378 let updatable: Vec<&Field> = scalar_fields
1380 .iter()
1381 .copied()
1382 .filter(|f| !f.is_id && !f.is_updated_at)
1383 .collect();
1384
1385 let updated_at: Vec<&Field> = scalar_fields
1386 .iter()
1387 .copied()
1388 .filter(|f| f.is_updated_at)
1389 .collect();
1390
1391 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1392
1393 let set_arms: Vec<TokenStream> = updatable
1395 .iter()
1396 .map(|f| {
1397 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1398 let db_name = &f.db_name;
1399 quote! {
1400 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1401 if !first_set { qb.push(", "); }
1402 first_set = false;
1403 qb.push(concat!("\"", #db_name, "\" = "));
1404 qb.push_bind(v);
1405 }
1406 }
1407 })
1408 .collect();
1409
1410 let updated_at_arms: Vec<TokenStream> = updated_at
1411 .iter()
1412 .map(|f| {
1413 let db_name = &f.db_name;
1414 quote! {
1415 if !first_set { qb.push(", "); }
1416 first_set = false;
1417 qb.push(concat!("\"", #db_name, "\" = "));
1418 qb.push_bind(chrono::Utc::now());
1419 }
1420 })
1421 .collect();
1422
1423 quote! {
1426 macro_rules! build_update {
1427 ($qb_type:ty) => {{
1428 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1429 let mut first_set = true;
1430 #(#set_arms)*
1431 #(#updated_at_arms)*
1432
1433 if first_set {
1434 return Err(FerriormError::Query("No fields to update".into()));
1435 }
1436
1437 qb.push(" WHERE 1=1");
1438 self.r#where.build_where(&mut qb);
1439 qb.push(" RETURNING *");
1440 qb
1441 }};
1442 }
1443
1444 match client {
1445 DatabaseClient::Postgres(_) => {
1446 let qb = build_update!(sqlx::Postgres);
1447 client.fetch_one_pg(qb).await
1448 }
1449 DatabaseClient::Sqlite(_) => {
1450 let qb = build_update!(sqlx::Sqlite);
1451 client.fetch_one_sqlite(qb).await
1452 }
1453 }
1454 }
1455}
1456
1457fn gen_update_first_code(
1460 _model: &Model,
1461 scalar_fields: &[&Field],
1462 table_name: &str,
1463) -> TokenStream {
1464 let updatable: Vec<&Field> = scalar_fields
1465 .iter()
1466 .copied()
1467 .filter(|f| !f.is_id && !f.is_updated_at)
1468 .collect();
1469
1470 let updated_at: Vec<&Field> = scalar_fields
1471 .iter()
1472 .copied()
1473 .filter(|f| f.is_updated_at)
1474 .collect();
1475
1476 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1477
1478 let set_arms: Vec<TokenStream> = updatable
1479 .iter()
1480 .map(|f| {
1481 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1482 let db_name = &f.db_name;
1483 quote! {
1484 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1485 if !first_set { qb.push(", "); }
1486 first_set = false;
1487 qb.push(concat!("\"", #db_name, "\" = "));
1488 qb.push_bind(v);
1489 }
1490 }
1491 })
1492 .collect();
1493
1494 let updated_at_arms: Vec<TokenStream> = updated_at
1495 .iter()
1496 .map(|f| {
1497 let db_name = &f.db_name;
1498 quote! {
1499 if !first_set { qb.push(", "); }
1500 first_set = false;
1501 qb.push(concat!("\"", #db_name, "\" = "));
1502 qb.push_bind(chrono::Utc::now());
1503 }
1504 })
1505 .collect();
1506
1507 quote! {
1508 macro_rules! build_update_first {
1509 ($qb_type:ty) => {{
1510 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1511 let mut first_set = true;
1512 #(#set_arms)*
1513 #(#updated_at_arms)*
1514
1515 if first_set {
1516 return Err(FerriormError::Query("No fields to update".into()));
1517 }
1518
1519 qb.push(" WHERE 1=1");
1520 self.r#where.build_where(&mut qb);
1521 qb.push(" RETURNING *");
1522 qb
1523 }};
1524 }
1525
1526 match client {
1527 DatabaseClient::Postgres(_) => {
1528 let qb = build_update_first!(sqlx::Postgres);
1529 client.fetch_optional_pg(qb).await
1530 }
1531 DatabaseClient::Sqlite(_) => {
1532 let qb = build_update_first!(sqlx::Sqlite);
1533 client.fetch_optional_sqlite(qb).await
1534 }
1535 }
1536 }
1537}
1538
1539fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1542 let updatable: Vec<&Field> = scalar_fields
1544 .iter()
1545 .copied()
1546 .filter(|f| !f.is_id && !f.is_updated_at)
1547 .collect();
1548
1549 let updated_at: Vec<&Field> = scalar_fields
1550 .iter()
1551 .copied()
1552 .filter(|f| f.is_updated_at)
1553 .collect();
1554
1555 let update_start = format!(r#"UPDATE "{table_name}" SET "#);
1556
1557 let set_arms: Vec<TokenStream> = updatable
1559 .iter()
1560 .map(|f| {
1561 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1562 let db_name = &f.db_name;
1563 quote! {
1564 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1565 if !first_set { qb.push(", "); }
1566 first_set = false;
1567 qb.push(concat!("\"", #db_name, "\" = "));
1568 qb.push_bind(v);
1569 }
1570 }
1571 })
1572 .collect();
1573
1574 let updated_at_arms: Vec<TokenStream> = updated_at
1575 .iter()
1576 .map(|f| {
1577 let db_name = &f.db_name;
1578 quote! {
1579 if !first_set { qb.push(", "); }
1580 first_set = false;
1581 qb.push(concat!("\"", #db_name, "\" = "));
1582 qb.push_bind(chrono::Utc::now());
1583 }
1584 })
1585 .collect();
1586
1587 quote! {
1588 macro_rules! build_update_many {
1589 ($qb_type:ty) => {{
1590 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1591 let mut first_set = true;
1592 #(#set_arms)*
1593 #(#updated_at_arms)*
1594
1595 if first_set {
1596 return Ok(0);
1597 }
1598
1599 qb.push(" WHERE 1=1");
1600 self.r#where.build_where(&mut qb);
1601 qb
1602 }};
1603 }
1604
1605 match client {
1606 DatabaseClient::Postgres(_) => {
1607 let qb = build_update_many!(sqlx::Postgres);
1608 client.execute_pg(qb).await
1609 }
1610 DatabaseClient::Sqlite(_) => {
1611 let qb = build_update_many!(sqlx::Sqlite);
1612 client.execute_sqlite(qb).await
1613 }
1614 }
1615 }
1616}
1617
1618enum AggregateKind {
1622 Numeric,
1624 DateTime,
1626}
1627
1628#[allow(clippy::too_many_lines)]
1631fn gen_upsert_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1632 let required: Vec<&Field> = scalar_fields
1634 .iter()
1635 .copied()
1636 .filter(|f| !f.has_default() && !f.is_updated_at)
1637 .collect();
1638 let optional: Vec<&Field> = scalar_fields
1639 .iter()
1640 .copied()
1641 .filter(|f| f.has_default() && !f.is_updated_at)
1642 .collect();
1643 let updated_at: Vec<&Field> = scalar_fields
1644 .iter()
1645 .copied()
1646 .filter(|f| f.is_updated_at)
1647 .collect();
1648
1649 let mut col_pushes = vec![];
1650 let mut val_pushes = vec![];
1651
1652 for f in &required {
1653 let db_name = &f.db_name;
1654 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1655 col_pushes.push(quote! { cols.push(#db_name); });
1656 val_pushes.push(quote! { sep.push_bind(self.create.#field_ident); });
1657 }
1658 for f in &optional {
1659 let db_name = &f.db_name;
1660 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1661 if is_autoincrement(f) {
1662 col_pushes.push(quote! {
1663 if self.create.#field_ident.is_some() { cols.push(#db_name); }
1664 });
1665 val_pushes.push(quote! {
1666 if let Some(val) = self.create.#field_ident {
1667 sep.push_bind(val);
1668 }
1669 });
1670 } else {
1671 let default_expr = gen_default_expr(f, &f.field_type);
1672 col_pushes.push(quote! { cols.push(#db_name); });
1673 val_pushes.push(quote! {
1674 let val = self.create.#field_ident.unwrap_or_else(|| #default_expr);
1675 sep.push_bind(val);
1676 });
1677 }
1678 }
1679 for f in &updated_at {
1680 let db_name = &f.db_name;
1681 col_pushes.push(quote! { cols.push(#db_name); });
1682 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
1683 }
1684
1685 let updatable: Vec<&Field> = scalar_fields
1687 .iter()
1688 .copied()
1689 .filter(|f| !f.is_id && !f.is_updated_at)
1690 .collect();
1691
1692 let set_arms: Vec<TokenStream> = updatable
1693 .iter()
1694 .map(|f| {
1695 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1696 let db_name = &f.db_name;
1697 quote! {
1698 if let Some(SetValue::Set(v)) = self.update.#field_ident {
1699 if !first_set { qb.push(", "); }
1700 first_set = false;
1701 qb.push(concat!("\"", #db_name, "\" = "));
1702 qb.push_bind(v);
1703 }
1704 }
1705 })
1706 .collect();
1707
1708 let updated_at_set: Vec<TokenStream> = updated_at
1709 .iter()
1710 .map(|f| {
1711 let db_name = &f.db_name;
1712 quote! {
1713 if !first_set { qb.push(", "); }
1714 first_set = false;
1715 qb.push(concat!("\"", #db_name, "\" = "));
1716 qb.push_bind(chrono::Utc::now());
1717 }
1718 })
1719 .collect();
1720
1721 let insert_start = format!(r#"INSERT INTO "{table_name}""#);
1722
1723 quote! {
1724 let conflict_target = self.r#where.conflict_target();
1725 let first_conflict_col = self.r#where.first_conflict_col();
1726
1727 macro_rules! build_upsert {
1728 ($qb_type:ty) => {{
1729 let mut cols: Vec<&str> = Vec::new();
1730 #(#col_pushes)*
1731
1732 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
1733 qb.push(" (");
1734 for (i, col) in cols.iter().enumerate() {
1735 if i > 0 { qb.push(", "); }
1736 qb.push("\"");
1737 qb.push(*col);
1738 qb.push("\"");
1739 }
1740 qb.push(") VALUES (");
1741 {
1742 let mut sep = qb.separated(", ");
1743 #(#val_pushes)*
1744 }
1745 qb.push(")");
1746 qb.push(" ON CONFLICT ");
1747 qb.push(conflict_target);
1748 qb.push(" DO UPDATE SET ");
1749
1750 let mut first_set = true;
1751 #(#set_arms)*
1752 #(#updated_at_set)*
1753
1754 if first_set {
1755 qb.push(first_conflict_col);
1758 qb.push(" = ");
1759 qb.push(first_conflict_col);
1760 }
1761
1762 qb.push(" RETURNING *");
1763 qb
1764 }};
1765 }
1766
1767 match client {
1768 DatabaseClient::Postgres(_) => {
1769 let qb = build_upsert!(sqlx::Postgres);
1770 client.fetch_one_pg(qb).await
1771 }
1772 DatabaseClient::Sqlite(_) => {
1773 let qb = build_upsert!(sqlx::Sqlite);
1774 client.fetch_one_sqlite(qb).await
1775 }
1776 }
1777 }
1778}
1779
1780#[allow(clippy::too_many_lines)]
1781fn gen_aggregate_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1782 let aggregate_field_name = format_ident!("{}AggregateField", model.name);
1783 let aggregate_result_name = format_ident!("{}AggregateResult", model.name);
1784 let _where_input = format_ident!("{}WhereInput", model.name);
1785 let table_name = &model.db_name;
1786
1787 let agg_fields: Vec<(&Field, AggregateKind)> = scalar_fields
1789 .iter()
1790 .filter_map(|f| match &f.field_type {
1791 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt | ScalarType::Float) => {
1792 Some((*f, AggregateKind::Numeric))
1793 }
1794 FieldKind::Scalar(ScalarType::DateTime) => Some((*f, AggregateKind::DateTime)),
1795 _ => None,
1796 })
1797 .collect();
1798
1799 if agg_fields.is_empty() {
1800 return quote! {};
1801 }
1802
1803 let enum_variants: Vec<TokenStream> = agg_fields
1805 .iter()
1806 .map(|(f, _)| {
1807 let variant = format_ident!("{}", to_pascal_case(&f.name));
1808 quote! { #variant }
1809 })
1810 .collect();
1811
1812 let db_name_arms: Vec<TokenStream> = agg_fields
1814 .iter()
1815 .map(|(f, _)| {
1816 let variant = format_ident!("{}", to_pascal_case(&f.name));
1817 let db_name = &f.db_name;
1818 quote! { Self::#variant => #db_name }
1819 })
1820 .collect();
1821
1822 let mut result_fields = Vec::new();
1824 for (f, kind) in &agg_fields {
1825 let snake = to_snake_case(&f.name);
1826 let orig_ty = rust_type_tokens(
1827 &Field {
1828 is_optional: false,
1829 ..(*f).clone()
1830 },
1831 ModuleDepth::TopLevel,
1832 );
1833
1834 match kind {
1835 AggregateKind::Numeric => {
1836 let avg_name = format_ident!("avg_{}", snake);
1837 let sum_name = format_ident!("sum_{}", snake);
1838 let min_name = format_ident!("min_{}", snake);
1839 let max_name = format_ident!("max_{}", snake);
1840 result_fields.push(quote! { #[sqlx(default)] pub #avg_name: Option<f64> });
1841 result_fields.push(quote! { #[sqlx(default)] pub #sum_name: Option<f64> });
1842 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1843 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1844 }
1845 AggregateKind::DateTime => {
1846 let min_name = format_ident!("min_{}", snake);
1847 let max_name = format_ident!("max_{}", snake);
1848 result_fields.push(quote! { #[sqlx(default)] pub #min_name: Option<#orig_ty> });
1849 result_fields.push(quote! { #[sqlx(default)] pub #max_name: Option<#orig_ty> });
1850 }
1851 }
1852 }
1853
1854 let numeric_arms: Vec<TokenStream> = agg_fields
1856 .iter()
1857 .filter(|(_, kind)| matches!(kind, AggregateKind::Numeric))
1858 .map(|(f, _)| {
1859 let variant = format_ident!("{}", to_pascal_case(&f.name));
1860 quote! { Self::#variant => true }
1861 })
1862 .collect();
1863
1864 let has_numeric = !numeric_arms.is_empty();
1865 let is_numeric_method = if has_numeric {
1866 quote! {
1867 fn is_numeric(&self) -> bool {
1868 match self {
1869 #(#numeric_arms,)*
1870 #[allow(unreachable_patterns)]
1871 _ => false,
1872 }
1873 }
1874 }
1875 } else {
1876 quote! {
1877 fn is_numeric(&self) -> bool { false }
1878 }
1879 };
1880
1881 let mut alias_arms = Vec::new();
1883 for (f, kind) in &agg_fields {
1884 let variant = format_ident!("{}", to_pascal_case(&f.name));
1885 let snake = to_snake_case(&f.name);
1886 let prefixes = match kind {
1887 AggregateKind::Numeric => vec!["avg", "sum", "min", "max"],
1888 AggregateKind::DateTime => vec!["min", "max"],
1889 };
1890 for prefix in prefixes {
1891 let alias_str = format!("{prefix}_{snake}");
1892 alias_arms.push(quote! { (#prefix, Self::#variant) => #alias_str });
1893 }
1894 }
1895
1896 let agg_select_base = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
1897
1898 quote! {
1899 #[derive(Debug, Clone, Copy)]
1900 pub enum #aggregate_field_name {
1901 #(#enum_variants),*
1902 }
1903
1904 impl #aggregate_field_name {
1905 pub fn db_name(&self) -> &'static str {
1906 match self {
1907 #(#db_name_arms,)*
1908 }
1909 }
1910
1911 fn alias(&self, prefix: &'static str) -> &'static str {
1912 match (prefix, self) {
1913 #(#alias_arms,)*
1914 _ => unreachable!(),
1915 }
1916 }
1917
1918 #is_numeric_method
1919 }
1920
1921 #[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
1922 pub struct #aggregate_result_name {
1923 #(#result_fields,)*
1924 }
1925
1926 pub struct AggregateQuery<'a> {
1927 client: &'a DatabaseClient,
1928 r#where: filter::#_where_input,
1929 ops: Vec<(&'static str, &'static str, &'static str)>,
1930 }
1931
1932 impl<'a> AggregateQuery<'a> {
1933 pub fn avg(mut self, field: #aggregate_field_name) -> Self {
1934 assert!(field.is_numeric(), "avg() is only supported on numeric fields");
1935 let db_name = field.db_name();
1936 let alias = field.alias("avg");
1937 self.ops.push(("AVG", db_name, alias));
1938 self
1939 }
1940
1941 pub fn sum(mut self, field: #aggregate_field_name) -> Self {
1942 assert!(field.is_numeric(), "sum() is only supported on numeric fields");
1943 let db_name = field.db_name();
1944 let alias = field.alias("sum");
1945 self.ops.push(("SUM", db_name, alias));
1946 self
1947 }
1948
1949 pub fn min(mut self, field: #aggregate_field_name) -> Self {
1950 let db_name = field.db_name();
1951 let alias = field.alias("min");
1952 self.ops.push(("MIN", db_name, alias));
1953 self
1954 }
1955
1956 pub fn max(mut self, field: #aggregate_field_name) -> Self {
1957 let db_name = field.db_name();
1958 let alias = field.alias("max");
1959 self.ops.push(("MAX", db_name, alias));
1960 self
1961 }
1962
1963 pub async fn exec(self) -> Result<#aggregate_result_name, FerriormError> {
1964 if self.ops.is_empty() {
1965 return Err(FerriormError::Query("No aggregate operations specified".into()));
1966 }
1967
1968 let selections: Vec<String> = self.ops.iter()
1969 .map(|(func, col, alias)| format!(r#"{}("{}") as "{}""#, func, col, alias))
1970 .collect();
1971 let select_clause = selections.join(", ");
1972 let base_sql = format!(#agg_select_base, select_clause);
1973
1974 match self.client {
1975 DatabaseClient::Postgres(_) => {
1976 let mut qb = sqlx::QueryBuilder::<sqlx::Postgres>::new(&base_sql);
1977 self.r#where.build_where(&mut qb);
1978 self.client.fetch_one_pg(qb).await
1979 }
1980 DatabaseClient::Sqlite(_) => {
1981 let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(&base_sql);
1982 self.r#where.build_where(&mut qb);
1983 self.client.fetch_one_sqlite(qb).await
1984 }
1985 }
1986 }
1987 }
1988 }
1989}
1990
1991#[allow(clippy::too_many_lines)]
1994fn gen_select_types(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
1995 let select_name = format_ident!("{}Select", model.name);
1996 let partial_name = format_ident!("{}Partial", model.name);
1997 let _where_input = format_ident!("{}WhereInput", model.name);
1998 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
1999 let order_by_name = format_ident!("{}OrderByInput", model.name);
2000 let table_name = &model.db_name;
2001
2002 let select_fields: Vec<TokenStream> = scalar_fields
2004 .iter()
2005 .map(|f| {
2006 let name = format_ident!("{}", to_snake_case(&f.name));
2007 quote! { pub #name: bool }
2008 })
2009 .collect();
2010
2011 let partial_fields: Vec<TokenStream> = scalar_fields
2014 .iter()
2015 .map(|f| {
2016 let name = format_ident!("{}", to_snake_case(&f.name));
2017 let db_name = &f.db_name;
2018 let base_ty = rust_type_tokens(
2020 &Field {
2021 is_optional: false,
2022 ..(*f).clone()
2023 },
2024 ModuleDepth::TopLevel,
2025 );
2026 let rename = if db_name == &to_snake_case(&f.name) {
2027 quote! {}
2028 } else {
2029 quote! { #[sqlx(rename = #db_name)] }
2030 };
2031 quote! { #[sqlx(default)] #rename pub #name: Option<#base_ty> }
2033 })
2034 .collect();
2035
2036 let select_col_arms: Vec<TokenStream> = scalar_fields
2038 .iter()
2039 .map(|f| {
2040 let name = format_ident!("{}", to_snake_case(&f.name));
2041 let db_name = &f.db_name;
2042 let col_expr = format!(r#""{db_name}""#);
2043 quote! {
2044 if select.#name { cols.push(#col_expr); }
2045 }
2046 })
2047 .collect();
2048
2049 let select_sql_prefix = format!(r#"SELECT {{}} FROM "{table_name}" WHERE 1=1"#);
2050
2051 quote! {
2052 #[derive(Debug, Clone, Default)]
2053 pub struct #select_name {
2054 #(#select_fields,)*
2055 }
2056
2057 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
2058 #[sqlx(rename_all = "snake_case")]
2059 pub struct #partial_name {
2060 #(#partial_fields,)*
2061 }
2062
2063 fn build_select_columns(select: &#select_name) -> String {
2064 let mut cols = Vec::new();
2065 #(#select_col_arms)*
2066 if cols.is_empty() {
2067 "*".to_string()
2068 } else {
2069 cols.join(", ")
2070 }
2071 }
2072
2073 pub struct FindManySelectQuery<'a> {
2076 client: &'a DatabaseClient,
2077 r#where: filter::#_where_input,
2078 order_by: Vec<order::#order_by_name>,
2079 skip: Option<i64>,
2080 take: Option<i64>,
2081 select: #select_name,
2082 }
2083
2084 impl<'a> FindManySelectQuery<'a> {
2085 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2086 self.order_by.push(order);
2087 self
2088 }
2089
2090 pub fn skip(mut self, n: i64) -> Self {
2091 self.skip = Some(n);
2092 self
2093 }
2094
2095 pub fn take(mut self, n: i64) -> Self {
2096 self.take = Some(n);
2097 self
2098 }
2099
2100 pub async fn exec(self) -> Result<Vec<#partial_name>, FerriormError> {
2101 let cols = build_select_columns(&self.select);
2102 let base_sql = format!(#select_sql_prefix, cols);
2103
2104 match self.client {
2105 DatabaseClient::Postgres(_) => {
2106 let qb = build_select_query::<sqlx::Postgres>(
2107 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2108 );
2109 self.client.fetch_all_pg(qb).await
2110 }
2111 DatabaseClient::Sqlite(_) => {
2112 let qb = build_select_query::<sqlx::Sqlite>(
2113 &base_sql, &self.r#where, &self.order_by, self.take, self.skip,
2114 );
2115 self.client.fetch_all_sqlite(qb).await
2116 }
2117 }
2118 }
2119 }
2120
2121 pub struct FindUniqueSelectQuery<'a> {
2124 client: &'a DatabaseClient,
2125 r#where: filter::#_where_unique,
2126 select: #select_name,
2127 }
2128
2129 impl<'a> FindUniqueSelectQuery<'a> {
2130 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2131 let cols = build_select_columns(&self.select);
2132 let base_sql = format!(#select_sql_prefix, cols);
2133
2134 match self.client {
2135 DatabaseClient::Postgres(_) => {
2136 let qb = build_unique_select_query::<sqlx::Postgres>(
2137 &base_sql, &self.r#where,
2138 );
2139 self.client.fetch_optional_pg(qb).await
2140 }
2141 DatabaseClient::Sqlite(_) => {
2142 let qb = build_unique_select_query::<sqlx::Sqlite>(
2143 &base_sql, &self.r#where,
2144 );
2145 self.client.fetch_optional_sqlite(qb).await
2146 }
2147 }
2148 }
2149 }
2150
2151 pub struct FindFirstSelectQuery<'a> {
2154 client: &'a DatabaseClient,
2155 r#where: filter::#_where_input,
2156 order_by: Vec<order::#order_by_name>,
2157 select: #select_name,
2158 }
2159
2160 impl<'a> FindFirstSelectQuery<'a> {
2161 pub fn order_by(mut self, order: order::#order_by_name) -> Self {
2162 self.order_by.push(order);
2163 self
2164 }
2165
2166 pub async fn exec(self) -> Result<Option<#partial_name>, FerriormError> {
2167 let cols = build_select_columns(&self.select);
2168 let base_sql = format!(#select_sql_prefix, cols);
2169
2170 match self.client {
2171 DatabaseClient::Postgres(_) => {
2172 let qb = build_select_query::<sqlx::Postgres>(
2173 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2174 );
2175 self.client.fetch_optional_pg(qb).await
2176 }
2177 DatabaseClient::Sqlite(_) => {
2178 let qb = build_select_query::<sqlx::Sqlite>(
2179 &base_sql, &self.r#where, &self.order_by, Some(1), None,
2180 );
2181 self.client.fetch_optional_sqlite(qb).await
2182 }
2183 }
2184 }
2185 }
2186 }
2187}