Skip to main content

rdbi_codegen/codegen/
dao_generator.rs

1//! DAO generator - generates async rdbi query functions from table metadata
2
3use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13    escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14    generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18/// Priority levels for method signature deduplication
19const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24/// Represents a method signature for deduplication
25#[derive(Debug, Clone)]
26struct MethodSignature {
27    columns: Vec<String>,
28    method_name: String,
29    priority: u8,
30    is_unique: bool,
31    source: String,
32}
33
34impl MethodSignature {
35    fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36        let method_name = generate_find_by_method_name(&columns);
37        Self {
38            columns,
39            method_name,
40            priority,
41            is_unique,
42            source: source.to_string(),
43        }
44    }
45}
46
47/// Generate DAO files for all tables
48pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49    let output_dir = &config.output_dao_dir;
50    fs::create_dir_all(output_dir)?;
51
52    // Generate mod.rs
53    let mut mod_content = String::new();
54    mod_content.push_str("// Generated DAO functions\n\n");
55
56    for table in tables {
57        let file_name = heck::AsSnakeCase(&table.name).to_string();
58        mod_content.push_str(&format!("pub mod {};\n", file_name));
59    }
60
61    fs::write(output_dir.join("mod.rs"), mod_content)?;
62
63    // Generate each DAO file
64    for table in tables {
65        generate_dao_file(table, output_dir, &config.models_module)?;
66    }
67
68    Ok(())
69}
70
71/// Generate a single DAO file for a table
72fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
73    let struct_name = to_struct_name(&table.name);
74    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
75    debug!("Generating DAO for {} -> {}", struct_name, file_name);
76
77    let mut code = String::new();
78
79    // Build column map for quick lookup
80    let column_map: HashMap<&str, &ColumnMetadata> =
81        table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
82
83    // Collect import requirements
84    let mut needs_chrono = false;
85    let mut needs_decimal = false;
86
87    for col in &table.columns {
88        let rust_type = TypeResolver::resolve(col, &table.name);
89        if rust_type.needs_chrono() {
90            needs_chrono = true;
91        }
92        if rust_type.needs_decimal() {
93            needs_decimal = true;
94        }
95    }
96
97    // Check if we have enum columns for imports
98    let has_enums = table.columns.iter().any(|c| c.is_enum());
99
100    // Generate imports
101    code.push_str("use rdbi::{Pool, Query, Result};\n");
102    code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
103
104    if has_enums {
105        // Import enum types from the same struct module
106        for col in &table.columns {
107            if col.is_enum() {
108                let enum_name = super::naming::to_enum_name(&table.name, &col.name);
109                code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
110            }
111        }
112    }
113
114    if needs_chrono {
115        code.push_str("#[allow(unused_imports)]\n");
116        code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
117    }
118    if needs_decimal {
119        code.push_str("#[allow(unused_imports)]\n");
120        code.push_str("use rust_decimal::Decimal;\n");
121    }
122
123    code.push('\n');
124
125    // Generate select columns list
126    let select_columns = build_select_columns(table);
127
128    // Generate find_all
129    code.push_str(&generate_find_all(table, &struct_name, &select_columns));
130
131    // Generate count_all
132    code.push_str(&generate_count_all(table));
133
134    // Generate primary key methods
135    if let Some(pk) = &table.primary_key {
136        code.push_str(&generate_pk_methods(
137            table,
138            pk,
139            &column_map,
140            &struct_name,
141            &select_columns,
142        ));
143    }
144
145    // Generate insert methods
146    code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
147
148    // Generate insert_plain method (individual params)
149    code.push_str(&generate_insert_plain_method(table, &column_map));
150
151    // Generate batch insert method
152    code.push_str(&generate_insert_all_method(table, &struct_name));
153
154    // Generate upsert method
155    code.push_str(&generate_upsert_method(table, &struct_name));
156
157    // Generate update methods
158    if table.primary_key.is_some() {
159        code.push_str(&generate_update_methods(table, &struct_name, &column_map));
160        // Generate update_plain method (individual params)
161        code.push_str(&generate_update_plain_method(table, &column_map));
162    }
163
164    // Generate index-aware findBy methods
165    let signatures = collect_method_signatures(table);
166    for sig in signatures.values() {
167        // Skip if this is the primary key (already generated separately)
168        if sig.source == "PRIMARY_KEY" {
169            continue;
170        }
171        code.push_str(&generate_find_by_method(
172            table,
173            sig,
174            &column_map,
175            &struct_name,
176            &select_columns,
177        ));
178    }
179
180    // Generate list-based findBy methods for single-column indexes
181    code.push_str(&generate_find_by_list_methods(
182        table,
183        &column_map,
184        &struct_name,
185        &select_columns,
186    ));
187
188    // Generate composite enum list methods (e.g., find_by_user_id_and_device_types)
189    code.push_str(&generate_composite_enum_list_methods(
190        table,
191        &column_map,
192        &struct_name,
193        &select_columns,
194    ));
195
196    // Generate pagination methods
197    code.push_str(&generate_pagination_methods(
198        table,
199        &struct_name,
200        &select_columns,
201        models_module,
202    ));
203
204    fs::write(output_dir.join(&file_name), code)?;
205    Ok(())
206}
207
208/// Build the SELECT columns list
209fn build_select_columns(table: &TableMetadata) -> String {
210    table
211        .columns
212        .iter()
213        .map(|c| format!("`{}`", c.name))
214        .collect::<Vec<_>>()
215        .join(", ")
216}
217
218/// Build the WHERE clause for columns
219fn build_where_clause(columns: &[String]) -> String {
220    columns
221        .iter()
222        .map(|c| format!("`{}` = ?", c))
223        .collect::<Vec<_>>()
224        .join(" AND ")
225}
226
227/// Build parameter list for a function signature
228fn build_params(
229    columns: &[String],
230    column_map: &HashMap<&str, &ColumnMetadata>,
231    table_name: &str,
232) -> String {
233    columns
234        .iter()
235        .map(|c| {
236            let col = column_map.get(c.as_str()).unwrap();
237            let rust_type = TypeResolver::resolve(col, table_name);
238            let param_type = rust_type.to_param_type_string();
239            format!("{}: {}", escape_field_name(c), param_type)
240        })
241        .collect::<Vec<_>>()
242        .join(", ")
243}
244
245/// Generate bind calls for query
246fn generate_bind_section(columns: &[String]) -> String {
247    columns
248        .iter()
249        .map(|c| format!("        .bind({})", escape_field_name(c)))
250        .collect::<Vec<_>>()
251        .join("\n")
252}
253
254/// Generate find_all function
255fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
256    format!(
257        r#"/// Find all records
258pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
259    Query::new("SELECT {select_columns} FROM `{table_name}`")
260        .fetch_all(pool)
261        .await
262}}
263
264"#,
265        struct_name = struct_name,
266        select_columns = select_columns,
267        table_name = table.name,
268    )
269}
270
271/// Generate count_all function
272fn generate_count_all(table: &TableMetadata) -> String {
273    format!(
274        r#"/// Count all records
275pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
276    Query::new("SELECT COUNT(*) FROM `{table_name}`")
277        .fetch_scalar(pool)
278        .await
279}}
280
281"#,
282        table_name = table.name,
283    )
284}
285
286/// Generate primary key methods (find, delete)
287fn generate_pk_methods(
288    table: &TableMetadata,
289    pk: &crate::parser::PrimaryKey,
290    column_map: &HashMap<&str, &ColumnMetadata>,
291    struct_name: &str,
292    select_columns: &str,
293) -> String {
294    let mut code = String::new();
295
296    let method_name = generate_find_by_method_name(&pk.columns);
297    let params = build_params(&pk.columns, column_map, &table.name);
298    let where_clause = build_where_clause(&pk.columns);
299    let bind_section = generate_bind_section(&pk.columns);
300
301    // find_by_pk
302    code.push_str(&format!(
303        r#"/// Find by primary key
304pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
305    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
306{bind_section}
307        .fetch_optional(pool)
308        .await
309}}
310
311"#,
312        method_name = method_name,
313        params = params,
314        struct_name = struct_name,
315        select_columns = select_columns,
316        table_name = table.name,
317        where_clause = where_clause,
318        bind_section = bind_section,
319    ));
320
321    // delete_by_pk
322    let delete_method = generate_delete_by_method_name(&pk.columns);
323    code.push_str(&format!(
324        r#"/// Delete by primary key
325pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
326    Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
327{bind_section}
328        .execute(pool)
329        .await
330        .map(|r| r.rows_affected)
331}}
332
333"#,
334        delete_method = delete_method,
335        params = params,
336        table_name = table.name,
337        where_clause = where_clause,
338        bind_section = bind_section,
339    ));
340
341    code
342}
343
344/// Generate insert methods
345fn generate_insert_methods(
346    table: &TableMetadata,
347    struct_name: &str,
348    _column_map: &HashMap<&str, &ColumnMetadata>,
349) -> String {
350    let mut code = String::new();
351
352    // Get non-auto-increment columns for insert
353    let insert_columns: Vec<&ColumnMetadata> = table
354        .columns
355        .iter()
356        .filter(|c| !c.is_auto_increment)
357        .collect();
358
359    if insert_columns.is_empty() {
360        return code;
361    }
362
363    let column_list = insert_columns
364        .iter()
365        .map(|c| format!("`{}`", c.name))
366        .collect::<Vec<_>>()
367        .join(", ");
368
369    let placeholders = insert_columns
370        .iter()
371        .map(|_| "?")
372        .collect::<Vec<_>>()
373        .join(", ");
374
375    let bind_fields = insert_columns
376        .iter()
377        .map(|c| {
378            let field = escape_field_name(&c.name);
379            format!("        .bind(&entity.{})", field)
380        })
381        .collect::<Vec<_>>()
382        .join("\n");
383
384    // insert (with entity)
385    code.push_str(&format!(
386        r#"/// Insert a new record
387pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
388    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
389{bind_fields}
390        .execute(pool)
391        .await
392        .map(|r| r.last_insert_id.unwrap_or(0))
393}}
394
395"#,
396        struct_name = struct_name,
397        table_name = table.name,
398        column_list = column_list,
399        placeholders = placeholders,
400        bind_fields = bind_fields,
401    ));
402
403    code
404}
405
406/// Generate insert_plain method with individual parameters
407fn generate_insert_plain_method(
408    table: &TableMetadata,
409    column_map: &HashMap<&str, &ColumnMetadata>,
410) -> String {
411    // Get non-auto-increment columns for insert
412    let insert_columns: Vec<&ColumnMetadata> = table
413        .columns
414        .iter()
415        .filter(|c| !c.is_auto_increment)
416        .collect();
417
418    if insert_columns.is_empty() {
419        return String::new();
420    }
421
422    let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
423    let params = build_params(&column_names, column_map, &table.name);
424
425    let column_list = insert_columns
426        .iter()
427        .map(|c| format!("`{}`", c.name))
428        .collect::<Vec<_>>()
429        .join(", ");
430
431    let placeholders = insert_columns
432        .iter()
433        .map(|_| "?")
434        .collect::<Vec<_>>()
435        .join(", ");
436
437    let bind_section = generate_bind_section(&column_names);
438
439    format!(
440        r#"/// Insert a new record with individual parameters
441pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
442    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
443{bind_section}
444        .execute(pool)
445        .await
446        .map(|r| r.last_insert_id.unwrap_or(0))
447}}
448
449"#,
450        params = params,
451        table_name = table.name,
452        column_list = column_list,
453        placeholders = placeholders,
454        bind_section = bind_section,
455    )
456}
457
458/// Generate batch insert method (insert_all) using BatchInsert
459fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
460    // Get non-auto-increment columns for insert
461    let insert_columns: Vec<&ColumnMetadata> = table
462        .columns
463        .iter()
464        .filter(|c| !c.is_auto_increment)
465        .collect();
466
467    if insert_columns.is_empty() {
468        return String::new();
469    }
470
471    format!(
472        r#"/// Insert multiple records in a single batch
473pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
474    rdbi::BatchInsert::new("{table_name}", entities)
475        .execute(pool)
476        .await
477        .map(|r| r.rows_affected)
478}}
479
480"#,
481        struct_name = struct_name,
482        table_name = table.name,
483    )
484}
485
486/// Check if table has a unique index (excluding primary key)
487fn has_unique_index(table: &TableMetadata) -> bool {
488    table.indexes.iter().any(|idx| idx.unique)
489}
490
491/// Generate upsert method (INSERT ... ON DUPLICATE KEY UPDATE)
492fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
493    // Only generate if table has primary key or unique index
494    if table.primary_key.is_none() && !has_unique_index(table) {
495        return String::new();
496    }
497
498    // Get non-auto-increment columns for insert
499    let insert_columns: Vec<&ColumnMetadata> = table
500        .columns
501        .iter()
502        .filter(|c| !c.is_auto_increment)
503        .collect();
504
505    if insert_columns.is_empty() {
506        return String::new();
507    }
508
509    // Get primary key columns for exclusion from UPDATE clause
510    let pk_columns: HashSet<&str> = table
511        .primary_key
512        .as_ref()
513        .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
514        .unwrap_or_default();
515
516    // Columns to update on duplicate key (all non-PK, non-auto-increment columns)
517    let update_columns: Vec<&ColumnMetadata> = insert_columns
518        .iter()
519        .filter(|c| !pk_columns.contains(c.name.as_str()))
520        .copied()
521        .collect();
522
523    // If no columns to update, skip upsert generation
524    if update_columns.is_empty() {
525        return String::new();
526    }
527
528    let column_list = insert_columns
529        .iter()
530        .map(|c| format!("`{}`", c.name))
531        .collect::<Vec<_>>()
532        .join(", ");
533
534    let placeholders = insert_columns
535        .iter()
536        .map(|_| "?")
537        .collect::<Vec<_>>()
538        .join(", ");
539
540    let update_clause = update_columns
541        .iter()
542        .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
543        .collect::<Vec<_>>()
544        .join(", ");
545
546    let bind_fields = insert_columns
547        .iter()
548        .map(|c| {
549            let field = escape_field_name(&c.name);
550            format!("        .bind(&entity.{})", field)
551        })
552        .collect::<Vec<_>>()
553        .join("\n");
554
555    format!(
556        r#"/// Upsert a record (insert or update on duplicate key)
557/// Returns rows_affected: 1 if inserted, 2 if updated
558pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
559    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
560         ON DUPLICATE KEY UPDATE {update_clause}")
561{bind_fields}
562        .execute(pool)
563        .await
564        .map(|r| r.rows_affected)
565}}
566
567"#,
568        struct_name = struct_name,
569        table_name = table.name,
570        column_list = column_list,
571        placeholders = placeholders,
572        update_clause = update_clause,
573        bind_fields = bind_fields,
574    )
575}
576
577/// Generate update methods
578fn generate_update_methods(
579    table: &TableMetadata,
580    struct_name: &str,
581    _column_map: &HashMap<&str, &ColumnMetadata>,
582) -> String {
583    let mut code = String::new();
584
585    let pk = table.primary_key.as_ref().unwrap();
586
587    // Get non-PK columns for SET clause
588    let update_columns: Vec<&ColumnMetadata> = table
589        .columns
590        .iter()
591        .filter(|c| !pk.columns.contains(&c.name))
592        .collect();
593
594    if update_columns.is_empty() {
595        return code;
596    }
597
598    let set_clause = update_columns
599        .iter()
600        .map(|c| format!("`{}` = ?", c.name))
601        .collect::<Vec<_>>()
602        .join(", ");
603
604    let where_clause = build_where_clause(&pk.columns);
605
606    // Bind update columns first, then PK columns
607    let bind_fields: Vec<String> = update_columns
608        .iter()
609        .map(|c| {
610            let field = escape_field_name(&c.name);
611            format!("        .bind(&entity.{})", field)
612        })
613        .chain(pk.columns.iter().map(|c| {
614            let field = escape_field_name(c);
615            format!("        .bind(&entity.{})", field)
616        }))
617        .collect();
618
619    // update_by_bean (using entity)
620    code.push_str(&format!(
621        r#"/// Update a record by primary key
622pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
623    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
624{bind_fields}
625        .execute(pool)
626        .await
627        .map(|r| r.rows_affected)
628}}
629
630"#,
631        struct_name = struct_name,
632        table_name = table.name,
633        set_clause = set_clause,
634        where_clause = where_clause,
635        bind_fields = bind_fields.join("\n"),
636    ));
637
638    code
639}
640
641/// Generate update_plain method with individual parameters (update_by_<pk>)
642fn generate_update_plain_method(
643    table: &TableMetadata,
644    column_map: &HashMap<&str, &ColumnMetadata>,
645) -> String {
646    let pk = table.primary_key.as_ref().unwrap();
647
648    // Get non-PK columns for SET clause
649    let update_columns: Vec<&ColumnMetadata> = table
650        .columns
651        .iter()
652        .filter(|c| !pk.columns.contains(&c.name))
653        .collect();
654
655    if update_columns.is_empty() {
656        return String::new();
657    }
658
659    // Build method name based on PK columns
660    let method_name = generate_update_by_method_name(&pk.columns);
661
662    // Build params: PK columns first, then update columns
663    let pk_params = build_params(&pk.columns, column_map, &table.name);
664    let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
665    let update_params = build_params(&update_column_names, column_map, &table.name);
666    let all_params = format!("{}, {}", pk_params, update_params);
667
668    let set_clause = update_columns
669        .iter()
670        .map(|c| format!("`{}` = ?", c.name))
671        .collect::<Vec<_>>()
672        .join(", ");
673
674    let where_clause = build_where_clause(&pk.columns);
675
676    // Bind update columns first (for SET), then PK columns (for WHERE)
677    let bind_section_update = generate_bind_section(&update_column_names);
678    let bind_section_pk = generate_bind_section(&pk.columns);
679
680    format!(
681        r#"/// Update a record by primary key with individual parameters
682pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
683    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
684{bind_section_update}
685{bind_section_pk}
686        .execute(pool)
687        .await
688        .map(|r| r.rows_affected)
689}}
690
691"#,
692        method_name = method_name,
693        all_params = all_params,
694        table_name = table.name,
695        set_clause = set_clause,
696        where_clause = where_clause,
697        bind_section_update = bind_section_update,
698        bind_section_pk = bind_section_pk,
699    )
700}
701
702/// Collect method signatures with priority-based deduplication
703fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
704    let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
705
706    // Priority 1: Primary key
707    if let Some(pk) = &table.primary_key {
708        let sig = MethodSignature::new(
709            pk.columns.clone(),
710            PRIORITY_PRIMARY_KEY,
711            true,
712            "PRIMARY_KEY",
713        );
714        signatures.insert(pk.columns.clone(), sig);
715    }
716
717    // Priority 2 & 3: Indexes
718    for index in &table.indexes {
719        let priority = if index.unique {
720            PRIORITY_UNIQUE_INDEX
721        } else {
722            PRIORITY_NON_UNIQUE_INDEX
723        };
724        let source = if index.unique {
725            "UNIQUE_INDEX"
726        } else {
727            "NON_UNIQUE_INDEX"
728        };
729        let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
730
731        // Only add if no higher priority signature exists
732        if let Some(existing) = signatures.get(&index.columns) {
733            if sig.priority < existing.priority {
734                signatures.insert(index.columns.clone(), sig);
735            }
736        } else {
737            signatures.insert(index.columns.clone(), sig);
738        }
739    }
740
741    // Priority 4: Foreign keys
742    for fk in &table.foreign_keys {
743        let columns = vec![fk.column_name.clone()];
744        let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
745
746        signatures.entry(columns).or_insert(sig);
747    }
748
749    signatures
750}
751
752/// Generate a find_by method for an index/FK
753fn generate_find_by_method(
754    table: &TableMetadata,
755    sig: &MethodSignature,
756    column_map: &HashMap<&str, &ColumnMetadata>,
757    struct_name: &str,
758    select_columns: &str,
759) -> String {
760    let params = build_params(&sig.columns, column_map, &table.name);
761
762    let (return_type, fetch_method) = if sig.is_unique {
763        (format!("Option<{}>", struct_name), "fetch_optional")
764    } else {
765        (format!("Vec<{}>", struct_name), "fetch_all")
766    };
767
768    let return_desc = if sig.is_unique {
769        "Option (unique)"
770    } else {
771        "Vec (non-unique)"
772    };
773
774    // Check if any column is nullable - if so, we need dynamic SQL for IS NULL handling
775    let has_nullable = sig.columns.iter().any(|c| {
776        column_map
777            .get(c.as_str())
778            .map(|col| col.nullable)
779            .unwrap_or(false)
780    });
781
782    if has_nullable {
783        // Generate dynamic query that handles NULL values properly
784        generate_find_by_method_nullable(
785            table,
786            sig,
787            column_map,
788            select_columns,
789            &params,
790            &return_type,
791            fetch_method,
792            return_desc,
793        )
794    } else {
795        // Use static query for non-nullable columns
796        let where_clause = build_where_clause(&sig.columns);
797        let bind_section = generate_bind_section(&sig.columns);
798
799        format!(
800            r#"/// Find by {source}: returns {return_desc}
801pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
802    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
803{bind_section}
804        .{fetch_method}(pool)
805        .await
806}}
807
808"#,
809            source = sig.source.to_lowercase().replace('_', " "),
810            return_desc = return_desc,
811            method_name = sig.method_name,
812            params = params,
813            return_type = return_type,
814            select_columns = select_columns,
815            table_name = table.name,
816            where_clause = where_clause,
817            bind_section = bind_section,
818            fetch_method = fetch_method,
819        )
820    }
821}
822
823/// Generate a find_by method that handles nullable columns with IS NULL
824#[allow(clippy::too_many_arguments)]
825fn generate_find_by_method_nullable(
826    table: &TableMetadata,
827    sig: &MethodSignature,
828    column_map: &HashMap<&str, &ColumnMetadata>,
829    select_columns: &str,
830    params: &str,
831    return_type: &str,
832    fetch_method: &str,
833    return_desc: &str,
834) -> String {
835    // Build the where clause conditions and bind logic
836    let mut where_parts = Vec::new();
837    let mut bind_parts = Vec::new();
838
839    for col in &sig.columns {
840        let col_meta = column_map.get(col.as_str()).unwrap();
841        let field_name = escape_field_name(col);
842
843        if col_meta.nullable {
844            // For nullable columns, check if value is None and use IS NULL
845            where_parts.push(format!(
846                r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
847                field = field_name,
848                col = col,
849            ));
850            bind_parts.push(format!(
851                r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
852                field = field_name,
853            ));
854        } else {
855            where_parts.push(format!(r#""`{}` = ?""#, col));
856            bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
857        }
858    }
859
860    let where_expr = if where_parts.len() == 1 {
861        where_parts[0].clone()
862    } else {
863        // Join with " AND "
864        let parts = where_parts
865            .iter()
866            .map(|p| format!("({})", p))
867            .collect::<Vec<_>>()
868            .join(", ");
869        format!("vec![{}].join(\" AND \")", parts)
870    };
871
872    let bind_code = bind_parts.join("\n        ");
873
874    format!(
875        r#"/// Find by {source}: returns {return_desc}
876pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
877    let where_clause = {where_expr};
878    let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
879    let mut query = rdbi::DynamicQuery::new(sql);
880    {bind_code}
881    query.{fetch_method}(pool).await
882}}
883
884"#,
885        source = sig.source.to_lowercase().replace('_', " "),
886        return_desc = return_desc,
887        method_name = sig.method_name,
888        params = params,
889        return_type = return_type,
890        select_columns = select_columns,
891        table_name = table.name,
892        where_expr = where_expr,
893        bind_code = bind_code,
894        fetch_method = fetch_method,
895    )
896}
897
898/// Generate list-based findBy methods for single-column indexes
899fn generate_find_by_list_methods(
900    table: &TableMetadata,
901    column_map: &HashMap<&str, &ColumnMetadata>,
902    struct_name: &str,
903    select_columns: &str,
904) -> String {
905    let mut code = String::new();
906    let mut processed: HashSet<String> = HashSet::new();
907
908    // Primary key (if single column)
909    if let Some(pk) = &table.primary_key {
910        if pk.columns.len() == 1 {
911            let col = &pk.columns[0];
912            code.push_str(&generate_single_find_by_list(
913                table,
914                col,
915                column_map,
916                struct_name,
917                select_columns,
918            ));
919            processed.insert(col.clone());
920        }
921    }
922
923    // Single-column indexes
924    for index in &table.indexes {
925        if index.columns.len() == 1 {
926            let col = &index.columns[0];
927            if !processed.contains(col) {
928                code.push_str(&generate_single_find_by_list(
929                    table,
930                    col,
931                    column_map,
932                    struct_name,
933                    select_columns,
934                ));
935                processed.insert(col.clone());
936            }
937        }
938    }
939
940    code
941}
942
943/// Generate a single find_by_<column>s method (using IN clause)
944fn generate_single_find_by_list(
945    table: &TableMetadata,
946    column_name: &str,
947    column_map: &HashMap<&str, &ColumnMetadata>,
948    struct_name: &str,
949    select_columns: &str,
950) -> String {
951    let method_name = generate_find_by_list_method_name(column_name);
952    let param_name = pluralize(&escape_field_name(column_name));
953    let column = column_map.get(column_name).unwrap();
954    let rust_type = TypeResolver::resolve(column, &table.name);
955
956    // Get the inner type (unwrap Option if nullable)
957    let inner_type = rust_type.inner_type().to_type_string();
958
959    let column_name_plural = pluralize(column_name);
960    format!(
961        r#"/// Find by list of {column_name_plural} (IN clause)
962pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
963    if {param_name}.is_empty() {{
964        return Ok(Vec::new());
965    }}
966    let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
967    let query = format!(
968        "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
969        placeholders
970    );
971    rdbi::DynamicQuery::new(query)
972        .bind_all({param_name})
973        .fetch_all(pool)
974        .await
975}}
976
977"#,
978        column_name_plural = column_name_plural,
979        column_name = column_name,
980        method_name = method_name,
981        param_name = param_name,
982        inner_type = inner_type,
983        struct_name = struct_name,
984        select_columns = select_columns,
985        table_name = table.name,
986    )
987}
988
989/// Generate composite enum list methods for multi-column indexes with enum columns
990/// Example: find_by_user_id_and_device_types(user_id, &[DeviceType])
991fn generate_composite_enum_list_methods(
992    table: &TableMetadata,
993    column_map: &HashMap<&str, &ColumnMetadata>,
994    struct_name: &str,
995    select_columns: &str,
996) -> String {
997    let mut code = String::new();
998
999    for index in &table.indexes {
1000        // Skip single-column indexes (handled by generate_find_by_list_methods)
1001        if index.columns.len() <= 1 {
1002            continue;
1003        }
1004
1005        // Identify which columns are enums
1006        let enum_columns: HashSet<&str> = index
1007            .columns
1008            .iter()
1009            .filter(|col_name| {
1010                column_map
1011                    .get(col_name.as_str())
1012                    .map(|col| col.is_enum())
1013                    .unwrap_or(false)
1014            })
1015            .map(|s| s.as_str())
1016            .collect();
1017
1018        // Skip if no enum columns
1019        if enum_columns.is_empty() {
1020            continue;
1021        }
1022
1023        // Skip if first column is enum (for optimal index usage, equality should be on leading column)
1024        let first_column = &index.columns[0];
1025        if enum_columns.contains(first_column.as_str()) {
1026            continue;
1027        }
1028
1029        code.push_str(&generate_composite_enum_list_method(
1030            table,
1031            &index.columns,
1032            &enum_columns,
1033            column_map,
1034            struct_name,
1035            select_columns,
1036        ));
1037    }
1038
1039    code
1040}
1041
1042/// Generate a single composite enum list method
1043fn generate_composite_enum_list_method(
1044    table: &TableMetadata,
1045    columns: &[String],
1046    enum_columns: &HashSet<&str>,
1047    column_map: &HashMap<&str, &ColumnMetadata>,
1048    struct_name: &str,
1049    select_columns: &str,
1050) -> String {
1051    // Build method name: pluralize enum column names
1052    let method_name = generate_composite_enum_method_name(columns, enum_columns);
1053
1054    // Build params and WHERE clause parts
1055    let mut params_parts = Vec::new();
1056
1057    for col_name in columns {
1058        let col = column_map.get(col_name.as_str()).unwrap();
1059        let rust_type = TypeResolver::resolve(col, &table.name);
1060        let is_enum = enum_columns.contains(col_name.as_str());
1061
1062        if is_enum {
1063            // Enum column uses list parameter
1064            let param_name = pluralize(&escape_field_name(col_name));
1065            let inner_type = rust_type.inner_type().to_type_string();
1066            params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1067        } else {
1068            // Non-enum column uses single value
1069            let param_name = escape_field_name(col_name);
1070            let param_type = rust_type.to_param_type_string();
1071            params_parts.push(format!("{}: {}", param_name, param_type));
1072        }
1073    }
1074
1075    let params = params_parts.join(", ");
1076
1077    // Build WHERE clause with proper placeholders
1078    let where_clause_static: Vec<String> = columns
1079        .iter()
1080        .map(|col_name| {
1081            if enum_columns.contains(col_name.as_str()) {
1082                format!("`{}` IN ({{}})", col_name) // placeholder for IN clause
1083            } else {
1084                format!("`{}` = ?", col_name)
1085            }
1086        })
1087        .collect();
1088
1089    // Build the bind section
1090    let mut bind_code = String::new();
1091
1092    // First bind non-enum (single value) columns
1093    for col_name in columns {
1094        if !enum_columns.contains(col_name.as_str()) {
1095            let param_name = escape_field_name(col_name);
1096            bind_code.push_str(&format!("        .bind({})\n", param_name));
1097        }
1098    }
1099
1100    // Then bind enum (list) columns
1101    for col_name in columns {
1102        if enum_columns.contains(col_name.as_str()) {
1103            let param_name = pluralize(&escape_field_name(col_name));
1104            bind_code.push_str(&format!("        .bind_all({})\n", param_name));
1105        }
1106    }
1107
1108    // Build the column name description for doc comment
1109    let column_desc: Vec<String> = columns
1110        .iter()
1111        .map(|col| {
1112            if enum_columns.contains(col.as_str()) {
1113                pluralize(col)
1114            } else {
1115                col.clone()
1116            }
1117        })
1118        .collect();
1119
1120    // Build dynamic WHERE clause construction
1121    let enum_col_names: Vec<&str> = columns
1122        .iter()
1123        .filter(|c| enum_columns.contains(c.as_str()))
1124        .map(|s| s.as_str())
1125        .collect();
1126
1127    // Generate the IN clause placeholders dynamically
1128    let in_clause_builders: Vec<String> = enum_col_names
1129        .iter()
1130        .map(|col| {
1131            let param_name = pluralize(&escape_field_name(col));
1132            format!(
1133                "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1134                param_name = param_name
1135            )
1136        })
1137        .collect();
1138
1139    // Build format args for the WHERE clause
1140    let format_args = in_clause_builders.join(", ");
1141
1142    format!(
1143        r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1144pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1145    // Check for empty enum lists
1146{empty_checks}
1147    // Build IN clause placeholders for enum columns
1148    let where_clause = format!("{where_template}", {format_args});
1149    let query = format!(
1150        "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1151        where_clause
1152    );
1153    rdbi::DynamicQuery::new(query)
1154{bind_code}        .fetch_all(pool)
1155        .await
1156}}
1157
1158"#,
1159        column_desc = column_desc.join(" and "),
1160        method_name = method_name,
1161        params = params,
1162        struct_name = struct_name,
1163        select_columns = select_columns,
1164        table_name = table.name,
1165        where_template = where_clause_static.join(" AND "),
1166        format_args = format_args,
1167        bind_code = bind_code,
1168        empty_checks = generate_empty_checks(columns, enum_columns),
1169    )
1170}
1171
1172/// Generate empty checks for enum list parameters
1173fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1174    let mut checks = String::new();
1175    for col_name in columns {
1176        if enum_columns.contains(col_name.as_str()) {
1177            let param_name = pluralize(&escape_field_name(col_name));
1178            checks.push_str(&format!(
1179                "    if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1180                param_name
1181            ));
1182        }
1183    }
1184    checks
1185}
1186
1187/// Generate method name for composite enum queries
1188/// Example: ["user_id", "device_type"] with enum_columns={"device_type"} -> "find_by_user_id_and_device_types"
1189fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1190    let mut parts = Vec::new();
1191    for col in columns {
1192        if enum_columns.contains(col.as_str()) {
1193            parts.push(pluralize(col));
1194        } else {
1195            parts.push(col.clone());
1196        }
1197    }
1198    generate_find_by_method_name(&parts)
1199}
1200
1201/// Generate pagination methods (find_all_paginated, get_paginated_result)
1202fn generate_pagination_methods(
1203    table: &TableMetadata,
1204    struct_name: &str,
1205    select_columns: &str,
1206    models_module: &str,
1207) -> String {
1208    let sort_by_enum = format!("{}SortBy", struct_name);
1209
1210    format!(
1211        r#"/// Find all records with pagination and sorting
1212pub async fn find_all_paginated<P: Pool>(
1213    pool: &P,
1214    limit: i32,
1215    offset: i32,
1216    sort_by: crate::{models_module}::{sort_by_enum},
1217    sort_dir: crate::{models_module}::SortDirection,
1218) -> Result<Vec<{struct_name}>> {{
1219    let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1220    let query = format!(
1221        "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1222        order_clause
1223    );
1224    rdbi::DynamicQuery::new(query)
1225        .bind(limit)
1226        .bind(offset)
1227        .fetch_all(pool)
1228        .await
1229}}
1230
1231/// Get paginated result with total count
1232pub async fn get_paginated_result<P: Pool>(
1233    pool: &P,
1234    page_size: i32,
1235    current_page: i32,
1236    sort_by: crate::{models_module}::{sort_by_enum},
1237    sort_dir: crate::{models_module}::SortDirection,
1238) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1239    let page_size = page_size.max(1);
1240    let current_page = current_page.max(1);
1241    let offset = (current_page - 1) * page_size;
1242
1243    let total_count = count_all(pool).await?;
1244    let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1245
1246    Ok(crate::{models_module}::PaginatedResult::new(
1247        items,
1248        total_count,
1249        current_page,
1250        page_size,
1251    ))
1252}}
1253
1254"#,
1255        struct_name = struct_name,
1256        select_columns = select_columns,
1257        table_name = table.name,
1258        models_module = models_module,
1259        sort_by_enum = sort_by_enum,
1260    )
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265    use super::*;
1266    use crate::parser::{IndexMetadata, PrimaryKey};
1267
1268    fn make_table() -> TableMetadata {
1269        TableMetadata {
1270            name: "users".to_string(),
1271            comment: None,
1272            columns: vec![
1273                ColumnMetadata {
1274                    name: "id".to_string(),
1275                    data_type: "BIGINT".to_string(),
1276                    nullable: false,
1277                    default_value: None,
1278                    is_auto_increment: true,
1279                    is_unsigned: false,
1280                    enum_values: None,
1281                    comment: None,
1282                },
1283                ColumnMetadata {
1284                    name: "email".to_string(),
1285                    data_type: "VARCHAR(255)".to_string(),
1286                    nullable: false,
1287                    default_value: None,
1288                    is_auto_increment: false,
1289                    is_unsigned: false,
1290                    enum_values: None,
1291                    comment: None,
1292                },
1293                ColumnMetadata {
1294                    name: "status".to_string(),
1295                    data_type: "ENUM".to_string(),
1296                    nullable: false,
1297                    default_value: None,
1298                    is_auto_increment: false,
1299                    is_unsigned: false,
1300                    enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1301                    comment: None,
1302                },
1303            ],
1304            indexes: vec![
1305                IndexMetadata {
1306                    name: "email_unique".to_string(),
1307                    columns: vec!["email".to_string()],
1308                    unique: true,
1309                },
1310                IndexMetadata {
1311                    name: "idx_status".to_string(),
1312                    columns: vec!["status".to_string()],
1313                    unique: false,
1314                },
1315            ],
1316            foreign_keys: vec![],
1317            primary_key: Some(PrimaryKey {
1318                columns: vec!["id".to_string()],
1319            }),
1320        }
1321    }
1322
1323    #[test]
1324    fn test_collect_method_signatures() {
1325        let table = make_table();
1326        let sigs = collect_method_signatures(&table);
1327
1328        // Should have signatures for: id (PK), email (unique), status (non-unique)
1329        assert_eq!(sigs.len(), 3);
1330
1331        let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1332        assert!(id_sig.is_unique);
1333        assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1334
1335        let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1336        assert!(email_sig.is_unique);
1337        assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1338
1339        let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1340        assert!(!status_sig.is_unique);
1341        assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1342    }
1343
1344    #[test]
1345    fn test_build_select_columns() {
1346        let table = make_table();
1347        let cols = build_select_columns(&table);
1348        assert!(cols.contains("`id`"));
1349        assert!(cols.contains("`email`"));
1350        assert!(cols.contains("`status`"));
1351    }
1352
1353    #[test]
1354    fn test_build_where_clause() {
1355        let clause = build_where_clause(&["id".to_string()]);
1356        assert_eq!(clause, "`id` = ?");
1357
1358        let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1359        assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1360    }
1361
1362    #[test]
1363    fn test_generate_upsert_method() {
1364        let table = make_table();
1365        let code = generate_upsert_method(&table, "Users");
1366
1367        // Should contain upsert function
1368        assert!(code.contains("pub async fn upsert"));
1369        // Should contain ON DUPLICATE KEY UPDATE
1370        assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1371        // Should NOT update the PK column (id)
1372        assert!(!code.contains("`id` = VALUES(`id`)"));
1373        // Should update non-PK columns
1374        assert!(code.contains("`email` = VALUES(`email`)"));
1375        assert!(code.contains("`status` = VALUES(`status`)"));
1376    }
1377
1378    #[test]
1379    fn test_generate_upsert_method_no_pk() {
1380        let mut table = make_table();
1381        table.primary_key = None;
1382        table.indexes.clear();
1383
1384        let code = generate_upsert_method(&table, "Users");
1385        // Should not generate upsert without PK or unique index
1386        assert!(code.is_empty());
1387    }
1388
1389    #[test]
1390    fn test_generate_insert_all_method() {
1391        let table = make_table();
1392        let code = generate_insert_all_method(&table, "Users");
1393
1394        // Should contain insert_all function
1395        assert!(code.contains("pub async fn insert_all"));
1396        // Should use BatchInsert
1397        assert!(code.contains("rdbi::BatchInsert::new"));
1398    }
1399
1400    #[test]
1401    fn test_generate_pagination_methods() {
1402        let table = make_table();
1403        let select_columns = build_select_columns(&table);
1404        let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1405
1406        // Should contain find_all_paginated function
1407        assert!(code.contains("pub async fn find_all_paginated"));
1408        // Should have limit and offset params
1409        assert!(code.contains("limit: i32"));
1410        assert!(code.contains("offset: i32"));
1411        // Should use SortBy enum
1412        assert!(code.contains("UsersSortBy"));
1413        // Should use SortDirection
1414        assert!(code.contains("SortDirection"));
1415        // Should contain get_paginated_result
1416        assert!(code.contains("pub async fn get_paginated_result"));
1417        // Should use PaginatedResult
1418        assert!(code.contains("PaginatedResult<Users>"));
1419    }
1420}