1use ferriorm_core::schema::{Field, FieldKind, Model};
15use ferriorm_core::types::ScalarType;
16use ferriorm_core::utils::{to_pascal_case, to_snake_case};
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote};
19
20use crate::rust_type::{ModuleDepth, filter_type_tokens, rust_type_tokens};
21
22pub fn generate_model_module(model: &Model) -> TokenStream {
24 let scalar_fields: Vec<&Field> = model.fields.iter().filter(|f| f.is_scalar()).collect();
25
26 let data_struct = gen_data_struct(model, &scalar_fields);
27 let filter_module = gen_filter_module(model, &scalar_fields);
28 let data_module = gen_data_module(model, &scalar_fields);
29 let order_module = gen_order_module(model, &scalar_fields);
30 let actions_struct = gen_actions(model);
31 let query_builders = gen_query_builders(model, &scalar_fields);
32
33 quote! {
34 #![allow(unused_imports, dead_code, clippy::all, unused_variables)]
35
36 use serde::{Deserialize, Serialize};
37 use ferriorm_runtime::prelude::*;
38
39 #data_struct
40 #filter_module
41 #data_module
42 #order_module
43 #actions_struct
44 #query_builders
45 }
46}
47
48fn gen_data_struct(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
51 let struct_name = format_ident!("{}", model.name);
52 let table_name = &model.db_name;
53
54 let fields: Vec<TokenStream> = scalar_fields
55 .iter()
56 .map(|f| {
57 let name = format_ident!("{}", to_snake_case(&f.name));
58 let ty = rust_type_tokens(f, ModuleDepth::TopLevel);
59 let db_name = &f.db_name;
60 if db_name != &to_snake_case(&f.name) {
61 quote! { #[sqlx(rename = #db_name)] pub #name: #ty }
62 } else {
63 quote! { pub #name: #ty }
64 }
65 })
66 .collect();
67
68 quote! {
69 #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
70 #[sqlx(rename_all = "snake_case")]
71 pub struct #struct_name {
72 #(#fields),*
73 }
74
75 impl #struct_name {
76 pub const TABLE_NAME: &'static str = #table_name;
77 }
78 }
79}
80
81fn gen_filter_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
84 let where_input = format_ident!("{}WhereInput", model.name);
85 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
86
87 let where_fields: Vec<TokenStream> = scalar_fields
88 .iter()
89 .filter_map(|f| {
90 let filter_ty = filter_type_tokens(f, ModuleDepth::Nested)?;
91 let name = format_ident!("{}", to_snake_case(&f.name));
92 Some(quote! { pub #name: Option<#filter_ty> })
93 })
94 .collect();
95
96 let unique_variants: Vec<TokenStream> = scalar_fields
97 .iter()
98 .filter(|f| f.is_id || f.is_unique)
99 .map(|f| {
100 let variant = format_ident!("{}", to_pascal_case(&f.name));
101 let ty = rust_type_tokens(f, ModuleDepth::Nested);
102 quote! { #variant(#ty) }
103 })
104 .collect();
105
106 let db_bounds = collect_db_bounds(scalar_fields);
108 let where_arms = gen_where_arms(scalar_fields);
109 let unique_arms = gen_unique_where_arms(scalar_fields);
110
111 quote! {
112 pub mod filter {
113 use ferriorm_runtime::prelude::*;
114
115 #[derive(Debug, Clone, Default)]
116 pub struct #where_input {
117 #(#where_fields,)*
118 pub and: Option<Vec<#where_input>>,
119 pub or: Option<Vec<#where_input>>,
120 pub not: Option<Box<#where_input>>,
121 }
122
123 #[derive(Debug, Clone)]
124 pub enum #where_unique {
125 #(#unique_variants),*
126 }
127
128 impl #where_input {
129 pub(crate) fn build_where<'args, DB: sqlx::Database>(
130 &self,
131 qb: &mut sqlx::QueryBuilder<'args, DB>,
132 )
133 where
134 #(#db_bounds,)*
135 {
136 #(#where_arms)*
137
138 if let Some(conditions) = &self.and {
139 for c in conditions {
140 c.build_where(qb);
141 }
142 }
143 if let Some(conditions) = &self.or {
144 if !conditions.is_empty() {
145 qb.push(" AND (");
146 for (i, c) in conditions.iter().enumerate() {
147 if i > 0 { qb.push(" OR "); }
148 qb.push("(1=1");
149 c.build_where(qb);
150 qb.push(")");
151 }
152 qb.push(")");
153 }
154 }
155 if let Some(c) = &self.not {
156 qb.push(" AND NOT (1=1");
157 c.build_where(qb);
158 qb.push(")");
159 }
160 }
161 }
162
163 impl #where_unique {
164 pub(crate) fn build_where<'args, DB: sqlx::Database>(
165 &self,
166 qb: &mut sqlx::QueryBuilder<'args, DB>,
167 )
168 where
169 #(#db_bounds,)*
170 {
171 match self {
172 #(#unique_arms)*
173 }
174 }
175 }
176 }
177 }
178}
179
180fn collect_db_bounds(scalar_fields: &[&Field]) -> Vec<TokenStream> {
182 let mut seen = std::collections::HashSet::new();
183 let mut bounds = Vec::new();
184
185 seen.insert("i64");
187 bounds.push(quote! { i64: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
188
189 for f in scalar_fields {
190 match &f.field_type {
191 FieldKind::Scalar(scalar) => {
192 let key = scalar.rust_type();
193 if seen.insert(key)
194 && let Some(ty) = scalar_bound_tokens(scalar)
195 {
196 bounds.push(quote! { #ty: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> });
197 bounds.push(
199 quote! { Option<#ty>: sqlx::Type<DB> + for<'e> sqlx::Encode<'e, DB> },
200 );
201 }
202 }
203 FieldKind::Enum(_) => {}
204 _ => {}
205 }
206 }
207
208 bounds
209}
210
211fn scalar_bound_tokens(scalar: &ScalarType) -> Option<TokenStream> {
212 match scalar {
213 ScalarType::String => Some(quote! { String }),
214 ScalarType::Int => Some(quote! { i32 }),
215 ScalarType::BigInt => Some(quote! { i64 }),
216 ScalarType::Float => Some(quote! { f64 }),
217 ScalarType::Boolean => Some(quote! { bool }),
218 ScalarType::DateTime => Some(quote! { chrono::DateTime<chrono::Utc> }),
219 ScalarType::Bytes => Some(quote! { Vec<u8> }),
220 ScalarType::Json | ScalarType::Decimal => None,
221 }
222}
223
224fn gen_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
226 scalar_fields
227 .iter()
228 .filter_map(|f| {
229 if !matches!(&f.field_type, FieldKind::Scalar(_)) {
231 return None;
232 }
233 let field_ident = format_ident!("{}", to_snake_case(&f.name));
234 let db_name = &f.db_name;
235 let is_string = matches!(&f.field_type, FieldKind::Scalar(ScalarType::String));
236 let is_comparable = matches!(
237 &f.field_type,
238 FieldKind::Scalar(
239 ScalarType::Int | ScalarType::BigInt | ScalarType::Float | ScalarType::DateTime
240 )
241 );
242
243 let mut arms = vec![];
244
245 arms.push(quote! {
246 if let Some(v) = &filter.equals {
247 qb.push(concat!(" AND \"", #db_name, "\" = "));
248 qb.push_bind(v.clone());
249 }
250 if let Some(v) = &filter.not {
251 qb.push(concat!(" AND \"", #db_name, "\" != "));
252 qb.push_bind(v.clone());
253 }
254 });
255
256 if is_string {
257 arms.push(quote! {
258 if let Some(v) = &filter.contains {
259 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
260 qb.push_bind(format!("%{}%", v));
261 }
262 if let Some(v) = &filter.starts_with {
263 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
264 qb.push_bind(format!("{}%", v));
265 }
266 if let Some(v) = &filter.ends_with {
267 qb.push(concat!(" AND \"", #db_name, "\" LIKE "));
268 qb.push_bind(format!("%{}", v));
269 }
270 });
271 }
272
273 if is_comparable {
274 arms.push(quote! {
275 if let Some(v) = &filter.gt {
276 qb.push(concat!(" AND \"", #db_name, "\" > "));
277 qb.push_bind(v.clone());
278 }
279 if let Some(v) = &filter.gte {
280 qb.push(concat!(" AND \"", #db_name, "\" >= "));
281 qb.push_bind(v.clone());
282 }
283 if let Some(v) = &filter.lt {
284 qb.push(concat!(" AND \"", #db_name, "\" < "));
285 qb.push_bind(v.clone());
286 }
287 if let Some(v) = &filter.lte {
288 qb.push(concat!(" AND \"", #db_name, "\" <= "));
289 qb.push_bind(v.clone());
290 }
291 });
292 }
293
294 Some(quote! {
295 if let Some(filter) = &self.#field_ident {
296 #(#arms)*
297 }
298 })
299 })
300 .collect()
301}
302
303fn gen_unique_where_arms(scalar_fields: &[&Field]) -> Vec<TokenStream> {
304 let _where_unique = format_ident!(
305 "{}WhereUniqueInput",
306 "" );
308 scalar_fields
309 .iter()
310 .filter(|f| f.is_id || f.is_unique)
311 .map(|f| {
312 let variant = format_ident!("{}", to_pascal_case(&f.name));
313 let db_name = &f.db_name;
314 quote! {
315 Self::#variant(v) => {
316 qb.push(concat!(" AND \"", #db_name, "\" = "));
317 qb.push_bind(v.clone());
318 }
319 }
320 })
321 .collect()
322}
323
324fn gen_data_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
327 let create_name = format_ident!("{}CreateInput", model.name);
328 let update_name = format_ident!("{}UpdateInput", model.name);
329
330 let required_fields: Vec<TokenStream> = scalar_fields
331 .iter()
332 .filter(|f| !f.has_default() && !f.is_updated_at)
333 .map(|f| {
334 let name = format_ident!("{}", to_snake_case(&f.name));
335 let ty = rust_type_tokens(f, ModuleDepth::Nested);
336 quote! { pub #name: #ty }
337 })
338 .collect();
339
340 let optional_fields: Vec<TokenStream> = scalar_fields
341 .iter()
342 .filter(|f| f.has_default() && !f.is_updated_at)
343 .map(|f| {
344 let name = format_ident!("{}", to_snake_case(&f.name));
345 let base_ty = rust_type_tokens(f, ModuleDepth::Nested);
346 quote! { pub #name: Option<#base_ty> }
347 })
348 .collect();
349
350 let update_fields: Vec<TokenStream> = scalar_fields
351 .iter()
352 .filter(|f| !f.is_id && !f.is_updated_at)
353 .map(|f| {
354 let name = format_ident!("{}", to_snake_case(&f.name));
355 let ty = rust_type_tokens(f, ModuleDepth::Nested);
356 quote! { pub #name: Option<SetValue<#ty>> }
357 })
358 .collect();
359
360 quote! {
361 pub mod data {
362 use ferriorm_runtime::prelude::*;
363
364 #[derive(Debug, Clone)]
365 pub struct #create_name {
366 #(#required_fields,)*
367 #(#optional_fields,)*
368 }
369
370 #[derive(Debug, Clone, Default)]
371 pub struct #update_name {
372 #(#update_fields,)*
373 }
374 }
375 }
376}
377
378fn gen_order_module(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
381 let order_name = format_ident!("{}OrderByInput", model.name);
382
383 let variants: Vec<TokenStream> = scalar_fields
384 .iter()
385 .map(|f| {
386 let variant = format_ident!("{}", to_pascal_case(&f.name));
387 quote! { #variant(SortOrder) }
388 })
389 .collect();
390
391 let order_arms: Vec<TokenStream> = scalar_fields
392 .iter()
393 .map(|f| {
394 let variant = format_ident!("{}", to_pascal_case(&f.name));
395 let db_name = &f.db_name;
396 quote! {
397 Self::#variant(order) => {
398 qb.push(concat!("\"", #db_name, "\" "));
399 qb.push(order.as_sql());
400 }
401 }
402 })
403 .collect();
404
405 quote! {
406 pub mod order {
407 use ferriorm_runtime::prelude::*;
408
409 #[derive(Debug, Clone)]
410 pub enum #order_name {
411 #(#variants),*
412 }
413
414 impl #order_name {
415 pub(crate) fn build_order_by<'args, DB: sqlx::Database>(
416 &self,
417 qb: &mut sqlx::QueryBuilder<'args, DB>,
418 ) {
419 match self {
420 #(#order_arms)*
421 }
422 }
423 }
424 }
425 }
426}
427
428fn gen_actions(model: &Model) -> TokenStream {
431 let _model_ident = format_ident!("{}", model.name);
432 let actions_name = format_ident!("{}Actions", model.name);
433 let where_input = format_ident!("{}WhereInput", model.name);
434 let where_unique = format_ident!("{}WhereUniqueInput", model.name);
435 let create_input = format_ident!("{}CreateInput", model.name);
436 let update_input = format_ident!("{}UpdateInput", model.name);
437 let _order_by = format_ident!("{}OrderByInput", model.name);
438
439 quote! {
440 pub struct #actions_name<'a> {
441 client: &'a DatabaseClient,
442 }
443
444 impl<'a> #actions_name<'a> {
445 pub fn new(client: &'a DatabaseClient) -> Self { Self { client } }
446
447 pub fn find_unique(&self, r#where: filter::#where_unique) -> FindUniqueQuery<'a> {
448 FindUniqueQuery { client: self.client, r#where }
449 }
450
451 pub fn find_first(&self, r#where: filter::#where_input) -> FindFirstQuery<'a> {
452 FindFirstQuery { client: self.client, r#where, order_by: vec![] }
453 }
454
455 pub fn find_many(&self, r#where: filter::#where_input) -> FindManyQuery<'a> {
456 FindManyQuery { client: self.client, r#where, order_by: vec![], skip: None, take: None }
457 }
458
459 pub fn create(&self, data: data::#create_input) -> CreateQuery<'a> {
460 CreateQuery { client: self.client, data }
461 }
462
463 pub fn update(&self, r#where: filter::#where_unique, data: data::#update_input) -> UpdateQuery<'a> {
464 UpdateQuery { client: self.client, r#where, data }
465 }
466
467 pub fn delete(&self, r#where: filter::#where_unique) -> DeleteQuery<'a> {
468 DeleteQuery { client: self.client, r#where }
469 }
470
471 pub fn count(&self, r#where: filter::#where_input) -> CountQuery<'a> {
472 CountQuery { client: self.client, r#where }
473 }
474
475 pub fn create_many(&self, data: Vec<data::#create_input>) -> CreateManyQuery<'a> {
476 CreateManyQuery { client: self.client, data }
477 }
478
479 pub fn update_many(&self, r#where: filter::#where_input, data: data::#update_input) -> UpdateManyQuery<'a> {
480 UpdateManyQuery { client: self.client, r#where, data }
481 }
482
483 pub fn delete_many(&self, r#where: filter::#where_input) -> DeleteManyQuery<'a> {
484 DeleteManyQuery { client: self.client, r#where }
485 }
486 }
487 }
488}
489
490fn gen_query_builders(model: &Model, scalar_fields: &[&Field]) -> TokenStream {
493 let model_ident = format_ident!("{}", model.name);
494 let table_name = &model.db_name;
495 let _where_input = format_ident!("{}WhereInput", model.name);
496 let _where_unique = format_ident!("{}WhereUniqueInput", model.name);
497 let _create_input = format_ident!("{}CreateInput", model.name);
498 let _update_input = format_ident!("{}UpdateInput", model.name);
499 let order_by = format_ident!("{}OrderByInput", model.name);
500 let db_bounds = collect_db_bounds(scalar_fields);
501
502 let select_sql = format!(r#"SELECT * FROM "{}" WHERE 1=1"#, table_name);
503 let count_sql = format!(
504 r#"SELECT COUNT(*) as "count" FROM "{}" WHERE 1=1"#,
505 table_name
506 );
507 let delete_sql = format!(r#"DELETE FROM "{}" WHERE 1=1"#, table_name);
508
509 let insert_code = gen_insert_code(model, scalar_fields, table_name);
510 let update_code = gen_update_code(model, scalar_fields, table_name);
511 let update_many_code = gen_update_many_code(model, scalar_fields, table_name);
512
513 quote! {
514 fn build_order_by<'args, DB: sqlx::Database>(
516 orders: &[order::#order_by],
517 qb: &mut sqlx::QueryBuilder<'args, DB>,
518 ) {
519 if !orders.is_empty() {
520 qb.push(" ORDER BY ");
521 for (i, ob) in orders.iter().enumerate() {
522 if i > 0 { qb.push(", "); }
523 ob.build_order_by(qb);
524 }
525 }
526 }
527
528 fn build_select_query<'args, DB: sqlx::Database>(
530 base_sql: &str,
531 where_input: &filter::#_where_input,
532 orders: &[order::#order_by],
533 take: Option<i64>,
534 skip: Option<i64>,
535 ) -> sqlx::QueryBuilder<'args, DB>
536 where
537 #(#db_bounds,)*
538 {
539 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
540 where_input.build_where(&mut qb);
541 build_order_by(orders, &mut qb);
542 if let Some(take) = take {
543 qb.push(" LIMIT ");
544 qb.push_bind(take);
545 }
546 if let Some(skip) = skip {
547 qb.push(" OFFSET ");
548 qb.push_bind(skip);
549 }
550 qb
551 }
552
553 fn build_unique_select_query<'args, DB: sqlx::Database>(
555 base_sql: &str,
556 where_unique: &filter::#_where_unique,
557 ) -> sqlx::QueryBuilder<'args, DB>
558 where
559 #(#db_bounds,)*
560 {
561 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
562 where_unique.build_where(&mut qb);
563 qb.push(" LIMIT 1");
564 qb
565 }
566
567 fn build_delete_query<'args, DB: sqlx::Database>(
569 base_sql: &str,
570 where_unique: &filter::#_where_unique,
571 ) -> sqlx::QueryBuilder<'args, DB>
572 where
573 #(#db_bounds,)*
574 {
575 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
576 where_unique.build_where(&mut qb);
577 qb.push(" RETURNING *");
578 qb
579 }
580
581 fn build_count_query<'args, DB: sqlx::Database>(
583 base_sql: &str,
584 where_input: &filter::#_where_input,
585 ) -> sqlx::QueryBuilder<'args, DB>
586 where
587 #(#db_bounds,)*
588 {
589 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
590 where_input.build_where(&mut qb);
591 qb
592 }
593
594 fn build_delete_many_query<'args, DB: sqlx::Database>(
596 base_sql: &str,
597 where_input: &filter::#_where_input,
598 ) -> sqlx::QueryBuilder<'args, DB>
599 where
600 #(#db_bounds,)*
601 {
602 let mut qb = sqlx::QueryBuilder::<DB>::new(base_sql);
603 where_input.build_where(&mut qb);
604 qb
605 }
606
607 pub struct FindUniqueQuery<'a> {
608 client: &'a DatabaseClient,
609 r#where: filter::#_where_unique,
610 }
611
612 impl<'a> FindUniqueQuery<'a> {
613 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
614 match self.client {
615 DatabaseClient::Postgres(_) => {
616 let qb = build_unique_select_query::<sqlx::Postgres>(#select_sql, &self.r#where);
617 self.client.fetch_optional_pg(qb).await
618 }
619 DatabaseClient::Sqlite(_) => {
620 let qb = build_unique_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where);
621 self.client.fetch_optional_sqlite(qb).await
622 }
623 }
624 }
625 }
626
627 pub struct FindFirstQuery<'a> {
628 client: &'a DatabaseClient,
629 r#where: filter::#_where_input,
630 order_by: Vec<order::#order_by>,
631 }
632
633 impl<'a> FindFirstQuery<'a> {
634 pub fn order_by(mut self, order: order::#order_by) -> Self {
635 self.order_by.push(order);
636 self
637 }
638
639 pub async fn exec(self) -> Result<Option<#model_ident>, FerriormError> {
640 match self.client {
641 DatabaseClient::Postgres(_) => {
642 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
643 self.client.fetch_optional_pg(qb).await
644 }
645 DatabaseClient::Sqlite(_) => {
646 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, Some(1), None);
647 self.client.fetch_optional_sqlite(qb).await
648 }
649 }
650 }
651 }
652
653 pub struct FindManyQuery<'a> {
654 client: &'a DatabaseClient,
655 r#where: filter::#_where_input,
656 order_by: Vec<order::#order_by>,
657 skip: Option<i64>,
658 take: Option<i64>,
659 }
660
661 impl<'a> FindManyQuery<'a> {
662 pub fn order_by(mut self, order: order::#order_by) -> Self {
663 self.order_by.push(order);
664 self
665 }
666
667 pub fn skip(mut self, n: i64) -> Self {
668 self.skip = Some(n);
669 self
670 }
671
672 pub fn take(mut self, n: i64) -> Self {
673 self.take = Some(n);
674 self
675 }
676
677 pub async fn exec(self) -> Result<Vec<#model_ident>, FerriormError> {
678 match self.client {
679 DatabaseClient::Postgres(_) => {
680 let qb = build_select_query::<sqlx::Postgres>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
681 self.client.fetch_all_pg(qb).await
682 }
683 DatabaseClient::Sqlite(_) => {
684 let qb = build_select_query::<sqlx::Sqlite>(#select_sql, &self.r#where, &self.order_by, self.take, self.skip);
685 self.client.fetch_all_sqlite(qb).await
686 }
687 }
688 }
689 }
690
691 pub struct CreateQuery<'a> {
692 client: &'a DatabaseClient,
693 data: data::#_create_input,
694 }
695
696 impl<'a> CreateQuery<'a> {
697 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
698 let client = self.client;
699 #insert_code
700 }
701 }
702
703 pub struct UpdateQuery<'a> {
704 client: &'a DatabaseClient,
705 r#where: filter::#_where_unique,
706 data: data::#_update_input,
707 }
708
709 impl<'a> UpdateQuery<'a> {
710 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
711 let client = self.client;
712 #update_code
713 }
714 }
715
716 pub struct DeleteQuery<'a> {
717 client: &'a DatabaseClient,
718 r#where: filter::#_where_unique,
719 }
720
721 impl<'a> DeleteQuery<'a> {
722 pub async fn exec(self) -> Result<#model_ident, FerriormError> {
723 match self.client {
724 DatabaseClient::Postgres(_) => {
725 let qb = build_delete_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
726 self.client.fetch_one_pg(qb).await
727 }
728 DatabaseClient::Sqlite(_) => {
729 let qb = build_delete_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
730 self.client.fetch_one_sqlite(qb).await
731 }
732 }
733 }
734 }
735
736 #[derive(sqlx::FromRow)]
737 struct CountResult { count: i64 }
738
739 pub struct CountQuery<'a> {
740 client: &'a DatabaseClient,
741 r#where: filter::#_where_input,
742 }
743
744 impl<'a> CountQuery<'a> {
745 pub async fn exec(self) -> Result<i64, FerriormError> {
746 let row: CountResult = match self.client {
747 DatabaseClient::Postgres(_) => {
748 let qb = build_count_query::<sqlx::Postgres>(#count_sql, &self.r#where);
749 self.client.fetch_one_pg(qb).await?
750 }
751 DatabaseClient::Sqlite(_) => {
752 let qb = build_count_query::<sqlx::Sqlite>(#count_sql, &self.r#where);
753 self.client.fetch_one_sqlite(qb).await?
754 }
755 };
756 Ok(row.count)
757 }
758 }
759
760 pub struct CreateManyQuery<'a> {
761 client: &'a DatabaseClient,
762 data: Vec<data::#_create_input>,
763 }
764
765 impl<'a> CreateManyQuery<'a> {
766 pub async fn exec(self) -> Result<u64, FerriormError> {
767 if self.data.is_empty() { return Ok(0); }
768 let count = self.data.len() as u64;
769 for item in self.data {
770 CreateQuery { client: self.client, data: item }.exec().await?;
771 }
772 Ok(count)
773 }
774 }
775
776 pub struct UpdateManyQuery<'a> {
777 client: &'a DatabaseClient,
778 r#where: filter::#_where_input,
779 data: data::#_update_input,
780 }
781
782 impl<'a> UpdateManyQuery<'a> {
783 pub async fn exec(self) -> Result<u64, FerriormError> {
784 let client = self.client;
785 #update_many_code
786 }
787 }
788
789 pub struct DeleteManyQuery<'a> {
790 client: &'a DatabaseClient,
791 r#where: filter::#_where_input,
792 }
793
794 impl<'a> DeleteManyQuery<'a> {
795 pub async fn exec(self) -> Result<u64, FerriormError> {
796 match self.client {
797 DatabaseClient::Postgres(_) => {
798 let qb = build_delete_many_query::<sqlx::Postgres>(#delete_sql, &self.r#where);
799 self.client.execute_pg(qb).await
800 }
801 DatabaseClient::Sqlite(_) => {
802 let qb = build_delete_many_query::<sqlx::Sqlite>(#delete_sql, &self.r#where);
803 self.client.execute_sqlite(qb).await
804 }
805 }
806 }
807 }
808 }
809}
810
811fn gen_insert_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
814 let _model_ident = format_ident!("{}", model.name);
815
816 let required: Vec<&Field> = scalar_fields
818 .iter()
819 .copied()
820 .filter(|f| !f.has_default() && !f.is_updated_at)
821 .collect();
822
823 let optional: Vec<&Field> = scalar_fields
825 .iter()
826 .copied()
827 .filter(|f| f.has_default() && !f.is_updated_at)
828 .collect();
829
830 let updated_at: Vec<&Field> = scalar_fields
832 .iter()
833 .copied()
834 .filter(|f| f.is_updated_at)
835 .collect();
836
837 let mut col_pushes = vec![];
839 let mut val_pushes = vec![];
840
841 for f in &required {
843 let db_name = &f.db_name;
844 let field_ident = format_ident!("{}", to_snake_case(&f.name));
845 col_pushes.push(quote! { cols.push(#db_name); });
846 val_pushes.push(quote! { sep.push_bind(self.data.#field_ident); });
847 }
848
849 for f in &optional {
851 let db_name = &f.db_name;
852 let field_ident = format_ident!("{}", to_snake_case(&f.name));
853 let default_expr = gen_default_expr(f);
854
855 col_pushes.push(quote! { cols.push(#db_name); });
856 val_pushes.push(quote! {
857 let val = self.data.#field_ident.unwrap_or_else(|| #default_expr);
858 sep.push_bind(val);
859 });
860 }
861
862 for f in &updated_at {
864 let db_name = &f.db_name;
865 col_pushes.push(quote! { cols.push(#db_name); });
866 val_pushes.push(quote! { sep.push_bind(chrono::Utc::now()); });
867 }
868
869 let insert_start = format!(r#"INSERT INTO "{}""#, table_name);
870
871 quote! {
874 macro_rules! build_insert {
876 ($qb_type:ty) => {{
877 let mut cols: Vec<&str> = Vec::new();
878 #(#col_pushes)*
879
880 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#insert_start);
881 qb.push(" (");
882 for (i, col) in cols.iter().enumerate() {
883 if i > 0 { qb.push(", "); }
884 qb.push("\"");
885 qb.push(*col);
886 qb.push("\"");
887 }
888 qb.push(") VALUES (");
889 {
890 let mut sep = qb.separated(", ");
891 #(#val_pushes)*
892 }
893 qb.push(") RETURNING *");
894 qb
895 }};
896 }
897
898 match client {
899 DatabaseClient::Postgres(_) => {
900 let qb = build_insert!(sqlx::Postgres);
901 client.fetch_one_pg(qb).await
902 }
903 DatabaseClient::Sqlite(_) => {
904 let qb = build_insert!(sqlx::Sqlite);
905 client.fetch_one_sqlite(qb).await
906 }
907 }
908 }
909}
910
911fn gen_default_expr(field: &Field) -> TokenStream {
913 use ferriorm_core::ast::DefaultValue;
914
915 match &field.default {
916 Some(DefaultValue::Uuid) => quote! { uuid::Uuid::new_v4().to_string() },
917 Some(DefaultValue::Cuid) => quote! { uuid::Uuid::new_v4().to_string() }, Some(DefaultValue::Now) => quote! { chrono::Utc::now() },
919 Some(DefaultValue::AutoIncrement) => quote! { 0 }, Some(DefaultValue::Literal(lit)) => {
921 use ferriorm_core::ast::LiteralValue;
922 match lit {
923 LiteralValue::String(s) => quote! { #s.to_string() },
924 LiteralValue::Int(i) => quote! { #i },
925 LiteralValue::Float(f) => quote! { #f },
926 LiteralValue::Bool(b) => quote! { #b },
927 }
928 }
929 Some(DefaultValue::EnumVariant(v)) => {
930 let variant = format_ident!("{}", v);
932 if let FieldKind::Enum(enum_name) = &field.field_type {
933 let enum_ident = format_ident!("{}", enum_name);
934 quote! { super::enums::#enum_ident::#variant }
935 } else {
936 quote! { Default::default() }
937 }
938 }
939 None => quote! { Default::default() },
940 }
941}
942
943fn gen_update_code(model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
946 let _model_ident = format_ident!("{}", model.name);
947
948 let updatable: Vec<&Field> = scalar_fields
950 .iter()
951 .copied()
952 .filter(|f| !f.is_id && !f.is_updated_at)
953 .collect();
954
955 let updated_at: Vec<&Field> = scalar_fields
956 .iter()
957 .copied()
958 .filter(|f| f.is_updated_at)
959 .collect();
960
961 let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
962
963 let set_arms: Vec<TokenStream> = updatable
965 .iter()
966 .map(|f| {
967 let field_ident = format_ident!("{}", to_snake_case(&f.name));
968 let db_name = &f.db_name;
969 quote! {
970 if let Some(SetValue::Set(v)) = self.data.#field_ident {
971 if !first_set { qb.push(", "); }
972 first_set = false;
973 qb.push(concat!("\"", #db_name, "\" = "));
974 qb.push_bind(v);
975 }
976 }
977 })
978 .collect();
979
980 let updated_at_arms: Vec<TokenStream> = updated_at
981 .iter()
982 .map(|f| {
983 let db_name = &f.db_name;
984 quote! {
985 if !first_set { qb.push(", "); }
986 first_set = false;
987 qb.push(concat!("\"", #db_name, "\" = "));
988 qb.push_bind(chrono::Utc::now());
989 }
990 })
991 .collect();
992
993 quote! {
996 macro_rules! build_update {
997 ($qb_type:ty) => {{
998 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
999 let mut first_set = true;
1000 #(#set_arms)*
1001 #(#updated_at_arms)*
1002
1003 if first_set {
1004 return Err(FerriormError::Query("No fields to update".into()));
1005 }
1006
1007 qb.push(" WHERE 1=1");
1008 self.r#where.build_where(&mut qb);
1009 qb.push(" RETURNING *");
1010 qb
1011 }};
1012 }
1013
1014 match client {
1015 DatabaseClient::Postgres(_) => {
1016 let qb = build_update!(sqlx::Postgres);
1017 client.fetch_one_pg(qb).await
1018 }
1019 DatabaseClient::Sqlite(_) => {
1020 let qb = build_update!(sqlx::Sqlite);
1021 client.fetch_one_sqlite(qb).await
1022 }
1023 }
1024 }
1025}
1026
1027fn gen_update_many_code(_model: &Model, scalar_fields: &[&Field], table_name: &str) -> TokenStream {
1030 let updatable: Vec<&Field> = scalar_fields
1032 .iter()
1033 .copied()
1034 .filter(|f| !f.is_id && !f.is_updated_at)
1035 .collect();
1036
1037 let updated_at: Vec<&Field> = scalar_fields
1038 .iter()
1039 .copied()
1040 .filter(|f| f.is_updated_at)
1041 .collect();
1042
1043 let update_start = format!(r#"UPDATE "{}" SET "#, table_name);
1044
1045 let set_arms: Vec<TokenStream> = updatable
1047 .iter()
1048 .map(|f| {
1049 let field_ident = format_ident!("{}", to_snake_case(&f.name));
1050 let db_name = &f.db_name;
1051 quote! {
1052 if let Some(SetValue::Set(v)) = self.data.#field_ident {
1053 if !first_set { qb.push(", "); }
1054 first_set = false;
1055 qb.push(concat!("\"", #db_name, "\" = "));
1056 qb.push_bind(v);
1057 }
1058 }
1059 })
1060 .collect();
1061
1062 let updated_at_arms: Vec<TokenStream> = updated_at
1063 .iter()
1064 .map(|f| {
1065 let db_name = &f.db_name;
1066 quote! {
1067 if !first_set { qb.push(", "); }
1068 first_set = false;
1069 qb.push(concat!("\"", #db_name, "\" = "));
1070 qb.push_bind(chrono::Utc::now());
1071 }
1072 })
1073 .collect();
1074
1075 quote! {
1076 macro_rules! build_update_many {
1077 ($qb_type:ty) => {{
1078 let mut qb = sqlx::QueryBuilder::<$qb_type>::new(#update_start);
1079 let mut first_set = true;
1080 #(#set_arms)*
1081 #(#updated_at_arms)*
1082
1083 if first_set {
1084 return Ok(0);
1085 }
1086
1087 qb.push(" WHERE 1=1");
1088 self.r#where.build_where(&mut qb);
1089 qb
1090 }};
1091 }
1092
1093 match client {
1094 DatabaseClient::Postgres(_) => {
1095 let qb = build_update_many!(sqlx::Postgres);
1096 client.execute_pg(qb).await
1097 }
1098 DatabaseClient::Sqlite(_) => {
1099 let qb = build_update_many!(sqlx::Sqlite);
1100 client.execute_sqlite(qb).await
1101 }
1102 }
1103 }
1104}