1use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13 escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14 generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24#[derive(Debug, Clone)]
26struct MethodSignature {
27 columns: Vec<String>,
28 method_name: String,
29 priority: u8,
30 is_unique: bool,
31 source: String,
32}
33
34impl MethodSignature {
35 fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36 let method_name = generate_find_by_method_name(&columns);
37 Self {
38 columns,
39 method_name,
40 priority,
41 is_unique,
42 source: source.to_string(),
43 }
44 }
45}
46
47pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49 let output_dir = &config.output_dao_dir;
50 fs::create_dir_all(output_dir)?;
51
52 let mut mod_content = String::new();
54 mod_content.push_str("// Generated DAO functions\n\n");
55
56 for table in tables {
57 let file_name = heck::AsSnakeCase(&table.name).to_string();
58 mod_content.push_str(&format!(
59 "#[allow(dead_code, clippy::all)]\npub mod {};\n",
60 file_name
61 ));
62 }
63
64 let mod_path = output_dir.join("mod.rs");
65 fs::write(&mod_path, mod_content)?;
66
67 for table in tables {
69 generate_dao_file(table, output_dir, &config.models_module)?;
70 }
71
72 Ok(())
73}
74
75fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
77 let struct_name = to_struct_name(&table.name);
78 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
79 debug!("Generating DAO for {} -> {}", struct_name, file_name);
80
81 let mut code = String::new();
82
83 let column_map: HashMap<&str, &ColumnMetadata> =
85 table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
86
87 let mut needs_chrono = false;
89 let mut needs_decimal = false;
90
91 for col in &table.columns {
92 let rust_type = TypeResolver::resolve(col, &table.name);
93 if rust_type.needs_chrono() {
94 needs_chrono = true;
95 }
96 if rust_type.needs_decimal() {
97 needs_decimal = true;
98 }
99 }
100
101 let has_enums = table.columns.iter().any(|c| c.is_enum());
103
104 code.push_str("use rdbi::{Pool, Query, Result};\n");
106 code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
107
108 if has_enums {
109 for col in &table.columns {
111 if col.is_enum() {
112 let enum_name = super::naming::to_enum_name(&table.name, &col.name);
113 code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
114 }
115 }
116 }
117
118 if needs_chrono {
119 code.push_str("#[allow(unused_imports)]\n");
120 code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
121 }
122 if needs_decimal {
123 code.push_str("#[allow(unused_imports)]\n");
124 code.push_str("use rust_decimal::Decimal;\n");
125 }
126
127 code.push('\n');
128
129 let select_columns = build_select_columns(table);
131
132 code.push_str(&generate_find_all(table, &struct_name, &select_columns));
134
135 code.push_str(&generate_count_all(table));
137
138 if let Some(pk) = &table.primary_key {
140 code.push_str(&generate_pk_methods(
141 table,
142 pk,
143 &column_map,
144 &struct_name,
145 &select_columns,
146 ));
147 }
148
149 code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
151
152 code.push_str(&generate_insert_plain_method(table, &column_map));
154
155 code.push_str(&generate_insert_all_method(table, &struct_name));
157
158 code.push_str(&generate_upsert_method(table, &struct_name));
160
161 if table.primary_key.is_some() {
163 code.push_str(&generate_update_methods(table, &struct_name, &column_map));
164 code.push_str(&generate_update_plain_method(table, &column_map));
166 }
167
168 let signatures = collect_method_signatures(table);
170 for sig in signatures.values() {
171 if sig.source == "PRIMARY_KEY" {
173 continue;
174 }
175 code.push_str(&generate_find_by_method(
176 table,
177 sig,
178 &column_map,
179 &struct_name,
180 &select_columns,
181 ));
182 }
183
184 code.push_str(&generate_find_by_list_methods(
186 table,
187 &column_map,
188 &struct_name,
189 &select_columns,
190 ));
191
192 code.push_str(&generate_composite_enum_list_methods(
194 table,
195 &column_map,
196 &struct_name,
197 &select_columns,
198 ));
199
200 code.push_str(&generate_pagination_methods(
202 table,
203 &struct_name,
204 &select_columns,
205 models_module,
206 ));
207
208 let file_path = output_dir.join(&file_name);
209 fs::write(&file_path, code)?;
210 Ok(())
211}
212
213fn build_select_columns(table: &TableMetadata) -> String {
215 table
216 .columns
217 .iter()
218 .map(|c| format!("`{}`", c.name))
219 .collect::<Vec<_>>()
220 .join(", ")
221}
222
223fn build_where_clause(columns: &[String]) -> String {
225 columns
226 .iter()
227 .map(|c| format!("`{}` = ?", c))
228 .collect::<Vec<_>>()
229 .join(" AND ")
230}
231
232fn build_params(
234 columns: &[String],
235 column_map: &HashMap<&str, &ColumnMetadata>,
236 table_name: &str,
237) -> String {
238 columns
239 .iter()
240 .map(|c| {
241 let col = column_map.get(c.as_str()).unwrap();
242 let rust_type = TypeResolver::resolve(col, table_name);
243 let param_type = rust_type.to_param_type_string();
244 format!("{}: {}", escape_field_name(c), param_type)
245 })
246 .collect::<Vec<_>>()
247 .join(", ")
248}
249
250fn generate_bind_section(columns: &[String]) -> String {
252 columns
253 .iter()
254 .map(|c| format!(" .bind({})", escape_field_name(c)))
255 .collect::<Vec<_>>()
256 .join("\n")
257}
258
259fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
261 format!(
262 r#"/// Find all records
263pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
264 Query::new("SELECT {select_columns} FROM `{table_name}`")
265 .fetch_all(pool)
266 .await
267}}
268
269"#,
270 struct_name = struct_name,
271 select_columns = select_columns,
272 table_name = table.name,
273 )
274}
275
276fn generate_count_all(table: &TableMetadata) -> String {
278 format!(
279 r#"/// Count all records
280pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
281 Query::new("SELECT COUNT(*) FROM `{table_name}`")
282 .fetch_scalar(pool)
283 .await
284}}
285
286"#,
287 table_name = table.name,
288 )
289}
290
291fn generate_pk_methods(
293 table: &TableMetadata,
294 pk: &crate::parser::PrimaryKey,
295 column_map: &HashMap<&str, &ColumnMetadata>,
296 struct_name: &str,
297 select_columns: &str,
298) -> String {
299 let mut code = String::new();
300
301 let method_name = generate_find_by_method_name(&pk.columns);
302 let params = build_params(&pk.columns, column_map, &table.name);
303 let where_clause = build_where_clause(&pk.columns);
304 let bind_section = generate_bind_section(&pk.columns);
305
306 code.push_str(&format!(
308 r#"/// Find by primary key
309pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
310 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
311{bind_section}
312 .fetch_optional(pool)
313 .await
314}}
315
316"#,
317 method_name = method_name,
318 params = params,
319 struct_name = struct_name,
320 select_columns = select_columns,
321 table_name = table.name,
322 where_clause = where_clause,
323 bind_section = bind_section,
324 ));
325
326 let delete_method = generate_delete_by_method_name(&pk.columns);
328 code.push_str(&format!(
329 r#"/// Delete by primary key
330pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
331 Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
332{bind_section}
333 .execute(pool)
334 .await
335 .map(|r| r.rows_affected)
336}}
337
338"#,
339 delete_method = delete_method,
340 params = params,
341 table_name = table.name,
342 where_clause = where_clause,
343 bind_section = bind_section,
344 ));
345
346 code
347}
348
349fn generate_insert_methods(
351 table: &TableMetadata,
352 struct_name: &str,
353 _column_map: &HashMap<&str, &ColumnMetadata>,
354) -> String {
355 let mut code = String::new();
356
357 let insert_columns: Vec<&ColumnMetadata> = table
359 .columns
360 .iter()
361 .filter(|c| !c.is_auto_increment)
362 .collect();
363
364 if insert_columns.is_empty() {
365 return code;
366 }
367
368 let column_list = insert_columns
369 .iter()
370 .map(|c| format!("`{}`", c.name))
371 .collect::<Vec<_>>()
372 .join(", ");
373
374 let placeholders = insert_columns
375 .iter()
376 .map(|_| "?")
377 .collect::<Vec<_>>()
378 .join(", ");
379
380 let bind_fields = insert_columns
381 .iter()
382 .map(|c| {
383 let field = escape_field_name(&c.name);
384 let rust_type = TypeResolver::resolve(c, &table.name);
385 if rust_type.is_copy() {
386 format!(" .bind(entity.{})", field)
387 } else {
388 format!(" .bind(&entity.{})", field)
389 }
390 })
391 .collect::<Vec<_>>()
392 .join("\n");
393
394 code.push_str(&format!(
396 r#"/// Insert a new record
397pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
398 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
399{bind_fields}
400 .execute(pool)
401 .await
402 .map(|r| r.last_insert_id.unwrap_or(0))
403}}
404
405"#,
406 struct_name = struct_name,
407 table_name = table.name,
408 column_list = column_list,
409 placeholders = placeholders,
410 bind_fields = bind_fields,
411 ));
412
413 code
414}
415
416fn generate_insert_plain_method(
418 table: &TableMetadata,
419 column_map: &HashMap<&str, &ColumnMetadata>,
420) -> String {
421 let insert_columns: Vec<&ColumnMetadata> = table
423 .columns
424 .iter()
425 .filter(|c| !c.is_auto_increment)
426 .collect();
427
428 if insert_columns.is_empty() {
429 return String::new();
430 }
431
432 let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
433 let params = build_params(&column_names, column_map, &table.name);
434
435 let column_list = insert_columns
436 .iter()
437 .map(|c| format!("`{}`", c.name))
438 .collect::<Vec<_>>()
439 .join(", ");
440
441 let placeholders = insert_columns
442 .iter()
443 .map(|_| "?")
444 .collect::<Vec<_>>()
445 .join(", ");
446
447 let bind_section = generate_bind_section(&column_names);
448
449 format!(
450 r#"/// Insert a new record with individual parameters
451#[allow(clippy::too_many_arguments)]
452pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
453 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
454{bind_section}
455 .execute(pool)
456 .await
457 .map(|r| r.last_insert_id.unwrap_or(0))
458}}
459
460"#,
461 params = params,
462 table_name = table.name,
463 column_list = column_list,
464 placeholders = placeholders,
465 bind_section = bind_section,
466 )
467}
468
469fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
471 let insert_columns: Vec<&ColumnMetadata> = table
473 .columns
474 .iter()
475 .filter(|c| !c.is_auto_increment)
476 .collect();
477
478 if insert_columns.is_empty() {
479 return String::new();
480 }
481
482 format!(
483 r#"/// Insert multiple records in a single batch
484pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
485 rdbi::BatchInsert::new("{table_name}", entities)
486 .execute(pool)
487 .await
488 .map(|r| r.rows_affected)
489}}
490
491"#,
492 struct_name = struct_name,
493 table_name = table.name,
494 )
495}
496
497fn has_unique_index(table: &TableMetadata) -> bool {
499 table.indexes.iter().any(|idx| idx.unique)
500}
501
502fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
504 if table.primary_key.is_none() && !has_unique_index(table) {
506 return String::new();
507 }
508
509 let insert_columns: Vec<&ColumnMetadata> = table
511 .columns
512 .iter()
513 .filter(|c| !c.is_auto_increment)
514 .collect();
515
516 if insert_columns.is_empty() {
517 return String::new();
518 }
519
520 let pk_columns: HashSet<&str> = table
522 .primary_key
523 .as_ref()
524 .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
525 .unwrap_or_default();
526
527 let update_columns: Vec<&ColumnMetadata> = insert_columns
529 .iter()
530 .filter(|c| !pk_columns.contains(c.name.as_str()))
531 .copied()
532 .collect();
533
534 if update_columns.is_empty() {
536 return String::new();
537 }
538
539 let column_list = insert_columns
540 .iter()
541 .map(|c| format!("`{}`", c.name))
542 .collect::<Vec<_>>()
543 .join(", ");
544
545 let placeholders = insert_columns
546 .iter()
547 .map(|_| "?")
548 .collect::<Vec<_>>()
549 .join(", ");
550
551 let update_clause = update_columns
552 .iter()
553 .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
554 .collect::<Vec<_>>()
555 .join(", ");
556
557 let bind_fields = insert_columns
558 .iter()
559 .map(|c| {
560 let field = escape_field_name(&c.name);
561 let rust_type = TypeResolver::resolve(c, &table.name);
562 if rust_type.is_copy() {
563 format!(" .bind(entity.{})", field)
564 } else {
565 format!(" .bind(&entity.{})", field)
566 }
567 })
568 .collect::<Vec<_>>()
569 .join("\n");
570
571 format!(
572 r#"/// Upsert a record (insert or update on duplicate key)
573/// Returns rows_affected: 1 if inserted, 2 if updated
574pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
575 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
576 ON DUPLICATE KEY UPDATE {update_clause}")
577{bind_fields}
578 .execute(pool)
579 .await
580 .map(|r| r.rows_affected)
581}}
582
583"#,
584 struct_name = struct_name,
585 table_name = table.name,
586 column_list = column_list,
587 placeholders = placeholders,
588 update_clause = update_clause,
589 bind_fields = bind_fields,
590 )
591}
592
593fn generate_update_methods(
595 table: &TableMetadata,
596 struct_name: &str,
597 column_map: &HashMap<&str, &ColumnMetadata>,
598) -> String {
599 let mut code = String::new();
600
601 let pk = table.primary_key.as_ref().unwrap();
602
603 let update_columns: Vec<&ColumnMetadata> = table
605 .columns
606 .iter()
607 .filter(|c| !pk.columns.contains(&c.name))
608 .collect();
609
610 if update_columns.is_empty() {
611 return code;
612 }
613
614 let set_clause = update_columns
615 .iter()
616 .map(|c| format!("`{}` = ?", c.name))
617 .collect::<Vec<_>>()
618 .join(", ");
619
620 let where_clause = build_where_clause(&pk.columns);
621
622 let bind_fields: Vec<String> = update_columns
624 .iter()
625 .map(|c| {
626 let field = escape_field_name(&c.name);
627 let rust_type = TypeResolver::resolve(c, &table.name);
628 if rust_type.is_copy() {
629 format!(" .bind(entity.{})", field)
630 } else {
631 format!(" .bind(&entity.{})", field)
632 }
633 })
634 .chain(pk.columns.iter().map(|c| {
635 let field = escape_field_name(c);
636 let col = column_map.get(c.as_str()).unwrap();
637 let rust_type = TypeResolver::resolve(col, &table.name);
638 if rust_type.is_copy() {
639 format!(" .bind(entity.{})", field)
640 } else {
641 format!(" .bind(&entity.{})", field)
642 }
643 }))
644 .collect();
645
646 code.push_str(&format!(
648 r#"/// Update a record by primary key
649pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
650 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
651{bind_fields}
652 .execute(pool)
653 .await
654 .map(|r| r.rows_affected)
655}}
656
657"#,
658 struct_name = struct_name,
659 table_name = table.name,
660 set_clause = set_clause,
661 where_clause = where_clause,
662 bind_fields = bind_fields.join("\n"),
663 ));
664
665 code
666}
667
668fn generate_update_plain_method(
670 table: &TableMetadata,
671 column_map: &HashMap<&str, &ColumnMetadata>,
672) -> String {
673 let pk = table.primary_key.as_ref().unwrap();
674
675 let update_columns: Vec<&ColumnMetadata> = table
677 .columns
678 .iter()
679 .filter(|c| !pk.columns.contains(&c.name))
680 .collect();
681
682 if update_columns.is_empty() {
683 return String::new();
684 }
685
686 let method_name = generate_update_by_method_name(&pk.columns);
688
689 let pk_params = build_params(&pk.columns, column_map, &table.name);
691 let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
692 let update_params = build_params(&update_column_names, column_map, &table.name);
693 let all_params = format!("{}, {}", pk_params, update_params);
694
695 let set_clause = update_columns
696 .iter()
697 .map(|c| format!("`{}` = ?", c.name))
698 .collect::<Vec<_>>()
699 .join(", ");
700
701 let where_clause = build_where_clause(&pk.columns);
702
703 let bind_section_update = generate_bind_section(&update_column_names);
705 let bind_section_pk = generate_bind_section(&pk.columns);
706
707 format!(
708 r#"/// Update a record by primary key with individual parameters
709#[allow(clippy::too_many_arguments)]
710pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
711 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
712{bind_section_update}
713{bind_section_pk}
714 .execute(pool)
715 .await
716 .map(|r| r.rows_affected)
717}}
718
719"#,
720 method_name = method_name,
721 all_params = all_params,
722 table_name = table.name,
723 set_clause = set_clause,
724 where_clause = where_clause,
725 bind_section_update = bind_section_update,
726 bind_section_pk = bind_section_pk,
727 )
728}
729
730fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
732 let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
733
734 if let Some(pk) = &table.primary_key {
736 let sig = MethodSignature::new(
737 pk.columns.clone(),
738 PRIORITY_PRIMARY_KEY,
739 true,
740 "PRIMARY_KEY",
741 );
742 signatures.insert(pk.columns.clone(), sig);
743 }
744
745 for index in &table.indexes {
747 let priority = if index.unique {
748 PRIORITY_UNIQUE_INDEX
749 } else {
750 PRIORITY_NON_UNIQUE_INDEX
751 };
752 let source = if index.unique {
753 "UNIQUE_INDEX"
754 } else {
755 "NON_UNIQUE_INDEX"
756 };
757 let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
758
759 if let Some(existing) = signatures.get(&index.columns) {
761 if sig.priority < existing.priority {
762 signatures.insert(index.columns.clone(), sig);
763 }
764 } else {
765 signatures.insert(index.columns.clone(), sig);
766 }
767 }
768
769 for fk in &table.foreign_keys {
771 let columns = vec![fk.column_name.clone()];
772 let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
773
774 signatures.entry(columns).or_insert(sig);
775 }
776
777 signatures
778}
779
780fn generate_find_by_method(
782 table: &TableMetadata,
783 sig: &MethodSignature,
784 column_map: &HashMap<&str, &ColumnMetadata>,
785 struct_name: &str,
786 select_columns: &str,
787) -> String {
788 let params = build_params(&sig.columns, column_map, &table.name);
789
790 let (return_type, fetch_method) = if sig.is_unique {
791 (format!("Option<{}>", struct_name), "fetch_optional")
792 } else {
793 (format!("Vec<{}>", struct_name), "fetch_all")
794 };
795
796 let return_desc = if sig.is_unique {
797 "Option (unique)"
798 } else {
799 "Vec (non-unique)"
800 };
801
802 let has_nullable = sig.columns.iter().any(|c| {
804 column_map
805 .get(c.as_str())
806 .map(|col| col.nullable)
807 .unwrap_or(false)
808 });
809
810 if has_nullable {
811 generate_find_by_method_nullable(
813 table,
814 sig,
815 column_map,
816 select_columns,
817 ¶ms,
818 &return_type,
819 fetch_method,
820 return_desc,
821 )
822 } else {
823 let where_clause = build_where_clause(&sig.columns);
825 let bind_section = generate_bind_section(&sig.columns);
826
827 format!(
828 r#"/// Find by {source}: returns {return_desc}
829pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
830 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
831{bind_section}
832 .{fetch_method}(pool)
833 .await
834}}
835
836"#,
837 source = sig.source.to_lowercase().replace('_', " "),
838 return_desc = return_desc,
839 method_name = sig.method_name,
840 params = params,
841 return_type = return_type,
842 select_columns = select_columns,
843 table_name = table.name,
844 where_clause = where_clause,
845 bind_section = bind_section,
846 fetch_method = fetch_method,
847 )
848 }
849}
850
851#[allow(clippy::too_many_arguments)]
853fn generate_find_by_method_nullable(
854 table: &TableMetadata,
855 sig: &MethodSignature,
856 column_map: &HashMap<&str, &ColumnMetadata>,
857 select_columns: &str,
858 params: &str,
859 return_type: &str,
860 fetch_method: &str,
861 return_desc: &str,
862) -> String {
863 let mut where_parts = Vec::new();
865 let mut bind_parts = Vec::new();
866
867 for col in &sig.columns {
868 let col_meta = column_map.get(col.as_str()).unwrap();
869 let field_name = escape_field_name(col);
870
871 if col_meta.nullable {
872 where_parts.push(format!(
874 r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
875 field = field_name,
876 col = col,
877 ));
878 bind_parts.push(format!(
879 r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
880 field = field_name,
881 ));
882 } else {
883 where_parts.push(format!(r#""`{}` = ?""#, col));
884 bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
885 }
886 }
887
888 let where_expr = if where_parts.len() == 1 {
889 where_parts[0].clone()
890 } else {
891 let parts = where_parts
893 .iter()
894 .map(|p| format!("({})", p))
895 .collect::<Vec<_>>()
896 .join(", ");
897 format!("vec![{}].join(\" AND \")", parts)
898 };
899
900 let bind_code = bind_parts.join("\n ");
901
902 format!(
903 r#"/// Find by {source}: returns {return_desc}
904pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
905 let where_clause = {where_expr};
906 let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
907 let mut query = rdbi::DynamicQuery::new(sql);
908 {bind_code}
909 query.{fetch_method}(pool).await
910}}
911
912"#,
913 source = sig.source.to_lowercase().replace('_', " "),
914 return_desc = return_desc,
915 method_name = sig.method_name,
916 params = params,
917 return_type = return_type,
918 select_columns = select_columns,
919 table_name = table.name,
920 where_expr = where_expr,
921 bind_code = bind_code,
922 fetch_method = fetch_method,
923 )
924}
925
926fn generate_find_by_list_methods(
928 table: &TableMetadata,
929 column_map: &HashMap<&str, &ColumnMetadata>,
930 struct_name: &str,
931 select_columns: &str,
932) -> String {
933 let mut code = String::new();
934 let mut processed: HashSet<String> = HashSet::new();
935
936 if let Some(pk) = &table.primary_key {
938 if pk.columns.len() == 1 {
939 let col = &pk.columns[0];
940 code.push_str(&generate_single_find_by_list(
941 table,
942 col,
943 column_map,
944 struct_name,
945 select_columns,
946 ));
947 processed.insert(col.clone());
948 }
949 }
950
951 for index in &table.indexes {
953 if index.columns.len() == 1 {
954 let col = &index.columns[0];
955 if !processed.contains(col) {
956 code.push_str(&generate_single_find_by_list(
957 table,
958 col,
959 column_map,
960 struct_name,
961 select_columns,
962 ));
963 processed.insert(col.clone());
964 }
965 }
966 }
967
968 code
969}
970
971fn generate_single_find_by_list(
973 table: &TableMetadata,
974 column_name: &str,
975 column_map: &HashMap<&str, &ColumnMetadata>,
976 struct_name: &str,
977 select_columns: &str,
978) -> String {
979 let method_name = generate_find_by_list_method_name(column_name);
980 let param_name = pluralize(&escape_field_name(column_name));
981 let column = column_map.get(column_name).unwrap();
982 let rust_type = TypeResolver::resolve(column, &table.name);
983
984 let inner_type = rust_type.inner_type().to_type_string();
986
987 let column_name_plural = pluralize(column_name);
988 format!(
989 r#"/// Find by list of {column_name_plural} (IN clause)
990pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
991 if {param_name}.is_empty() {{
992 return Ok(Vec::new());
993 }}
994 let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
995 let query = format!(
996 "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
997 placeholders
998 );
999 rdbi::DynamicQuery::new(query)
1000 .bind_all({param_name})
1001 .fetch_all(pool)
1002 .await
1003}}
1004
1005"#,
1006 column_name_plural = column_name_plural,
1007 column_name = column_name,
1008 method_name = method_name,
1009 param_name = param_name,
1010 inner_type = inner_type,
1011 struct_name = struct_name,
1012 select_columns = select_columns,
1013 table_name = table.name,
1014 )
1015}
1016
1017fn generate_composite_enum_list_methods(
1020 table: &TableMetadata,
1021 column_map: &HashMap<&str, &ColumnMetadata>,
1022 struct_name: &str,
1023 select_columns: &str,
1024) -> String {
1025 let mut code = String::new();
1026
1027 for index in &table.indexes {
1028 if index.columns.len() <= 1 {
1030 continue;
1031 }
1032
1033 let enum_columns: HashSet<&str> = index
1035 .columns
1036 .iter()
1037 .filter(|col_name| {
1038 column_map
1039 .get(col_name.as_str())
1040 .map(|col| col.is_enum())
1041 .unwrap_or(false)
1042 })
1043 .map(|s| s.as_str())
1044 .collect();
1045
1046 if enum_columns.is_empty() {
1048 continue;
1049 }
1050
1051 let first_column = &index.columns[0];
1053 if enum_columns.contains(first_column.as_str()) {
1054 continue;
1055 }
1056
1057 code.push_str(&generate_composite_enum_list_method(
1058 table,
1059 &index.columns,
1060 &enum_columns,
1061 column_map,
1062 struct_name,
1063 select_columns,
1064 ));
1065 }
1066
1067 code
1068}
1069
1070fn generate_composite_enum_list_method(
1072 table: &TableMetadata,
1073 columns: &[String],
1074 enum_columns: &HashSet<&str>,
1075 column_map: &HashMap<&str, &ColumnMetadata>,
1076 struct_name: &str,
1077 select_columns: &str,
1078) -> String {
1079 let method_name = generate_composite_enum_method_name(columns, enum_columns);
1081
1082 let mut params_parts = Vec::new();
1084
1085 for col_name in columns {
1086 let col = column_map.get(col_name.as_str()).unwrap();
1087 let rust_type = TypeResolver::resolve(col, &table.name);
1088 let is_enum = enum_columns.contains(col_name.as_str());
1089
1090 if is_enum {
1091 let param_name = pluralize(&escape_field_name(col_name));
1093 let inner_type = rust_type.inner_type().to_type_string();
1094 params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1095 } else {
1096 let param_name = escape_field_name(col_name);
1098 let param_type = rust_type.to_param_type_string();
1099 params_parts.push(format!("{}: {}", param_name, param_type));
1100 }
1101 }
1102
1103 let params = params_parts.join(", ");
1104
1105 let where_clause_static: Vec<String> = columns
1107 .iter()
1108 .map(|col_name| {
1109 if enum_columns.contains(col_name.as_str()) {
1110 format!("`{}` IN ({{}})", col_name) } else {
1112 format!("`{}` = ?", col_name)
1113 }
1114 })
1115 .collect();
1116
1117 let mut bind_code = String::new();
1119
1120 for col_name in columns {
1122 if !enum_columns.contains(col_name.as_str()) {
1123 let param_name = escape_field_name(col_name);
1124 bind_code.push_str(&format!(" .bind({})\n", param_name));
1125 }
1126 }
1127
1128 for col_name in columns {
1130 if enum_columns.contains(col_name.as_str()) {
1131 let param_name = pluralize(&escape_field_name(col_name));
1132 bind_code.push_str(&format!(" .bind_all({})\n", param_name));
1133 }
1134 }
1135
1136 let column_desc: Vec<String> = columns
1138 .iter()
1139 .map(|col| {
1140 if enum_columns.contains(col.as_str()) {
1141 pluralize(col)
1142 } else {
1143 col.clone()
1144 }
1145 })
1146 .collect();
1147
1148 let enum_col_names: Vec<&str> = columns
1150 .iter()
1151 .filter(|c| enum_columns.contains(c.as_str()))
1152 .map(|s| s.as_str())
1153 .collect();
1154
1155 let in_clause_builders: Vec<String> = enum_col_names
1157 .iter()
1158 .map(|col| {
1159 let param_name = pluralize(&escape_field_name(col));
1160 format!(
1161 "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1162 param_name = param_name
1163 )
1164 })
1165 .collect();
1166
1167 let format_args = in_clause_builders.join(", ");
1169
1170 format!(
1171 r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1172pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1173 // Check for empty enum lists
1174{empty_checks}
1175 // Build IN clause placeholders for enum columns
1176 let where_clause = format!("{where_template}", {format_args});
1177 let query = format!(
1178 "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1179 where_clause
1180 );
1181 rdbi::DynamicQuery::new(query)
1182{bind_code} .fetch_all(pool)
1183 .await
1184}}
1185
1186"#,
1187 column_desc = column_desc.join(" and "),
1188 method_name = method_name,
1189 params = params,
1190 struct_name = struct_name,
1191 select_columns = select_columns,
1192 table_name = table.name,
1193 where_template = where_clause_static.join(" AND "),
1194 format_args = format_args,
1195 bind_code = bind_code,
1196 empty_checks = generate_empty_checks(columns, enum_columns),
1197 )
1198}
1199
1200fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1202 let mut checks = String::new();
1203 for col_name in columns {
1204 if enum_columns.contains(col_name.as_str()) {
1205 let param_name = pluralize(&escape_field_name(col_name));
1206 checks.push_str(&format!(
1207 " if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1208 param_name
1209 ));
1210 }
1211 }
1212 checks
1213}
1214
1215fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1218 let mut parts = Vec::new();
1219 for col in columns {
1220 if enum_columns.contains(col.as_str()) {
1221 parts.push(pluralize(col));
1222 } else {
1223 parts.push(col.clone());
1224 }
1225 }
1226 generate_find_by_method_name(&parts)
1227}
1228
1229fn generate_pagination_methods(
1231 table: &TableMetadata,
1232 struct_name: &str,
1233 select_columns: &str,
1234 models_module: &str,
1235) -> String {
1236 let sort_by_enum = format!("{}SortBy", struct_name);
1237
1238 format!(
1239 r#"/// Find all records with pagination and sorting
1240pub async fn find_all_paginated<P: Pool>(
1241 pool: &P,
1242 limit: i32,
1243 offset: i32,
1244 sort_by: crate::{models_module}::{sort_by_enum},
1245 sort_dir: crate::{models_module}::SortDirection,
1246) -> Result<Vec<{struct_name}>> {{
1247 let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1248 let query = format!(
1249 "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1250 order_clause
1251 );
1252 rdbi::DynamicQuery::new(query)
1253 .bind(limit)
1254 .bind(offset)
1255 .fetch_all(pool)
1256 .await
1257}}
1258
1259/// Get paginated result with total count
1260pub async fn get_paginated_result<P: Pool>(
1261 pool: &P,
1262 page_size: i32,
1263 current_page: i32,
1264 sort_by: crate::{models_module}::{sort_by_enum},
1265 sort_dir: crate::{models_module}::SortDirection,
1266) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1267 let page_size = page_size.max(1);
1268 let current_page = current_page.max(1);
1269 let offset = (current_page - 1) * page_size;
1270
1271 let total_count = count_all(pool).await?;
1272 let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1273
1274 Ok(crate::{models_module}::PaginatedResult::new(
1275 items,
1276 total_count,
1277 current_page,
1278 page_size,
1279 ))
1280}}
1281
1282"#,
1283 struct_name = struct_name,
1284 select_columns = select_columns,
1285 table_name = table.name,
1286 models_module = models_module,
1287 sort_by_enum = sort_by_enum,
1288 )
1289}
1290
1291#[cfg(test)]
1292mod tests {
1293 use super::*;
1294 use crate::parser::{IndexMetadata, PrimaryKey};
1295
1296 fn make_table() -> TableMetadata {
1297 TableMetadata {
1298 name: "users".to_string(),
1299 comment: None,
1300 columns: vec![
1301 ColumnMetadata {
1302 name: "id".to_string(),
1303 data_type: "BIGINT".to_string(),
1304 nullable: false,
1305 default_value: None,
1306 is_auto_increment: true,
1307 is_unsigned: false,
1308 enum_values: None,
1309 comment: None,
1310 },
1311 ColumnMetadata {
1312 name: "email".to_string(),
1313 data_type: "VARCHAR(255)".to_string(),
1314 nullable: false,
1315 default_value: None,
1316 is_auto_increment: false,
1317 is_unsigned: false,
1318 enum_values: None,
1319 comment: None,
1320 },
1321 ColumnMetadata {
1322 name: "status".to_string(),
1323 data_type: "ENUM".to_string(),
1324 nullable: false,
1325 default_value: None,
1326 is_auto_increment: false,
1327 is_unsigned: false,
1328 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1329 comment: None,
1330 },
1331 ],
1332 indexes: vec![
1333 IndexMetadata {
1334 name: "email_unique".to_string(),
1335 columns: vec!["email".to_string()],
1336 unique: true,
1337 },
1338 IndexMetadata {
1339 name: "idx_status".to_string(),
1340 columns: vec!["status".to_string()],
1341 unique: false,
1342 },
1343 ],
1344 foreign_keys: vec![],
1345 primary_key: Some(PrimaryKey {
1346 columns: vec!["id".to_string()],
1347 }),
1348 }
1349 }
1350
1351 #[test]
1352 fn test_collect_method_signatures() {
1353 let table = make_table();
1354 let sigs = collect_method_signatures(&table);
1355
1356 assert_eq!(sigs.len(), 3);
1358
1359 let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1360 assert!(id_sig.is_unique);
1361 assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1362
1363 let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1364 assert!(email_sig.is_unique);
1365 assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1366
1367 let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1368 assert!(!status_sig.is_unique);
1369 assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1370 }
1371
1372 #[test]
1373 fn test_build_select_columns() {
1374 let table = make_table();
1375 let cols = build_select_columns(&table);
1376 assert!(cols.contains("`id`"));
1377 assert!(cols.contains("`email`"));
1378 assert!(cols.contains("`status`"));
1379 }
1380
1381 #[test]
1382 fn test_build_where_clause() {
1383 let clause = build_where_clause(&["id".to_string()]);
1384 assert_eq!(clause, "`id` = ?");
1385
1386 let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1387 assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1388 }
1389
1390 #[test]
1391 fn test_generate_upsert_method() {
1392 let table = make_table();
1393 let code = generate_upsert_method(&table, "Users");
1394
1395 assert!(code.contains("pub async fn upsert"));
1397 assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1399 assert!(!code.contains("`id` = VALUES(`id`)"));
1401 assert!(code.contains("`email` = VALUES(`email`)"));
1403 assert!(code.contains("`status` = VALUES(`status`)"));
1404 }
1405
1406 #[test]
1407 fn test_generate_upsert_method_no_pk() {
1408 let mut table = make_table();
1409 table.primary_key = None;
1410 table.indexes.clear();
1411
1412 let code = generate_upsert_method(&table, "Users");
1413 assert!(code.is_empty());
1415 }
1416
1417 #[test]
1418 fn test_generate_insert_all_method() {
1419 let table = make_table();
1420 let code = generate_insert_all_method(&table, "Users");
1421
1422 assert!(code.contains("pub async fn insert_all"));
1424 assert!(code.contains("rdbi::BatchInsert::new"));
1426 }
1427
1428 #[test]
1429 fn test_generate_pagination_methods() {
1430 let table = make_table();
1431 let select_columns = build_select_columns(&table);
1432 let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1433
1434 assert!(code.contains("pub async fn find_all_paginated"));
1436 assert!(code.contains("limit: i32"));
1438 assert!(code.contains("offset: i32"));
1439 assert!(code.contains("UsersSortBy"));
1441 assert!(code.contains("SortDirection"));
1443 assert!(code.contains("pub async fn get_paginated_result"));
1445 assert!(code.contains("PaginatedResult<Users>"));
1447 }
1448}