Skip to main content

rdbi_codegen/codegen/
dao_generator.rs

1//! DAO generator - generates async rdbi query functions from table metadata
2
3use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13    escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14    generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18/// Priority levels for method signature deduplication
19const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24/// Represents a method signature for deduplication
25#[derive(Debug, Clone)]
26struct MethodSignature {
27    columns: Vec<String>,
28    method_name: String,
29    priority: u8,
30    is_unique: bool,
31    source: String,
32}
33
34impl MethodSignature {
35    fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36        let method_name = generate_find_by_method_name(&columns);
37        Self {
38            columns,
39            method_name,
40            priority,
41            is_unique,
42            source: source.to_string(),
43        }
44    }
45}
46
47/// Generate DAO files for all tables
48pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49    let output_dir = &config.output_dao_dir;
50    fs::create_dir_all(output_dir)?;
51
52    // Generate mod.rs
53    let mut mod_content = String::new();
54    mod_content.push_str("// Generated DAO functions\n\n");
55
56    for table in tables {
57        let file_name = heck::AsSnakeCase(&table.name).to_string();
58        mod_content.push_str(&format!("pub mod {};\n", file_name));
59    }
60
61    let mod_path = output_dir.join("mod.rs");
62    fs::write(&mod_path, mod_content)?;
63
64    // Generate each DAO file
65    for table in tables {
66        generate_dao_file(table, output_dir, &config.models_module)?;
67    }
68
69    Ok(())
70}
71
72/// Generate a single DAO file for a table
73fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
74    let struct_name = to_struct_name(&table.name);
75    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
76    debug!("Generating DAO for {} -> {}", struct_name, file_name);
77
78    let mut code = String::new();
79
80    // Build column map for quick lookup
81    let column_map: HashMap<&str, &ColumnMetadata> =
82        table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
83
84    // Collect import requirements
85    let mut needs_chrono = false;
86    let mut needs_decimal = false;
87
88    for col in &table.columns {
89        let rust_type = TypeResolver::resolve(col, &table.name);
90        if rust_type.needs_chrono() {
91            needs_chrono = true;
92        }
93        if rust_type.needs_decimal() {
94            needs_decimal = true;
95        }
96    }
97
98    // Check if we have enum columns for imports
99    let has_enums = table.columns.iter().any(|c| c.is_enum());
100
101    // Generate imports
102    code.push_str("use rdbi::{Pool, Query, Result};\n");
103    code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
104
105    if has_enums {
106        // Import enum types from the same struct module
107        for col in &table.columns {
108            if col.is_enum() {
109                let enum_name = super::naming::to_enum_name(&table.name, &col.name);
110                code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
111            }
112        }
113    }
114
115    if needs_chrono {
116        code.push_str("#[allow(unused_imports)]\n");
117        code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
118    }
119    if needs_decimal {
120        code.push_str("#[allow(unused_imports)]\n");
121        code.push_str("use rust_decimal::Decimal;\n");
122    }
123
124    code.push('\n');
125
126    // Generate select columns list
127    let select_columns = build_select_columns(table);
128
129    // Generate find_all
130    code.push_str(&generate_find_all(table, &struct_name, &select_columns));
131
132    // Generate count_all
133    code.push_str(&generate_count_all(table));
134
135    // Generate primary key methods
136    if let Some(pk) = &table.primary_key {
137        code.push_str(&generate_pk_methods(
138            table,
139            pk,
140            &column_map,
141            &struct_name,
142            &select_columns,
143        ));
144    }
145
146    // Generate insert methods
147    code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
148
149    // Generate insert_plain method (individual params)
150    code.push_str(&generate_insert_plain_method(table, &column_map));
151
152    // Generate batch insert method
153    code.push_str(&generate_insert_all_method(table, &struct_name));
154
155    // Generate upsert method
156    code.push_str(&generate_upsert_method(table, &struct_name));
157
158    // Generate update methods
159    if table.primary_key.is_some() {
160        code.push_str(&generate_update_methods(table, &struct_name, &column_map));
161        // Generate update_plain method (individual params)
162        code.push_str(&generate_update_plain_method(table, &column_map));
163    }
164
165    // Generate index-aware findBy methods
166    let signatures = collect_method_signatures(table);
167    for sig in signatures.values() {
168        // Skip if this is the primary key (already generated separately)
169        if sig.source == "PRIMARY_KEY" {
170            continue;
171        }
172        code.push_str(&generate_find_by_method(
173            table,
174            sig,
175            &column_map,
176            &struct_name,
177            &select_columns,
178        ));
179    }
180
181    // Generate list-based findBy methods for single-column indexes
182    code.push_str(&generate_find_by_list_methods(
183        table,
184        &column_map,
185        &struct_name,
186        &select_columns,
187    ));
188
189    // Generate composite enum list methods (e.g., find_by_user_id_and_device_types)
190    code.push_str(&generate_composite_enum_list_methods(
191        table,
192        &column_map,
193        &struct_name,
194        &select_columns,
195    ));
196
197    // Generate pagination methods
198    code.push_str(&generate_pagination_methods(
199        table,
200        &struct_name,
201        &select_columns,
202        models_module,
203    ));
204
205    let file_path = output_dir.join(&file_name);
206    fs::write(&file_path, code)?;
207    Ok(())
208}
209
210/// Build the SELECT columns list
211fn build_select_columns(table: &TableMetadata) -> String {
212    table
213        .columns
214        .iter()
215        .map(|c| format!("`{}`", c.name))
216        .collect::<Vec<_>>()
217        .join(", ")
218}
219
220/// Build the WHERE clause for columns
221fn build_where_clause(columns: &[String]) -> String {
222    columns
223        .iter()
224        .map(|c| format!("`{}` = ?", c))
225        .collect::<Vec<_>>()
226        .join(" AND ")
227}
228
229/// Build parameter list for a function signature
230fn build_params(
231    columns: &[String],
232    column_map: &HashMap<&str, &ColumnMetadata>,
233    table_name: &str,
234) -> String {
235    columns
236        .iter()
237        .map(|c| {
238            let col = column_map.get(c.as_str()).unwrap();
239            let rust_type = TypeResolver::resolve(col, table_name);
240            let param_type = rust_type.to_param_type_string();
241            format!("{}: {}", escape_field_name(c), param_type)
242        })
243        .collect::<Vec<_>>()
244        .join(", ")
245}
246
247/// Generate bind calls for query
248fn generate_bind_section(columns: &[String]) -> String {
249    columns
250        .iter()
251        .map(|c| format!("        .bind({})", escape_field_name(c)))
252        .collect::<Vec<_>>()
253        .join("\n")
254}
255
256/// Generate find_all function
257fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
258    format!(
259        r#"/// Find all records
260pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
261    Query::new("SELECT {select_columns} FROM `{table_name}`")
262        .fetch_all(pool)
263        .await
264}}
265
266"#,
267        struct_name = struct_name,
268        select_columns = select_columns,
269        table_name = table.name,
270    )
271}
272
273/// Generate count_all function
274fn generate_count_all(table: &TableMetadata) -> String {
275    format!(
276        r#"/// Count all records
277pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
278    Query::new("SELECT COUNT(*) FROM `{table_name}`")
279        .fetch_scalar(pool)
280        .await
281}}
282
283"#,
284        table_name = table.name,
285    )
286}
287
288/// Generate primary key methods (find, delete)
289fn generate_pk_methods(
290    table: &TableMetadata,
291    pk: &crate::parser::PrimaryKey,
292    column_map: &HashMap<&str, &ColumnMetadata>,
293    struct_name: &str,
294    select_columns: &str,
295) -> String {
296    let mut code = String::new();
297
298    let method_name = generate_find_by_method_name(&pk.columns);
299    let params = build_params(&pk.columns, column_map, &table.name);
300    let where_clause = build_where_clause(&pk.columns);
301    let bind_section = generate_bind_section(&pk.columns);
302
303    // find_by_pk
304    code.push_str(&format!(
305        r#"/// Find by primary key
306pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
307    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
308{bind_section}
309        .fetch_optional(pool)
310        .await
311}}
312
313"#,
314        method_name = method_name,
315        params = params,
316        struct_name = struct_name,
317        select_columns = select_columns,
318        table_name = table.name,
319        where_clause = where_clause,
320        bind_section = bind_section,
321    ));
322
323    // delete_by_pk
324    let delete_method = generate_delete_by_method_name(&pk.columns);
325    code.push_str(&format!(
326        r#"/// Delete by primary key
327pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
328    Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
329{bind_section}
330        .execute(pool)
331        .await
332        .map(|r| r.rows_affected)
333}}
334
335"#,
336        delete_method = delete_method,
337        params = params,
338        table_name = table.name,
339        where_clause = where_clause,
340        bind_section = bind_section,
341    ));
342
343    code
344}
345
346/// Generate insert methods
347fn generate_insert_methods(
348    table: &TableMetadata,
349    struct_name: &str,
350    _column_map: &HashMap<&str, &ColumnMetadata>,
351) -> String {
352    let mut code = String::new();
353
354    // Get non-auto-increment columns for insert
355    let insert_columns: Vec<&ColumnMetadata> = table
356        .columns
357        .iter()
358        .filter(|c| !c.is_auto_increment)
359        .collect();
360
361    if insert_columns.is_empty() {
362        return code;
363    }
364
365    let column_list = insert_columns
366        .iter()
367        .map(|c| format!("`{}`", c.name))
368        .collect::<Vec<_>>()
369        .join(", ");
370
371    let placeholders = insert_columns
372        .iter()
373        .map(|_| "?")
374        .collect::<Vec<_>>()
375        .join(", ");
376
377    let bind_fields = insert_columns
378        .iter()
379        .map(|c| {
380            let field = escape_field_name(&c.name);
381            let rust_type = TypeResolver::resolve(c, &table.name);
382            if rust_type.is_copy() {
383                format!("        .bind(entity.{})", field)
384            } else {
385                format!("        .bind(&entity.{})", field)
386            }
387        })
388        .collect::<Vec<_>>()
389        .join("\n");
390
391    // insert (with entity)
392    code.push_str(&format!(
393        r#"/// Insert a new record
394pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
395    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
396{bind_fields}
397        .execute(pool)
398        .await
399        .map(|r| r.last_insert_id.unwrap_or(0))
400}}
401
402"#,
403        struct_name = struct_name,
404        table_name = table.name,
405        column_list = column_list,
406        placeholders = placeholders,
407        bind_fields = bind_fields,
408    ));
409
410    code
411}
412
413/// Generate insert_plain method with individual parameters
414fn generate_insert_plain_method(
415    table: &TableMetadata,
416    column_map: &HashMap<&str, &ColumnMetadata>,
417) -> String {
418    // Get non-auto-increment columns for insert
419    let insert_columns: Vec<&ColumnMetadata> = table
420        .columns
421        .iter()
422        .filter(|c| !c.is_auto_increment)
423        .collect();
424
425    if insert_columns.is_empty() {
426        return String::new();
427    }
428
429    let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
430    let params = build_params(&column_names, column_map, &table.name);
431
432    let column_list = insert_columns
433        .iter()
434        .map(|c| format!("`{}`", c.name))
435        .collect::<Vec<_>>()
436        .join(", ");
437
438    let placeholders = insert_columns
439        .iter()
440        .map(|_| "?")
441        .collect::<Vec<_>>()
442        .join(", ");
443
444    let bind_section = generate_bind_section(&column_names);
445
446    format!(
447        r#"/// Insert a new record with individual parameters
448#[allow(clippy::too_many_arguments)]
449pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
450    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
451{bind_section}
452        .execute(pool)
453        .await
454        .map(|r| r.last_insert_id.unwrap_or(0))
455}}
456
457"#,
458        params = params,
459        table_name = table.name,
460        column_list = column_list,
461        placeholders = placeholders,
462        bind_section = bind_section,
463    )
464}
465
466/// Generate batch insert method (insert_all) using BatchInsert
467fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
468    // Get non-auto-increment columns for insert
469    let insert_columns: Vec<&ColumnMetadata> = table
470        .columns
471        .iter()
472        .filter(|c| !c.is_auto_increment)
473        .collect();
474
475    if insert_columns.is_empty() {
476        return String::new();
477    }
478
479    format!(
480        r#"/// Insert multiple records in a single batch
481pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
482    rdbi::BatchInsert::new("{table_name}", entities)
483        .execute(pool)
484        .await
485        .map(|r| r.rows_affected)
486}}
487
488"#,
489        struct_name = struct_name,
490        table_name = table.name,
491    )
492}
493
494/// Check if table has a unique index (excluding primary key)
495fn has_unique_index(table: &TableMetadata) -> bool {
496    table.indexes.iter().any(|idx| idx.unique)
497}
498
499/// Generate upsert method (INSERT ... ON DUPLICATE KEY UPDATE)
500fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
501    // Only generate if table has primary key or unique index
502    if table.primary_key.is_none() && !has_unique_index(table) {
503        return String::new();
504    }
505
506    // Get non-auto-increment columns for insert
507    let insert_columns: Vec<&ColumnMetadata> = table
508        .columns
509        .iter()
510        .filter(|c| !c.is_auto_increment)
511        .collect();
512
513    if insert_columns.is_empty() {
514        return String::new();
515    }
516
517    // Get primary key columns for exclusion from UPDATE clause
518    let pk_columns: HashSet<&str> = table
519        .primary_key
520        .as_ref()
521        .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
522        .unwrap_or_default();
523
524    // Columns to update on duplicate key (all non-PK, non-auto-increment columns)
525    let update_columns: Vec<&ColumnMetadata> = insert_columns
526        .iter()
527        .filter(|c| !pk_columns.contains(c.name.as_str()))
528        .copied()
529        .collect();
530
531    // If no columns to update, skip upsert generation
532    if update_columns.is_empty() {
533        return String::new();
534    }
535
536    let column_list = insert_columns
537        .iter()
538        .map(|c| format!("`{}`", c.name))
539        .collect::<Vec<_>>()
540        .join(", ");
541
542    let placeholders = insert_columns
543        .iter()
544        .map(|_| "?")
545        .collect::<Vec<_>>()
546        .join(", ");
547
548    let update_clause = update_columns
549        .iter()
550        .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
551        .collect::<Vec<_>>()
552        .join(", ");
553
554    let bind_fields = insert_columns
555        .iter()
556        .map(|c| {
557            let field = escape_field_name(&c.name);
558            let rust_type = TypeResolver::resolve(c, &table.name);
559            if rust_type.is_copy() {
560                format!("        .bind(entity.{})", field)
561            } else {
562                format!("        .bind(&entity.{})", field)
563            }
564        })
565        .collect::<Vec<_>>()
566        .join("\n");
567
568    format!(
569        r#"/// Upsert a record (insert or update on duplicate key)
570/// Returns rows_affected: 1 if inserted, 2 if updated
571pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
572    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
573         ON DUPLICATE KEY UPDATE {update_clause}")
574{bind_fields}
575        .execute(pool)
576        .await
577        .map(|r| r.rows_affected)
578}}
579
580"#,
581        struct_name = struct_name,
582        table_name = table.name,
583        column_list = column_list,
584        placeholders = placeholders,
585        update_clause = update_clause,
586        bind_fields = bind_fields,
587    )
588}
589
590/// Generate update methods
591fn generate_update_methods(
592    table: &TableMetadata,
593    struct_name: &str,
594    column_map: &HashMap<&str, &ColumnMetadata>,
595) -> String {
596    let mut code = String::new();
597
598    let pk = table.primary_key.as_ref().unwrap();
599
600    // Get non-PK columns for SET clause
601    let update_columns: Vec<&ColumnMetadata> = table
602        .columns
603        .iter()
604        .filter(|c| !pk.columns.contains(&c.name))
605        .collect();
606
607    if update_columns.is_empty() {
608        return code;
609    }
610
611    let set_clause = update_columns
612        .iter()
613        .map(|c| format!("`{}` = ?", c.name))
614        .collect::<Vec<_>>()
615        .join(", ");
616
617    let where_clause = build_where_clause(&pk.columns);
618
619    // Bind update columns first, then PK columns
620    let bind_fields: Vec<String> = update_columns
621        .iter()
622        .map(|c| {
623            let field = escape_field_name(&c.name);
624            let rust_type = TypeResolver::resolve(c, &table.name);
625            if rust_type.is_copy() {
626                format!("        .bind(entity.{})", field)
627            } else {
628                format!("        .bind(&entity.{})", field)
629            }
630        })
631        .chain(pk.columns.iter().map(|c| {
632            let field = escape_field_name(c);
633            let col = column_map.get(c.as_str()).unwrap();
634            let rust_type = TypeResolver::resolve(col, &table.name);
635            if rust_type.is_copy() {
636                format!("        .bind(entity.{})", field)
637            } else {
638                format!("        .bind(&entity.{})", field)
639            }
640        }))
641        .collect();
642
643    // update_by_bean (using entity)
644    code.push_str(&format!(
645        r#"/// Update a record by primary key
646pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
647    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
648{bind_fields}
649        .execute(pool)
650        .await
651        .map(|r| r.rows_affected)
652}}
653
654"#,
655        struct_name = struct_name,
656        table_name = table.name,
657        set_clause = set_clause,
658        where_clause = where_clause,
659        bind_fields = bind_fields.join("\n"),
660    ));
661
662    code
663}
664
665/// Generate update_plain method with individual parameters (update_by_<pk>)
666fn generate_update_plain_method(
667    table: &TableMetadata,
668    column_map: &HashMap<&str, &ColumnMetadata>,
669) -> String {
670    let pk = table.primary_key.as_ref().unwrap();
671
672    // Get non-PK columns for SET clause
673    let update_columns: Vec<&ColumnMetadata> = table
674        .columns
675        .iter()
676        .filter(|c| !pk.columns.contains(&c.name))
677        .collect();
678
679    if update_columns.is_empty() {
680        return String::new();
681    }
682
683    // Build method name based on PK columns
684    let method_name = generate_update_by_method_name(&pk.columns);
685
686    // Build params: PK columns first, then update columns
687    let pk_params = build_params(&pk.columns, column_map, &table.name);
688    let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
689    let update_params = build_params(&update_column_names, column_map, &table.name);
690    let all_params = format!("{}, {}", pk_params, update_params);
691
692    let set_clause = update_columns
693        .iter()
694        .map(|c| format!("`{}` = ?", c.name))
695        .collect::<Vec<_>>()
696        .join(", ");
697
698    let where_clause = build_where_clause(&pk.columns);
699
700    // Bind update columns first (for SET), then PK columns (for WHERE)
701    let bind_section_update = generate_bind_section(&update_column_names);
702    let bind_section_pk = generate_bind_section(&pk.columns);
703
704    format!(
705        r#"/// Update a record by primary key with individual parameters
706#[allow(clippy::too_many_arguments)]
707pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
708    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
709{bind_section_update}
710{bind_section_pk}
711        .execute(pool)
712        .await
713        .map(|r| r.rows_affected)
714}}
715
716"#,
717        method_name = method_name,
718        all_params = all_params,
719        table_name = table.name,
720        set_clause = set_clause,
721        where_clause = where_clause,
722        bind_section_update = bind_section_update,
723        bind_section_pk = bind_section_pk,
724    )
725}
726
727/// Collect method signatures with priority-based deduplication
728fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
729    let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
730
731    // Priority 1: Primary key
732    if let Some(pk) = &table.primary_key {
733        let sig = MethodSignature::new(
734            pk.columns.clone(),
735            PRIORITY_PRIMARY_KEY,
736            true,
737            "PRIMARY_KEY",
738        );
739        signatures.insert(pk.columns.clone(), sig);
740    }
741
742    // Priority 2 & 3: Indexes
743    for index in &table.indexes {
744        let priority = if index.unique {
745            PRIORITY_UNIQUE_INDEX
746        } else {
747            PRIORITY_NON_UNIQUE_INDEX
748        };
749        let source = if index.unique {
750            "UNIQUE_INDEX"
751        } else {
752            "NON_UNIQUE_INDEX"
753        };
754        let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
755
756        // Only add if no higher priority signature exists
757        if let Some(existing) = signatures.get(&index.columns) {
758            if sig.priority < existing.priority {
759                signatures.insert(index.columns.clone(), sig);
760            }
761        } else {
762            signatures.insert(index.columns.clone(), sig);
763        }
764    }
765
766    // Priority 4: Foreign keys
767    for fk in &table.foreign_keys {
768        let columns = vec![fk.column_name.clone()];
769        let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
770
771        signatures.entry(columns).or_insert(sig);
772    }
773
774    signatures
775}
776
777/// Generate a find_by method for an index/FK
778fn generate_find_by_method(
779    table: &TableMetadata,
780    sig: &MethodSignature,
781    column_map: &HashMap<&str, &ColumnMetadata>,
782    struct_name: &str,
783    select_columns: &str,
784) -> String {
785    let params = build_params(&sig.columns, column_map, &table.name);
786
787    let (return_type, fetch_method) = if sig.is_unique {
788        (format!("Option<{}>", struct_name), "fetch_optional")
789    } else {
790        (format!("Vec<{}>", struct_name), "fetch_all")
791    };
792
793    let return_desc = if sig.is_unique {
794        "Option (unique)"
795    } else {
796        "Vec (non-unique)"
797    };
798
799    // Check if any column is nullable - if so, we need dynamic SQL for IS NULL handling
800    let has_nullable = sig.columns.iter().any(|c| {
801        column_map
802            .get(c.as_str())
803            .map(|col| col.nullable)
804            .unwrap_or(false)
805    });
806
807    if has_nullable {
808        // Generate dynamic query that handles NULL values properly
809        generate_find_by_method_nullable(
810            table,
811            sig,
812            column_map,
813            select_columns,
814            &params,
815            &return_type,
816            fetch_method,
817            return_desc,
818        )
819    } else {
820        // Use static query for non-nullable columns
821        let where_clause = build_where_clause(&sig.columns);
822        let bind_section = generate_bind_section(&sig.columns);
823
824        format!(
825            r#"/// Find by {source}: returns {return_desc}
826pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
827    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
828{bind_section}
829        .{fetch_method}(pool)
830        .await
831}}
832
833"#,
834            source = sig.source.to_lowercase().replace('_', " "),
835            return_desc = return_desc,
836            method_name = sig.method_name,
837            params = params,
838            return_type = return_type,
839            select_columns = select_columns,
840            table_name = table.name,
841            where_clause = where_clause,
842            bind_section = bind_section,
843            fetch_method = fetch_method,
844        )
845    }
846}
847
848/// Generate a find_by method that handles nullable columns with IS NULL
849#[allow(clippy::too_many_arguments)]
850fn generate_find_by_method_nullable(
851    table: &TableMetadata,
852    sig: &MethodSignature,
853    column_map: &HashMap<&str, &ColumnMetadata>,
854    select_columns: &str,
855    params: &str,
856    return_type: &str,
857    fetch_method: &str,
858    return_desc: &str,
859) -> String {
860    // Build the where clause conditions and bind logic
861    let mut where_parts = Vec::new();
862    let mut bind_parts = Vec::new();
863
864    for col in &sig.columns {
865        let col_meta = column_map.get(col.as_str()).unwrap();
866        let field_name = escape_field_name(col);
867
868        if col_meta.nullable {
869            // For nullable columns, check if value is None and use IS NULL
870            where_parts.push(format!(
871                r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
872                field = field_name,
873                col = col,
874            ));
875            bind_parts.push(format!(
876                r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
877                field = field_name,
878            ));
879        } else {
880            where_parts.push(format!(r#""`{}` = ?""#, col));
881            bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
882        }
883    }
884
885    let where_expr = if where_parts.len() == 1 {
886        where_parts[0].clone()
887    } else {
888        // Join with " AND "
889        let parts = where_parts
890            .iter()
891            .map(|p| format!("({})", p))
892            .collect::<Vec<_>>()
893            .join(", ");
894        format!("vec![{}].join(\" AND \")", parts)
895    };
896
897    let bind_code = bind_parts.join("\n        ");
898
899    format!(
900        r#"/// Find by {source}: returns {return_desc}
901pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
902    let where_clause = {where_expr};
903    let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
904    let mut query = rdbi::DynamicQuery::new(sql);
905    {bind_code}
906    query.{fetch_method}(pool).await
907}}
908
909"#,
910        source = sig.source.to_lowercase().replace('_', " "),
911        return_desc = return_desc,
912        method_name = sig.method_name,
913        params = params,
914        return_type = return_type,
915        select_columns = select_columns,
916        table_name = table.name,
917        where_expr = where_expr,
918        bind_code = bind_code,
919        fetch_method = fetch_method,
920    )
921}
922
923/// Generate list-based findBy methods for single-column indexes
924fn generate_find_by_list_methods(
925    table: &TableMetadata,
926    column_map: &HashMap<&str, &ColumnMetadata>,
927    struct_name: &str,
928    select_columns: &str,
929) -> String {
930    let mut code = String::new();
931    let mut processed: HashSet<String> = HashSet::new();
932
933    // Primary key (if single column)
934    if let Some(pk) = &table.primary_key {
935        if pk.columns.len() == 1 {
936            let col = &pk.columns[0];
937            code.push_str(&generate_single_find_by_list(
938                table,
939                col,
940                column_map,
941                struct_name,
942                select_columns,
943            ));
944            processed.insert(col.clone());
945        }
946    }
947
948    // Single-column indexes
949    for index in &table.indexes {
950        if index.columns.len() == 1 {
951            let col = &index.columns[0];
952            if !processed.contains(col) {
953                code.push_str(&generate_single_find_by_list(
954                    table,
955                    col,
956                    column_map,
957                    struct_name,
958                    select_columns,
959                ));
960                processed.insert(col.clone());
961            }
962        }
963    }
964
965    code
966}
967
968/// Generate a single find_by_<column>s method (using IN clause)
969fn generate_single_find_by_list(
970    table: &TableMetadata,
971    column_name: &str,
972    column_map: &HashMap<&str, &ColumnMetadata>,
973    struct_name: &str,
974    select_columns: &str,
975) -> String {
976    let method_name = generate_find_by_list_method_name(column_name);
977    let param_name = pluralize(&escape_field_name(column_name));
978    let column = column_map.get(column_name).unwrap();
979    let rust_type = TypeResolver::resolve(column, &table.name);
980
981    // Get the inner type (unwrap Option if nullable)
982    let inner_type = rust_type.inner_type().to_type_string();
983
984    let column_name_plural = pluralize(column_name);
985    format!(
986        r#"/// Find by list of {column_name_plural} (IN clause)
987pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
988    if {param_name}.is_empty() {{
989        return Ok(Vec::new());
990    }}
991    let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
992    let query = format!(
993        "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
994        placeholders
995    );
996    rdbi::DynamicQuery::new(query)
997        .bind_all({param_name})
998        .fetch_all(pool)
999        .await
1000}}
1001
1002"#,
1003        column_name_plural = column_name_plural,
1004        column_name = column_name,
1005        method_name = method_name,
1006        param_name = param_name,
1007        inner_type = inner_type,
1008        struct_name = struct_name,
1009        select_columns = select_columns,
1010        table_name = table.name,
1011    )
1012}
1013
1014/// Generate composite enum list methods for multi-column indexes with enum columns
1015/// Example: find_by_user_id_and_device_types(user_id, &[DeviceType])
1016fn generate_composite_enum_list_methods(
1017    table: &TableMetadata,
1018    column_map: &HashMap<&str, &ColumnMetadata>,
1019    struct_name: &str,
1020    select_columns: &str,
1021) -> String {
1022    let mut code = String::new();
1023
1024    for index in &table.indexes {
1025        // Skip single-column indexes (handled by generate_find_by_list_methods)
1026        if index.columns.len() <= 1 {
1027            continue;
1028        }
1029
1030        // Identify which columns are enums
1031        let enum_columns: HashSet<&str> = index
1032            .columns
1033            .iter()
1034            .filter(|col_name| {
1035                column_map
1036                    .get(col_name.as_str())
1037                    .map(|col| col.is_enum())
1038                    .unwrap_or(false)
1039            })
1040            .map(|s| s.as_str())
1041            .collect();
1042
1043        // Skip if no enum columns
1044        if enum_columns.is_empty() {
1045            continue;
1046        }
1047
1048        // Skip if first column is enum (for optimal index usage, equality should be on leading column)
1049        let first_column = &index.columns[0];
1050        if enum_columns.contains(first_column.as_str()) {
1051            continue;
1052        }
1053
1054        code.push_str(&generate_composite_enum_list_method(
1055            table,
1056            &index.columns,
1057            &enum_columns,
1058            column_map,
1059            struct_name,
1060            select_columns,
1061        ));
1062    }
1063
1064    code
1065}
1066
1067/// Generate a single composite enum list method
1068fn generate_composite_enum_list_method(
1069    table: &TableMetadata,
1070    columns: &[String],
1071    enum_columns: &HashSet<&str>,
1072    column_map: &HashMap<&str, &ColumnMetadata>,
1073    struct_name: &str,
1074    select_columns: &str,
1075) -> String {
1076    // Build method name: pluralize enum column names
1077    let method_name = generate_composite_enum_method_name(columns, enum_columns);
1078
1079    // Build params and WHERE clause parts
1080    let mut params_parts = Vec::new();
1081
1082    for col_name in columns {
1083        let col = column_map.get(col_name.as_str()).unwrap();
1084        let rust_type = TypeResolver::resolve(col, &table.name);
1085        let is_enum = enum_columns.contains(col_name.as_str());
1086
1087        if is_enum {
1088            // Enum column uses list parameter
1089            let param_name = pluralize(&escape_field_name(col_name));
1090            let inner_type = rust_type.inner_type().to_type_string();
1091            params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1092        } else {
1093            // Non-enum column uses single value
1094            let param_name = escape_field_name(col_name);
1095            let param_type = rust_type.to_param_type_string();
1096            params_parts.push(format!("{}: {}", param_name, param_type));
1097        }
1098    }
1099
1100    let params = params_parts.join(", ");
1101
1102    // Build WHERE clause with proper placeholders
1103    let where_clause_static: Vec<String> = columns
1104        .iter()
1105        .map(|col_name| {
1106            if enum_columns.contains(col_name.as_str()) {
1107                format!("`{}` IN ({{}})", col_name) // placeholder for IN clause
1108            } else {
1109                format!("`{}` = ?", col_name)
1110            }
1111        })
1112        .collect();
1113
1114    // Build the bind section
1115    let mut bind_code = String::new();
1116
1117    // First bind non-enum (single value) columns
1118    for col_name in columns {
1119        if !enum_columns.contains(col_name.as_str()) {
1120            let param_name = escape_field_name(col_name);
1121            bind_code.push_str(&format!("        .bind({})\n", param_name));
1122        }
1123    }
1124
1125    // Then bind enum (list) columns
1126    for col_name in columns {
1127        if enum_columns.contains(col_name.as_str()) {
1128            let param_name = pluralize(&escape_field_name(col_name));
1129            bind_code.push_str(&format!("        .bind_all({})\n", param_name));
1130        }
1131    }
1132
1133    // Build the column name description for doc comment
1134    let column_desc: Vec<String> = columns
1135        .iter()
1136        .map(|col| {
1137            if enum_columns.contains(col.as_str()) {
1138                pluralize(col)
1139            } else {
1140                col.clone()
1141            }
1142        })
1143        .collect();
1144
1145    // Build dynamic WHERE clause construction
1146    let enum_col_names: Vec<&str> = columns
1147        .iter()
1148        .filter(|c| enum_columns.contains(c.as_str()))
1149        .map(|s| s.as_str())
1150        .collect();
1151
1152    // Generate the IN clause placeholders dynamically
1153    let in_clause_builders: Vec<String> = enum_col_names
1154        .iter()
1155        .map(|col| {
1156            let param_name = pluralize(&escape_field_name(col));
1157            format!(
1158                "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1159                param_name = param_name
1160            )
1161        })
1162        .collect();
1163
1164    // Build format args for the WHERE clause
1165    let format_args = in_clause_builders.join(", ");
1166
1167    format!(
1168        r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1169pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1170    // Check for empty enum lists
1171{empty_checks}
1172    // Build IN clause placeholders for enum columns
1173    let where_clause = format!("{where_template}", {format_args});
1174    let query = format!(
1175        "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1176        where_clause
1177    );
1178    rdbi::DynamicQuery::new(query)
1179{bind_code}        .fetch_all(pool)
1180        .await
1181}}
1182
1183"#,
1184        column_desc = column_desc.join(" and "),
1185        method_name = method_name,
1186        params = params,
1187        struct_name = struct_name,
1188        select_columns = select_columns,
1189        table_name = table.name,
1190        where_template = where_clause_static.join(" AND "),
1191        format_args = format_args,
1192        bind_code = bind_code,
1193        empty_checks = generate_empty_checks(columns, enum_columns),
1194    )
1195}
1196
1197/// Generate empty checks for enum list parameters
1198fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1199    let mut checks = String::new();
1200    for col_name in columns {
1201        if enum_columns.contains(col_name.as_str()) {
1202            let param_name = pluralize(&escape_field_name(col_name));
1203            checks.push_str(&format!(
1204                "    if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1205                param_name
1206            ));
1207        }
1208    }
1209    checks
1210}
1211
1212/// Generate method name for composite enum queries
1213/// Example: ["user_id", "device_type"] with enum_columns={"device_type"} -> "find_by_user_id_and_device_types"
1214fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1215    let mut parts = Vec::new();
1216    for col in columns {
1217        if enum_columns.contains(col.as_str()) {
1218            parts.push(pluralize(col));
1219        } else {
1220            parts.push(col.clone());
1221        }
1222    }
1223    generate_find_by_method_name(&parts)
1224}
1225
1226/// Generate pagination methods (find_all_paginated, get_paginated_result)
1227fn generate_pagination_methods(
1228    table: &TableMetadata,
1229    struct_name: &str,
1230    select_columns: &str,
1231    models_module: &str,
1232) -> String {
1233    let sort_by_enum = format!("{}SortBy", struct_name);
1234
1235    format!(
1236        r#"/// Find all records with pagination and sorting
1237pub async fn find_all_paginated<P: Pool>(
1238    pool: &P,
1239    limit: i32,
1240    offset: i32,
1241    sort_by: crate::{models_module}::{sort_by_enum},
1242    sort_dir: crate::{models_module}::SortDirection,
1243) -> Result<Vec<{struct_name}>> {{
1244    let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1245    let query = format!(
1246        "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1247        order_clause
1248    );
1249    rdbi::DynamicQuery::new(query)
1250        .bind(limit)
1251        .bind(offset)
1252        .fetch_all(pool)
1253        .await
1254}}
1255
1256/// Get paginated result with total count
1257pub async fn get_paginated_result<P: Pool>(
1258    pool: &P,
1259    page_size: i32,
1260    current_page: i32,
1261    sort_by: crate::{models_module}::{sort_by_enum},
1262    sort_dir: crate::{models_module}::SortDirection,
1263) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1264    let page_size = page_size.max(1);
1265    let current_page = current_page.max(1);
1266    let offset = (current_page - 1) * page_size;
1267
1268    let total_count = count_all(pool).await?;
1269    let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1270
1271    Ok(crate::{models_module}::PaginatedResult::new(
1272        items,
1273        total_count,
1274        current_page,
1275        page_size,
1276    ))
1277}}
1278
1279"#,
1280        struct_name = struct_name,
1281        select_columns = select_columns,
1282        table_name = table.name,
1283        models_module = models_module,
1284        sort_by_enum = sort_by_enum,
1285    )
1286}
1287
1288#[cfg(test)]
1289mod tests {
1290    use super::*;
1291    use crate::parser::{IndexMetadata, PrimaryKey};
1292
1293    fn make_table() -> TableMetadata {
1294        TableMetadata {
1295            name: "users".to_string(),
1296            comment: None,
1297            columns: vec![
1298                ColumnMetadata {
1299                    name: "id".to_string(),
1300                    data_type: "BIGINT".to_string(),
1301                    nullable: false,
1302                    default_value: None,
1303                    is_auto_increment: true,
1304                    is_unsigned: false,
1305                    enum_values: None,
1306                    comment: None,
1307                },
1308                ColumnMetadata {
1309                    name: "email".to_string(),
1310                    data_type: "VARCHAR(255)".to_string(),
1311                    nullable: false,
1312                    default_value: None,
1313                    is_auto_increment: false,
1314                    is_unsigned: false,
1315                    enum_values: None,
1316                    comment: None,
1317                },
1318                ColumnMetadata {
1319                    name: "status".to_string(),
1320                    data_type: "ENUM".to_string(),
1321                    nullable: false,
1322                    default_value: None,
1323                    is_auto_increment: false,
1324                    is_unsigned: false,
1325                    enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1326                    comment: None,
1327                },
1328            ],
1329            indexes: vec![
1330                IndexMetadata {
1331                    name: "email_unique".to_string(),
1332                    columns: vec!["email".to_string()],
1333                    unique: true,
1334                },
1335                IndexMetadata {
1336                    name: "idx_status".to_string(),
1337                    columns: vec!["status".to_string()],
1338                    unique: false,
1339                },
1340            ],
1341            foreign_keys: vec![],
1342            primary_key: Some(PrimaryKey {
1343                columns: vec!["id".to_string()],
1344            }),
1345        }
1346    }
1347
1348    #[test]
1349    fn test_collect_method_signatures() {
1350        let table = make_table();
1351        let sigs = collect_method_signatures(&table);
1352
1353        // Should have signatures for: id (PK), email (unique), status (non-unique)
1354        assert_eq!(sigs.len(), 3);
1355
1356        let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1357        assert!(id_sig.is_unique);
1358        assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1359
1360        let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1361        assert!(email_sig.is_unique);
1362        assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1363
1364        let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1365        assert!(!status_sig.is_unique);
1366        assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1367    }
1368
1369    #[test]
1370    fn test_build_select_columns() {
1371        let table = make_table();
1372        let cols = build_select_columns(&table);
1373        assert!(cols.contains("`id`"));
1374        assert!(cols.contains("`email`"));
1375        assert!(cols.contains("`status`"));
1376    }
1377
1378    #[test]
1379    fn test_build_where_clause() {
1380        let clause = build_where_clause(&["id".to_string()]);
1381        assert_eq!(clause, "`id` = ?");
1382
1383        let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1384        assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1385    }
1386
1387    #[test]
1388    fn test_generate_upsert_method() {
1389        let table = make_table();
1390        let code = generate_upsert_method(&table, "Users");
1391
1392        // Should contain upsert function
1393        assert!(code.contains("pub async fn upsert"));
1394        // Should contain ON DUPLICATE KEY UPDATE
1395        assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1396        // Should NOT update the PK column (id)
1397        assert!(!code.contains("`id` = VALUES(`id`)"));
1398        // Should update non-PK columns
1399        assert!(code.contains("`email` = VALUES(`email`)"));
1400        assert!(code.contains("`status` = VALUES(`status`)"));
1401    }
1402
1403    #[test]
1404    fn test_generate_upsert_method_no_pk() {
1405        let mut table = make_table();
1406        table.primary_key = None;
1407        table.indexes.clear();
1408
1409        let code = generate_upsert_method(&table, "Users");
1410        // Should not generate upsert without PK or unique index
1411        assert!(code.is_empty());
1412    }
1413
1414    #[test]
1415    fn test_generate_insert_all_method() {
1416        let table = make_table();
1417        let code = generate_insert_all_method(&table, "Users");
1418
1419        // Should contain insert_all function
1420        assert!(code.contains("pub async fn insert_all"));
1421        // Should use BatchInsert
1422        assert!(code.contains("rdbi::BatchInsert::new"));
1423    }
1424
1425    #[test]
1426    fn test_generate_pagination_methods() {
1427        let table = make_table();
1428        let select_columns = build_select_columns(&table);
1429        let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1430
1431        // Should contain find_all_paginated function
1432        assert!(code.contains("pub async fn find_all_paginated"));
1433        // Should have limit and offset params
1434        assert!(code.contains("limit: i32"));
1435        assert!(code.contains("offset: i32"));
1436        // Should use SortBy enum
1437        assert!(code.contains("UsersSortBy"));
1438        // Should use SortDirection
1439        assert!(code.contains("SortDirection"));
1440        // Should contain get_paginated_result
1441        assert!(code.contains("pub async fn get_paginated_result"));
1442        // Should use PaginatedResult
1443        assert!(code.contains("PaginatedResult<Users>"));
1444    }
1445}