1use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13 escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14 generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24#[derive(Debug, Clone)]
26struct MethodSignature {
27 columns: Vec<String>,
28 method_name: String,
29 priority: u8,
30 is_unique: bool,
31 source: String,
32}
33
34impl MethodSignature {
35 fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36 let method_name = generate_find_by_method_name(&columns);
37 Self {
38 columns,
39 method_name,
40 priority,
41 is_unique,
42 source: source.to_string(),
43 }
44 }
45}
46
47pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49 let output_dir = &config.output_dao_dir;
50 fs::create_dir_all(output_dir)?;
51
52 let mut mod_content = String::new();
54 mod_content.push_str("// Generated DAO functions\n\n");
55
56 for table in tables {
57 let file_name = heck::AsSnakeCase(&table.name).to_string();
58 mod_content.push_str(&format!("pub mod {};\n", file_name));
59 }
60
61 let mod_path = output_dir.join("mod.rs");
62 fs::write(&mod_path, mod_content)?;
63
64 for table in tables {
66 generate_dao_file(table, output_dir, &config.models_module)?;
67 }
68
69 Ok(())
70}
71
72fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
74 let struct_name = to_struct_name(&table.name);
75 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
76 debug!("Generating DAO for {} -> {}", struct_name, file_name);
77
78 let mut code = String::new();
79
80 let column_map: HashMap<&str, &ColumnMetadata> =
82 table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
83
84 let mut needs_chrono = false;
86 let mut needs_decimal = false;
87
88 for col in &table.columns {
89 let rust_type = TypeResolver::resolve(col, &table.name);
90 if rust_type.needs_chrono() {
91 needs_chrono = true;
92 }
93 if rust_type.needs_decimal() {
94 needs_decimal = true;
95 }
96 }
97
98 let has_enums = table.columns.iter().any(|c| c.is_enum());
100
101 code.push_str("use rdbi::{Pool, Query, Result};\n");
103 code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
104
105 if has_enums {
106 for col in &table.columns {
108 if col.is_enum() {
109 let enum_name = super::naming::to_enum_name(&table.name, &col.name);
110 code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
111 }
112 }
113 }
114
115 if needs_chrono {
116 code.push_str("#[allow(unused_imports)]\n");
117 code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
118 }
119 if needs_decimal {
120 code.push_str("#[allow(unused_imports)]\n");
121 code.push_str("use rust_decimal::Decimal;\n");
122 }
123
124 code.push('\n');
125
126 let select_columns = build_select_columns(table);
128
129 code.push_str(&generate_find_all(table, &struct_name, &select_columns));
131
132 code.push_str(&generate_count_all(table));
134
135 if let Some(pk) = &table.primary_key {
137 code.push_str(&generate_pk_methods(
138 table,
139 pk,
140 &column_map,
141 &struct_name,
142 &select_columns,
143 ));
144 }
145
146 code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
148
149 code.push_str(&generate_insert_plain_method(table, &column_map));
151
152 code.push_str(&generate_insert_all_method(table, &struct_name));
154
155 code.push_str(&generate_upsert_method(table, &struct_name));
157
158 if table.primary_key.is_some() {
160 code.push_str(&generate_update_methods(table, &struct_name, &column_map));
161 code.push_str(&generate_update_plain_method(table, &column_map));
163 }
164
165 let signatures = collect_method_signatures(table);
167 for sig in signatures.values() {
168 if sig.source == "PRIMARY_KEY" {
170 continue;
171 }
172 code.push_str(&generate_find_by_method(
173 table,
174 sig,
175 &column_map,
176 &struct_name,
177 &select_columns,
178 ));
179 }
180
181 code.push_str(&generate_find_by_list_methods(
183 table,
184 &column_map,
185 &struct_name,
186 &select_columns,
187 ));
188
189 code.push_str(&generate_composite_enum_list_methods(
191 table,
192 &column_map,
193 &struct_name,
194 &select_columns,
195 ));
196
197 code.push_str(&generate_pagination_methods(
199 table,
200 &struct_name,
201 &select_columns,
202 models_module,
203 ));
204
205 let file_path = output_dir.join(&file_name);
206 fs::write(&file_path, code)?;
207 Ok(())
208}
209
210fn build_select_columns(table: &TableMetadata) -> String {
212 table
213 .columns
214 .iter()
215 .map(|c| format!("`{}`", c.name))
216 .collect::<Vec<_>>()
217 .join(", ")
218}
219
220fn build_where_clause(columns: &[String]) -> String {
222 columns
223 .iter()
224 .map(|c| format!("`{}` = ?", c))
225 .collect::<Vec<_>>()
226 .join(" AND ")
227}
228
229fn build_params(
231 columns: &[String],
232 column_map: &HashMap<&str, &ColumnMetadata>,
233 table_name: &str,
234) -> String {
235 columns
236 .iter()
237 .map(|c| {
238 let col = column_map.get(c.as_str()).unwrap();
239 let rust_type = TypeResolver::resolve(col, table_name);
240 let param_type = rust_type.to_param_type_string();
241 format!("{}: {}", escape_field_name(c), param_type)
242 })
243 .collect::<Vec<_>>()
244 .join(", ")
245}
246
247fn generate_bind_section(columns: &[String]) -> String {
249 columns
250 .iter()
251 .map(|c| format!(" .bind({})", escape_field_name(c)))
252 .collect::<Vec<_>>()
253 .join("\n")
254}
255
256fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
258 format!(
259 r#"/// Find all records
260pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
261 Query::new("SELECT {select_columns} FROM `{table_name}`")
262 .fetch_all(pool)
263 .await
264}}
265
266"#,
267 struct_name = struct_name,
268 select_columns = select_columns,
269 table_name = table.name,
270 )
271}
272
273fn generate_count_all(table: &TableMetadata) -> String {
275 format!(
276 r#"/// Count all records
277pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
278 Query::new("SELECT COUNT(*) FROM `{table_name}`")
279 .fetch_scalar(pool)
280 .await
281}}
282
283"#,
284 table_name = table.name,
285 )
286}
287
288fn generate_pk_methods(
290 table: &TableMetadata,
291 pk: &crate::parser::PrimaryKey,
292 column_map: &HashMap<&str, &ColumnMetadata>,
293 struct_name: &str,
294 select_columns: &str,
295) -> String {
296 let mut code = String::new();
297
298 let method_name = generate_find_by_method_name(&pk.columns);
299 let params = build_params(&pk.columns, column_map, &table.name);
300 let where_clause = build_where_clause(&pk.columns);
301 let bind_section = generate_bind_section(&pk.columns);
302
303 code.push_str(&format!(
305 r#"/// Find by primary key
306pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
307 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
308{bind_section}
309 .fetch_optional(pool)
310 .await
311}}
312
313"#,
314 method_name = method_name,
315 params = params,
316 struct_name = struct_name,
317 select_columns = select_columns,
318 table_name = table.name,
319 where_clause = where_clause,
320 bind_section = bind_section,
321 ));
322
323 let delete_method = generate_delete_by_method_name(&pk.columns);
325 code.push_str(&format!(
326 r#"/// Delete by primary key
327pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
328 Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
329{bind_section}
330 .execute(pool)
331 .await
332 .map(|r| r.rows_affected)
333}}
334
335"#,
336 delete_method = delete_method,
337 params = params,
338 table_name = table.name,
339 where_clause = where_clause,
340 bind_section = bind_section,
341 ));
342
343 code
344}
345
346fn generate_insert_methods(
348 table: &TableMetadata,
349 struct_name: &str,
350 _column_map: &HashMap<&str, &ColumnMetadata>,
351) -> String {
352 let mut code = String::new();
353
354 let insert_columns: Vec<&ColumnMetadata> = table
356 .columns
357 .iter()
358 .filter(|c| !c.is_auto_increment)
359 .collect();
360
361 if insert_columns.is_empty() {
362 return code;
363 }
364
365 let column_list = insert_columns
366 .iter()
367 .map(|c| format!("`{}`", c.name))
368 .collect::<Vec<_>>()
369 .join(", ");
370
371 let placeholders = insert_columns
372 .iter()
373 .map(|_| "?")
374 .collect::<Vec<_>>()
375 .join(", ");
376
377 let bind_fields = insert_columns
378 .iter()
379 .map(|c| {
380 let field = escape_field_name(&c.name);
381 let rust_type = TypeResolver::resolve(c, &table.name);
382 if rust_type.is_copy() {
383 format!(" .bind(entity.{})", field)
384 } else {
385 format!(" .bind(&entity.{})", field)
386 }
387 })
388 .collect::<Vec<_>>()
389 .join("\n");
390
391 code.push_str(&format!(
393 r#"/// Insert a new record
394pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
395 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
396{bind_fields}
397 .execute(pool)
398 .await
399 .map(|r| r.last_insert_id.unwrap_or(0))
400}}
401
402"#,
403 struct_name = struct_name,
404 table_name = table.name,
405 column_list = column_list,
406 placeholders = placeholders,
407 bind_fields = bind_fields,
408 ));
409
410 code
411}
412
413fn generate_insert_plain_method(
415 table: &TableMetadata,
416 column_map: &HashMap<&str, &ColumnMetadata>,
417) -> String {
418 let insert_columns: Vec<&ColumnMetadata> = table
420 .columns
421 .iter()
422 .filter(|c| !c.is_auto_increment)
423 .collect();
424
425 if insert_columns.is_empty() {
426 return String::new();
427 }
428
429 let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
430 let params = build_params(&column_names, column_map, &table.name);
431
432 let column_list = insert_columns
433 .iter()
434 .map(|c| format!("`{}`", c.name))
435 .collect::<Vec<_>>()
436 .join(", ");
437
438 let placeholders = insert_columns
439 .iter()
440 .map(|_| "?")
441 .collect::<Vec<_>>()
442 .join(", ");
443
444 let bind_section = generate_bind_section(&column_names);
445
446 format!(
447 r#"/// Insert a new record with individual parameters
448#[allow(clippy::too_many_arguments)]
449pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
450 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
451{bind_section}
452 .execute(pool)
453 .await
454 .map(|r| r.last_insert_id.unwrap_or(0))
455}}
456
457"#,
458 params = params,
459 table_name = table.name,
460 column_list = column_list,
461 placeholders = placeholders,
462 bind_section = bind_section,
463 )
464}
465
466fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
468 let insert_columns: Vec<&ColumnMetadata> = table
470 .columns
471 .iter()
472 .filter(|c| !c.is_auto_increment)
473 .collect();
474
475 if insert_columns.is_empty() {
476 return String::new();
477 }
478
479 format!(
480 r#"/// Insert multiple records in a single batch
481pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
482 rdbi::BatchInsert::new("{table_name}", entities)
483 .execute(pool)
484 .await
485 .map(|r| r.rows_affected)
486}}
487
488"#,
489 struct_name = struct_name,
490 table_name = table.name,
491 )
492}
493
494fn has_unique_index(table: &TableMetadata) -> bool {
496 table.indexes.iter().any(|idx| idx.unique)
497}
498
499fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
501 if table.primary_key.is_none() && !has_unique_index(table) {
503 return String::new();
504 }
505
506 let insert_columns: Vec<&ColumnMetadata> = table
508 .columns
509 .iter()
510 .filter(|c| !c.is_auto_increment)
511 .collect();
512
513 if insert_columns.is_empty() {
514 return String::new();
515 }
516
517 let pk_columns: HashSet<&str> = table
519 .primary_key
520 .as_ref()
521 .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
522 .unwrap_or_default();
523
524 let update_columns: Vec<&ColumnMetadata> = insert_columns
526 .iter()
527 .filter(|c| !pk_columns.contains(c.name.as_str()))
528 .copied()
529 .collect();
530
531 if update_columns.is_empty() {
533 return String::new();
534 }
535
536 let column_list = insert_columns
537 .iter()
538 .map(|c| format!("`{}`", c.name))
539 .collect::<Vec<_>>()
540 .join(", ");
541
542 let placeholders = insert_columns
543 .iter()
544 .map(|_| "?")
545 .collect::<Vec<_>>()
546 .join(", ");
547
548 let update_clause = update_columns
549 .iter()
550 .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
551 .collect::<Vec<_>>()
552 .join(", ");
553
554 let bind_fields = insert_columns
555 .iter()
556 .map(|c| {
557 let field = escape_field_name(&c.name);
558 let rust_type = TypeResolver::resolve(c, &table.name);
559 if rust_type.is_copy() {
560 format!(" .bind(entity.{})", field)
561 } else {
562 format!(" .bind(&entity.{})", field)
563 }
564 })
565 .collect::<Vec<_>>()
566 .join("\n");
567
568 format!(
569 r#"/// Upsert a record (insert or update on duplicate key)
570/// Returns rows_affected: 1 if inserted, 2 if updated
571pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
572 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
573 ON DUPLICATE KEY UPDATE {update_clause}")
574{bind_fields}
575 .execute(pool)
576 .await
577 .map(|r| r.rows_affected)
578}}
579
580"#,
581 struct_name = struct_name,
582 table_name = table.name,
583 column_list = column_list,
584 placeholders = placeholders,
585 update_clause = update_clause,
586 bind_fields = bind_fields,
587 )
588}
589
590fn generate_update_methods(
592 table: &TableMetadata,
593 struct_name: &str,
594 column_map: &HashMap<&str, &ColumnMetadata>,
595) -> String {
596 let mut code = String::new();
597
598 let pk = table.primary_key.as_ref().unwrap();
599
600 let update_columns: Vec<&ColumnMetadata> = table
602 .columns
603 .iter()
604 .filter(|c| !pk.columns.contains(&c.name))
605 .collect();
606
607 if update_columns.is_empty() {
608 return code;
609 }
610
611 let set_clause = update_columns
612 .iter()
613 .map(|c| format!("`{}` = ?", c.name))
614 .collect::<Vec<_>>()
615 .join(", ");
616
617 let where_clause = build_where_clause(&pk.columns);
618
619 let bind_fields: Vec<String> = update_columns
621 .iter()
622 .map(|c| {
623 let field = escape_field_name(&c.name);
624 let rust_type = TypeResolver::resolve(c, &table.name);
625 if rust_type.is_copy() {
626 format!(" .bind(entity.{})", field)
627 } else {
628 format!(" .bind(&entity.{})", field)
629 }
630 })
631 .chain(pk.columns.iter().map(|c| {
632 let field = escape_field_name(c);
633 let col = column_map.get(c.as_str()).unwrap();
634 let rust_type = TypeResolver::resolve(col, &table.name);
635 if rust_type.is_copy() {
636 format!(" .bind(entity.{})", field)
637 } else {
638 format!(" .bind(&entity.{})", field)
639 }
640 }))
641 .collect();
642
643 code.push_str(&format!(
645 r#"/// Update a record by primary key
646pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
647 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
648{bind_fields}
649 .execute(pool)
650 .await
651 .map(|r| r.rows_affected)
652}}
653
654"#,
655 struct_name = struct_name,
656 table_name = table.name,
657 set_clause = set_clause,
658 where_clause = where_clause,
659 bind_fields = bind_fields.join("\n"),
660 ));
661
662 code
663}
664
665fn generate_update_plain_method(
667 table: &TableMetadata,
668 column_map: &HashMap<&str, &ColumnMetadata>,
669) -> String {
670 let pk = table.primary_key.as_ref().unwrap();
671
672 let update_columns: Vec<&ColumnMetadata> = table
674 .columns
675 .iter()
676 .filter(|c| !pk.columns.contains(&c.name))
677 .collect();
678
679 if update_columns.is_empty() {
680 return String::new();
681 }
682
683 let method_name = generate_update_by_method_name(&pk.columns);
685
686 let pk_params = build_params(&pk.columns, column_map, &table.name);
688 let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
689 let update_params = build_params(&update_column_names, column_map, &table.name);
690 let all_params = format!("{}, {}", pk_params, update_params);
691
692 let set_clause = update_columns
693 .iter()
694 .map(|c| format!("`{}` = ?", c.name))
695 .collect::<Vec<_>>()
696 .join(", ");
697
698 let where_clause = build_where_clause(&pk.columns);
699
700 let bind_section_update = generate_bind_section(&update_column_names);
702 let bind_section_pk = generate_bind_section(&pk.columns);
703
704 format!(
705 r#"/// Update a record by primary key with individual parameters
706#[allow(clippy::too_many_arguments)]
707pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
708 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
709{bind_section_update}
710{bind_section_pk}
711 .execute(pool)
712 .await
713 .map(|r| r.rows_affected)
714}}
715
716"#,
717 method_name = method_name,
718 all_params = all_params,
719 table_name = table.name,
720 set_clause = set_clause,
721 where_clause = where_clause,
722 bind_section_update = bind_section_update,
723 bind_section_pk = bind_section_pk,
724 )
725}
726
727fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
729 let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
730
731 if let Some(pk) = &table.primary_key {
733 let sig = MethodSignature::new(
734 pk.columns.clone(),
735 PRIORITY_PRIMARY_KEY,
736 true,
737 "PRIMARY_KEY",
738 );
739 signatures.insert(pk.columns.clone(), sig);
740 }
741
742 for index in &table.indexes {
744 let priority = if index.unique {
745 PRIORITY_UNIQUE_INDEX
746 } else {
747 PRIORITY_NON_UNIQUE_INDEX
748 };
749 let source = if index.unique {
750 "UNIQUE_INDEX"
751 } else {
752 "NON_UNIQUE_INDEX"
753 };
754 let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
755
756 if let Some(existing) = signatures.get(&index.columns) {
758 if sig.priority < existing.priority {
759 signatures.insert(index.columns.clone(), sig);
760 }
761 } else {
762 signatures.insert(index.columns.clone(), sig);
763 }
764 }
765
766 for fk in &table.foreign_keys {
768 let columns = vec![fk.column_name.clone()];
769 let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
770
771 signatures.entry(columns).or_insert(sig);
772 }
773
774 signatures
775}
776
777fn generate_find_by_method(
779 table: &TableMetadata,
780 sig: &MethodSignature,
781 column_map: &HashMap<&str, &ColumnMetadata>,
782 struct_name: &str,
783 select_columns: &str,
784) -> String {
785 let params = build_params(&sig.columns, column_map, &table.name);
786
787 let (return_type, fetch_method) = if sig.is_unique {
788 (format!("Option<{}>", struct_name), "fetch_optional")
789 } else {
790 (format!("Vec<{}>", struct_name), "fetch_all")
791 };
792
793 let return_desc = if sig.is_unique {
794 "Option (unique)"
795 } else {
796 "Vec (non-unique)"
797 };
798
799 let has_nullable = sig.columns.iter().any(|c| {
801 column_map
802 .get(c.as_str())
803 .map(|col| col.nullable)
804 .unwrap_or(false)
805 });
806
807 if has_nullable {
808 generate_find_by_method_nullable(
810 table,
811 sig,
812 column_map,
813 select_columns,
814 ¶ms,
815 &return_type,
816 fetch_method,
817 return_desc,
818 )
819 } else {
820 let where_clause = build_where_clause(&sig.columns);
822 let bind_section = generate_bind_section(&sig.columns);
823
824 format!(
825 r#"/// Find by {source}: returns {return_desc}
826pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
827 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
828{bind_section}
829 .{fetch_method}(pool)
830 .await
831}}
832
833"#,
834 source = sig.source.to_lowercase().replace('_', " "),
835 return_desc = return_desc,
836 method_name = sig.method_name,
837 params = params,
838 return_type = return_type,
839 select_columns = select_columns,
840 table_name = table.name,
841 where_clause = where_clause,
842 bind_section = bind_section,
843 fetch_method = fetch_method,
844 )
845 }
846}
847
848#[allow(clippy::too_many_arguments)]
850fn generate_find_by_method_nullable(
851 table: &TableMetadata,
852 sig: &MethodSignature,
853 column_map: &HashMap<&str, &ColumnMetadata>,
854 select_columns: &str,
855 params: &str,
856 return_type: &str,
857 fetch_method: &str,
858 return_desc: &str,
859) -> String {
860 let mut where_parts = Vec::new();
862 let mut bind_parts = Vec::new();
863
864 for col in &sig.columns {
865 let col_meta = column_map.get(col.as_str()).unwrap();
866 let field_name = escape_field_name(col);
867
868 if col_meta.nullable {
869 where_parts.push(format!(
871 r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
872 field = field_name,
873 col = col,
874 ));
875 bind_parts.push(format!(
876 r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
877 field = field_name,
878 ));
879 } else {
880 where_parts.push(format!(r#""`{}` = ?""#, col));
881 bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
882 }
883 }
884
885 let where_expr = if where_parts.len() == 1 {
886 where_parts[0].clone()
887 } else {
888 let parts = where_parts
890 .iter()
891 .map(|p| format!("({})", p))
892 .collect::<Vec<_>>()
893 .join(", ");
894 format!("vec![{}].join(\" AND \")", parts)
895 };
896
897 let bind_code = bind_parts.join("\n ");
898
899 format!(
900 r#"/// Find by {source}: returns {return_desc}
901pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
902 let where_clause = {where_expr};
903 let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
904 let mut query = rdbi::DynamicQuery::new(sql);
905 {bind_code}
906 query.{fetch_method}(pool).await
907}}
908
909"#,
910 source = sig.source.to_lowercase().replace('_', " "),
911 return_desc = return_desc,
912 method_name = sig.method_name,
913 params = params,
914 return_type = return_type,
915 select_columns = select_columns,
916 table_name = table.name,
917 where_expr = where_expr,
918 bind_code = bind_code,
919 fetch_method = fetch_method,
920 )
921}
922
923fn generate_find_by_list_methods(
925 table: &TableMetadata,
926 column_map: &HashMap<&str, &ColumnMetadata>,
927 struct_name: &str,
928 select_columns: &str,
929) -> String {
930 let mut code = String::new();
931 let mut processed: HashSet<String> = HashSet::new();
932
933 if let Some(pk) = &table.primary_key {
935 if pk.columns.len() == 1 {
936 let col = &pk.columns[0];
937 code.push_str(&generate_single_find_by_list(
938 table,
939 col,
940 column_map,
941 struct_name,
942 select_columns,
943 ));
944 processed.insert(col.clone());
945 }
946 }
947
948 for index in &table.indexes {
950 if index.columns.len() == 1 {
951 let col = &index.columns[0];
952 if !processed.contains(col) {
953 code.push_str(&generate_single_find_by_list(
954 table,
955 col,
956 column_map,
957 struct_name,
958 select_columns,
959 ));
960 processed.insert(col.clone());
961 }
962 }
963 }
964
965 code
966}
967
968fn generate_single_find_by_list(
970 table: &TableMetadata,
971 column_name: &str,
972 column_map: &HashMap<&str, &ColumnMetadata>,
973 struct_name: &str,
974 select_columns: &str,
975) -> String {
976 let method_name = generate_find_by_list_method_name(column_name);
977 let param_name = pluralize(&escape_field_name(column_name));
978 let column = column_map.get(column_name).unwrap();
979 let rust_type = TypeResolver::resolve(column, &table.name);
980
981 let inner_type = rust_type.inner_type().to_type_string();
983
984 let column_name_plural = pluralize(column_name);
985 format!(
986 r#"/// Find by list of {column_name_plural} (IN clause)
987pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
988 if {param_name}.is_empty() {{
989 return Ok(Vec::new());
990 }}
991 let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
992 let query = format!(
993 "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
994 placeholders
995 );
996 rdbi::DynamicQuery::new(query)
997 .bind_all({param_name})
998 .fetch_all(pool)
999 .await
1000}}
1001
1002"#,
1003 column_name_plural = column_name_plural,
1004 column_name = column_name,
1005 method_name = method_name,
1006 param_name = param_name,
1007 inner_type = inner_type,
1008 struct_name = struct_name,
1009 select_columns = select_columns,
1010 table_name = table.name,
1011 )
1012}
1013
1014fn generate_composite_enum_list_methods(
1017 table: &TableMetadata,
1018 column_map: &HashMap<&str, &ColumnMetadata>,
1019 struct_name: &str,
1020 select_columns: &str,
1021) -> String {
1022 let mut code = String::new();
1023
1024 for index in &table.indexes {
1025 if index.columns.len() <= 1 {
1027 continue;
1028 }
1029
1030 let enum_columns: HashSet<&str> = index
1032 .columns
1033 .iter()
1034 .filter(|col_name| {
1035 column_map
1036 .get(col_name.as_str())
1037 .map(|col| col.is_enum())
1038 .unwrap_or(false)
1039 })
1040 .map(|s| s.as_str())
1041 .collect();
1042
1043 if enum_columns.is_empty() {
1045 continue;
1046 }
1047
1048 let first_column = &index.columns[0];
1050 if enum_columns.contains(first_column.as_str()) {
1051 continue;
1052 }
1053
1054 code.push_str(&generate_composite_enum_list_method(
1055 table,
1056 &index.columns,
1057 &enum_columns,
1058 column_map,
1059 struct_name,
1060 select_columns,
1061 ));
1062 }
1063
1064 code
1065}
1066
1067fn generate_composite_enum_list_method(
1069 table: &TableMetadata,
1070 columns: &[String],
1071 enum_columns: &HashSet<&str>,
1072 column_map: &HashMap<&str, &ColumnMetadata>,
1073 struct_name: &str,
1074 select_columns: &str,
1075) -> String {
1076 let method_name = generate_composite_enum_method_name(columns, enum_columns);
1078
1079 let mut params_parts = Vec::new();
1081
1082 for col_name in columns {
1083 let col = column_map.get(col_name.as_str()).unwrap();
1084 let rust_type = TypeResolver::resolve(col, &table.name);
1085 let is_enum = enum_columns.contains(col_name.as_str());
1086
1087 if is_enum {
1088 let param_name = pluralize(&escape_field_name(col_name));
1090 let inner_type = rust_type.inner_type().to_type_string();
1091 params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1092 } else {
1093 let param_name = escape_field_name(col_name);
1095 let param_type = rust_type.to_param_type_string();
1096 params_parts.push(format!("{}: {}", param_name, param_type));
1097 }
1098 }
1099
1100 let params = params_parts.join(", ");
1101
1102 let where_clause_static: Vec<String> = columns
1104 .iter()
1105 .map(|col_name| {
1106 if enum_columns.contains(col_name.as_str()) {
1107 format!("`{}` IN ({{}})", col_name) } else {
1109 format!("`{}` = ?", col_name)
1110 }
1111 })
1112 .collect();
1113
1114 let mut bind_code = String::new();
1116
1117 for col_name in columns {
1119 if !enum_columns.contains(col_name.as_str()) {
1120 let param_name = escape_field_name(col_name);
1121 bind_code.push_str(&format!(" .bind({})\n", param_name));
1122 }
1123 }
1124
1125 for col_name in columns {
1127 if enum_columns.contains(col_name.as_str()) {
1128 let param_name = pluralize(&escape_field_name(col_name));
1129 bind_code.push_str(&format!(" .bind_all({})\n", param_name));
1130 }
1131 }
1132
1133 let column_desc: Vec<String> = columns
1135 .iter()
1136 .map(|col| {
1137 if enum_columns.contains(col.as_str()) {
1138 pluralize(col)
1139 } else {
1140 col.clone()
1141 }
1142 })
1143 .collect();
1144
1145 let enum_col_names: Vec<&str> = columns
1147 .iter()
1148 .filter(|c| enum_columns.contains(c.as_str()))
1149 .map(|s| s.as_str())
1150 .collect();
1151
1152 let in_clause_builders: Vec<String> = enum_col_names
1154 .iter()
1155 .map(|col| {
1156 let param_name = pluralize(&escape_field_name(col));
1157 format!(
1158 "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1159 param_name = param_name
1160 )
1161 })
1162 .collect();
1163
1164 let format_args = in_clause_builders.join(", ");
1166
1167 format!(
1168 r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1169pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1170 // Check for empty enum lists
1171{empty_checks}
1172 // Build IN clause placeholders for enum columns
1173 let where_clause = format!("{where_template}", {format_args});
1174 let query = format!(
1175 "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1176 where_clause
1177 );
1178 rdbi::DynamicQuery::new(query)
1179{bind_code} .fetch_all(pool)
1180 .await
1181}}
1182
1183"#,
1184 column_desc = column_desc.join(" and "),
1185 method_name = method_name,
1186 params = params,
1187 struct_name = struct_name,
1188 select_columns = select_columns,
1189 table_name = table.name,
1190 where_template = where_clause_static.join(" AND "),
1191 format_args = format_args,
1192 bind_code = bind_code,
1193 empty_checks = generate_empty_checks(columns, enum_columns),
1194 )
1195}
1196
1197fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1199 let mut checks = String::new();
1200 for col_name in columns {
1201 if enum_columns.contains(col_name.as_str()) {
1202 let param_name = pluralize(&escape_field_name(col_name));
1203 checks.push_str(&format!(
1204 " if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1205 param_name
1206 ));
1207 }
1208 }
1209 checks
1210}
1211
1212fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1215 let mut parts = Vec::new();
1216 for col in columns {
1217 if enum_columns.contains(col.as_str()) {
1218 parts.push(pluralize(col));
1219 } else {
1220 parts.push(col.clone());
1221 }
1222 }
1223 generate_find_by_method_name(&parts)
1224}
1225
1226fn generate_pagination_methods(
1228 table: &TableMetadata,
1229 struct_name: &str,
1230 select_columns: &str,
1231 models_module: &str,
1232) -> String {
1233 let sort_by_enum = format!("{}SortBy", struct_name);
1234
1235 format!(
1236 r#"/// Find all records with pagination and sorting
1237pub async fn find_all_paginated<P: Pool>(
1238 pool: &P,
1239 limit: i32,
1240 offset: i32,
1241 sort_by: crate::{models_module}::{sort_by_enum},
1242 sort_dir: crate::{models_module}::SortDirection,
1243) -> Result<Vec<{struct_name}>> {{
1244 let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1245 let query = format!(
1246 "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1247 order_clause
1248 );
1249 rdbi::DynamicQuery::new(query)
1250 .bind(limit)
1251 .bind(offset)
1252 .fetch_all(pool)
1253 .await
1254}}
1255
1256/// Get paginated result with total count
1257pub async fn get_paginated_result<P: Pool>(
1258 pool: &P,
1259 page_size: i32,
1260 current_page: i32,
1261 sort_by: crate::{models_module}::{sort_by_enum},
1262 sort_dir: crate::{models_module}::SortDirection,
1263) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1264 let page_size = page_size.max(1);
1265 let current_page = current_page.max(1);
1266 let offset = (current_page - 1) * page_size;
1267
1268 let total_count = count_all(pool).await?;
1269 let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1270
1271 Ok(crate::{models_module}::PaginatedResult::new(
1272 items,
1273 total_count,
1274 current_page,
1275 page_size,
1276 ))
1277}}
1278
1279"#,
1280 struct_name = struct_name,
1281 select_columns = select_columns,
1282 table_name = table.name,
1283 models_module = models_module,
1284 sort_by_enum = sort_by_enum,
1285 )
1286}
1287
1288#[cfg(test)]
1289mod tests {
1290 use super::*;
1291 use crate::parser::{IndexMetadata, PrimaryKey};
1292
1293 fn make_table() -> TableMetadata {
1294 TableMetadata {
1295 name: "users".to_string(),
1296 comment: None,
1297 columns: vec![
1298 ColumnMetadata {
1299 name: "id".to_string(),
1300 data_type: "BIGINT".to_string(),
1301 nullable: false,
1302 default_value: None,
1303 is_auto_increment: true,
1304 is_unsigned: false,
1305 enum_values: None,
1306 comment: None,
1307 },
1308 ColumnMetadata {
1309 name: "email".to_string(),
1310 data_type: "VARCHAR(255)".to_string(),
1311 nullable: false,
1312 default_value: None,
1313 is_auto_increment: false,
1314 is_unsigned: false,
1315 enum_values: None,
1316 comment: None,
1317 },
1318 ColumnMetadata {
1319 name: "status".to_string(),
1320 data_type: "ENUM".to_string(),
1321 nullable: false,
1322 default_value: None,
1323 is_auto_increment: false,
1324 is_unsigned: false,
1325 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1326 comment: None,
1327 },
1328 ],
1329 indexes: vec![
1330 IndexMetadata {
1331 name: "email_unique".to_string(),
1332 columns: vec!["email".to_string()],
1333 unique: true,
1334 },
1335 IndexMetadata {
1336 name: "idx_status".to_string(),
1337 columns: vec!["status".to_string()],
1338 unique: false,
1339 },
1340 ],
1341 foreign_keys: vec![],
1342 primary_key: Some(PrimaryKey {
1343 columns: vec!["id".to_string()],
1344 }),
1345 }
1346 }
1347
1348 #[test]
1349 fn test_collect_method_signatures() {
1350 let table = make_table();
1351 let sigs = collect_method_signatures(&table);
1352
1353 assert_eq!(sigs.len(), 3);
1355
1356 let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1357 assert!(id_sig.is_unique);
1358 assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1359
1360 let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1361 assert!(email_sig.is_unique);
1362 assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1363
1364 let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1365 assert!(!status_sig.is_unique);
1366 assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1367 }
1368
1369 #[test]
1370 fn test_build_select_columns() {
1371 let table = make_table();
1372 let cols = build_select_columns(&table);
1373 assert!(cols.contains("`id`"));
1374 assert!(cols.contains("`email`"));
1375 assert!(cols.contains("`status`"));
1376 }
1377
1378 #[test]
1379 fn test_build_where_clause() {
1380 let clause = build_where_clause(&["id".to_string()]);
1381 assert_eq!(clause, "`id` = ?");
1382
1383 let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1384 assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1385 }
1386
1387 #[test]
1388 fn test_generate_upsert_method() {
1389 let table = make_table();
1390 let code = generate_upsert_method(&table, "Users");
1391
1392 assert!(code.contains("pub async fn upsert"));
1394 assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1396 assert!(!code.contains("`id` = VALUES(`id`)"));
1398 assert!(code.contains("`email` = VALUES(`email`)"));
1400 assert!(code.contains("`status` = VALUES(`status`)"));
1401 }
1402
1403 #[test]
1404 fn test_generate_upsert_method_no_pk() {
1405 let mut table = make_table();
1406 table.primary_key = None;
1407 table.indexes.clear();
1408
1409 let code = generate_upsert_method(&table, "Users");
1410 assert!(code.is_empty());
1412 }
1413
1414 #[test]
1415 fn test_generate_insert_all_method() {
1416 let table = make_table();
1417 let code = generate_insert_all_method(&table, "Users");
1418
1419 assert!(code.contains("pub async fn insert_all"));
1421 assert!(code.contains("rdbi::BatchInsert::new"));
1423 }
1424
1425 #[test]
1426 fn test_generate_pagination_methods() {
1427 let table = make_table();
1428 let select_columns = build_select_columns(&table);
1429 let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1430
1431 assert!(code.contains("pub async fn find_all_paginated"));
1433 assert!(code.contains("limit: i32"));
1435 assert!(code.contains("offset: i32"));
1436 assert!(code.contains("UsersSortBy"));
1438 assert!(code.contains("SortDirection"));
1440 assert!(code.contains("pub async fn get_paginated_result"));
1442 assert!(code.contains("PaginatedResult<Users>"));
1444 }
1445}