1use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13 escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14 generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24#[derive(Debug, Clone)]
26struct MethodSignature {
27 columns: Vec<String>,
28 method_name: String,
29 priority: u8,
30 is_unique: bool,
31 source: String,
32}
33
34impl MethodSignature {
35 fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36 let method_name = generate_find_by_method_name(&columns);
37 Self {
38 columns,
39 method_name,
40 priority,
41 is_unique,
42 source: source.to_string(),
43 }
44 }
45}
46
47pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49 let output_dir = &config.output_dao_dir;
50 fs::create_dir_all(output_dir)?;
51
52 let mut mod_content = String::new();
54 mod_content.push_str("// Generated DAO functions\n\n");
55
56 for table in tables {
57 let file_name = heck::AsSnakeCase(&table.name).to_string();
58 mod_content.push_str(&format!("pub mod {};\n", file_name));
59 }
60
61 fs::write(output_dir.join("mod.rs"), mod_content)?;
62
63 for table in tables {
65 generate_dao_file(table, output_dir, &config.models_module)?;
66 }
67
68 Ok(())
69}
70
71fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
73 let struct_name = to_struct_name(&table.name);
74 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
75 debug!("Generating DAO for {} -> {}", struct_name, file_name);
76
77 let mut code = String::new();
78
79 let column_map: HashMap<&str, &ColumnMetadata> =
81 table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
82
83 let mut needs_chrono = false;
85 let mut needs_decimal = false;
86
87 for col in &table.columns {
88 let rust_type = TypeResolver::resolve(col, &table.name);
89 if rust_type.needs_chrono() {
90 needs_chrono = true;
91 }
92 if rust_type.needs_decimal() {
93 needs_decimal = true;
94 }
95 }
96
97 let has_enums = table.columns.iter().any(|c| c.is_enum());
99
100 code.push_str("use rdbi::{Pool, Query, Result};\n");
102 code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
103
104 if has_enums {
105 for col in &table.columns {
107 if col.is_enum() {
108 let enum_name = super::naming::to_enum_name(&table.name, &col.name);
109 code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
110 }
111 }
112 }
113
114 if needs_chrono {
115 code.push_str("#[allow(unused_imports)]\n");
116 code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
117 }
118 if needs_decimal {
119 code.push_str("#[allow(unused_imports)]\n");
120 code.push_str("use rust_decimal::Decimal;\n");
121 }
122
123 code.push('\n');
124
125 let select_columns = build_select_columns(table);
127
128 code.push_str(&generate_find_all(table, &struct_name, &select_columns));
130
131 code.push_str(&generate_count_all(table));
133
134 if let Some(pk) = &table.primary_key {
136 code.push_str(&generate_pk_methods(
137 table,
138 pk,
139 &column_map,
140 &struct_name,
141 &select_columns,
142 ));
143 }
144
145 code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
147
148 code.push_str(&generate_insert_plain_method(table, &column_map));
150
151 code.push_str(&generate_insert_all_method(table, &struct_name));
153
154 code.push_str(&generate_upsert_method(table, &struct_name));
156
157 if table.primary_key.is_some() {
159 code.push_str(&generate_update_methods(table, &struct_name, &column_map));
160 code.push_str(&generate_update_plain_method(table, &column_map));
162 }
163
164 let signatures = collect_method_signatures(table);
166 for sig in signatures.values() {
167 if sig.source == "PRIMARY_KEY" {
169 continue;
170 }
171 code.push_str(&generate_find_by_method(
172 table,
173 sig,
174 &column_map,
175 &struct_name,
176 &select_columns,
177 ));
178 }
179
180 code.push_str(&generate_find_by_list_methods(
182 table,
183 &column_map,
184 &struct_name,
185 &select_columns,
186 ));
187
188 code.push_str(&generate_composite_enum_list_methods(
190 table,
191 &column_map,
192 &struct_name,
193 &select_columns,
194 ));
195
196 code.push_str(&generate_pagination_methods(
198 table,
199 &struct_name,
200 &select_columns,
201 models_module,
202 ));
203
204 fs::write(output_dir.join(&file_name), code)?;
205 Ok(())
206}
207
208fn build_select_columns(table: &TableMetadata) -> String {
210 table
211 .columns
212 .iter()
213 .map(|c| format!("`{}`", c.name))
214 .collect::<Vec<_>>()
215 .join(", ")
216}
217
218fn build_where_clause(columns: &[String]) -> String {
220 columns
221 .iter()
222 .map(|c| format!("`{}` = ?", c))
223 .collect::<Vec<_>>()
224 .join(" AND ")
225}
226
227fn build_params(
229 columns: &[String],
230 column_map: &HashMap<&str, &ColumnMetadata>,
231 table_name: &str,
232) -> String {
233 columns
234 .iter()
235 .map(|c| {
236 let col = column_map.get(c.as_str()).unwrap();
237 let rust_type = TypeResolver::resolve(col, table_name);
238 let param_type = rust_type.to_param_type_string();
239 format!("{}: {}", escape_field_name(c), param_type)
240 })
241 .collect::<Vec<_>>()
242 .join(", ")
243}
244
245fn generate_bind_section(columns: &[String]) -> String {
247 columns
248 .iter()
249 .map(|c| format!(" .bind({})", escape_field_name(c)))
250 .collect::<Vec<_>>()
251 .join("\n")
252}
253
254fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
256 format!(
257 r#"/// Find all records
258pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
259 Query::new("SELECT {select_columns} FROM `{table_name}`")
260 .fetch_all(pool)
261 .await
262}}
263
264"#,
265 struct_name = struct_name,
266 select_columns = select_columns,
267 table_name = table.name,
268 )
269}
270
271fn generate_count_all(table: &TableMetadata) -> String {
273 format!(
274 r#"/// Count all records
275pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
276 Query::new("SELECT COUNT(*) FROM `{table_name}`")
277 .fetch_scalar(pool)
278 .await
279}}
280
281"#,
282 table_name = table.name,
283 )
284}
285
286fn generate_pk_methods(
288 table: &TableMetadata,
289 pk: &crate::parser::PrimaryKey,
290 column_map: &HashMap<&str, &ColumnMetadata>,
291 struct_name: &str,
292 select_columns: &str,
293) -> String {
294 let mut code = String::new();
295
296 let method_name = generate_find_by_method_name(&pk.columns);
297 let params = build_params(&pk.columns, column_map, &table.name);
298 let where_clause = build_where_clause(&pk.columns);
299 let bind_section = generate_bind_section(&pk.columns);
300
301 code.push_str(&format!(
303 r#"/// Find by primary key
304pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
305 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
306{bind_section}
307 .fetch_optional(pool)
308 .await
309}}
310
311"#,
312 method_name = method_name,
313 params = params,
314 struct_name = struct_name,
315 select_columns = select_columns,
316 table_name = table.name,
317 where_clause = where_clause,
318 bind_section = bind_section,
319 ));
320
321 let delete_method = generate_delete_by_method_name(&pk.columns);
323 code.push_str(&format!(
324 r#"/// Delete by primary key
325pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
326 Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
327{bind_section}
328 .execute(pool)
329 .await
330 .map(|r| r.rows_affected)
331}}
332
333"#,
334 delete_method = delete_method,
335 params = params,
336 table_name = table.name,
337 where_clause = where_clause,
338 bind_section = bind_section,
339 ));
340
341 code
342}
343
344fn generate_insert_methods(
346 table: &TableMetadata,
347 struct_name: &str,
348 _column_map: &HashMap<&str, &ColumnMetadata>,
349) -> String {
350 let mut code = String::new();
351
352 let insert_columns: Vec<&ColumnMetadata> = table
354 .columns
355 .iter()
356 .filter(|c| !c.is_auto_increment)
357 .collect();
358
359 if insert_columns.is_empty() {
360 return code;
361 }
362
363 let column_list = insert_columns
364 .iter()
365 .map(|c| format!("`{}`", c.name))
366 .collect::<Vec<_>>()
367 .join(", ");
368
369 let placeholders = insert_columns
370 .iter()
371 .map(|_| "?")
372 .collect::<Vec<_>>()
373 .join(", ");
374
375 let bind_fields = insert_columns
376 .iter()
377 .map(|c| {
378 let field = escape_field_name(&c.name);
379 format!(" .bind(&entity.{})", field)
380 })
381 .collect::<Vec<_>>()
382 .join("\n");
383
384 code.push_str(&format!(
386 r#"/// Insert a new record
387pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
388 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
389{bind_fields}
390 .execute(pool)
391 .await
392 .map(|r| r.last_insert_id.unwrap_or(0))
393}}
394
395"#,
396 struct_name = struct_name,
397 table_name = table.name,
398 column_list = column_list,
399 placeholders = placeholders,
400 bind_fields = bind_fields,
401 ));
402
403 code
404}
405
406fn generate_insert_plain_method(
408 table: &TableMetadata,
409 column_map: &HashMap<&str, &ColumnMetadata>,
410) -> String {
411 let insert_columns: Vec<&ColumnMetadata> = table
413 .columns
414 .iter()
415 .filter(|c| !c.is_auto_increment)
416 .collect();
417
418 if insert_columns.is_empty() {
419 return String::new();
420 }
421
422 let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
423 let params = build_params(&column_names, column_map, &table.name);
424
425 let column_list = insert_columns
426 .iter()
427 .map(|c| format!("`{}`", c.name))
428 .collect::<Vec<_>>()
429 .join(", ");
430
431 let placeholders = insert_columns
432 .iter()
433 .map(|_| "?")
434 .collect::<Vec<_>>()
435 .join(", ");
436
437 let bind_section = generate_bind_section(&column_names);
438
439 format!(
440 r#"/// Insert a new record with individual parameters
441pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
442 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
443{bind_section}
444 .execute(pool)
445 .await
446 .map(|r| r.last_insert_id.unwrap_or(0))
447}}
448
449"#,
450 params = params,
451 table_name = table.name,
452 column_list = column_list,
453 placeholders = placeholders,
454 bind_section = bind_section,
455 )
456}
457
458fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
460 let insert_columns: Vec<&ColumnMetadata> = table
462 .columns
463 .iter()
464 .filter(|c| !c.is_auto_increment)
465 .collect();
466
467 if insert_columns.is_empty() {
468 return String::new();
469 }
470
471 format!(
472 r#"/// Insert multiple records in a single batch
473pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
474 rdbi::BatchInsert::new("{table_name}", entities)
475 .execute(pool)
476 .await
477 .map(|r| r.rows_affected)
478}}
479
480"#,
481 struct_name = struct_name,
482 table_name = table.name,
483 )
484}
485
486fn has_unique_index(table: &TableMetadata) -> bool {
488 table.indexes.iter().any(|idx| idx.unique)
489}
490
491fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
493 if table.primary_key.is_none() && !has_unique_index(table) {
495 return String::new();
496 }
497
498 let insert_columns: Vec<&ColumnMetadata> = table
500 .columns
501 .iter()
502 .filter(|c| !c.is_auto_increment)
503 .collect();
504
505 if insert_columns.is_empty() {
506 return String::new();
507 }
508
509 let pk_columns: HashSet<&str> = table
511 .primary_key
512 .as_ref()
513 .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
514 .unwrap_or_default();
515
516 let update_columns: Vec<&ColumnMetadata> = insert_columns
518 .iter()
519 .filter(|c| !pk_columns.contains(c.name.as_str()))
520 .copied()
521 .collect();
522
523 if update_columns.is_empty() {
525 return String::new();
526 }
527
528 let column_list = insert_columns
529 .iter()
530 .map(|c| format!("`{}`", c.name))
531 .collect::<Vec<_>>()
532 .join(", ");
533
534 let placeholders = insert_columns
535 .iter()
536 .map(|_| "?")
537 .collect::<Vec<_>>()
538 .join(", ");
539
540 let update_clause = update_columns
541 .iter()
542 .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
543 .collect::<Vec<_>>()
544 .join(", ");
545
546 let bind_fields = insert_columns
547 .iter()
548 .map(|c| {
549 let field = escape_field_name(&c.name);
550 format!(" .bind(&entity.{})", field)
551 })
552 .collect::<Vec<_>>()
553 .join("\n");
554
555 format!(
556 r#"/// Upsert a record (insert or update on duplicate key)
557/// Returns rows_affected: 1 if inserted, 2 if updated
558pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
559 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
560 ON DUPLICATE KEY UPDATE {update_clause}")
561{bind_fields}
562 .execute(pool)
563 .await
564 .map(|r| r.rows_affected)
565}}
566
567"#,
568 struct_name = struct_name,
569 table_name = table.name,
570 column_list = column_list,
571 placeholders = placeholders,
572 update_clause = update_clause,
573 bind_fields = bind_fields,
574 )
575}
576
577fn generate_update_methods(
579 table: &TableMetadata,
580 struct_name: &str,
581 _column_map: &HashMap<&str, &ColumnMetadata>,
582) -> String {
583 let mut code = String::new();
584
585 let pk = table.primary_key.as_ref().unwrap();
586
587 let update_columns: Vec<&ColumnMetadata> = table
589 .columns
590 .iter()
591 .filter(|c| !pk.columns.contains(&c.name))
592 .collect();
593
594 if update_columns.is_empty() {
595 return code;
596 }
597
598 let set_clause = update_columns
599 .iter()
600 .map(|c| format!("`{}` = ?", c.name))
601 .collect::<Vec<_>>()
602 .join(", ");
603
604 let where_clause = build_where_clause(&pk.columns);
605
606 let bind_fields: Vec<String> = update_columns
608 .iter()
609 .map(|c| {
610 let field = escape_field_name(&c.name);
611 format!(" .bind(&entity.{})", field)
612 })
613 .chain(pk.columns.iter().map(|c| {
614 let field = escape_field_name(c);
615 format!(" .bind(&entity.{})", field)
616 }))
617 .collect();
618
619 code.push_str(&format!(
621 r#"/// Update a record by primary key
622pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
623 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
624{bind_fields}
625 .execute(pool)
626 .await
627 .map(|r| r.rows_affected)
628}}
629
630"#,
631 struct_name = struct_name,
632 table_name = table.name,
633 set_clause = set_clause,
634 where_clause = where_clause,
635 bind_fields = bind_fields.join("\n"),
636 ));
637
638 code
639}
640
641fn generate_update_plain_method(
643 table: &TableMetadata,
644 column_map: &HashMap<&str, &ColumnMetadata>,
645) -> String {
646 let pk = table.primary_key.as_ref().unwrap();
647
648 let update_columns: Vec<&ColumnMetadata> = table
650 .columns
651 .iter()
652 .filter(|c| !pk.columns.contains(&c.name))
653 .collect();
654
655 if update_columns.is_empty() {
656 return String::new();
657 }
658
659 let method_name = generate_update_by_method_name(&pk.columns);
661
662 let pk_params = build_params(&pk.columns, column_map, &table.name);
664 let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
665 let update_params = build_params(&update_column_names, column_map, &table.name);
666 let all_params = format!("{}, {}", pk_params, update_params);
667
668 let set_clause = update_columns
669 .iter()
670 .map(|c| format!("`{}` = ?", c.name))
671 .collect::<Vec<_>>()
672 .join(", ");
673
674 let where_clause = build_where_clause(&pk.columns);
675
676 let bind_section_update = generate_bind_section(&update_column_names);
678 let bind_section_pk = generate_bind_section(&pk.columns);
679
680 format!(
681 r#"/// Update a record by primary key with individual parameters
682pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
683 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
684{bind_section_update}
685{bind_section_pk}
686 .execute(pool)
687 .await
688 .map(|r| r.rows_affected)
689}}
690
691"#,
692 method_name = method_name,
693 all_params = all_params,
694 table_name = table.name,
695 set_clause = set_clause,
696 where_clause = where_clause,
697 bind_section_update = bind_section_update,
698 bind_section_pk = bind_section_pk,
699 )
700}
701
702fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
704 let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
705
706 if let Some(pk) = &table.primary_key {
708 let sig = MethodSignature::new(
709 pk.columns.clone(),
710 PRIORITY_PRIMARY_KEY,
711 true,
712 "PRIMARY_KEY",
713 );
714 signatures.insert(pk.columns.clone(), sig);
715 }
716
717 for index in &table.indexes {
719 let priority = if index.unique {
720 PRIORITY_UNIQUE_INDEX
721 } else {
722 PRIORITY_NON_UNIQUE_INDEX
723 };
724 let source = if index.unique {
725 "UNIQUE_INDEX"
726 } else {
727 "NON_UNIQUE_INDEX"
728 };
729 let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
730
731 if let Some(existing) = signatures.get(&index.columns) {
733 if sig.priority < existing.priority {
734 signatures.insert(index.columns.clone(), sig);
735 }
736 } else {
737 signatures.insert(index.columns.clone(), sig);
738 }
739 }
740
741 for fk in &table.foreign_keys {
743 let columns = vec![fk.column_name.clone()];
744 let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
745
746 signatures.entry(columns).or_insert(sig);
747 }
748
749 signatures
750}
751
752fn generate_find_by_method(
754 table: &TableMetadata,
755 sig: &MethodSignature,
756 column_map: &HashMap<&str, &ColumnMetadata>,
757 struct_name: &str,
758 select_columns: &str,
759) -> String {
760 let params = build_params(&sig.columns, column_map, &table.name);
761
762 let (return_type, fetch_method) = if sig.is_unique {
763 (format!("Option<{}>", struct_name), "fetch_optional")
764 } else {
765 (format!("Vec<{}>", struct_name), "fetch_all")
766 };
767
768 let return_desc = if sig.is_unique {
769 "Option (unique)"
770 } else {
771 "Vec (non-unique)"
772 };
773
774 let has_nullable = sig.columns.iter().any(|c| {
776 column_map
777 .get(c.as_str())
778 .map(|col| col.nullable)
779 .unwrap_or(false)
780 });
781
782 if has_nullable {
783 generate_find_by_method_nullable(
785 table,
786 sig,
787 column_map,
788 select_columns,
789 ¶ms,
790 &return_type,
791 fetch_method,
792 return_desc,
793 )
794 } else {
795 let where_clause = build_where_clause(&sig.columns);
797 let bind_section = generate_bind_section(&sig.columns);
798
799 format!(
800 r#"/// Find by {source}: returns {return_desc}
801pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
802 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
803{bind_section}
804 .{fetch_method}(pool)
805 .await
806}}
807
808"#,
809 source = sig.source.to_lowercase().replace('_', " "),
810 return_desc = return_desc,
811 method_name = sig.method_name,
812 params = params,
813 return_type = return_type,
814 select_columns = select_columns,
815 table_name = table.name,
816 where_clause = where_clause,
817 bind_section = bind_section,
818 fetch_method = fetch_method,
819 )
820 }
821}
822
823#[allow(clippy::too_many_arguments)]
825fn generate_find_by_method_nullable(
826 table: &TableMetadata,
827 sig: &MethodSignature,
828 column_map: &HashMap<&str, &ColumnMetadata>,
829 select_columns: &str,
830 params: &str,
831 return_type: &str,
832 fetch_method: &str,
833 return_desc: &str,
834) -> String {
835 let mut where_parts = Vec::new();
837 let mut bind_parts = Vec::new();
838
839 for col in &sig.columns {
840 let col_meta = column_map.get(col.as_str()).unwrap();
841 let field_name = escape_field_name(col);
842
843 if col_meta.nullable {
844 where_parts.push(format!(
846 r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
847 field = field_name,
848 col = col,
849 ));
850 bind_parts.push(format!(
851 r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
852 field = field_name,
853 ));
854 } else {
855 where_parts.push(format!(r#""`{}` = ?""#, col));
856 bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
857 }
858 }
859
860 let where_expr = if where_parts.len() == 1 {
861 where_parts[0].clone()
862 } else {
863 let parts = where_parts
865 .iter()
866 .map(|p| format!("({})", p))
867 .collect::<Vec<_>>()
868 .join(", ");
869 format!("vec![{}].join(\" AND \")", parts)
870 };
871
872 let bind_code = bind_parts.join("\n ");
873
874 format!(
875 r#"/// Find by {source}: returns {return_desc}
876pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
877 let where_clause = {where_expr};
878 let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
879 let mut query = rdbi::DynamicQuery::new(sql);
880 {bind_code}
881 query.{fetch_method}(pool).await
882}}
883
884"#,
885 source = sig.source.to_lowercase().replace('_', " "),
886 return_desc = return_desc,
887 method_name = sig.method_name,
888 params = params,
889 return_type = return_type,
890 select_columns = select_columns,
891 table_name = table.name,
892 where_expr = where_expr,
893 bind_code = bind_code,
894 fetch_method = fetch_method,
895 )
896}
897
898fn generate_find_by_list_methods(
900 table: &TableMetadata,
901 column_map: &HashMap<&str, &ColumnMetadata>,
902 struct_name: &str,
903 select_columns: &str,
904) -> String {
905 let mut code = String::new();
906 let mut processed: HashSet<String> = HashSet::new();
907
908 if let Some(pk) = &table.primary_key {
910 if pk.columns.len() == 1 {
911 let col = &pk.columns[0];
912 code.push_str(&generate_single_find_by_list(
913 table,
914 col,
915 column_map,
916 struct_name,
917 select_columns,
918 ));
919 processed.insert(col.clone());
920 }
921 }
922
923 for index in &table.indexes {
925 if index.columns.len() == 1 {
926 let col = &index.columns[0];
927 if !processed.contains(col) {
928 code.push_str(&generate_single_find_by_list(
929 table,
930 col,
931 column_map,
932 struct_name,
933 select_columns,
934 ));
935 processed.insert(col.clone());
936 }
937 }
938 }
939
940 code
941}
942
943fn generate_single_find_by_list(
945 table: &TableMetadata,
946 column_name: &str,
947 column_map: &HashMap<&str, &ColumnMetadata>,
948 struct_name: &str,
949 select_columns: &str,
950) -> String {
951 let method_name = generate_find_by_list_method_name(column_name);
952 let param_name = pluralize(&escape_field_name(column_name));
953 let column = column_map.get(column_name).unwrap();
954 let rust_type = TypeResolver::resolve(column, &table.name);
955
956 let inner_type = rust_type.inner_type().to_type_string();
958
959 let column_name_plural = pluralize(column_name);
960 format!(
961 r#"/// Find by list of {column_name_plural} (IN clause)
962pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
963 if {param_name}.is_empty() {{
964 return Ok(Vec::new());
965 }}
966 let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
967 let query = format!(
968 "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
969 placeholders
970 );
971 rdbi::DynamicQuery::new(query)
972 .bind_all({param_name})
973 .fetch_all(pool)
974 .await
975}}
976
977"#,
978 column_name_plural = column_name_plural,
979 column_name = column_name,
980 method_name = method_name,
981 param_name = param_name,
982 inner_type = inner_type,
983 struct_name = struct_name,
984 select_columns = select_columns,
985 table_name = table.name,
986 )
987}
988
989fn generate_composite_enum_list_methods(
992 table: &TableMetadata,
993 column_map: &HashMap<&str, &ColumnMetadata>,
994 struct_name: &str,
995 select_columns: &str,
996) -> String {
997 let mut code = String::new();
998
999 for index in &table.indexes {
1000 if index.columns.len() <= 1 {
1002 continue;
1003 }
1004
1005 let enum_columns: HashSet<&str> = index
1007 .columns
1008 .iter()
1009 .filter(|col_name| {
1010 column_map
1011 .get(col_name.as_str())
1012 .map(|col| col.is_enum())
1013 .unwrap_or(false)
1014 })
1015 .map(|s| s.as_str())
1016 .collect();
1017
1018 if enum_columns.is_empty() {
1020 continue;
1021 }
1022
1023 let first_column = &index.columns[0];
1025 if enum_columns.contains(first_column.as_str()) {
1026 continue;
1027 }
1028
1029 code.push_str(&generate_composite_enum_list_method(
1030 table,
1031 &index.columns,
1032 &enum_columns,
1033 column_map,
1034 struct_name,
1035 select_columns,
1036 ));
1037 }
1038
1039 code
1040}
1041
1042fn generate_composite_enum_list_method(
1044 table: &TableMetadata,
1045 columns: &[String],
1046 enum_columns: &HashSet<&str>,
1047 column_map: &HashMap<&str, &ColumnMetadata>,
1048 struct_name: &str,
1049 select_columns: &str,
1050) -> String {
1051 let method_name = generate_composite_enum_method_name(columns, enum_columns);
1053
1054 let mut params_parts = Vec::new();
1056
1057 for col_name in columns {
1058 let col = column_map.get(col_name.as_str()).unwrap();
1059 let rust_type = TypeResolver::resolve(col, &table.name);
1060 let is_enum = enum_columns.contains(col_name.as_str());
1061
1062 if is_enum {
1063 let param_name = pluralize(&escape_field_name(col_name));
1065 let inner_type = rust_type.inner_type().to_type_string();
1066 params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1067 } else {
1068 let param_name = escape_field_name(col_name);
1070 let param_type = rust_type.to_param_type_string();
1071 params_parts.push(format!("{}: {}", param_name, param_type));
1072 }
1073 }
1074
1075 let params = params_parts.join(", ");
1076
1077 let where_clause_static: Vec<String> = columns
1079 .iter()
1080 .map(|col_name| {
1081 if enum_columns.contains(col_name.as_str()) {
1082 format!("`{}` IN ({{}})", col_name) } else {
1084 format!("`{}` = ?", col_name)
1085 }
1086 })
1087 .collect();
1088
1089 let mut bind_code = String::new();
1091
1092 for col_name in columns {
1094 if !enum_columns.contains(col_name.as_str()) {
1095 let param_name = escape_field_name(col_name);
1096 bind_code.push_str(&format!(" .bind({})\n", param_name));
1097 }
1098 }
1099
1100 for col_name in columns {
1102 if enum_columns.contains(col_name.as_str()) {
1103 let param_name = pluralize(&escape_field_name(col_name));
1104 bind_code.push_str(&format!(" .bind_all({})\n", param_name));
1105 }
1106 }
1107
1108 let column_desc: Vec<String> = columns
1110 .iter()
1111 .map(|col| {
1112 if enum_columns.contains(col.as_str()) {
1113 pluralize(col)
1114 } else {
1115 col.clone()
1116 }
1117 })
1118 .collect();
1119
1120 let enum_col_names: Vec<&str> = columns
1122 .iter()
1123 .filter(|c| enum_columns.contains(c.as_str()))
1124 .map(|s| s.as_str())
1125 .collect();
1126
1127 let in_clause_builders: Vec<String> = enum_col_names
1129 .iter()
1130 .map(|col| {
1131 let param_name = pluralize(&escape_field_name(col));
1132 format!(
1133 "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1134 param_name = param_name
1135 )
1136 })
1137 .collect();
1138
1139 let format_args = in_clause_builders.join(", ");
1141
1142 format!(
1143 r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1144pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1145 // Check for empty enum lists
1146{empty_checks}
1147 // Build IN clause placeholders for enum columns
1148 let where_clause = format!("{where_template}", {format_args});
1149 let query = format!(
1150 "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1151 where_clause
1152 );
1153 rdbi::DynamicQuery::new(query)
1154{bind_code} .fetch_all(pool)
1155 .await
1156}}
1157
1158"#,
1159 column_desc = column_desc.join(" and "),
1160 method_name = method_name,
1161 params = params,
1162 struct_name = struct_name,
1163 select_columns = select_columns,
1164 table_name = table.name,
1165 where_template = where_clause_static.join(" AND "),
1166 format_args = format_args,
1167 bind_code = bind_code,
1168 empty_checks = generate_empty_checks(columns, enum_columns),
1169 )
1170}
1171
1172fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1174 let mut checks = String::new();
1175 for col_name in columns {
1176 if enum_columns.contains(col_name.as_str()) {
1177 let param_name = pluralize(&escape_field_name(col_name));
1178 checks.push_str(&format!(
1179 " if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1180 param_name
1181 ));
1182 }
1183 }
1184 checks
1185}
1186
1187fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1190 let mut parts = Vec::new();
1191 for col in columns {
1192 if enum_columns.contains(col.as_str()) {
1193 parts.push(pluralize(col));
1194 } else {
1195 parts.push(col.clone());
1196 }
1197 }
1198 generate_find_by_method_name(&parts)
1199}
1200
1201fn generate_pagination_methods(
1203 table: &TableMetadata,
1204 struct_name: &str,
1205 select_columns: &str,
1206 models_module: &str,
1207) -> String {
1208 let sort_by_enum = format!("{}SortBy", struct_name);
1209
1210 format!(
1211 r#"/// Find all records with pagination and sorting
1212pub async fn find_all_paginated<P: Pool>(
1213 pool: &P,
1214 limit: i32,
1215 offset: i32,
1216 sort_by: crate::{models_module}::{sort_by_enum},
1217 sort_dir: crate::{models_module}::SortDirection,
1218) -> Result<Vec<{struct_name}>> {{
1219 let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1220 let query = format!(
1221 "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1222 order_clause
1223 );
1224 rdbi::DynamicQuery::new(query)
1225 .bind(limit)
1226 .bind(offset)
1227 .fetch_all(pool)
1228 .await
1229}}
1230
1231/// Get paginated result with total count
1232pub async fn get_paginated_result<P: Pool>(
1233 pool: &P,
1234 page_size: i32,
1235 current_page: i32,
1236 sort_by: crate::{models_module}::{sort_by_enum},
1237 sort_dir: crate::{models_module}::SortDirection,
1238) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1239 let page_size = page_size.max(1);
1240 let current_page = current_page.max(1);
1241 let offset = (current_page - 1) * page_size;
1242
1243 let total_count = count_all(pool).await?;
1244 let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1245
1246 Ok(crate::{models_module}::PaginatedResult::new(
1247 items,
1248 total_count,
1249 current_page,
1250 page_size,
1251 ))
1252}}
1253
1254"#,
1255 struct_name = struct_name,
1256 select_columns = select_columns,
1257 table_name = table.name,
1258 models_module = models_module,
1259 sort_by_enum = sort_by_enum,
1260 )
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265 use super::*;
1266 use crate::parser::{IndexMetadata, PrimaryKey};
1267
1268 fn make_table() -> TableMetadata {
1269 TableMetadata {
1270 name: "users".to_string(),
1271 comment: None,
1272 columns: vec![
1273 ColumnMetadata {
1274 name: "id".to_string(),
1275 data_type: "BIGINT".to_string(),
1276 nullable: false,
1277 default_value: None,
1278 is_auto_increment: true,
1279 is_unsigned: false,
1280 enum_values: None,
1281 comment: None,
1282 },
1283 ColumnMetadata {
1284 name: "email".to_string(),
1285 data_type: "VARCHAR(255)".to_string(),
1286 nullable: false,
1287 default_value: None,
1288 is_auto_increment: false,
1289 is_unsigned: false,
1290 enum_values: None,
1291 comment: None,
1292 },
1293 ColumnMetadata {
1294 name: "status".to_string(),
1295 data_type: "ENUM".to_string(),
1296 nullable: false,
1297 default_value: None,
1298 is_auto_increment: false,
1299 is_unsigned: false,
1300 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1301 comment: None,
1302 },
1303 ],
1304 indexes: vec![
1305 IndexMetadata {
1306 name: "email_unique".to_string(),
1307 columns: vec!["email".to_string()],
1308 unique: true,
1309 },
1310 IndexMetadata {
1311 name: "idx_status".to_string(),
1312 columns: vec!["status".to_string()],
1313 unique: false,
1314 },
1315 ],
1316 foreign_keys: vec![],
1317 primary_key: Some(PrimaryKey {
1318 columns: vec!["id".to_string()],
1319 }),
1320 }
1321 }
1322
1323 #[test]
1324 fn test_collect_method_signatures() {
1325 let table = make_table();
1326 let sigs = collect_method_signatures(&table);
1327
1328 assert_eq!(sigs.len(), 3);
1330
1331 let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1332 assert!(id_sig.is_unique);
1333 assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1334
1335 let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1336 assert!(email_sig.is_unique);
1337 assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1338
1339 let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1340 assert!(!status_sig.is_unique);
1341 assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1342 }
1343
1344 #[test]
1345 fn test_build_select_columns() {
1346 let table = make_table();
1347 let cols = build_select_columns(&table);
1348 assert!(cols.contains("`id`"));
1349 assert!(cols.contains("`email`"));
1350 assert!(cols.contains("`status`"));
1351 }
1352
1353 #[test]
1354 fn test_build_where_clause() {
1355 let clause = build_where_clause(&["id".to_string()]);
1356 assert_eq!(clause, "`id` = ?");
1357
1358 let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1359 assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1360 }
1361
1362 #[test]
1363 fn test_generate_upsert_method() {
1364 let table = make_table();
1365 let code = generate_upsert_method(&table, "Users");
1366
1367 assert!(code.contains("pub async fn upsert"));
1369 assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1371 assert!(!code.contains("`id` = VALUES(`id`)"));
1373 assert!(code.contains("`email` = VALUES(`email`)"));
1375 assert!(code.contains("`status` = VALUES(`status`)"));
1376 }
1377
1378 #[test]
1379 fn test_generate_upsert_method_no_pk() {
1380 let mut table = make_table();
1381 table.primary_key = None;
1382 table.indexes.clear();
1383
1384 let code = generate_upsert_method(&table, "Users");
1385 assert!(code.is_empty());
1387 }
1388
1389 #[test]
1390 fn test_generate_insert_all_method() {
1391 let table = make_table();
1392 let code = generate_insert_all_method(&table, "Users");
1393
1394 assert!(code.contains("pub async fn insert_all"));
1396 assert!(code.contains("rdbi::BatchInsert::new"));
1398 }
1399
1400 #[test]
1401 fn test_generate_pagination_methods() {
1402 let table = make_table();
1403 let select_columns = build_select_columns(&table);
1404 let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1405
1406 assert!(code.contains("pub async fn find_all_paginated"));
1408 assert!(code.contains("limit: i32"));
1410 assert!(code.contains("offset: i32"));
1411 assert!(code.contains("UsersSortBy"));
1413 assert!(code.contains("SortDirection"));
1415 assert!(code.contains("pub async fn get_paginated_result"));
1417 assert!(code.contains("PaginatedResult<Users>"));
1419 }
1420}