Skip to main content

rdbi_codegen/codegen/
dao_generator.rs

1//! DAO generator - generates async rdbi query functions from table metadata
2
3use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13    escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14    generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18/// Priority levels for method signature deduplication
19const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24/// Represents a method signature for deduplication
25#[derive(Debug, Clone)]
26struct MethodSignature {
27    columns: Vec<String>,
28    method_name: String,
29    priority: u8,
30    is_unique: bool,
31    source: String,
32}
33
34impl MethodSignature {
35    fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36        let method_name = generate_find_by_method_name(&columns);
37        Self {
38            columns,
39            method_name,
40            priority,
41            is_unique,
42            source: source.to_string(),
43        }
44    }
45}
46
47/// Generate DAO files for all tables
48pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49    let output_dir = &config.output_dao_dir;
50    fs::create_dir_all(output_dir)?;
51
52    // Generate mod.rs
53    let mut mod_content = String::new();
54    mod_content.push_str("// Generated DAO functions\n\n");
55
56    for table in tables {
57        let file_name = heck::AsSnakeCase(&table.name).to_string();
58        mod_content.push_str(&format!("pub mod {};\n", file_name));
59    }
60
61    fs::write(output_dir.join("mod.rs"), mod_content)?;
62
63    // Generate each DAO file
64    for table in tables {
65        generate_dao_file(table, output_dir, &config.models_module)?;
66    }
67
68    Ok(())
69}
70
71/// Generate a single DAO file for a table
72fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
73    let struct_name = to_struct_name(&table.name);
74    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
75    debug!("Generating DAO for {} -> {}", struct_name, file_name);
76
77    let mut code = String::new();
78
79    // Build column map for quick lookup
80    let column_map: HashMap<&str, &ColumnMetadata> =
81        table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
82
83    // Collect import requirements
84    let mut needs_chrono = false;
85    let mut needs_decimal = false;
86
87    for col in &table.columns {
88        let rust_type = TypeResolver::resolve(col, &table.name);
89        if rust_type.needs_chrono() {
90            needs_chrono = true;
91        }
92        if rust_type.needs_decimal() {
93            needs_decimal = true;
94        }
95    }
96
97    // Check if we have enum columns for imports
98    let has_enums = table.columns.iter().any(|c| c.is_enum());
99
100    // Generate imports
101    code.push_str("use rdbi::{Pool, Query, Result};\n");
102    code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
103
104    if has_enums {
105        // Import enum types from the same struct module
106        for col in &table.columns {
107            if col.is_enum() {
108                let enum_name = super::naming::to_enum_name(&table.name, &col.name);
109                code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
110            }
111        }
112    }
113
114    if needs_chrono {
115        code.push_str("#[allow(unused_imports)]\n");
116        code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
117    }
118    if needs_decimal {
119        code.push_str("#[allow(unused_imports)]\n");
120        code.push_str("use rust_decimal::Decimal;\n");
121    }
122
123    code.push('\n');
124
125    // Generate select columns list
126    let select_columns = build_select_columns(table);
127
128    // Generate find_all
129    code.push_str(&generate_find_all(table, &struct_name, &select_columns));
130
131    // Generate count_all
132    code.push_str(&generate_count_all(table));
133
134    // Generate primary key methods
135    if let Some(pk) = &table.primary_key {
136        code.push_str(&generate_pk_methods(
137            table,
138            pk,
139            &column_map,
140            &struct_name,
141            &select_columns,
142        ));
143    }
144
145    // Generate insert methods
146    code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
147
148    // Generate insert_plain method (individual params)
149    code.push_str(&generate_insert_plain_method(table, &column_map));
150
151    // Generate batch insert method
152    code.push_str(&generate_insert_all_method(table, &struct_name));
153
154    // Generate upsert method
155    code.push_str(&generate_upsert_method(table, &struct_name));
156
157    // Generate update methods
158    if table.primary_key.is_some() {
159        code.push_str(&generate_update_methods(table, &struct_name, &column_map));
160        // Generate update_plain method (individual params)
161        code.push_str(&generate_update_plain_method(table, &column_map));
162    }
163
164    // Generate index-aware findBy methods
165    let signatures = collect_method_signatures(table);
166    for sig in signatures.values() {
167        // Skip if this is the primary key (already generated separately)
168        if sig.source == "PRIMARY_KEY" {
169            continue;
170        }
171        code.push_str(&generate_find_by_method(
172            table,
173            sig,
174            &column_map,
175            &struct_name,
176            &select_columns,
177        ));
178    }
179
180    // Generate list-based findBy methods for single-column indexes
181    code.push_str(&generate_find_by_list_methods(
182        table,
183        &column_map,
184        &struct_name,
185        &select_columns,
186    ));
187
188    // Generate composite enum list methods (e.g., find_by_user_id_and_device_types)
189    code.push_str(&generate_composite_enum_list_methods(
190        table,
191        &column_map,
192        &struct_name,
193        &select_columns,
194    ));
195
196    // Generate pagination methods
197    code.push_str(&generate_pagination_methods(
198        table,
199        &struct_name,
200        &select_columns,
201        models_module,
202    ));
203
204    fs::write(output_dir.join(&file_name), code)?;
205    Ok(())
206}
207
208/// Build the SELECT columns list
209fn build_select_columns(table: &TableMetadata) -> String {
210    table
211        .columns
212        .iter()
213        .map(|c| format!("`{}`", c.name))
214        .collect::<Vec<_>>()
215        .join(", ")
216}
217
218/// Build the WHERE clause for columns
219fn build_where_clause(columns: &[String]) -> String {
220    columns
221        .iter()
222        .map(|c| format!("`{}` = ?", c))
223        .collect::<Vec<_>>()
224        .join(" AND ")
225}
226
227/// Build parameter list for a function signature
228fn build_params(
229    columns: &[String],
230    column_map: &HashMap<&str, &ColumnMetadata>,
231    table_name: &str,
232) -> String {
233    columns
234        .iter()
235        .map(|c| {
236            let col = column_map.get(c.as_str()).unwrap();
237            let rust_type = TypeResolver::resolve(col, table_name);
238            let param_type = rust_type.to_param_type_string();
239            format!("{}: {}", escape_field_name(c), param_type)
240        })
241        .collect::<Vec<_>>()
242        .join(", ")
243}
244
245/// Generate bind calls for query
246fn generate_bind_section(columns: &[String]) -> String {
247    columns
248        .iter()
249        .map(|c| format!("        .bind({})", escape_field_name(c)))
250        .collect::<Vec<_>>()
251        .join("\n")
252}
253
254/// Generate find_all function
255fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
256    format!(
257        r#"/// Find all records
258pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
259    Query::new("SELECT {select_columns} FROM `{table_name}`")
260        .fetch_all(pool)
261        .await
262}}
263
264"#,
265        struct_name = struct_name,
266        select_columns = select_columns,
267        table_name = table.name,
268    )
269}
270
271/// Generate count_all function
272fn generate_count_all(table: &TableMetadata) -> String {
273    format!(
274        r#"/// Count all records
275pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
276    Query::new("SELECT COUNT(*) FROM `{table_name}`")
277        .fetch_scalar(pool)
278        .await
279}}
280
281"#,
282        table_name = table.name,
283    )
284}
285
286/// Generate primary key methods (find, delete)
287fn generate_pk_methods(
288    table: &TableMetadata,
289    pk: &crate::parser::PrimaryKey,
290    column_map: &HashMap<&str, &ColumnMetadata>,
291    struct_name: &str,
292    select_columns: &str,
293) -> String {
294    let mut code = String::new();
295
296    let method_name = generate_find_by_method_name(&pk.columns);
297    let params = build_params(&pk.columns, column_map, &table.name);
298    let where_clause = build_where_clause(&pk.columns);
299    let bind_section = generate_bind_section(&pk.columns);
300
301    // find_by_pk
302    code.push_str(&format!(
303        r#"/// Find by primary key
304pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
305    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
306{bind_section}
307        .fetch_optional(pool)
308        .await
309}}
310
311"#,
312        method_name = method_name,
313        params = params,
314        struct_name = struct_name,
315        select_columns = select_columns,
316        table_name = table.name,
317        where_clause = where_clause,
318        bind_section = bind_section,
319    ));
320
321    // delete_by_pk
322    let delete_method = generate_delete_by_method_name(&pk.columns);
323    code.push_str(&format!(
324        r#"/// Delete by primary key
325pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
326    Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
327{bind_section}
328        .execute(pool)
329        .await
330        .map(|r| r.rows_affected)
331}}
332
333"#,
334        delete_method = delete_method,
335        params = params,
336        table_name = table.name,
337        where_clause = where_clause,
338        bind_section = bind_section,
339    ));
340
341    code
342}
343
344/// Generate insert methods
345fn generate_insert_methods(
346    table: &TableMetadata,
347    struct_name: &str,
348    _column_map: &HashMap<&str, &ColumnMetadata>,
349) -> String {
350    let mut code = String::new();
351
352    // Get non-auto-increment columns for insert
353    let insert_columns: Vec<&ColumnMetadata> = table
354        .columns
355        .iter()
356        .filter(|c| !c.is_auto_increment)
357        .collect();
358
359    if insert_columns.is_empty() {
360        return code;
361    }
362
363    let column_list = insert_columns
364        .iter()
365        .map(|c| format!("`{}`", c.name))
366        .collect::<Vec<_>>()
367        .join(", ");
368
369    let placeholders = insert_columns
370        .iter()
371        .map(|_| "?")
372        .collect::<Vec<_>>()
373        .join(", ");
374
375    let bind_fields = insert_columns
376        .iter()
377        .map(|c| {
378            let field = escape_field_name(&c.name);
379            let rust_type = TypeResolver::resolve(c, &table.name);
380            if rust_type.is_copy() {
381                format!("        .bind(entity.{})", field)
382            } else {
383                format!("        .bind(&entity.{})", field)
384            }
385        })
386        .collect::<Vec<_>>()
387        .join("\n");
388
389    // insert (with entity)
390    code.push_str(&format!(
391        r#"/// Insert a new record
392pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
393    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
394{bind_fields}
395        .execute(pool)
396        .await
397        .map(|r| r.last_insert_id.unwrap_or(0))
398}}
399
400"#,
401        struct_name = struct_name,
402        table_name = table.name,
403        column_list = column_list,
404        placeholders = placeholders,
405        bind_fields = bind_fields,
406    ));
407
408    code
409}
410
411/// Generate insert_plain method with individual parameters
412fn generate_insert_plain_method(
413    table: &TableMetadata,
414    column_map: &HashMap<&str, &ColumnMetadata>,
415) -> String {
416    // Get non-auto-increment columns for insert
417    let insert_columns: Vec<&ColumnMetadata> = table
418        .columns
419        .iter()
420        .filter(|c| !c.is_auto_increment)
421        .collect();
422
423    if insert_columns.is_empty() {
424        return String::new();
425    }
426
427    let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
428    let params = build_params(&column_names, column_map, &table.name);
429
430    let column_list = insert_columns
431        .iter()
432        .map(|c| format!("`{}`", c.name))
433        .collect::<Vec<_>>()
434        .join(", ");
435
436    let placeholders = insert_columns
437        .iter()
438        .map(|_| "?")
439        .collect::<Vec<_>>()
440        .join(", ");
441
442    let bind_section = generate_bind_section(&column_names);
443
444    format!(
445        r#"/// Insert a new record with individual parameters
446#[allow(clippy::too_many_arguments)]
447pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
448    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
449{bind_section}
450        .execute(pool)
451        .await
452        .map(|r| r.last_insert_id.unwrap_or(0))
453}}
454
455"#,
456        params = params,
457        table_name = table.name,
458        column_list = column_list,
459        placeholders = placeholders,
460        bind_section = bind_section,
461    )
462}
463
464/// Generate batch insert method (insert_all) using BatchInsert
465fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
466    // Get non-auto-increment columns for insert
467    let insert_columns: Vec<&ColumnMetadata> = table
468        .columns
469        .iter()
470        .filter(|c| !c.is_auto_increment)
471        .collect();
472
473    if insert_columns.is_empty() {
474        return String::new();
475    }
476
477    format!(
478        r#"/// Insert multiple records in a single batch
479pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
480    rdbi::BatchInsert::new("{table_name}", entities)
481        .execute(pool)
482        .await
483        .map(|r| r.rows_affected)
484}}
485
486"#,
487        struct_name = struct_name,
488        table_name = table.name,
489    )
490}
491
492/// Check if table has a unique index (excluding primary key)
493fn has_unique_index(table: &TableMetadata) -> bool {
494    table.indexes.iter().any(|idx| idx.unique)
495}
496
497/// Generate upsert method (INSERT ... ON DUPLICATE KEY UPDATE)
498fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
499    // Only generate if table has primary key or unique index
500    if table.primary_key.is_none() && !has_unique_index(table) {
501        return String::new();
502    }
503
504    // Get non-auto-increment columns for insert
505    let insert_columns: Vec<&ColumnMetadata> = table
506        .columns
507        .iter()
508        .filter(|c| !c.is_auto_increment)
509        .collect();
510
511    if insert_columns.is_empty() {
512        return String::new();
513    }
514
515    // Get primary key columns for exclusion from UPDATE clause
516    let pk_columns: HashSet<&str> = table
517        .primary_key
518        .as_ref()
519        .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
520        .unwrap_or_default();
521
522    // Columns to update on duplicate key (all non-PK, non-auto-increment columns)
523    let update_columns: Vec<&ColumnMetadata> = insert_columns
524        .iter()
525        .filter(|c| !pk_columns.contains(c.name.as_str()))
526        .copied()
527        .collect();
528
529    // If no columns to update, skip upsert generation
530    if update_columns.is_empty() {
531        return String::new();
532    }
533
534    let column_list = insert_columns
535        .iter()
536        .map(|c| format!("`{}`", c.name))
537        .collect::<Vec<_>>()
538        .join(", ");
539
540    let placeholders = insert_columns
541        .iter()
542        .map(|_| "?")
543        .collect::<Vec<_>>()
544        .join(", ");
545
546    let update_clause = update_columns
547        .iter()
548        .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
549        .collect::<Vec<_>>()
550        .join(", ");
551
552    let bind_fields = insert_columns
553        .iter()
554        .map(|c| {
555            let field = escape_field_name(&c.name);
556            let rust_type = TypeResolver::resolve(c, &table.name);
557            if rust_type.is_copy() {
558                format!("        .bind(entity.{})", field)
559            } else {
560                format!("        .bind(&entity.{})", field)
561            }
562        })
563        .collect::<Vec<_>>()
564        .join("\n");
565
566    format!(
567        r#"/// Upsert a record (insert or update on duplicate key)
568/// Returns rows_affected: 1 if inserted, 2 if updated
569pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
570    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
571         ON DUPLICATE KEY UPDATE {update_clause}")
572{bind_fields}
573        .execute(pool)
574        .await
575        .map(|r| r.rows_affected)
576}}
577
578"#,
579        struct_name = struct_name,
580        table_name = table.name,
581        column_list = column_list,
582        placeholders = placeholders,
583        update_clause = update_clause,
584        bind_fields = bind_fields,
585    )
586}
587
588/// Generate update methods
589fn generate_update_methods(
590    table: &TableMetadata,
591    struct_name: &str,
592    column_map: &HashMap<&str, &ColumnMetadata>,
593) -> String {
594    let mut code = String::new();
595
596    let pk = table.primary_key.as_ref().unwrap();
597
598    // Get non-PK columns for SET clause
599    let update_columns: Vec<&ColumnMetadata> = table
600        .columns
601        .iter()
602        .filter(|c| !pk.columns.contains(&c.name))
603        .collect();
604
605    if update_columns.is_empty() {
606        return code;
607    }
608
609    let set_clause = update_columns
610        .iter()
611        .map(|c| format!("`{}` = ?", c.name))
612        .collect::<Vec<_>>()
613        .join(", ");
614
615    let where_clause = build_where_clause(&pk.columns);
616
617    // Bind update columns first, then PK columns
618    let bind_fields: Vec<String> = update_columns
619        .iter()
620        .map(|c| {
621            let field = escape_field_name(&c.name);
622            let rust_type = TypeResolver::resolve(c, &table.name);
623            if rust_type.is_copy() {
624                format!("        .bind(entity.{})", field)
625            } else {
626                format!("        .bind(&entity.{})", field)
627            }
628        })
629        .chain(pk.columns.iter().map(|c| {
630            let field = escape_field_name(c);
631            let col = column_map.get(c.as_str()).unwrap();
632            let rust_type = TypeResolver::resolve(col, &table.name);
633            if rust_type.is_copy() {
634                format!("        .bind(entity.{})", field)
635            } else {
636                format!("        .bind(&entity.{})", field)
637            }
638        }))
639        .collect();
640
641    // update_by_bean (using entity)
642    code.push_str(&format!(
643        r#"/// Update a record by primary key
644pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
645    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
646{bind_fields}
647        .execute(pool)
648        .await
649        .map(|r| r.rows_affected)
650}}
651
652"#,
653        struct_name = struct_name,
654        table_name = table.name,
655        set_clause = set_clause,
656        where_clause = where_clause,
657        bind_fields = bind_fields.join("\n"),
658    ));
659
660    code
661}
662
663/// Generate update_plain method with individual parameters (update_by_<pk>)
664fn generate_update_plain_method(
665    table: &TableMetadata,
666    column_map: &HashMap<&str, &ColumnMetadata>,
667) -> String {
668    let pk = table.primary_key.as_ref().unwrap();
669
670    // Get non-PK columns for SET clause
671    let update_columns: Vec<&ColumnMetadata> = table
672        .columns
673        .iter()
674        .filter(|c| !pk.columns.contains(&c.name))
675        .collect();
676
677    if update_columns.is_empty() {
678        return String::new();
679    }
680
681    // Build method name based on PK columns
682    let method_name = generate_update_by_method_name(&pk.columns);
683
684    // Build params: PK columns first, then update columns
685    let pk_params = build_params(&pk.columns, column_map, &table.name);
686    let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
687    let update_params = build_params(&update_column_names, column_map, &table.name);
688    let all_params = format!("{}, {}", pk_params, update_params);
689
690    let set_clause = update_columns
691        .iter()
692        .map(|c| format!("`{}` = ?", c.name))
693        .collect::<Vec<_>>()
694        .join(", ");
695
696    let where_clause = build_where_clause(&pk.columns);
697
698    // Bind update columns first (for SET), then PK columns (for WHERE)
699    let bind_section_update = generate_bind_section(&update_column_names);
700    let bind_section_pk = generate_bind_section(&pk.columns);
701
702    format!(
703        r#"/// Update a record by primary key with individual parameters
704#[allow(clippy::too_many_arguments)]
705pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
706    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
707{bind_section_update}
708{bind_section_pk}
709        .execute(pool)
710        .await
711        .map(|r| r.rows_affected)
712}}
713
714"#,
715        method_name = method_name,
716        all_params = all_params,
717        table_name = table.name,
718        set_clause = set_clause,
719        where_clause = where_clause,
720        bind_section_update = bind_section_update,
721        bind_section_pk = bind_section_pk,
722    )
723}
724
725/// Collect method signatures with priority-based deduplication
726fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
727    let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
728
729    // Priority 1: Primary key
730    if let Some(pk) = &table.primary_key {
731        let sig = MethodSignature::new(
732            pk.columns.clone(),
733            PRIORITY_PRIMARY_KEY,
734            true,
735            "PRIMARY_KEY",
736        );
737        signatures.insert(pk.columns.clone(), sig);
738    }
739
740    // Priority 2 & 3: Indexes
741    for index in &table.indexes {
742        let priority = if index.unique {
743            PRIORITY_UNIQUE_INDEX
744        } else {
745            PRIORITY_NON_UNIQUE_INDEX
746        };
747        let source = if index.unique {
748            "UNIQUE_INDEX"
749        } else {
750            "NON_UNIQUE_INDEX"
751        };
752        let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
753
754        // Only add if no higher priority signature exists
755        if let Some(existing) = signatures.get(&index.columns) {
756            if sig.priority < existing.priority {
757                signatures.insert(index.columns.clone(), sig);
758            }
759        } else {
760            signatures.insert(index.columns.clone(), sig);
761        }
762    }
763
764    // Priority 4: Foreign keys
765    for fk in &table.foreign_keys {
766        let columns = vec![fk.column_name.clone()];
767        let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
768
769        signatures.entry(columns).or_insert(sig);
770    }
771
772    signatures
773}
774
775/// Generate a find_by method for an index/FK
776fn generate_find_by_method(
777    table: &TableMetadata,
778    sig: &MethodSignature,
779    column_map: &HashMap<&str, &ColumnMetadata>,
780    struct_name: &str,
781    select_columns: &str,
782) -> String {
783    let params = build_params(&sig.columns, column_map, &table.name);
784
785    let (return_type, fetch_method) = if sig.is_unique {
786        (format!("Option<{}>", struct_name), "fetch_optional")
787    } else {
788        (format!("Vec<{}>", struct_name), "fetch_all")
789    };
790
791    let return_desc = if sig.is_unique {
792        "Option (unique)"
793    } else {
794        "Vec (non-unique)"
795    };
796
797    // Check if any column is nullable - if so, we need dynamic SQL for IS NULL handling
798    let has_nullable = sig.columns.iter().any(|c| {
799        column_map
800            .get(c.as_str())
801            .map(|col| col.nullable)
802            .unwrap_or(false)
803    });
804
805    if has_nullable {
806        // Generate dynamic query that handles NULL values properly
807        generate_find_by_method_nullable(
808            table,
809            sig,
810            column_map,
811            select_columns,
812            &params,
813            &return_type,
814            fetch_method,
815            return_desc,
816        )
817    } else {
818        // Use static query for non-nullable columns
819        let where_clause = build_where_clause(&sig.columns);
820        let bind_section = generate_bind_section(&sig.columns);
821
822        format!(
823            r#"/// Find by {source}: returns {return_desc}
824pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
825    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
826{bind_section}
827        .{fetch_method}(pool)
828        .await
829}}
830
831"#,
832            source = sig.source.to_lowercase().replace('_', " "),
833            return_desc = return_desc,
834            method_name = sig.method_name,
835            params = params,
836            return_type = return_type,
837            select_columns = select_columns,
838            table_name = table.name,
839            where_clause = where_clause,
840            bind_section = bind_section,
841            fetch_method = fetch_method,
842        )
843    }
844}
845
846/// Generate a find_by method that handles nullable columns with IS NULL
847#[allow(clippy::too_many_arguments)]
848fn generate_find_by_method_nullable(
849    table: &TableMetadata,
850    sig: &MethodSignature,
851    column_map: &HashMap<&str, &ColumnMetadata>,
852    select_columns: &str,
853    params: &str,
854    return_type: &str,
855    fetch_method: &str,
856    return_desc: &str,
857) -> String {
858    // Build the where clause conditions and bind logic
859    let mut where_parts = Vec::new();
860    let mut bind_parts = Vec::new();
861
862    for col in &sig.columns {
863        let col_meta = column_map.get(col.as_str()).unwrap();
864        let field_name = escape_field_name(col);
865
866        if col_meta.nullable {
867            // For nullable columns, check if value is None and use IS NULL
868            where_parts.push(format!(
869                r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
870                field = field_name,
871                col = col,
872            ));
873            bind_parts.push(format!(
874                r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
875                field = field_name,
876            ));
877        } else {
878            where_parts.push(format!(r#""`{}` = ?""#, col));
879            bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
880        }
881    }
882
883    let where_expr = if where_parts.len() == 1 {
884        where_parts[0].clone()
885    } else {
886        // Join with " AND "
887        let parts = where_parts
888            .iter()
889            .map(|p| format!("({})", p))
890            .collect::<Vec<_>>()
891            .join(", ");
892        format!("vec![{}].join(\" AND \")", parts)
893    };
894
895    let bind_code = bind_parts.join("\n        ");
896
897    format!(
898        r#"/// Find by {source}: returns {return_desc}
899pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
900    let where_clause = {where_expr};
901    let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
902    let mut query = rdbi::DynamicQuery::new(sql);
903    {bind_code}
904    query.{fetch_method}(pool).await
905}}
906
907"#,
908        source = sig.source.to_lowercase().replace('_', " "),
909        return_desc = return_desc,
910        method_name = sig.method_name,
911        params = params,
912        return_type = return_type,
913        select_columns = select_columns,
914        table_name = table.name,
915        where_expr = where_expr,
916        bind_code = bind_code,
917        fetch_method = fetch_method,
918    )
919}
920
921/// Generate list-based findBy methods for single-column indexes
922fn generate_find_by_list_methods(
923    table: &TableMetadata,
924    column_map: &HashMap<&str, &ColumnMetadata>,
925    struct_name: &str,
926    select_columns: &str,
927) -> String {
928    let mut code = String::new();
929    let mut processed: HashSet<String> = HashSet::new();
930
931    // Primary key (if single column)
932    if let Some(pk) = &table.primary_key {
933        if pk.columns.len() == 1 {
934            let col = &pk.columns[0];
935            code.push_str(&generate_single_find_by_list(
936                table,
937                col,
938                column_map,
939                struct_name,
940                select_columns,
941            ));
942            processed.insert(col.clone());
943        }
944    }
945
946    // Single-column indexes
947    for index in &table.indexes {
948        if index.columns.len() == 1 {
949            let col = &index.columns[0];
950            if !processed.contains(col) {
951                code.push_str(&generate_single_find_by_list(
952                    table,
953                    col,
954                    column_map,
955                    struct_name,
956                    select_columns,
957                ));
958                processed.insert(col.clone());
959            }
960        }
961    }
962
963    code
964}
965
966/// Generate a single find_by_<column>s method (using IN clause)
967fn generate_single_find_by_list(
968    table: &TableMetadata,
969    column_name: &str,
970    column_map: &HashMap<&str, &ColumnMetadata>,
971    struct_name: &str,
972    select_columns: &str,
973) -> String {
974    let method_name = generate_find_by_list_method_name(column_name);
975    let param_name = pluralize(&escape_field_name(column_name));
976    let column = column_map.get(column_name).unwrap();
977    let rust_type = TypeResolver::resolve(column, &table.name);
978
979    // Get the inner type (unwrap Option if nullable)
980    let inner_type = rust_type.inner_type().to_type_string();
981
982    let column_name_plural = pluralize(column_name);
983    format!(
984        r#"/// Find by list of {column_name_plural} (IN clause)
985pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
986    if {param_name}.is_empty() {{
987        return Ok(Vec::new());
988    }}
989    let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
990    let query = format!(
991        "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
992        placeholders
993    );
994    rdbi::DynamicQuery::new(query)
995        .bind_all({param_name})
996        .fetch_all(pool)
997        .await
998}}
999
1000"#,
1001        column_name_plural = column_name_plural,
1002        column_name = column_name,
1003        method_name = method_name,
1004        param_name = param_name,
1005        inner_type = inner_type,
1006        struct_name = struct_name,
1007        select_columns = select_columns,
1008        table_name = table.name,
1009    )
1010}
1011
1012/// Generate composite enum list methods for multi-column indexes with enum columns
1013/// Example: find_by_user_id_and_device_types(user_id, &[DeviceType])
1014fn generate_composite_enum_list_methods(
1015    table: &TableMetadata,
1016    column_map: &HashMap<&str, &ColumnMetadata>,
1017    struct_name: &str,
1018    select_columns: &str,
1019) -> String {
1020    let mut code = String::new();
1021
1022    for index in &table.indexes {
1023        // Skip single-column indexes (handled by generate_find_by_list_methods)
1024        if index.columns.len() <= 1 {
1025            continue;
1026        }
1027
1028        // Identify which columns are enums
1029        let enum_columns: HashSet<&str> = index
1030            .columns
1031            .iter()
1032            .filter(|col_name| {
1033                column_map
1034                    .get(col_name.as_str())
1035                    .map(|col| col.is_enum())
1036                    .unwrap_or(false)
1037            })
1038            .map(|s| s.as_str())
1039            .collect();
1040
1041        // Skip if no enum columns
1042        if enum_columns.is_empty() {
1043            continue;
1044        }
1045
1046        // Skip if first column is enum (for optimal index usage, equality should be on leading column)
1047        let first_column = &index.columns[0];
1048        if enum_columns.contains(first_column.as_str()) {
1049            continue;
1050        }
1051
1052        code.push_str(&generate_composite_enum_list_method(
1053            table,
1054            &index.columns,
1055            &enum_columns,
1056            column_map,
1057            struct_name,
1058            select_columns,
1059        ));
1060    }
1061
1062    code
1063}
1064
1065/// Generate a single composite enum list method
1066fn generate_composite_enum_list_method(
1067    table: &TableMetadata,
1068    columns: &[String],
1069    enum_columns: &HashSet<&str>,
1070    column_map: &HashMap<&str, &ColumnMetadata>,
1071    struct_name: &str,
1072    select_columns: &str,
1073) -> String {
1074    // Build method name: pluralize enum column names
1075    let method_name = generate_composite_enum_method_name(columns, enum_columns);
1076
1077    // Build params and WHERE clause parts
1078    let mut params_parts = Vec::new();
1079
1080    for col_name in columns {
1081        let col = column_map.get(col_name.as_str()).unwrap();
1082        let rust_type = TypeResolver::resolve(col, &table.name);
1083        let is_enum = enum_columns.contains(col_name.as_str());
1084
1085        if is_enum {
1086            // Enum column uses list parameter
1087            let param_name = pluralize(&escape_field_name(col_name));
1088            let inner_type = rust_type.inner_type().to_type_string();
1089            params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1090        } else {
1091            // Non-enum column uses single value
1092            let param_name = escape_field_name(col_name);
1093            let param_type = rust_type.to_param_type_string();
1094            params_parts.push(format!("{}: {}", param_name, param_type));
1095        }
1096    }
1097
1098    let params = params_parts.join(", ");
1099
1100    // Build WHERE clause with proper placeholders
1101    let where_clause_static: Vec<String> = columns
1102        .iter()
1103        .map(|col_name| {
1104            if enum_columns.contains(col_name.as_str()) {
1105                format!("`{}` IN ({{}})", col_name) // placeholder for IN clause
1106            } else {
1107                format!("`{}` = ?", col_name)
1108            }
1109        })
1110        .collect();
1111
1112    // Build the bind section
1113    let mut bind_code = String::new();
1114
1115    // First bind non-enum (single value) columns
1116    for col_name in columns {
1117        if !enum_columns.contains(col_name.as_str()) {
1118            let param_name = escape_field_name(col_name);
1119            bind_code.push_str(&format!("        .bind({})\n", param_name));
1120        }
1121    }
1122
1123    // Then bind enum (list) columns
1124    for col_name in columns {
1125        if enum_columns.contains(col_name.as_str()) {
1126            let param_name = pluralize(&escape_field_name(col_name));
1127            bind_code.push_str(&format!("        .bind_all({})\n", param_name));
1128        }
1129    }
1130
1131    // Build the column name description for doc comment
1132    let column_desc: Vec<String> = columns
1133        .iter()
1134        .map(|col| {
1135            if enum_columns.contains(col.as_str()) {
1136                pluralize(col)
1137            } else {
1138                col.clone()
1139            }
1140        })
1141        .collect();
1142
1143    // Build dynamic WHERE clause construction
1144    let enum_col_names: Vec<&str> = columns
1145        .iter()
1146        .filter(|c| enum_columns.contains(c.as_str()))
1147        .map(|s| s.as_str())
1148        .collect();
1149
1150    // Generate the IN clause placeholders dynamically
1151    let in_clause_builders: Vec<String> = enum_col_names
1152        .iter()
1153        .map(|col| {
1154            let param_name = pluralize(&escape_field_name(col));
1155            format!(
1156                "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1157                param_name = param_name
1158            )
1159        })
1160        .collect();
1161
1162    // Build format args for the WHERE clause
1163    let format_args = in_clause_builders.join(", ");
1164
1165    format!(
1166        r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1167pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1168    // Check for empty enum lists
1169{empty_checks}
1170    // Build IN clause placeholders for enum columns
1171    let where_clause = format!("{where_template}", {format_args});
1172    let query = format!(
1173        "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1174        where_clause
1175    );
1176    rdbi::DynamicQuery::new(query)
1177{bind_code}        .fetch_all(pool)
1178        .await
1179}}
1180
1181"#,
1182        column_desc = column_desc.join(" and "),
1183        method_name = method_name,
1184        params = params,
1185        struct_name = struct_name,
1186        select_columns = select_columns,
1187        table_name = table.name,
1188        where_template = where_clause_static.join(" AND "),
1189        format_args = format_args,
1190        bind_code = bind_code,
1191        empty_checks = generate_empty_checks(columns, enum_columns),
1192    )
1193}
1194
1195/// Generate empty checks for enum list parameters
1196fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1197    let mut checks = String::new();
1198    for col_name in columns {
1199        if enum_columns.contains(col_name.as_str()) {
1200            let param_name = pluralize(&escape_field_name(col_name));
1201            checks.push_str(&format!(
1202                "    if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1203                param_name
1204            ));
1205        }
1206    }
1207    checks
1208}
1209
1210/// Generate method name for composite enum queries
1211/// Example: ["user_id", "device_type"] with enum_columns={"device_type"} -> "find_by_user_id_and_device_types"
1212fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1213    let mut parts = Vec::new();
1214    for col in columns {
1215        if enum_columns.contains(col.as_str()) {
1216            parts.push(pluralize(col));
1217        } else {
1218            parts.push(col.clone());
1219        }
1220    }
1221    generate_find_by_method_name(&parts)
1222}
1223
1224/// Generate pagination methods (find_all_paginated, get_paginated_result)
1225fn generate_pagination_methods(
1226    table: &TableMetadata,
1227    struct_name: &str,
1228    select_columns: &str,
1229    models_module: &str,
1230) -> String {
1231    let sort_by_enum = format!("{}SortBy", struct_name);
1232
1233    format!(
1234        r#"/// Find all records with pagination and sorting
1235pub async fn find_all_paginated<P: Pool>(
1236    pool: &P,
1237    limit: i32,
1238    offset: i32,
1239    sort_by: crate::{models_module}::{sort_by_enum},
1240    sort_dir: crate::{models_module}::SortDirection,
1241) -> Result<Vec<{struct_name}>> {{
1242    let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1243    let query = format!(
1244        "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1245        order_clause
1246    );
1247    rdbi::DynamicQuery::new(query)
1248        .bind(limit)
1249        .bind(offset)
1250        .fetch_all(pool)
1251        .await
1252}}
1253
1254/// Get paginated result with total count
1255pub async fn get_paginated_result<P: Pool>(
1256    pool: &P,
1257    page_size: i32,
1258    current_page: i32,
1259    sort_by: crate::{models_module}::{sort_by_enum},
1260    sort_dir: crate::{models_module}::SortDirection,
1261) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1262    let page_size = page_size.max(1);
1263    let current_page = current_page.max(1);
1264    let offset = (current_page - 1) * page_size;
1265
1266    let total_count = count_all(pool).await?;
1267    let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1268
1269    Ok(crate::{models_module}::PaginatedResult::new(
1270        items,
1271        total_count,
1272        current_page,
1273        page_size,
1274    ))
1275}}
1276
1277"#,
1278        struct_name = struct_name,
1279        select_columns = select_columns,
1280        table_name = table.name,
1281        models_module = models_module,
1282        sort_by_enum = sort_by_enum,
1283    )
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289    use crate::parser::{IndexMetadata, PrimaryKey};
1290
1291    fn make_table() -> TableMetadata {
1292        TableMetadata {
1293            name: "users".to_string(),
1294            comment: None,
1295            columns: vec![
1296                ColumnMetadata {
1297                    name: "id".to_string(),
1298                    data_type: "BIGINT".to_string(),
1299                    nullable: false,
1300                    default_value: None,
1301                    is_auto_increment: true,
1302                    is_unsigned: false,
1303                    enum_values: None,
1304                    comment: None,
1305                },
1306                ColumnMetadata {
1307                    name: "email".to_string(),
1308                    data_type: "VARCHAR(255)".to_string(),
1309                    nullable: false,
1310                    default_value: None,
1311                    is_auto_increment: false,
1312                    is_unsigned: false,
1313                    enum_values: None,
1314                    comment: None,
1315                },
1316                ColumnMetadata {
1317                    name: "status".to_string(),
1318                    data_type: "ENUM".to_string(),
1319                    nullable: false,
1320                    default_value: None,
1321                    is_auto_increment: false,
1322                    is_unsigned: false,
1323                    enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1324                    comment: None,
1325                },
1326            ],
1327            indexes: vec![
1328                IndexMetadata {
1329                    name: "email_unique".to_string(),
1330                    columns: vec!["email".to_string()],
1331                    unique: true,
1332                },
1333                IndexMetadata {
1334                    name: "idx_status".to_string(),
1335                    columns: vec!["status".to_string()],
1336                    unique: false,
1337                },
1338            ],
1339            foreign_keys: vec![],
1340            primary_key: Some(PrimaryKey {
1341                columns: vec!["id".to_string()],
1342            }),
1343        }
1344    }
1345
1346    #[test]
1347    fn test_collect_method_signatures() {
1348        let table = make_table();
1349        let sigs = collect_method_signatures(&table);
1350
1351        // Should have signatures for: id (PK), email (unique), status (non-unique)
1352        assert_eq!(sigs.len(), 3);
1353
1354        let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1355        assert!(id_sig.is_unique);
1356        assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1357
1358        let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1359        assert!(email_sig.is_unique);
1360        assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1361
1362        let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1363        assert!(!status_sig.is_unique);
1364        assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1365    }
1366
1367    #[test]
1368    fn test_build_select_columns() {
1369        let table = make_table();
1370        let cols = build_select_columns(&table);
1371        assert!(cols.contains("`id`"));
1372        assert!(cols.contains("`email`"));
1373        assert!(cols.contains("`status`"));
1374    }
1375
1376    #[test]
1377    fn test_build_where_clause() {
1378        let clause = build_where_clause(&["id".to_string()]);
1379        assert_eq!(clause, "`id` = ?");
1380
1381        let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1382        assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1383    }
1384
1385    #[test]
1386    fn test_generate_upsert_method() {
1387        let table = make_table();
1388        let code = generate_upsert_method(&table, "Users");
1389
1390        // Should contain upsert function
1391        assert!(code.contains("pub async fn upsert"));
1392        // Should contain ON DUPLICATE KEY UPDATE
1393        assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1394        // Should NOT update the PK column (id)
1395        assert!(!code.contains("`id` = VALUES(`id`)"));
1396        // Should update non-PK columns
1397        assert!(code.contains("`email` = VALUES(`email`)"));
1398        assert!(code.contains("`status` = VALUES(`status`)"));
1399    }
1400
1401    #[test]
1402    fn test_generate_upsert_method_no_pk() {
1403        let mut table = make_table();
1404        table.primary_key = None;
1405        table.indexes.clear();
1406
1407        let code = generate_upsert_method(&table, "Users");
1408        // Should not generate upsert without PK or unique index
1409        assert!(code.is_empty());
1410    }
1411
1412    #[test]
1413    fn test_generate_insert_all_method() {
1414        let table = make_table();
1415        let code = generate_insert_all_method(&table, "Users");
1416
1417        // Should contain insert_all function
1418        assert!(code.contains("pub async fn insert_all"));
1419        // Should use BatchInsert
1420        assert!(code.contains("rdbi::BatchInsert::new"));
1421    }
1422
1423    #[test]
1424    fn test_generate_pagination_methods() {
1425        let table = make_table();
1426        let select_columns = build_select_columns(&table);
1427        let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1428
1429        // Should contain find_all_paginated function
1430        assert!(code.contains("pub async fn find_all_paginated"));
1431        // Should have limit and offset params
1432        assert!(code.contains("limit: i32"));
1433        assert!(code.contains("offset: i32"));
1434        // Should use SortBy enum
1435        assert!(code.contains("UsersSortBy"));
1436        // Should use SortDirection
1437        assert!(code.contains("SortDirection"));
1438        // Should contain get_paginated_result
1439        assert!(code.contains("pub async fn get_paginated_result"));
1440        // Should use PaginatedResult
1441        assert!(code.contains("PaginatedResult<Users>"));
1442    }
1443}