1use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13 escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14 generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24#[derive(Debug, Clone)]
26struct MethodSignature {
27 columns: Vec<String>,
28 method_name: String,
29 priority: u8,
30 is_unique: bool,
31 source: String,
32}
33
34impl MethodSignature {
35 fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36 let method_name = generate_find_by_method_name(&columns);
37 Self {
38 columns,
39 method_name,
40 priority,
41 is_unique,
42 source: source.to_string(),
43 }
44 }
45}
46
47pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49 let output_dir = &config.output_dao_dir;
50 fs::create_dir_all(output_dir)?;
51
52 let mut mod_content = String::new();
54 mod_content.push_str("// Generated DAO functions\n\n");
55
56 for table in tables {
57 let file_name = heck::AsSnakeCase(&table.name).to_string();
58 mod_content.push_str(&format!("pub mod {};\n", file_name));
59 }
60
61 let mod_path = output_dir.join("mod.rs");
62 fs::write(&mod_path, mod_content)?;
63 super::format_file(&mod_path);
64
65 for table in tables {
67 generate_dao_file(table, output_dir, &config.models_module)?;
68 }
69
70 Ok(())
71}
72
73fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
75 let struct_name = to_struct_name(&table.name);
76 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
77 debug!("Generating DAO for {} -> {}", struct_name, file_name);
78
79 let mut code = String::new();
80
81 let column_map: HashMap<&str, &ColumnMetadata> =
83 table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
84
85 let mut needs_chrono = false;
87 let mut needs_decimal = false;
88
89 for col in &table.columns {
90 let rust_type = TypeResolver::resolve(col, &table.name);
91 if rust_type.needs_chrono() {
92 needs_chrono = true;
93 }
94 if rust_type.needs_decimal() {
95 needs_decimal = true;
96 }
97 }
98
99 let has_enums = table.columns.iter().any(|c| c.is_enum());
101
102 code.push_str("use rdbi::{Pool, Query, Result};\n");
104 code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
105
106 if has_enums {
107 for col in &table.columns {
109 if col.is_enum() {
110 let enum_name = super::naming::to_enum_name(&table.name, &col.name);
111 code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
112 }
113 }
114 }
115
116 if needs_chrono {
117 code.push_str("#[allow(unused_imports)]\n");
118 code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
119 }
120 if needs_decimal {
121 code.push_str("#[allow(unused_imports)]\n");
122 code.push_str("use rust_decimal::Decimal;\n");
123 }
124
125 code.push('\n');
126
127 let select_columns = build_select_columns(table);
129
130 code.push_str(&generate_find_all(table, &struct_name, &select_columns));
132
133 code.push_str(&generate_count_all(table));
135
136 if let Some(pk) = &table.primary_key {
138 code.push_str(&generate_pk_methods(
139 table,
140 pk,
141 &column_map,
142 &struct_name,
143 &select_columns,
144 ));
145 }
146
147 code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
149
150 code.push_str(&generate_insert_plain_method(table, &column_map));
152
153 code.push_str(&generate_insert_all_method(table, &struct_name));
155
156 code.push_str(&generate_upsert_method(table, &struct_name));
158
159 if table.primary_key.is_some() {
161 code.push_str(&generate_update_methods(table, &struct_name, &column_map));
162 code.push_str(&generate_update_plain_method(table, &column_map));
164 }
165
166 let signatures = collect_method_signatures(table);
168 for sig in signatures.values() {
169 if sig.source == "PRIMARY_KEY" {
171 continue;
172 }
173 code.push_str(&generate_find_by_method(
174 table,
175 sig,
176 &column_map,
177 &struct_name,
178 &select_columns,
179 ));
180 }
181
182 code.push_str(&generate_find_by_list_methods(
184 table,
185 &column_map,
186 &struct_name,
187 &select_columns,
188 ));
189
190 code.push_str(&generate_composite_enum_list_methods(
192 table,
193 &column_map,
194 &struct_name,
195 &select_columns,
196 ));
197
198 code.push_str(&generate_pagination_methods(
200 table,
201 &struct_name,
202 &select_columns,
203 models_module,
204 ));
205
206 let file_path = output_dir.join(&file_name);
207 fs::write(&file_path, code)?;
208 super::format_file(&file_path);
209 Ok(())
210}
211
212fn build_select_columns(table: &TableMetadata) -> String {
214 table
215 .columns
216 .iter()
217 .map(|c| format!("`{}`", c.name))
218 .collect::<Vec<_>>()
219 .join(", ")
220}
221
222fn build_where_clause(columns: &[String]) -> String {
224 columns
225 .iter()
226 .map(|c| format!("`{}` = ?", c))
227 .collect::<Vec<_>>()
228 .join(" AND ")
229}
230
231fn build_params(
233 columns: &[String],
234 column_map: &HashMap<&str, &ColumnMetadata>,
235 table_name: &str,
236) -> String {
237 columns
238 .iter()
239 .map(|c| {
240 let col = column_map.get(c.as_str()).unwrap();
241 let rust_type = TypeResolver::resolve(col, table_name);
242 let param_type = rust_type.to_param_type_string();
243 format!("{}: {}", escape_field_name(c), param_type)
244 })
245 .collect::<Vec<_>>()
246 .join(", ")
247}
248
249fn generate_bind_section(columns: &[String]) -> String {
251 columns
252 .iter()
253 .map(|c| format!(" .bind({})", escape_field_name(c)))
254 .collect::<Vec<_>>()
255 .join("\n")
256}
257
258fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
260 format!(
261 r#"/// Find all records
262pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
263 Query::new("SELECT {select_columns} FROM `{table_name}`")
264 .fetch_all(pool)
265 .await
266}}
267
268"#,
269 struct_name = struct_name,
270 select_columns = select_columns,
271 table_name = table.name,
272 )
273}
274
275fn generate_count_all(table: &TableMetadata) -> String {
277 format!(
278 r#"/// Count all records
279pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
280 Query::new("SELECT COUNT(*) FROM `{table_name}`")
281 .fetch_scalar(pool)
282 .await
283}}
284
285"#,
286 table_name = table.name,
287 )
288}
289
290fn generate_pk_methods(
292 table: &TableMetadata,
293 pk: &crate::parser::PrimaryKey,
294 column_map: &HashMap<&str, &ColumnMetadata>,
295 struct_name: &str,
296 select_columns: &str,
297) -> String {
298 let mut code = String::new();
299
300 let method_name = generate_find_by_method_name(&pk.columns);
301 let params = build_params(&pk.columns, column_map, &table.name);
302 let where_clause = build_where_clause(&pk.columns);
303 let bind_section = generate_bind_section(&pk.columns);
304
305 code.push_str(&format!(
307 r#"/// Find by primary key
308pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
309 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
310{bind_section}
311 .fetch_optional(pool)
312 .await
313}}
314
315"#,
316 method_name = method_name,
317 params = params,
318 struct_name = struct_name,
319 select_columns = select_columns,
320 table_name = table.name,
321 where_clause = where_clause,
322 bind_section = bind_section,
323 ));
324
325 let delete_method = generate_delete_by_method_name(&pk.columns);
327 code.push_str(&format!(
328 r#"/// Delete by primary key
329pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
330 Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
331{bind_section}
332 .execute(pool)
333 .await
334 .map(|r| r.rows_affected)
335}}
336
337"#,
338 delete_method = delete_method,
339 params = params,
340 table_name = table.name,
341 where_clause = where_clause,
342 bind_section = bind_section,
343 ));
344
345 code
346}
347
348fn generate_insert_methods(
350 table: &TableMetadata,
351 struct_name: &str,
352 _column_map: &HashMap<&str, &ColumnMetadata>,
353) -> String {
354 let mut code = String::new();
355
356 let insert_columns: Vec<&ColumnMetadata> = table
358 .columns
359 .iter()
360 .filter(|c| !c.is_auto_increment)
361 .collect();
362
363 if insert_columns.is_empty() {
364 return code;
365 }
366
367 let column_list = insert_columns
368 .iter()
369 .map(|c| format!("`{}`", c.name))
370 .collect::<Vec<_>>()
371 .join(", ");
372
373 let placeholders = insert_columns
374 .iter()
375 .map(|_| "?")
376 .collect::<Vec<_>>()
377 .join(", ");
378
379 let bind_fields = insert_columns
380 .iter()
381 .map(|c| {
382 let field = escape_field_name(&c.name);
383 let rust_type = TypeResolver::resolve(c, &table.name);
384 if rust_type.is_copy() {
385 format!(" .bind(entity.{})", field)
386 } else {
387 format!(" .bind(&entity.{})", field)
388 }
389 })
390 .collect::<Vec<_>>()
391 .join("\n");
392
393 code.push_str(&format!(
395 r#"/// Insert a new record
396pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
397 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
398{bind_fields}
399 .execute(pool)
400 .await
401 .map(|r| r.last_insert_id.unwrap_or(0))
402}}
403
404"#,
405 struct_name = struct_name,
406 table_name = table.name,
407 column_list = column_list,
408 placeholders = placeholders,
409 bind_fields = bind_fields,
410 ));
411
412 code
413}
414
415fn generate_insert_plain_method(
417 table: &TableMetadata,
418 column_map: &HashMap<&str, &ColumnMetadata>,
419) -> String {
420 let insert_columns: Vec<&ColumnMetadata> = table
422 .columns
423 .iter()
424 .filter(|c| !c.is_auto_increment)
425 .collect();
426
427 if insert_columns.is_empty() {
428 return String::new();
429 }
430
431 let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
432 let params = build_params(&column_names, column_map, &table.name);
433
434 let column_list = insert_columns
435 .iter()
436 .map(|c| format!("`{}`", c.name))
437 .collect::<Vec<_>>()
438 .join(", ");
439
440 let placeholders = insert_columns
441 .iter()
442 .map(|_| "?")
443 .collect::<Vec<_>>()
444 .join(", ");
445
446 let bind_section = generate_bind_section(&column_names);
447
448 format!(
449 r#"/// Insert a new record with individual parameters
450#[allow(clippy::too_many_arguments)]
451pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
452 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
453{bind_section}
454 .execute(pool)
455 .await
456 .map(|r| r.last_insert_id.unwrap_or(0))
457}}
458
459"#,
460 params = params,
461 table_name = table.name,
462 column_list = column_list,
463 placeholders = placeholders,
464 bind_section = bind_section,
465 )
466}
467
468fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
470 let insert_columns: Vec<&ColumnMetadata> = table
472 .columns
473 .iter()
474 .filter(|c| !c.is_auto_increment)
475 .collect();
476
477 if insert_columns.is_empty() {
478 return String::new();
479 }
480
481 format!(
482 r#"/// Insert multiple records in a single batch
483pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
484 rdbi::BatchInsert::new("{table_name}", entities)
485 .execute(pool)
486 .await
487 .map(|r| r.rows_affected)
488}}
489
490"#,
491 struct_name = struct_name,
492 table_name = table.name,
493 )
494}
495
496fn has_unique_index(table: &TableMetadata) -> bool {
498 table.indexes.iter().any(|idx| idx.unique)
499}
500
501fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
503 if table.primary_key.is_none() && !has_unique_index(table) {
505 return String::new();
506 }
507
508 let insert_columns: Vec<&ColumnMetadata> = table
510 .columns
511 .iter()
512 .filter(|c| !c.is_auto_increment)
513 .collect();
514
515 if insert_columns.is_empty() {
516 return String::new();
517 }
518
519 let pk_columns: HashSet<&str> = table
521 .primary_key
522 .as_ref()
523 .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
524 .unwrap_or_default();
525
526 let update_columns: Vec<&ColumnMetadata> = insert_columns
528 .iter()
529 .filter(|c| !pk_columns.contains(c.name.as_str()))
530 .copied()
531 .collect();
532
533 if update_columns.is_empty() {
535 return String::new();
536 }
537
538 let column_list = insert_columns
539 .iter()
540 .map(|c| format!("`{}`", c.name))
541 .collect::<Vec<_>>()
542 .join(", ");
543
544 let placeholders = insert_columns
545 .iter()
546 .map(|_| "?")
547 .collect::<Vec<_>>()
548 .join(", ");
549
550 let update_clause = update_columns
551 .iter()
552 .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
553 .collect::<Vec<_>>()
554 .join(", ");
555
556 let bind_fields = insert_columns
557 .iter()
558 .map(|c| {
559 let field = escape_field_name(&c.name);
560 let rust_type = TypeResolver::resolve(c, &table.name);
561 if rust_type.is_copy() {
562 format!(" .bind(entity.{})", field)
563 } else {
564 format!(" .bind(&entity.{})", field)
565 }
566 })
567 .collect::<Vec<_>>()
568 .join("\n");
569
570 format!(
571 r#"/// Upsert a record (insert or update on duplicate key)
572/// Returns rows_affected: 1 if inserted, 2 if updated
573pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
574 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
575 ON DUPLICATE KEY UPDATE {update_clause}")
576{bind_fields}
577 .execute(pool)
578 .await
579 .map(|r| r.rows_affected)
580}}
581
582"#,
583 struct_name = struct_name,
584 table_name = table.name,
585 column_list = column_list,
586 placeholders = placeholders,
587 update_clause = update_clause,
588 bind_fields = bind_fields,
589 )
590}
591
592fn generate_update_methods(
594 table: &TableMetadata,
595 struct_name: &str,
596 column_map: &HashMap<&str, &ColumnMetadata>,
597) -> String {
598 let mut code = String::new();
599
600 let pk = table.primary_key.as_ref().unwrap();
601
602 let update_columns: Vec<&ColumnMetadata> = table
604 .columns
605 .iter()
606 .filter(|c| !pk.columns.contains(&c.name))
607 .collect();
608
609 if update_columns.is_empty() {
610 return code;
611 }
612
613 let set_clause = update_columns
614 .iter()
615 .map(|c| format!("`{}` = ?", c.name))
616 .collect::<Vec<_>>()
617 .join(", ");
618
619 let where_clause = build_where_clause(&pk.columns);
620
621 let bind_fields: Vec<String> = update_columns
623 .iter()
624 .map(|c| {
625 let field = escape_field_name(&c.name);
626 let rust_type = TypeResolver::resolve(c, &table.name);
627 if rust_type.is_copy() {
628 format!(" .bind(entity.{})", field)
629 } else {
630 format!(" .bind(&entity.{})", field)
631 }
632 })
633 .chain(pk.columns.iter().map(|c| {
634 let field = escape_field_name(c);
635 let col = column_map.get(c.as_str()).unwrap();
636 let rust_type = TypeResolver::resolve(col, &table.name);
637 if rust_type.is_copy() {
638 format!(" .bind(entity.{})", field)
639 } else {
640 format!(" .bind(&entity.{})", field)
641 }
642 }))
643 .collect();
644
645 code.push_str(&format!(
647 r#"/// Update a record by primary key
648pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
649 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
650{bind_fields}
651 .execute(pool)
652 .await
653 .map(|r| r.rows_affected)
654}}
655
656"#,
657 struct_name = struct_name,
658 table_name = table.name,
659 set_clause = set_clause,
660 where_clause = where_clause,
661 bind_fields = bind_fields.join("\n"),
662 ));
663
664 code
665}
666
667fn generate_update_plain_method(
669 table: &TableMetadata,
670 column_map: &HashMap<&str, &ColumnMetadata>,
671) -> String {
672 let pk = table.primary_key.as_ref().unwrap();
673
674 let update_columns: Vec<&ColumnMetadata> = table
676 .columns
677 .iter()
678 .filter(|c| !pk.columns.contains(&c.name))
679 .collect();
680
681 if update_columns.is_empty() {
682 return String::new();
683 }
684
685 let method_name = generate_update_by_method_name(&pk.columns);
687
688 let pk_params = build_params(&pk.columns, column_map, &table.name);
690 let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
691 let update_params = build_params(&update_column_names, column_map, &table.name);
692 let all_params = format!("{}, {}", pk_params, update_params);
693
694 let set_clause = update_columns
695 .iter()
696 .map(|c| format!("`{}` = ?", c.name))
697 .collect::<Vec<_>>()
698 .join(", ");
699
700 let where_clause = build_where_clause(&pk.columns);
701
702 let bind_section_update = generate_bind_section(&update_column_names);
704 let bind_section_pk = generate_bind_section(&pk.columns);
705
706 format!(
707 r#"/// Update a record by primary key with individual parameters
708#[allow(clippy::too_many_arguments)]
709pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
710 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
711{bind_section_update}
712{bind_section_pk}
713 .execute(pool)
714 .await
715 .map(|r| r.rows_affected)
716}}
717
718"#,
719 method_name = method_name,
720 all_params = all_params,
721 table_name = table.name,
722 set_clause = set_clause,
723 where_clause = where_clause,
724 bind_section_update = bind_section_update,
725 bind_section_pk = bind_section_pk,
726 )
727}
728
729fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
731 let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
732
733 if let Some(pk) = &table.primary_key {
735 let sig = MethodSignature::new(
736 pk.columns.clone(),
737 PRIORITY_PRIMARY_KEY,
738 true,
739 "PRIMARY_KEY",
740 );
741 signatures.insert(pk.columns.clone(), sig);
742 }
743
744 for index in &table.indexes {
746 let priority = if index.unique {
747 PRIORITY_UNIQUE_INDEX
748 } else {
749 PRIORITY_NON_UNIQUE_INDEX
750 };
751 let source = if index.unique {
752 "UNIQUE_INDEX"
753 } else {
754 "NON_UNIQUE_INDEX"
755 };
756 let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
757
758 if let Some(existing) = signatures.get(&index.columns) {
760 if sig.priority < existing.priority {
761 signatures.insert(index.columns.clone(), sig);
762 }
763 } else {
764 signatures.insert(index.columns.clone(), sig);
765 }
766 }
767
768 for fk in &table.foreign_keys {
770 let columns = vec![fk.column_name.clone()];
771 let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
772
773 signatures.entry(columns).or_insert(sig);
774 }
775
776 signatures
777}
778
779fn generate_find_by_method(
781 table: &TableMetadata,
782 sig: &MethodSignature,
783 column_map: &HashMap<&str, &ColumnMetadata>,
784 struct_name: &str,
785 select_columns: &str,
786) -> String {
787 let params = build_params(&sig.columns, column_map, &table.name);
788
789 let (return_type, fetch_method) = if sig.is_unique {
790 (format!("Option<{}>", struct_name), "fetch_optional")
791 } else {
792 (format!("Vec<{}>", struct_name), "fetch_all")
793 };
794
795 let return_desc = if sig.is_unique {
796 "Option (unique)"
797 } else {
798 "Vec (non-unique)"
799 };
800
801 let has_nullable = sig.columns.iter().any(|c| {
803 column_map
804 .get(c.as_str())
805 .map(|col| col.nullable)
806 .unwrap_or(false)
807 });
808
809 if has_nullable {
810 generate_find_by_method_nullable(
812 table,
813 sig,
814 column_map,
815 select_columns,
816 ¶ms,
817 &return_type,
818 fetch_method,
819 return_desc,
820 )
821 } else {
822 let where_clause = build_where_clause(&sig.columns);
824 let bind_section = generate_bind_section(&sig.columns);
825
826 format!(
827 r#"/// Find by {source}: returns {return_desc}
828pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
829 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
830{bind_section}
831 .{fetch_method}(pool)
832 .await
833}}
834
835"#,
836 source = sig.source.to_lowercase().replace('_', " "),
837 return_desc = return_desc,
838 method_name = sig.method_name,
839 params = params,
840 return_type = return_type,
841 select_columns = select_columns,
842 table_name = table.name,
843 where_clause = where_clause,
844 bind_section = bind_section,
845 fetch_method = fetch_method,
846 )
847 }
848}
849
850#[allow(clippy::too_many_arguments)]
852fn generate_find_by_method_nullable(
853 table: &TableMetadata,
854 sig: &MethodSignature,
855 column_map: &HashMap<&str, &ColumnMetadata>,
856 select_columns: &str,
857 params: &str,
858 return_type: &str,
859 fetch_method: &str,
860 return_desc: &str,
861) -> String {
862 let mut where_parts = Vec::new();
864 let mut bind_parts = Vec::new();
865
866 for col in &sig.columns {
867 let col_meta = column_map.get(col.as_str()).unwrap();
868 let field_name = escape_field_name(col);
869
870 if col_meta.nullable {
871 where_parts.push(format!(
873 r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
874 field = field_name,
875 col = col,
876 ));
877 bind_parts.push(format!(
878 r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
879 field = field_name,
880 ));
881 } else {
882 where_parts.push(format!(r#""`{}` = ?""#, col));
883 bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
884 }
885 }
886
887 let where_expr = if where_parts.len() == 1 {
888 where_parts[0].clone()
889 } else {
890 let parts = where_parts
892 .iter()
893 .map(|p| format!("({})", p))
894 .collect::<Vec<_>>()
895 .join(", ");
896 format!("vec![{}].join(\" AND \")", parts)
897 };
898
899 let bind_code = bind_parts.join("\n ");
900
901 format!(
902 r#"/// Find by {source}: returns {return_desc}
903pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
904 let where_clause = {where_expr};
905 let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
906 let mut query = rdbi::DynamicQuery::new(sql);
907 {bind_code}
908 query.{fetch_method}(pool).await
909}}
910
911"#,
912 source = sig.source.to_lowercase().replace('_', " "),
913 return_desc = return_desc,
914 method_name = sig.method_name,
915 params = params,
916 return_type = return_type,
917 select_columns = select_columns,
918 table_name = table.name,
919 where_expr = where_expr,
920 bind_code = bind_code,
921 fetch_method = fetch_method,
922 )
923}
924
925fn generate_find_by_list_methods(
927 table: &TableMetadata,
928 column_map: &HashMap<&str, &ColumnMetadata>,
929 struct_name: &str,
930 select_columns: &str,
931) -> String {
932 let mut code = String::new();
933 let mut processed: HashSet<String> = HashSet::new();
934
935 if let Some(pk) = &table.primary_key {
937 if pk.columns.len() == 1 {
938 let col = &pk.columns[0];
939 code.push_str(&generate_single_find_by_list(
940 table,
941 col,
942 column_map,
943 struct_name,
944 select_columns,
945 ));
946 processed.insert(col.clone());
947 }
948 }
949
950 for index in &table.indexes {
952 if index.columns.len() == 1 {
953 let col = &index.columns[0];
954 if !processed.contains(col) {
955 code.push_str(&generate_single_find_by_list(
956 table,
957 col,
958 column_map,
959 struct_name,
960 select_columns,
961 ));
962 processed.insert(col.clone());
963 }
964 }
965 }
966
967 code
968}
969
970fn generate_single_find_by_list(
972 table: &TableMetadata,
973 column_name: &str,
974 column_map: &HashMap<&str, &ColumnMetadata>,
975 struct_name: &str,
976 select_columns: &str,
977) -> String {
978 let method_name = generate_find_by_list_method_name(column_name);
979 let param_name = pluralize(&escape_field_name(column_name));
980 let column = column_map.get(column_name).unwrap();
981 let rust_type = TypeResolver::resolve(column, &table.name);
982
983 let inner_type = rust_type.inner_type().to_type_string();
985
986 let column_name_plural = pluralize(column_name);
987 format!(
988 r#"/// Find by list of {column_name_plural} (IN clause)
989pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
990 if {param_name}.is_empty() {{
991 return Ok(Vec::new());
992 }}
993 let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
994 let query = format!(
995 "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
996 placeholders
997 );
998 rdbi::DynamicQuery::new(query)
999 .bind_all({param_name})
1000 .fetch_all(pool)
1001 .await
1002}}
1003
1004"#,
1005 column_name_plural = column_name_plural,
1006 column_name = column_name,
1007 method_name = method_name,
1008 param_name = param_name,
1009 inner_type = inner_type,
1010 struct_name = struct_name,
1011 select_columns = select_columns,
1012 table_name = table.name,
1013 )
1014}
1015
1016fn generate_composite_enum_list_methods(
1019 table: &TableMetadata,
1020 column_map: &HashMap<&str, &ColumnMetadata>,
1021 struct_name: &str,
1022 select_columns: &str,
1023) -> String {
1024 let mut code = String::new();
1025
1026 for index in &table.indexes {
1027 if index.columns.len() <= 1 {
1029 continue;
1030 }
1031
1032 let enum_columns: HashSet<&str> = index
1034 .columns
1035 .iter()
1036 .filter(|col_name| {
1037 column_map
1038 .get(col_name.as_str())
1039 .map(|col| col.is_enum())
1040 .unwrap_or(false)
1041 })
1042 .map(|s| s.as_str())
1043 .collect();
1044
1045 if enum_columns.is_empty() {
1047 continue;
1048 }
1049
1050 let first_column = &index.columns[0];
1052 if enum_columns.contains(first_column.as_str()) {
1053 continue;
1054 }
1055
1056 code.push_str(&generate_composite_enum_list_method(
1057 table,
1058 &index.columns,
1059 &enum_columns,
1060 column_map,
1061 struct_name,
1062 select_columns,
1063 ));
1064 }
1065
1066 code
1067}
1068
1069fn generate_composite_enum_list_method(
1071 table: &TableMetadata,
1072 columns: &[String],
1073 enum_columns: &HashSet<&str>,
1074 column_map: &HashMap<&str, &ColumnMetadata>,
1075 struct_name: &str,
1076 select_columns: &str,
1077) -> String {
1078 let method_name = generate_composite_enum_method_name(columns, enum_columns);
1080
1081 let mut params_parts = Vec::new();
1083
1084 for col_name in columns {
1085 let col = column_map.get(col_name.as_str()).unwrap();
1086 let rust_type = TypeResolver::resolve(col, &table.name);
1087 let is_enum = enum_columns.contains(col_name.as_str());
1088
1089 if is_enum {
1090 let param_name = pluralize(&escape_field_name(col_name));
1092 let inner_type = rust_type.inner_type().to_type_string();
1093 params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1094 } else {
1095 let param_name = escape_field_name(col_name);
1097 let param_type = rust_type.to_param_type_string();
1098 params_parts.push(format!("{}: {}", param_name, param_type));
1099 }
1100 }
1101
1102 let params = params_parts.join(", ");
1103
1104 let where_clause_static: Vec<String> = columns
1106 .iter()
1107 .map(|col_name| {
1108 if enum_columns.contains(col_name.as_str()) {
1109 format!("`{}` IN ({{}})", col_name) } else {
1111 format!("`{}` = ?", col_name)
1112 }
1113 })
1114 .collect();
1115
1116 let mut bind_code = String::new();
1118
1119 for col_name in columns {
1121 if !enum_columns.contains(col_name.as_str()) {
1122 let param_name = escape_field_name(col_name);
1123 bind_code.push_str(&format!(" .bind({})\n", param_name));
1124 }
1125 }
1126
1127 for col_name in columns {
1129 if enum_columns.contains(col_name.as_str()) {
1130 let param_name = pluralize(&escape_field_name(col_name));
1131 bind_code.push_str(&format!(" .bind_all({})\n", param_name));
1132 }
1133 }
1134
1135 let column_desc: Vec<String> = columns
1137 .iter()
1138 .map(|col| {
1139 if enum_columns.contains(col.as_str()) {
1140 pluralize(col)
1141 } else {
1142 col.clone()
1143 }
1144 })
1145 .collect();
1146
1147 let enum_col_names: Vec<&str> = columns
1149 .iter()
1150 .filter(|c| enum_columns.contains(c.as_str()))
1151 .map(|s| s.as_str())
1152 .collect();
1153
1154 let in_clause_builders: Vec<String> = enum_col_names
1156 .iter()
1157 .map(|col| {
1158 let param_name = pluralize(&escape_field_name(col));
1159 format!(
1160 "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1161 param_name = param_name
1162 )
1163 })
1164 .collect();
1165
1166 let format_args = in_clause_builders.join(", ");
1168
1169 format!(
1170 r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1171pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1172 // Check for empty enum lists
1173{empty_checks}
1174 // Build IN clause placeholders for enum columns
1175 let where_clause = format!("{where_template}", {format_args});
1176 let query = format!(
1177 "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1178 where_clause
1179 );
1180 rdbi::DynamicQuery::new(query)
1181{bind_code} .fetch_all(pool)
1182 .await
1183}}
1184
1185"#,
1186 column_desc = column_desc.join(" and "),
1187 method_name = method_name,
1188 params = params,
1189 struct_name = struct_name,
1190 select_columns = select_columns,
1191 table_name = table.name,
1192 where_template = where_clause_static.join(" AND "),
1193 format_args = format_args,
1194 bind_code = bind_code,
1195 empty_checks = generate_empty_checks(columns, enum_columns),
1196 )
1197}
1198
1199fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1201 let mut checks = String::new();
1202 for col_name in columns {
1203 if enum_columns.contains(col_name.as_str()) {
1204 let param_name = pluralize(&escape_field_name(col_name));
1205 checks.push_str(&format!(
1206 " if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1207 param_name
1208 ));
1209 }
1210 }
1211 checks
1212}
1213
1214fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1217 let mut parts = Vec::new();
1218 for col in columns {
1219 if enum_columns.contains(col.as_str()) {
1220 parts.push(pluralize(col));
1221 } else {
1222 parts.push(col.clone());
1223 }
1224 }
1225 generate_find_by_method_name(&parts)
1226}
1227
1228fn generate_pagination_methods(
1230 table: &TableMetadata,
1231 struct_name: &str,
1232 select_columns: &str,
1233 models_module: &str,
1234) -> String {
1235 let sort_by_enum = format!("{}SortBy", struct_name);
1236
1237 format!(
1238 r#"/// Find all records with pagination and sorting
1239pub async fn find_all_paginated<P: Pool>(
1240 pool: &P,
1241 limit: i32,
1242 offset: i32,
1243 sort_by: crate::{models_module}::{sort_by_enum},
1244 sort_dir: crate::{models_module}::SortDirection,
1245) -> Result<Vec<{struct_name}>> {{
1246 let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1247 let query = format!(
1248 "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1249 order_clause
1250 );
1251 rdbi::DynamicQuery::new(query)
1252 .bind(limit)
1253 .bind(offset)
1254 .fetch_all(pool)
1255 .await
1256}}
1257
1258/// Get paginated result with total count
1259pub async fn get_paginated_result<P: Pool>(
1260 pool: &P,
1261 page_size: i32,
1262 current_page: i32,
1263 sort_by: crate::{models_module}::{sort_by_enum},
1264 sort_dir: crate::{models_module}::SortDirection,
1265) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1266 let page_size = page_size.max(1);
1267 let current_page = current_page.max(1);
1268 let offset = (current_page - 1) * page_size;
1269
1270 let total_count = count_all(pool).await?;
1271 let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1272
1273 Ok(crate::{models_module}::PaginatedResult::new(
1274 items,
1275 total_count,
1276 current_page,
1277 page_size,
1278 ))
1279}}
1280
1281"#,
1282 struct_name = struct_name,
1283 select_columns = select_columns,
1284 table_name = table.name,
1285 models_module = models_module,
1286 sort_by_enum = sort_by_enum,
1287 )
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292 use super::*;
1293 use crate::parser::{IndexMetadata, PrimaryKey};
1294
1295 fn make_table() -> TableMetadata {
1296 TableMetadata {
1297 name: "users".to_string(),
1298 comment: None,
1299 columns: vec![
1300 ColumnMetadata {
1301 name: "id".to_string(),
1302 data_type: "BIGINT".to_string(),
1303 nullable: false,
1304 default_value: None,
1305 is_auto_increment: true,
1306 is_unsigned: false,
1307 enum_values: None,
1308 comment: None,
1309 },
1310 ColumnMetadata {
1311 name: "email".to_string(),
1312 data_type: "VARCHAR(255)".to_string(),
1313 nullable: false,
1314 default_value: None,
1315 is_auto_increment: false,
1316 is_unsigned: false,
1317 enum_values: None,
1318 comment: None,
1319 },
1320 ColumnMetadata {
1321 name: "status".to_string(),
1322 data_type: "ENUM".to_string(),
1323 nullable: false,
1324 default_value: None,
1325 is_auto_increment: false,
1326 is_unsigned: false,
1327 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1328 comment: None,
1329 },
1330 ],
1331 indexes: vec![
1332 IndexMetadata {
1333 name: "email_unique".to_string(),
1334 columns: vec!["email".to_string()],
1335 unique: true,
1336 },
1337 IndexMetadata {
1338 name: "idx_status".to_string(),
1339 columns: vec!["status".to_string()],
1340 unique: false,
1341 },
1342 ],
1343 foreign_keys: vec![],
1344 primary_key: Some(PrimaryKey {
1345 columns: vec!["id".to_string()],
1346 }),
1347 }
1348 }
1349
1350 #[test]
1351 fn test_collect_method_signatures() {
1352 let table = make_table();
1353 let sigs = collect_method_signatures(&table);
1354
1355 assert_eq!(sigs.len(), 3);
1357
1358 let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1359 assert!(id_sig.is_unique);
1360 assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1361
1362 let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1363 assert!(email_sig.is_unique);
1364 assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1365
1366 let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1367 assert!(!status_sig.is_unique);
1368 assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1369 }
1370
1371 #[test]
1372 fn test_build_select_columns() {
1373 let table = make_table();
1374 let cols = build_select_columns(&table);
1375 assert!(cols.contains("`id`"));
1376 assert!(cols.contains("`email`"));
1377 assert!(cols.contains("`status`"));
1378 }
1379
1380 #[test]
1381 fn test_build_where_clause() {
1382 let clause = build_where_clause(&["id".to_string()]);
1383 assert_eq!(clause, "`id` = ?");
1384
1385 let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1386 assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1387 }
1388
1389 #[test]
1390 fn test_generate_upsert_method() {
1391 let table = make_table();
1392 let code = generate_upsert_method(&table, "Users");
1393
1394 assert!(code.contains("pub async fn upsert"));
1396 assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1398 assert!(!code.contains("`id` = VALUES(`id`)"));
1400 assert!(code.contains("`email` = VALUES(`email`)"));
1402 assert!(code.contains("`status` = VALUES(`status`)"));
1403 }
1404
1405 #[test]
1406 fn test_generate_upsert_method_no_pk() {
1407 let mut table = make_table();
1408 table.primary_key = None;
1409 table.indexes.clear();
1410
1411 let code = generate_upsert_method(&table, "Users");
1412 assert!(code.is_empty());
1414 }
1415
1416 #[test]
1417 fn test_generate_insert_all_method() {
1418 let table = make_table();
1419 let code = generate_insert_all_method(&table, "Users");
1420
1421 assert!(code.contains("pub async fn insert_all"));
1423 assert!(code.contains("rdbi::BatchInsert::new"));
1425 }
1426
1427 #[test]
1428 fn test_generate_pagination_methods() {
1429 let table = make_table();
1430 let select_columns = build_select_columns(&table);
1431 let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1432
1433 assert!(code.contains("pub async fn find_all_paginated"));
1435 assert!(code.contains("limit: i32"));
1437 assert!(code.contains("offset: i32"));
1438 assert!(code.contains("UsersSortBy"));
1440 assert!(code.contains("SortDirection"));
1442 assert!(code.contains("pub async fn get_paginated_result"));
1444 assert!(code.contains("PaginatedResult<Users>"));
1446 }
1447}