1use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13 escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14 generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24#[derive(Debug, Clone)]
26struct MethodSignature {
27 columns: Vec<String>,
28 method_name: String,
29 priority: u8,
30 is_unique: bool,
31 source: String,
32}
33
34impl MethodSignature {
35 fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36 let method_name = generate_find_by_method_name(&columns);
37 Self {
38 columns,
39 method_name,
40 priority,
41 is_unique,
42 source: source.to_string(),
43 }
44 }
45}
46
47pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49 let output_dir = &config.output_dao_dir;
50 fs::create_dir_all(output_dir)?;
51
52 let mut mod_content = String::new();
54 mod_content.push_str("// Generated DAO functions\n\n");
55
56 for table in tables {
57 let file_name = heck::AsSnakeCase(&table.name).to_string();
58 mod_content.push_str(&format!("pub mod {};\n", file_name));
59 }
60
61 fs::write(output_dir.join("mod.rs"), mod_content)?;
62
63 for table in tables {
65 generate_dao_file(table, output_dir, &config.models_module)?;
66 }
67
68 Ok(())
69}
70
71fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
73 let struct_name = to_struct_name(&table.name);
74 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
75 debug!("Generating DAO for {} -> {}", struct_name, file_name);
76
77 let mut code = String::new();
78
79 let column_map: HashMap<&str, &ColumnMetadata> =
81 table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
82
83 let mut needs_chrono = false;
85 let mut needs_decimal = false;
86
87 for col in &table.columns {
88 let rust_type = TypeResolver::resolve(col, &table.name);
89 if rust_type.needs_chrono() {
90 needs_chrono = true;
91 }
92 if rust_type.needs_decimal() {
93 needs_decimal = true;
94 }
95 }
96
97 let has_enums = table.columns.iter().any(|c| c.is_enum());
99
100 code.push_str("use rdbi::{Pool, Query, Result};\n");
102 code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
103
104 if has_enums {
105 for col in &table.columns {
107 if col.is_enum() {
108 let enum_name = super::naming::to_enum_name(&table.name, &col.name);
109 code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
110 }
111 }
112 }
113
114 if needs_chrono {
115 code.push_str("#[allow(unused_imports)]\n");
116 code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
117 }
118 if needs_decimal {
119 code.push_str("#[allow(unused_imports)]\n");
120 code.push_str("use rust_decimal::Decimal;\n");
121 }
122
123 code.push('\n');
124
125 let select_columns = build_select_columns(table);
127
128 code.push_str(&generate_find_all(table, &struct_name, &select_columns));
130
131 code.push_str(&generate_count_all(table));
133
134 if let Some(pk) = &table.primary_key {
136 code.push_str(&generate_pk_methods(
137 table,
138 pk,
139 &column_map,
140 &struct_name,
141 &select_columns,
142 ));
143 }
144
145 code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
147
148 code.push_str(&generate_insert_plain_method(table, &column_map));
150
151 code.push_str(&generate_insert_all_method(table, &struct_name));
153
154 code.push_str(&generate_upsert_method(table, &struct_name));
156
157 if table.primary_key.is_some() {
159 code.push_str(&generate_update_methods(table, &struct_name, &column_map));
160 code.push_str(&generate_update_plain_method(table, &column_map));
162 }
163
164 let signatures = collect_method_signatures(table);
166 for sig in signatures.values() {
167 if sig.source == "PRIMARY_KEY" {
169 continue;
170 }
171 code.push_str(&generate_find_by_method(
172 table,
173 sig,
174 &column_map,
175 &struct_name,
176 &select_columns,
177 ));
178 }
179
180 code.push_str(&generate_find_by_list_methods(
182 table,
183 &column_map,
184 &struct_name,
185 &select_columns,
186 ));
187
188 code.push_str(&generate_composite_enum_list_methods(
190 table,
191 &column_map,
192 &struct_name,
193 &select_columns,
194 ));
195
196 code.push_str(&generate_pagination_methods(
198 table,
199 &struct_name,
200 &select_columns,
201 models_module,
202 ));
203
204 fs::write(output_dir.join(&file_name), code)?;
205 Ok(())
206}
207
208fn build_select_columns(table: &TableMetadata) -> String {
210 table
211 .columns
212 .iter()
213 .map(|c| format!("`{}`", c.name))
214 .collect::<Vec<_>>()
215 .join(", ")
216}
217
218fn build_where_clause(columns: &[String]) -> String {
220 columns
221 .iter()
222 .map(|c| format!("`{}` = ?", c))
223 .collect::<Vec<_>>()
224 .join(" AND ")
225}
226
227fn build_params(
229 columns: &[String],
230 column_map: &HashMap<&str, &ColumnMetadata>,
231 table_name: &str,
232) -> String {
233 columns
234 .iter()
235 .map(|c| {
236 let col = column_map.get(c.as_str()).unwrap();
237 let rust_type = TypeResolver::resolve(col, table_name);
238 let param_type = rust_type.to_param_type_string();
239 format!("{}: {}", escape_field_name(c), param_type)
240 })
241 .collect::<Vec<_>>()
242 .join(", ")
243}
244
245fn generate_bind_section(columns: &[String]) -> String {
247 columns
248 .iter()
249 .map(|c| format!(" .bind({})", escape_field_name(c)))
250 .collect::<Vec<_>>()
251 .join("\n")
252}
253
254fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
256 format!(
257 r#"/// Find all records
258pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
259 Query::new("SELECT {select_columns} FROM `{table_name}`")
260 .fetch_all(pool)
261 .await
262}}
263
264"#,
265 struct_name = struct_name,
266 select_columns = select_columns,
267 table_name = table.name,
268 )
269}
270
271fn generate_count_all(table: &TableMetadata) -> String {
273 format!(
274 r#"/// Count all records
275pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
276 Query::new("SELECT COUNT(*) FROM `{table_name}`")
277 .fetch_scalar(pool)
278 .await
279}}
280
281"#,
282 table_name = table.name,
283 )
284}
285
286fn generate_pk_methods(
288 table: &TableMetadata,
289 pk: &crate::parser::PrimaryKey,
290 column_map: &HashMap<&str, &ColumnMetadata>,
291 struct_name: &str,
292 select_columns: &str,
293) -> String {
294 let mut code = String::new();
295
296 let method_name = generate_find_by_method_name(&pk.columns);
297 let params = build_params(&pk.columns, column_map, &table.name);
298 let where_clause = build_where_clause(&pk.columns);
299 let bind_section = generate_bind_section(&pk.columns);
300
301 code.push_str(&format!(
303 r#"/// Find by primary key
304pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
305 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
306{bind_section}
307 .fetch_optional(pool)
308 .await
309}}
310
311"#,
312 method_name = method_name,
313 params = params,
314 struct_name = struct_name,
315 select_columns = select_columns,
316 table_name = table.name,
317 where_clause = where_clause,
318 bind_section = bind_section,
319 ));
320
321 let delete_method = generate_delete_by_method_name(&pk.columns);
323 code.push_str(&format!(
324 r#"/// Delete by primary key
325pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
326 Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
327{bind_section}
328 .execute(pool)
329 .await
330 .map(|r| r.rows_affected)
331}}
332
333"#,
334 delete_method = delete_method,
335 params = params,
336 table_name = table.name,
337 where_clause = where_clause,
338 bind_section = bind_section,
339 ));
340
341 code
342}
343
344fn generate_insert_methods(
346 table: &TableMetadata,
347 struct_name: &str,
348 _column_map: &HashMap<&str, &ColumnMetadata>,
349) -> String {
350 let mut code = String::new();
351
352 let insert_columns: Vec<&ColumnMetadata> = table
354 .columns
355 .iter()
356 .filter(|c| !c.is_auto_increment)
357 .collect();
358
359 if insert_columns.is_empty() {
360 return code;
361 }
362
363 let column_list = insert_columns
364 .iter()
365 .map(|c| format!("`{}`", c.name))
366 .collect::<Vec<_>>()
367 .join(", ");
368
369 let placeholders = insert_columns
370 .iter()
371 .map(|_| "?")
372 .collect::<Vec<_>>()
373 .join(", ");
374
375 let bind_fields = insert_columns
376 .iter()
377 .map(|c| {
378 let field = escape_field_name(&c.name);
379 let rust_type = TypeResolver::resolve(c, &table.name);
380 if rust_type.is_copy() {
381 format!(" .bind(entity.{})", field)
382 } else {
383 format!(" .bind(&entity.{})", field)
384 }
385 })
386 .collect::<Vec<_>>()
387 .join("\n");
388
389 code.push_str(&format!(
391 r#"/// Insert a new record
392pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
393 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
394{bind_fields}
395 .execute(pool)
396 .await
397 .map(|r| r.last_insert_id.unwrap_or(0))
398}}
399
400"#,
401 struct_name = struct_name,
402 table_name = table.name,
403 column_list = column_list,
404 placeholders = placeholders,
405 bind_fields = bind_fields,
406 ));
407
408 code
409}
410
411fn generate_insert_plain_method(
413 table: &TableMetadata,
414 column_map: &HashMap<&str, &ColumnMetadata>,
415) -> String {
416 let insert_columns: Vec<&ColumnMetadata> = table
418 .columns
419 .iter()
420 .filter(|c| !c.is_auto_increment)
421 .collect();
422
423 if insert_columns.is_empty() {
424 return String::new();
425 }
426
427 let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
428 let params = build_params(&column_names, column_map, &table.name);
429
430 let column_list = insert_columns
431 .iter()
432 .map(|c| format!("`{}`", c.name))
433 .collect::<Vec<_>>()
434 .join(", ");
435
436 let placeholders = insert_columns
437 .iter()
438 .map(|_| "?")
439 .collect::<Vec<_>>()
440 .join(", ");
441
442 let bind_section = generate_bind_section(&column_names);
443
444 format!(
445 r#"/// Insert a new record with individual parameters
446#[allow(clippy::too_many_arguments)]
447pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
448 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
449{bind_section}
450 .execute(pool)
451 .await
452 .map(|r| r.last_insert_id.unwrap_or(0))
453}}
454
455"#,
456 params = params,
457 table_name = table.name,
458 column_list = column_list,
459 placeholders = placeholders,
460 bind_section = bind_section,
461 )
462}
463
464fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
466 let insert_columns: Vec<&ColumnMetadata> = table
468 .columns
469 .iter()
470 .filter(|c| !c.is_auto_increment)
471 .collect();
472
473 if insert_columns.is_empty() {
474 return String::new();
475 }
476
477 format!(
478 r#"/// Insert multiple records in a single batch
479pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
480 rdbi::BatchInsert::new("{table_name}", entities)
481 .execute(pool)
482 .await
483 .map(|r| r.rows_affected)
484}}
485
486"#,
487 struct_name = struct_name,
488 table_name = table.name,
489 )
490}
491
492fn has_unique_index(table: &TableMetadata) -> bool {
494 table.indexes.iter().any(|idx| idx.unique)
495}
496
497fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
499 if table.primary_key.is_none() && !has_unique_index(table) {
501 return String::new();
502 }
503
504 let insert_columns: Vec<&ColumnMetadata> = table
506 .columns
507 .iter()
508 .filter(|c| !c.is_auto_increment)
509 .collect();
510
511 if insert_columns.is_empty() {
512 return String::new();
513 }
514
515 let pk_columns: HashSet<&str> = table
517 .primary_key
518 .as_ref()
519 .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
520 .unwrap_or_default();
521
522 let update_columns: Vec<&ColumnMetadata> = insert_columns
524 .iter()
525 .filter(|c| !pk_columns.contains(c.name.as_str()))
526 .copied()
527 .collect();
528
529 if update_columns.is_empty() {
531 return String::new();
532 }
533
534 let column_list = insert_columns
535 .iter()
536 .map(|c| format!("`{}`", c.name))
537 .collect::<Vec<_>>()
538 .join(", ");
539
540 let placeholders = insert_columns
541 .iter()
542 .map(|_| "?")
543 .collect::<Vec<_>>()
544 .join(", ");
545
546 let update_clause = update_columns
547 .iter()
548 .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
549 .collect::<Vec<_>>()
550 .join(", ");
551
552 let bind_fields = insert_columns
553 .iter()
554 .map(|c| {
555 let field = escape_field_name(&c.name);
556 let rust_type = TypeResolver::resolve(c, &table.name);
557 if rust_type.is_copy() {
558 format!(" .bind(entity.{})", field)
559 } else {
560 format!(" .bind(&entity.{})", field)
561 }
562 })
563 .collect::<Vec<_>>()
564 .join("\n");
565
566 format!(
567 r#"/// Upsert a record (insert or update on duplicate key)
568/// Returns rows_affected: 1 if inserted, 2 if updated
569pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
570 Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
571 ON DUPLICATE KEY UPDATE {update_clause}")
572{bind_fields}
573 .execute(pool)
574 .await
575 .map(|r| r.rows_affected)
576}}
577
578"#,
579 struct_name = struct_name,
580 table_name = table.name,
581 column_list = column_list,
582 placeholders = placeholders,
583 update_clause = update_clause,
584 bind_fields = bind_fields,
585 )
586}
587
588fn generate_update_methods(
590 table: &TableMetadata,
591 struct_name: &str,
592 column_map: &HashMap<&str, &ColumnMetadata>,
593) -> String {
594 let mut code = String::new();
595
596 let pk = table.primary_key.as_ref().unwrap();
597
598 let update_columns: Vec<&ColumnMetadata> = table
600 .columns
601 .iter()
602 .filter(|c| !pk.columns.contains(&c.name))
603 .collect();
604
605 if update_columns.is_empty() {
606 return code;
607 }
608
609 let set_clause = update_columns
610 .iter()
611 .map(|c| format!("`{}` = ?", c.name))
612 .collect::<Vec<_>>()
613 .join(", ");
614
615 let where_clause = build_where_clause(&pk.columns);
616
617 let bind_fields: Vec<String> = update_columns
619 .iter()
620 .map(|c| {
621 let field = escape_field_name(&c.name);
622 let rust_type = TypeResolver::resolve(c, &table.name);
623 if rust_type.is_copy() {
624 format!(" .bind(entity.{})", field)
625 } else {
626 format!(" .bind(&entity.{})", field)
627 }
628 })
629 .chain(pk.columns.iter().map(|c| {
630 let field = escape_field_name(c);
631 let col = column_map.get(c.as_str()).unwrap();
632 let rust_type = TypeResolver::resolve(col, &table.name);
633 if rust_type.is_copy() {
634 format!(" .bind(entity.{})", field)
635 } else {
636 format!(" .bind(&entity.{})", field)
637 }
638 }))
639 .collect();
640
641 code.push_str(&format!(
643 r#"/// Update a record by primary key
644pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
645 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
646{bind_fields}
647 .execute(pool)
648 .await
649 .map(|r| r.rows_affected)
650}}
651
652"#,
653 struct_name = struct_name,
654 table_name = table.name,
655 set_clause = set_clause,
656 where_clause = where_clause,
657 bind_fields = bind_fields.join("\n"),
658 ));
659
660 code
661}
662
663fn generate_update_plain_method(
665 table: &TableMetadata,
666 column_map: &HashMap<&str, &ColumnMetadata>,
667) -> String {
668 let pk = table.primary_key.as_ref().unwrap();
669
670 let update_columns: Vec<&ColumnMetadata> = table
672 .columns
673 .iter()
674 .filter(|c| !pk.columns.contains(&c.name))
675 .collect();
676
677 if update_columns.is_empty() {
678 return String::new();
679 }
680
681 let method_name = generate_update_by_method_name(&pk.columns);
683
684 let pk_params = build_params(&pk.columns, column_map, &table.name);
686 let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
687 let update_params = build_params(&update_column_names, column_map, &table.name);
688 let all_params = format!("{}, {}", pk_params, update_params);
689
690 let set_clause = update_columns
691 .iter()
692 .map(|c| format!("`{}` = ?", c.name))
693 .collect::<Vec<_>>()
694 .join(", ");
695
696 let where_clause = build_where_clause(&pk.columns);
697
698 let bind_section_update = generate_bind_section(&update_column_names);
700 let bind_section_pk = generate_bind_section(&pk.columns);
701
702 format!(
703 r#"/// Update a record by primary key with individual parameters
704#[allow(clippy::too_many_arguments)]
705pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
706 Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
707{bind_section_update}
708{bind_section_pk}
709 .execute(pool)
710 .await
711 .map(|r| r.rows_affected)
712}}
713
714"#,
715 method_name = method_name,
716 all_params = all_params,
717 table_name = table.name,
718 set_clause = set_clause,
719 where_clause = where_clause,
720 bind_section_update = bind_section_update,
721 bind_section_pk = bind_section_pk,
722 )
723}
724
725fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
727 let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
728
729 if let Some(pk) = &table.primary_key {
731 let sig = MethodSignature::new(
732 pk.columns.clone(),
733 PRIORITY_PRIMARY_KEY,
734 true,
735 "PRIMARY_KEY",
736 );
737 signatures.insert(pk.columns.clone(), sig);
738 }
739
740 for index in &table.indexes {
742 let priority = if index.unique {
743 PRIORITY_UNIQUE_INDEX
744 } else {
745 PRIORITY_NON_UNIQUE_INDEX
746 };
747 let source = if index.unique {
748 "UNIQUE_INDEX"
749 } else {
750 "NON_UNIQUE_INDEX"
751 };
752 let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
753
754 if let Some(existing) = signatures.get(&index.columns) {
756 if sig.priority < existing.priority {
757 signatures.insert(index.columns.clone(), sig);
758 }
759 } else {
760 signatures.insert(index.columns.clone(), sig);
761 }
762 }
763
764 for fk in &table.foreign_keys {
766 let columns = vec![fk.column_name.clone()];
767 let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
768
769 signatures.entry(columns).or_insert(sig);
770 }
771
772 signatures
773}
774
775fn generate_find_by_method(
777 table: &TableMetadata,
778 sig: &MethodSignature,
779 column_map: &HashMap<&str, &ColumnMetadata>,
780 struct_name: &str,
781 select_columns: &str,
782) -> String {
783 let params = build_params(&sig.columns, column_map, &table.name);
784
785 let (return_type, fetch_method) = if sig.is_unique {
786 (format!("Option<{}>", struct_name), "fetch_optional")
787 } else {
788 (format!("Vec<{}>", struct_name), "fetch_all")
789 };
790
791 let return_desc = if sig.is_unique {
792 "Option (unique)"
793 } else {
794 "Vec (non-unique)"
795 };
796
797 let has_nullable = sig.columns.iter().any(|c| {
799 column_map
800 .get(c.as_str())
801 .map(|col| col.nullable)
802 .unwrap_or(false)
803 });
804
805 if has_nullable {
806 generate_find_by_method_nullable(
808 table,
809 sig,
810 column_map,
811 select_columns,
812 ¶ms,
813 &return_type,
814 fetch_method,
815 return_desc,
816 )
817 } else {
818 let where_clause = build_where_clause(&sig.columns);
820 let bind_section = generate_bind_section(&sig.columns);
821
822 format!(
823 r#"/// Find by {source}: returns {return_desc}
824pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
825 Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
826{bind_section}
827 .{fetch_method}(pool)
828 .await
829}}
830
831"#,
832 source = sig.source.to_lowercase().replace('_', " "),
833 return_desc = return_desc,
834 method_name = sig.method_name,
835 params = params,
836 return_type = return_type,
837 select_columns = select_columns,
838 table_name = table.name,
839 where_clause = where_clause,
840 bind_section = bind_section,
841 fetch_method = fetch_method,
842 )
843 }
844}
845
846#[allow(clippy::too_many_arguments)]
848fn generate_find_by_method_nullable(
849 table: &TableMetadata,
850 sig: &MethodSignature,
851 column_map: &HashMap<&str, &ColumnMetadata>,
852 select_columns: &str,
853 params: &str,
854 return_type: &str,
855 fetch_method: &str,
856 return_desc: &str,
857) -> String {
858 let mut where_parts = Vec::new();
860 let mut bind_parts = Vec::new();
861
862 for col in &sig.columns {
863 let col_meta = column_map.get(col.as_str()).unwrap();
864 let field_name = escape_field_name(col);
865
866 if col_meta.nullable {
867 where_parts.push(format!(
869 r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
870 field = field_name,
871 col = col,
872 ));
873 bind_parts.push(format!(
874 r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
875 field = field_name,
876 ));
877 } else {
878 where_parts.push(format!(r#""`{}` = ?""#, col));
879 bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
880 }
881 }
882
883 let where_expr = if where_parts.len() == 1 {
884 where_parts[0].clone()
885 } else {
886 let parts = where_parts
888 .iter()
889 .map(|p| format!("({})", p))
890 .collect::<Vec<_>>()
891 .join(", ");
892 format!("vec![{}].join(\" AND \")", parts)
893 };
894
895 let bind_code = bind_parts.join("\n ");
896
897 format!(
898 r#"/// Find by {source}: returns {return_desc}
899pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
900 let where_clause = {where_expr};
901 let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
902 let mut query = rdbi::DynamicQuery::new(sql);
903 {bind_code}
904 query.{fetch_method}(pool).await
905}}
906
907"#,
908 source = sig.source.to_lowercase().replace('_', " "),
909 return_desc = return_desc,
910 method_name = sig.method_name,
911 params = params,
912 return_type = return_type,
913 select_columns = select_columns,
914 table_name = table.name,
915 where_expr = where_expr,
916 bind_code = bind_code,
917 fetch_method = fetch_method,
918 )
919}
920
921fn generate_find_by_list_methods(
923 table: &TableMetadata,
924 column_map: &HashMap<&str, &ColumnMetadata>,
925 struct_name: &str,
926 select_columns: &str,
927) -> String {
928 let mut code = String::new();
929 let mut processed: HashSet<String> = HashSet::new();
930
931 if let Some(pk) = &table.primary_key {
933 if pk.columns.len() == 1 {
934 let col = &pk.columns[0];
935 code.push_str(&generate_single_find_by_list(
936 table,
937 col,
938 column_map,
939 struct_name,
940 select_columns,
941 ));
942 processed.insert(col.clone());
943 }
944 }
945
946 for index in &table.indexes {
948 if index.columns.len() == 1 {
949 let col = &index.columns[0];
950 if !processed.contains(col) {
951 code.push_str(&generate_single_find_by_list(
952 table,
953 col,
954 column_map,
955 struct_name,
956 select_columns,
957 ));
958 processed.insert(col.clone());
959 }
960 }
961 }
962
963 code
964}
965
966fn generate_single_find_by_list(
968 table: &TableMetadata,
969 column_name: &str,
970 column_map: &HashMap<&str, &ColumnMetadata>,
971 struct_name: &str,
972 select_columns: &str,
973) -> String {
974 let method_name = generate_find_by_list_method_name(column_name);
975 let param_name = pluralize(&escape_field_name(column_name));
976 let column = column_map.get(column_name).unwrap();
977 let rust_type = TypeResolver::resolve(column, &table.name);
978
979 let inner_type = rust_type.inner_type().to_type_string();
981
982 let column_name_plural = pluralize(column_name);
983 format!(
984 r#"/// Find by list of {column_name_plural} (IN clause)
985pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
986 if {param_name}.is_empty() {{
987 return Ok(Vec::new());
988 }}
989 let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
990 let query = format!(
991 "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
992 placeholders
993 );
994 rdbi::DynamicQuery::new(query)
995 .bind_all({param_name})
996 .fetch_all(pool)
997 .await
998}}
999
1000"#,
1001 column_name_plural = column_name_plural,
1002 column_name = column_name,
1003 method_name = method_name,
1004 param_name = param_name,
1005 inner_type = inner_type,
1006 struct_name = struct_name,
1007 select_columns = select_columns,
1008 table_name = table.name,
1009 )
1010}
1011
1012fn generate_composite_enum_list_methods(
1015 table: &TableMetadata,
1016 column_map: &HashMap<&str, &ColumnMetadata>,
1017 struct_name: &str,
1018 select_columns: &str,
1019) -> String {
1020 let mut code = String::new();
1021
1022 for index in &table.indexes {
1023 if index.columns.len() <= 1 {
1025 continue;
1026 }
1027
1028 let enum_columns: HashSet<&str> = index
1030 .columns
1031 .iter()
1032 .filter(|col_name| {
1033 column_map
1034 .get(col_name.as_str())
1035 .map(|col| col.is_enum())
1036 .unwrap_or(false)
1037 })
1038 .map(|s| s.as_str())
1039 .collect();
1040
1041 if enum_columns.is_empty() {
1043 continue;
1044 }
1045
1046 let first_column = &index.columns[0];
1048 if enum_columns.contains(first_column.as_str()) {
1049 continue;
1050 }
1051
1052 code.push_str(&generate_composite_enum_list_method(
1053 table,
1054 &index.columns,
1055 &enum_columns,
1056 column_map,
1057 struct_name,
1058 select_columns,
1059 ));
1060 }
1061
1062 code
1063}
1064
1065fn generate_composite_enum_list_method(
1067 table: &TableMetadata,
1068 columns: &[String],
1069 enum_columns: &HashSet<&str>,
1070 column_map: &HashMap<&str, &ColumnMetadata>,
1071 struct_name: &str,
1072 select_columns: &str,
1073) -> String {
1074 let method_name = generate_composite_enum_method_name(columns, enum_columns);
1076
1077 let mut params_parts = Vec::new();
1079
1080 for col_name in columns {
1081 let col = column_map.get(col_name.as_str()).unwrap();
1082 let rust_type = TypeResolver::resolve(col, &table.name);
1083 let is_enum = enum_columns.contains(col_name.as_str());
1084
1085 if is_enum {
1086 let param_name = pluralize(&escape_field_name(col_name));
1088 let inner_type = rust_type.inner_type().to_type_string();
1089 params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1090 } else {
1091 let param_name = escape_field_name(col_name);
1093 let param_type = rust_type.to_param_type_string();
1094 params_parts.push(format!("{}: {}", param_name, param_type));
1095 }
1096 }
1097
1098 let params = params_parts.join(", ");
1099
1100 let where_clause_static: Vec<String> = columns
1102 .iter()
1103 .map(|col_name| {
1104 if enum_columns.contains(col_name.as_str()) {
1105 format!("`{}` IN ({{}})", col_name) } else {
1107 format!("`{}` = ?", col_name)
1108 }
1109 })
1110 .collect();
1111
1112 let mut bind_code = String::new();
1114
1115 for col_name in columns {
1117 if !enum_columns.contains(col_name.as_str()) {
1118 let param_name = escape_field_name(col_name);
1119 bind_code.push_str(&format!(" .bind({})\n", param_name));
1120 }
1121 }
1122
1123 for col_name in columns {
1125 if enum_columns.contains(col_name.as_str()) {
1126 let param_name = pluralize(&escape_field_name(col_name));
1127 bind_code.push_str(&format!(" .bind_all({})\n", param_name));
1128 }
1129 }
1130
1131 let column_desc: Vec<String> = columns
1133 .iter()
1134 .map(|col| {
1135 if enum_columns.contains(col.as_str()) {
1136 pluralize(col)
1137 } else {
1138 col.clone()
1139 }
1140 })
1141 .collect();
1142
1143 let enum_col_names: Vec<&str> = columns
1145 .iter()
1146 .filter(|c| enum_columns.contains(c.as_str()))
1147 .map(|s| s.as_str())
1148 .collect();
1149
1150 let in_clause_builders: Vec<String> = enum_col_names
1152 .iter()
1153 .map(|col| {
1154 let param_name = pluralize(&escape_field_name(col));
1155 format!(
1156 "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1157 param_name = param_name
1158 )
1159 })
1160 .collect();
1161
1162 let format_args = in_clause_builders.join(", ");
1164
1165 format!(
1166 r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1167pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1168 // Check for empty enum lists
1169{empty_checks}
1170 // Build IN clause placeholders for enum columns
1171 let where_clause = format!("{where_template}", {format_args});
1172 let query = format!(
1173 "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1174 where_clause
1175 );
1176 rdbi::DynamicQuery::new(query)
1177{bind_code} .fetch_all(pool)
1178 .await
1179}}
1180
1181"#,
1182 column_desc = column_desc.join(" and "),
1183 method_name = method_name,
1184 params = params,
1185 struct_name = struct_name,
1186 select_columns = select_columns,
1187 table_name = table.name,
1188 where_template = where_clause_static.join(" AND "),
1189 format_args = format_args,
1190 bind_code = bind_code,
1191 empty_checks = generate_empty_checks(columns, enum_columns),
1192 )
1193}
1194
1195fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1197 let mut checks = String::new();
1198 for col_name in columns {
1199 if enum_columns.contains(col_name.as_str()) {
1200 let param_name = pluralize(&escape_field_name(col_name));
1201 checks.push_str(&format!(
1202 " if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1203 param_name
1204 ));
1205 }
1206 }
1207 checks
1208}
1209
1210fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1213 let mut parts = Vec::new();
1214 for col in columns {
1215 if enum_columns.contains(col.as_str()) {
1216 parts.push(pluralize(col));
1217 } else {
1218 parts.push(col.clone());
1219 }
1220 }
1221 generate_find_by_method_name(&parts)
1222}
1223
1224fn generate_pagination_methods(
1226 table: &TableMetadata,
1227 struct_name: &str,
1228 select_columns: &str,
1229 models_module: &str,
1230) -> String {
1231 let sort_by_enum = format!("{}SortBy", struct_name);
1232
1233 format!(
1234 r#"/// Find all records with pagination and sorting
1235pub async fn find_all_paginated<P: Pool>(
1236 pool: &P,
1237 limit: i32,
1238 offset: i32,
1239 sort_by: crate::{models_module}::{sort_by_enum},
1240 sort_dir: crate::{models_module}::SortDirection,
1241) -> Result<Vec<{struct_name}>> {{
1242 let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1243 let query = format!(
1244 "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1245 order_clause
1246 );
1247 rdbi::DynamicQuery::new(query)
1248 .bind(limit)
1249 .bind(offset)
1250 .fetch_all(pool)
1251 .await
1252}}
1253
1254/// Get paginated result with total count
1255pub async fn get_paginated_result<P: Pool>(
1256 pool: &P,
1257 page_size: i32,
1258 current_page: i32,
1259 sort_by: crate::{models_module}::{sort_by_enum},
1260 sort_dir: crate::{models_module}::SortDirection,
1261) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1262 let page_size = page_size.max(1);
1263 let current_page = current_page.max(1);
1264 let offset = (current_page - 1) * page_size;
1265
1266 let total_count = count_all(pool).await?;
1267 let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1268
1269 Ok(crate::{models_module}::PaginatedResult::new(
1270 items,
1271 total_count,
1272 current_page,
1273 page_size,
1274 ))
1275}}
1276
1277"#,
1278 struct_name = struct_name,
1279 select_columns = select_columns,
1280 table_name = table.name,
1281 models_module = models_module,
1282 sort_by_enum = sort_by_enum,
1283 )
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288 use super::*;
1289 use crate::parser::{IndexMetadata, PrimaryKey};
1290
1291 fn make_table() -> TableMetadata {
1292 TableMetadata {
1293 name: "users".to_string(),
1294 comment: None,
1295 columns: vec![
1296 ColumnMetadata {
1297 name: "id".to_string(),
1298 data_type: "BIGINT".to_string(),
1299 nullable: false,
1300 default_value: None,
1301 is_auto_increment: true,
1302 is_unsigned: false,
1303 enum_values: None,
1304 comment: None,
1305 },
1306 ColumnMetadata {
1307 name: "email".to_string(),
1308 data_type: "VARCHAR(255)".to_string(),
1309 nullable: false,
1310 default_value: None,
1311 is_auto_increment: false,
1312 is_unsigned: false,
1313 enum_values: None,
1314 comment: None,
1315 },
1316 ColumnMetadata {
1317 name: "status".to_string(),
1318 data_type: "ENUM".to_string(),
1319 nullable: false,
1320 default_value: None,
1321 is_auto_increment: false,
1322 is_unsigned: false,
1323 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1324 comment: None,
1325 },
1326 ],
1327 indexes: vec![
1328 IndexMetadata {
1329 name: "email_unique".to_string(),
1330 columns: vec!["email".to_string()],
1331 unique: true,
1332 },
1333 IndexMetadata {
1334 name: "idx_status".to_string(),
1335 columns: vec!["status".to_string()],
1336 unique: false,
1337 },
1338 ],
1339 foreign_keys: vec![],
1340 primary_key: Some(PrimaryKey {
1341 columns: vec!["id".to_string()],
1342 }),
1343 }
1344 }
1345
1346 #[test]
1347 fn test_collect_method_signatures() {
1348 let table = make_table();
1349 let sigs = collect_method_signatures(&table);
1350
1351 assert_eq!(sigs.len(), 3);
1353
1354 let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1355 assert!(id_sig.is_unique);
1356 assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1357
1358 let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1359 assert!(email_sig.is_unique);
1360 assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1361
1362 let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1363 assert!(!status_sig.is_unique);
1364 assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1365 }
1366
1367 #[test]
1368 fn test_build_select_columns() {
1369 let table = make_table();
1370 let cols = build_select_columns(&table);
1371 assert!(cols.contains("`id`"));
1372 assert!(cols.contains("`email`"));
1373 assert!(cols.contains("`status`"));
1374 }
1375
1376 #[test]
1377 fn test_build_where_clause() {
1378 let clause = build_where_clause(&["id".to_string()]);
1379 assert_eq!(clause, "`id` = ?");
1380
1381 let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1382 assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1383 }
1384
1385 #[test]
1386 fn test_generate_upsert_method() {
1387 let table = make_table();
1388 let code = generate_upsert_method(&table, "Users");
1389
1390 assert!(code.contains("pub async fn upsert"));
1392 assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1394 assert!(!code.contains("`id` = VALUES(`id`)"));
1396 assert!(code.contains("`email` = VALUES(`email`)"));
1398 assert!(code.contains("`status` = VALUES(`status`)"));
1399 }
1400
1401 #[test]
1402 fn test_generate_upsert_method_no_pk() {
1403 let mut table = make_table();
1404 table.primary_key = None;
1405 table.indexes.clear();
1406
1407 let code = generate_upsert_method(&table, "Users");
1408 assert!(code.is_empty());
1410 }
1411
1412 #[test]
1413 fn test_generate_insert_all_method() {
1414 let table = make_table();
1415 let code = generate_insert_all_method(&table, "Users");
1416
1417 assert!(code.contains("pub async fn insert_all"));
1419 assert!(code.contains("rdbi::BatchInsert::new"));
1421 }
1422
1423 #[test]
1424 fn test_generate_pagination_methods() {
1425 let table = make_table();
1426 let select_columns = build_select_columns(&table);
1427 let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1428
1429 assert!(code.contains("pub async fn find_all_paginated"));
1431 assert!(code.contains("limit: i32"));
1433 assert!(code.contains("offset: i32"));
1434 assert!(code.contains("UsersSortBy"));
1436 assert!(code.contains("SortDirection"));
1438 assert!(code.contains("pub async fn get_paginated_result"));
1440 assert!(code.contains("PaginatedResult<Users>"));
1442 }
1443}