Skip to main content

rdbi_codegen/codegen/
dao_generator.rs

1//! DAO generator - generates async rdbi query functions from table metadata
2
3use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13    escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14    generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18/// Priority levels for method signature deduplication
19const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24/// Represents a method signature for deduplication
25#[derive(Debug, Clone)]
26struct MethodSignature {
27    columns: Vec<String>,
28    method_name: String,
29    priority: u8,
30    is_unique: bool,
31    source: String,
32}
33
34impl MethodSignature {
35    fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36        let method_name = generate_find_by_method_name(&columns);
37        Self {
38            columns,
39            method_name,
40            priority,
41            is_unique,
42            source: source.to_string(),
43        }
44    }
45}
46
47/// Generate DAO files for all tables
48pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49    let output_dir = &config.output_dao_dir;
50    fs::create_dir_all(output_dir)?;
51
52    // Generate mod.rs
53    let mut mod_content = String::new();
54    mod_content.push_str("// Generated DAO functions\n\n");
55
56    for table in tables {
57        let file_name = heck::AsSnakeCase(&table.name).to_string();
58        mod_content.push_str(&format!(
59            "#[allow(dead_code, clippy::all)]\npub mod {};\n",
60            file_name
61        ));
62    }
63
64    let mod_path = output_dir.join("mod.rs");
65    fs::write(&mod_path, mod_content)?;
66
67    // Generate each DAO file
68    for table in tables {
69        generate_dao_file(table, output_dir, &config.models_module)?;
70    }
71
72    Ok(())
73}
74
75/// Generate a single DAO file for a table
76fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
77    let struct_name = to_struct_name(&table.name);
78    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
79    debug!("Generating DAO for {} -> {}", struct_name, file_name);
80
81    let mut code = String::new();
82
83    // Build column map for quick lookup
84    let column_map: HashMap<&str, &ColumnMetadata> =
85        table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
86
87    // Collect import requirements
88    let mut needs_chrono = false;
89    let mut needs_decimal = false;
90
91    for col in &table.columns {
92        let rust_type = TypeResolver::resolve(col, &table.name);
93        if rust_type.needs_chrono() {
94            needs_chrono = true;
95        }
96        if rust_type.needs_decimal() {
97            needs_decimal = true;
98        }
99    }
100
101    // Check if we have enum columns for imports
102    let has_enums = table.columns.iter().any(|c| c.is_enum());
103
104    // Generate imports
105    code.push_str("use rdbi::{Pool, Query, Result};\n");
106    code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
107
108    if has_enums {
109        // Import enum types from the same struct module
110        for col in &table.columns {
111            if col.is_enum() {
112                let enum_name = super::naming::to_enum_name(&table.name, &col.name);
113                code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
114            }
115        }
116    }
117
118    if needs_chrono {
119        code.push_str("#[allow(unused_imports)]\n");
120        code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
121    }
122    if needs_decimal {
123        code.push_str("#[allow(unused_imports)]\n");
124        code.push_str("use rust_decimal::Decimal;\n");
125    }
126
127    code.push('\n');
128
129    // Generate select columns list
130    let select_columns = build_select_columns(table);
131
132    // Generate find_all
133    code.push_str(&generate_find_all(table, &struct_name, &select_columns));
134
135    // Generate count_all
136    code.push_str(&generate_count_all(table));
137
138    // Generate primary key methods
139    if let Some(pk) = &table.primary_key {
140        code.push_str(&generate_pk_methods(
141            table,
142            pk,
143            &column_map,
144            &struct_name,
145            &select_columns,
146        ));
147    }
148
149    // Generate insert methods
150    code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
151
152    // Generate insert_plain method (individual params)
153    code.push_str(&generate_insert_plain_method(table, &column_map));
154
155    // Generate batch insert method
156    code.push_str(&generate_insert_all_method(table, &struct_name));
157
158    // Generate upsert method
159    code.push_str(&generate_upsert_method(table, &struct_name));
160
161    // Generate update methods
162    if table.primary_key.is_some() {
163        code.push_str(&generate_update_methods(table, &struct_name, &column_map));
164        // Generate update_plain method (individual params)
165        code.push_str(&generate_update_plain_method(table, &column_map));
166    }
167
168    // Generate index-aware findBy methods
169    let signatures = collect_method_signatures(table);
170    for sig in signatures.values() {
171        // Skip if this is the primary key (already generated separately)
172        if sig.source == "PRIMARY_KEY" {
173            continue;
174        }
175        code.push_str(&generate_find_by_method(
176            table,
177            sig,
178            &column_map,
179            &struct_name,
180            &select_columns,
181        ));
182    }
183
184    // Generate list-based findBy methods for single-column indexes
185    code.push_str(&generate_find_by_list_methods(
186        table,
187        &column_map,
188        &struct_name,
189        &select_columns,
190    ));
191
192    // Generate composite enum list methods (e.g., find_by_user_id_and_device_types)
193    code.push_str(&generate_composite_enum_list_methods(
194        table,
195        &column_map,
196        &struct_name,
197        &select_columns,
198    ));
199
200    // Generate pagination methods
201    code.push_str(&generate_pagination_methods(
202        table,
203        &struct_name,
204        &select_columns,
205        models_module,
206    ));
207
208    let file_path = output_dir.join(&file_name);
209    fs::write(&file_path, code)?;
210    Ok(())
211}
212
213/// Build the SELECT columns list
214fn build_select_columns(table: &TableMetadata) -> String {
215    table
216        .columns
217        .iter()
218        .map(|c| format!("`{}`", c.name))
219        .collect::<Vec<_>>()
220        .join(", ")
221}
222
223/// Build the WHERE clause for columns
224fn build_where_clause(columns: &[String]) -> String {
225    columns
226        .iter()
227        .map(|c| format!("`{}` = ?", c))
228        .collect::<Vec<_>>()
229        .join(" AND ")
230}
231
232/// Build parameter list for a function signature
233fn build_params(
234    columns: &[String],
235    column_map: &HashMap<&str, &ColumnMetadata>,
236    table_name: &str,
237) -> String {
238    columns
239        .iter()
240        .map(|c| {
241            let col = column_map.get(c.as_str()).unwrap();
242            let rust_type = TypeResolver::resolve(col, table_name);
243            let param_type = rust_type.to_param_type_string();
244            format!("{}: {}", escape_field_name(c), param_type)
245        })
246        .collect::<Vec<_>>()
247        .join(", ")
248}
249
250/// Generate bind calls for query
251fn generate_bind_section(columns: &[String]) -> String {
252    columns
253        .iter()
254        .map(|c| format!("        .bind({})", escape_field_name(c)))
255        .collect::<Vec<_>>()
256        .join("\n")
257}
258
259/// Generate find_all function
260fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
261    format!(
262        r#"/// Find all records
263pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
264    Query::new("SELECT {select_columns} FROM `{table_name}`")
265        .fetch_all(pool)
266        .await
267}}
268
269"#,
270        struct_name = struct_name,
271        select_columns = select_columns,
272        table_name = table.name,
273    )
274}
275
276/// Generate count_all function
277fn generate_count_all(table: &TableMetadata) -> String {
278    format!(
279        r#"/// Count all records
280pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
281    Query::new("SELECT COUNT(*) FROM `{table_name}`")
282        .fetch_scalar(pool)
283        .await
284}}
285
286"#,
287        table_name = table.name,
288    )
289}
290
291/// Generate primary key methods (find, delete)
292fn generate_pk_methods(
293    table: &TableMetadata,
294    pk: &crate::parser::PrimaryKey,
295    column_map: &HashMap<&str, &ColumnMetadata>,
296    struct_name: &str,
297    select_columns: &str,
298) -> String {
299    let mut code = String::new();
300
301    let method_name = generate_find_by_method_name(&pk.columns);
302    let params = build_params(&pk.columns, column_map, &table.name);
303    let where_clause = build_where_clause(&pk.columns);
304    let bind_section = generate_bind_section(&pk.columns);
305
306    // find_by_pk
307    code.push_str(&format!(
308        r#"/// Find by primary key
309pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
310    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
311{bind_section}
312        .fetch_optional(pool)
313        .await
314}}
315
316"#,
317        method_name = method_name,
318        params = params,
319        struct_name = struct_name,
320        select_columns = select_columns,
321        table_name = table.name,
322        where_clause = where_clause,
323        bind_section = bind_section,
324    ));
325
326    // delete_by_pk
327    let delete_method = generate_delete_by_method_name(&pk.columns);
328    code.push_str(&format!(
329        r#"/// Delete by primary key
330pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
331    Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
332{bind_section}
333        .execute(pool)
334        .await
335        .map(|r| r.rows_affected)
336}}
337
338"#,
339        delete_method = delete_method,
340        params = params,
341        table_name = table.name,
342        where_clause = where_clause,
343        bind_section = bind_section,
344    ));
345
346    code
347}
348
349/// Generate insert methods
350fn generate_insert_methods(
351    table: &TableMetadata,
352    struct_name: &str,
353    _column_map: &HashMap<&str, &ColumnMetadata>,
354) -> String {
355    let mut code = String::new();
356
357    // Get non-auto-increment columns for insert
358    let insert_columns: Vec<&ColumnMetadata> = table
359        .columns
360        .iter()
361        .filter(|c| !c.is_auto_increment)
362        .collect();
363
364    if insert_columns.is_empty() {
365        return code;
366    }
367
368    let column_list = insert_columns
369        .iter()
370        .map(|c| format!("`{}`", c.name))
371        .collect::<Vec<_>>()
372        .join(", ");
373
374    let placeholders = insert_columns
375        .iter()
376        .map(|_| "?")
377        .collect::<Vec<_>>()
378        .join(", ");
379
380    let bind_fields = insert_columns
381        .iter()
382        .map(|c| {
383            let field = escape_field_name(&c.name);
384            let rust_type = TypeResolver::resolve(c, &table.name);
385            if rust_type.is_copy() {
386                format!("        .bind(entity.{})", field)
387            } else {
388                format!("        .bind(&entity.{})", field)
389            }
390        })
391        .collect::<Vec<_>>()
392        .join("\n");
393
394    // insert (with entity)
395    code.push_str(&format!(
396        r#"/// Insert a new record
397pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
398    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
399{bind_fields}
400        .execute(pool)
401        .await
402        .map(|r| r.last_insert_id.unwrap_or(0))
403}}
404
405"#,
406        struct_name = struct_name,
407        table_name = table.name,
408        column_list = column_list,
409        placeholders = placeholders,
410        bind_fields = bind_fields,
411    ));
412
413    code
414}
415
416/// Generate insert_plain method with individual parameters
417fn generate_insert_plain_method(
418    table: &TableMetadata,
419    column_map: &HashMap<&str, &ColumnMetadata>,
420) -> String {
421    // Get non-auto-increment columns for insert
422    let insert_columns: Vec<&ColumnMetadata> = table
423        .columns
424        .iter()
425        .filter(|c| !c.is_auto_increment)
426        .collect();
427
428    if insert_columns.is_empty() {
429        return String::new();
430    }
431
432    let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
433    let params = build_params(&column_names, column_map, &table.name);
434
435    let column_list = insert_columns
436        .iter()
437        .map(|c| format!("`{}`", c.name))
438        .collect::<Vec<_>>()
439        .join(", ");
440
441    let placeholders = insert_columns
442        .iter()
443        .map(|_| "?")
444        .collect::<Vec<_>>()
445        .join(", ");
446
447    let bind_section = generate_bind_section(&column_names);
448
449    format!(
450        r#"/// Insert a new record with individual parameters
451#[allow(clippy::too_many_arguments)]
452pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
453    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
454{bind_section}
455        .execute(pool)
456        .await
457        .map(|r| r.last_insert_id.unwrap_or(0))
458}}
459
460"#,
461        params = params,
462        table_name = table.name,
463        column_list = column_list,
464        placeholders = placeholders,
465        bind_section = bind_section,
466    )
467}
468
469/// Generate batch insert method (insert_all) using BatchInsert
470fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
471    // Get non-auto-increment columns for insert
472    let insert_columns: Vec<&ColumnMetadata> = table
473        .columns
474        .iter()
475        .filter(|c| !c.is_auto_increment)
476        .collect();
477
478    if insert_columns.is_empty() {
479        return String::new();
480    }
481
482    format!(
483        r#"/// Insert multiple records in a single batch
484pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
485    rdbi::BatchInsert::new("{table_name}", entities)
486        .execute(pool)
487        .await
488        .map(|r| r.rows_affected)
489}}
490
491"#,
492        struct_name = struct_name,
493        table_name = table.name,
494    )
495}
496
497/// Check if table has a unique index (excluding primary key)
498fn has_unique_index(table: &TableMetadata) -> bool {
499    table.indexes.iter().any(|idx| idx.unique)
500}
501
502/// Generate upsert method (INSERT ... ON DUPLICATE KEY UPDATE)
503fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
504    // Only generate if table has primary key or unique index
505    if table.primary_key.is_none() && !has_unique_index(table) {
506        return String::new();
507    }
508
509    // Get non-auto-increment columns for insert
510    let insert_columns: Vec<&ColumnMetadata> = table
511        .columns
512        .iter()
513        .filter(|c| !c.is_auto_increment)
514        .collect();
515
516    if insert_columns.is_empty() {
517        return String::new();
518    }
519
520    // Get primary key columns for exclusion from UPDATE clause
521    let pk_columns: HashSet<&str> = table
522        .primary_key
523        .as_ref()
524        .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
525        .unwrap_or_default();
526
527    // Columns to update on duplicate key (all non-PK, non-auto-increment columns)
528    let update_columns: Vec<&ColumnMetadata> = insert_columns
529        .iter()
530        .filter(|c| !pk_columns.contains(c.name.as_str()))
531        .copied()
532        .collect();
533
534    // If no columns to update, skip upsert generation
535    if update_columns.is_empty() {
536        return String::new();
537    }
538
539    let column_list = insert_columns
540        .iter()
541        .map(|c| format!("`{}`", c.name))
542        .collect::<Vec<_>>()
543        .join(", ");
544
545    let placeholders = insert_columns
546        .iter()
547        .map(|_| "?")
548        .collect::<Vec<_>>()
549        .join(", ");
550
551    let update_clause = update_columns
552        .iter()
553        .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
554        .collect::<Vec<_>>()
555        .join(", ");
556
557    let bind_fields = insert_columns
558        .iter()
559        .map(|c| {
560            let field = escape_field_name(&c.name);
561            let rust_type = TypeResolver::resolve(c, &table.name);
562            if rust_type.is_copy() {
563                format!("        .bind(entity.{})", field)
564            } else {
565                format!("        .bind(&entity.{})", field)
566            }
567        })
568        .collect::<Vec<_>>()
569        .join("\n");
570
571    format!(
572        r#"/// Upsert a record (insert or update on duplicate key)
573/// Returns rows_affected: 1 if inserted, 2 if updated
574pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
575    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
576         ON DUPLICATE KEY UPDATE {update_clause}")
577{bind_fields}
578        .execute(pool)
579        .await
580        .map(|r| r.rows_affected)
581}}
582
583"#,
584        struct_name = struct_name,
585        table_name = table.name,
586        column_list = column_list,
587        placeholders = placeholders,
588        update_clause = update_clause,
589        bind_fields = bind_fields,
590    )
591}
592
593/// Generate update methods
594fn generate_update_methods(
595    table: &TableMetadata,
596    struct_name: &str,
597    column_map: &HashMap<&str, &ColumnMetadata>,
598) -> String {
599    let mut code = String::new();
600
601    let pk = table.primary_key.as_ref().unwrap();
602
603    // Get non-PK columns for SET clause
604    let update_columns: Vec<&ColumnMetadata> = table
605        .columns
606        .iter()
607        .filter(|c| !pk.columns.contains(&c.name))
608        .collect();
609
610    if update_columns.is_empty() {
611        return code;
612    }
613
614    let set_clause = update_columns
615        .iter()
616        .map(|c| format!("`{}` = ?", c.name))
617        .collect::<Vec<_>>()
618        .join(", ");
619
620    let where_clause = build_where_clause(&pk.columns);
621
622    // Bind update columns first, then PK columns
623    let bind_fields: Vec<String> = update_columns
624        .iter()
625        .map(|c| {
626            let field = escape_field_name(&c.name);
627            let rust_type = TypeResolver::resolve(c, &table.name);
628            if rust_type.is_copy() {
629                format!("        .bind(entity.{})", field)
630            } else {
631                format!("        .bind(&entity.{})", field)
632            }
633        })
634        .chain(pk.columns.iter().map(|c| {
635            let field = escape_field_name(c);
636            let col = column_map.get(c.as_str()).unwrap();
637            let rust_type = TypeResolver::resolve(col, &table.name);
638            if rust_type.is_copy() {
639                format!("        .bind(entity.{})", field)
640            } else {
641                format!("        .bind(&entity.{})", field)
642            }
643        }))
644        .collect();
645
646    // update_by_bean (using entity)
647    code.push_str(&format!(
648        r#"/// Update a record by primary key
649pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
650    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
651{bind_fields}
652        .execute(pool)
653        .await
654        .map(|r| r.rows_affected)
655}}
656
657"#,
658        struct_name = struct_name,
659        table_name = table.name,
660        set_clause = set_clause,
661        where_clause = where_clause,
662        bind_fields = bind_fields.join("\n"),
663    ));
664
665    code
666}
667
668/// Generate update_plain method with individual parameters (update_by_<pk>)
669fn generate_update_plain_method(
670    table: &TableMetadata,
671    column_map: &HashMap<&str, &ColumnMetadata>,
672) -> String {
673    let pk = table.primary_key.as_ref().unwrap();
674
675    // Get non-PK columns for SET clause
676    let update_columns: Vec<&ColumnMetadata> = table
677        .columns
678        .iter()
679        .filter(|c| !pk.columns.contains(&c.name))
680        .collect();
681
682    if update_columns.is_empty() {
683        return String::new();
684    }
685
686    // Build method name based on PK columns
687    let method_name = generate_update_by_method_name(&pk.columns);
688
689    // Build params: PK columns first, then update columns
690    let pk_params = build_params(&pk.columns, column_map, &table.name);
691    let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
692    let update_params = build_params(&update_column_names, column_map, &table.name);
693    let all_params = format!("{}, {}", pk_params, update_params);
694
695    let set_clause = update_columns
696        .iter()
697        .map(|c| format!("`{}` = ?", c.name))
698        .collect::<Vec<_>>()
699        .join(", ");
700
701    let where_clause = build_where_clause(&pk.columns);
702
703    // Bind update columns first (for SET), then PK columns (for WHERE)
704    let bind_section_update = generate_bind_section(&update_column_names);
705    let bind_section_pk = generate_bind_section(&pk.columns);
706
707    format!(
708        r#"/// Update a record by primary key with individual parameters
709#[allow(clippy::too_many_arguments)]
710pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
711    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
712{bind_section_update}
713{bind_section_pk}
714        .execute(pool)
715        .await
716        .map(|r| r.rows_affected)
717}}
718
719"#,
720        method_name = method_name,
721        all_params = all_params,
722        table_name = table.name,
723        set_clause = set_clause,
724        where_clause = where_clause,
725        bind_section_update = bind_section_update,
726        bind_section_pk = bind_section_pk,
727    )
728}
729
730/// Collect method signatures with priority-based deduplication
731fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
732    let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
733
734    // Priority 1: Primary key
735    if let Some(pk) = &table.primary_key {
736        let sig = MethodSignature::new(
737            pk.columns.clone(),
738            PRIORITY_PRIMARY_KEY,
739            true,
740            "PRIMARY_KEY",
741        );
742        signatures.insert(pk.columns.clone(), sig);
743    }
744
745    // Priority 2 & 3: Indexes
746    for index in &table.indexes {
747        let priority = if index.unique {
748            PRIORITY_UNIQUE_INDEX
749        } else {
750            PRIORITY_NON_UNIQUE_INDEX
751        };
752        let source = if index.unique {
753            "UNIQUE_INDEX"
754        } else {
755            "NON_UNIQUE_INDEX"
756        };
757        let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
758
759        // Only add if no higher priority signature exists
760        if let Some(existing) = signatures.get(&index.columns) {
761            if sig.priority < existing.priority {
762                signatures.insert(index.columns.clone(), sig);
763            }
764        } else {
765            signatures.insert(index.columns.clone(), sig);
766        }
767    }
768
769    // Priority 4: Foreign keys
770    for fk in &table.foreign_keys {
771        let columns = vec![fk.column_name.clone()];
772        let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
773
774        signatures.entry(columns).or_insert(sig);
775    }
776
777    signatures
778}
779
780/// Generate a find_by method for an index/FK
781fn generate_find_by_method(
782    table: &TableMetadata,
783    sig: &MethodSignature,
784    column_map: &HashMap<&str, &ColumnMetadata>,
785    struct_name: &str,
786    select_columns: &str,
787) -> String {
788    let params = build_params(&sig.columns, column_map, &table.name);
789
790    let (return_type, fetch_method) = if sig.is_unique {
791        (format!("Option<{}>", struct_name), "fetch_optional")
792    } else {
793        (format!("Vec<{}>", struct_name), "fetch_all")
794    };
795
796    let return_desc = if sig.is_unique {
797        "Option (unique)"
798    } else {
799        "Vec (non-unique)"
800    };
801
802    // Check if any column is nullable - if so, we need dynamic SQL for IS NULL handling
803    let has_nullable = sig.columns.iter().any(|c| {
804        column_map
805            .get(c.as_str())
806            .map(|col| col.nullable)
807            .unwrap_or(false)
808    });
809
810    if has_nullable {
811        // Generate dynamic query that handles NULL values properly
812        generate_find_by_method_nullable(
813            table,
814            sig,
815            column_map,
816            select_columns,
817            &params,
818            &return_type,
819            fetch_method,
820            return_desc,
821        )
822    } else {
823        // Use static query for non-nullable columns
824        let where_clause = build_where_clause(&sig.columns);
825        let bind_section = generate_bind_section(&sig.columns);
826
827        format!(
828            r#"/// Find by {source}: returns {return_desc}
829pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
830    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
831{bind_section}
832        .{fetch_method}(pool)
833        .await
834}}
835
836"#,
837            source = sig.source.to_lowercase().replace('_', " "),
838            return_desc = return_desc,
839            method_name = sig.method_name,
840            params = params,
841            return_type = return_type,
842            select_columns = select_columns,
843            table_name = table.name,
844            where_clause = where_clause,
845            bind_section = bind_section,
846            fetch_method = fetch_method,
847        )
848    }
849}
850
851/// Generate a find_by method that handles nullable columns with IS NULL
852#[allow(clippy::too_many_arguments)]
853fn generate_find_by_method_nullable(
854    table: &TableMetadata,
855    sig: &MethodSignature,
856    column_map: &HashMap<&str, &ColumnMetadata>,
857    select_columns: &str,
858    params: &str,
859    return_type: &str,
860    fetch_method: &str,
861    return_desc: &str,
862) -> String {
863    // Build the where clause conditions and bind logic
864    let mut where_parts = Vec::new();
865    let mut bind_parts = Vec::new();
866
867    for col in &sig.columns {
868        let col_meta = column_map.get(col.as_str()).unwrap();
869        let field_name = escape_field_name(col);
870
871        if col_meta.nullable {
872            // For nullable columns, check if value is None and use IS NULL
873            where_parts.push(format!(
874                r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
875                field = field_name,
876                col = col,
877            ));
878            bind_parts.push(format!(
879                r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
880                field = field_name,
881            ));
882        } else {
883            where_parts.push(format!(r#""`{}` = ?""#, col));
884            bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
885        }
886    }
887
888    let where_expr = if where_parts.len() == 1 {
889        where_parts[0].clone()
890    } else {
891        // Join with " AND "
892        let parts = where_parts
893            .iter()
894            .map(|p| format!("({})", p))
895            .collect::<Vec<_>>()
896            .join(", ");
897        format!("vec![{}].join(\" AND \")", parts)
898    };
899
900    let bind_code = bind_parts.join("\n        ");
901
902    format!(
903        r#"/// Find by {source}: returns {return_desc}
904pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
905    let where_clause = {where_expr};
906    let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
907    let mut query = rdbi::DynamicQuery::new(sql);
908    {bind_code}
909    query.{fetch_method}(pool).await
910}}
911
912"#,
913        source = sig.source.to_lowercase().replace('_', " "),
914        return_desc = return_desc,
915        method_name = sig.method_name,
916        params = params,
917        return_type = return_type,
918        select_columns = select_columns,
919        table_name = table.name,
920        where_expr = where_expr,
921        bind_code = bind_code,
922        fetch_method = fetch_method,
923    )
924}
925
926/// Generate list-based findBy methods for single-column indexes
927fn generate_find_by_list_methods(
928    table: &TableMetadata,
929    column_map: &HashMap<&str, &ColumnMetadata>,
930    struct_name: &str,
931    select_columns: &str,
932) -> String {
933    let mut code = String::new();
934    let mut processed: HashSet<String> = HashSet::new();
935
936    // Primary key (if single column)
937    if let Some(pk) = &table.primary_key {
938        if pk.columns.len() == 1 {
939            let col = &pk.columns[0];
940            code.push_str(&generate_single_find_by_list(
941                table,
942                col,
943                column_map,
944                struct_name,
945                select_columns,
946            ));
947            processed.insert(col.clone());
948        }
949    }
950
951    // Single-column indexes
952    for index in &table.indexes {
953        if index.columns.len() == 1 {
954            let col = &index.columns[0];
955            if !processed.contains(col) {
956                code.push_str(&generate_single_find_by_list(
957                    table,
958                    col,
959                    column_map,
960                    struct_name,
961                    select_columns,
962                ));
963                processed.insert(col.clone());
964            }
965        }
966    }
967
968    code
969}
970
971/// Generate a single find_by_<column>s method (using IN clause)
972fn generate_single_find_by_list(
973    table: &TableMetadata,
974    column_name: &str,
975    column_map: &HashMap<&str, &ColumnMetadata>,
976    struct_name: &str,
977    select_columns: &str,
978) -> String {
979    let method_name = generate_find_by_list_method_name(column_name);
980    let param_name = pluralize(&escape_field_name(column_name));
981    let column = column_map.get(column_name).unwrap();
982    let rust_type = TypeResolver::resolve(column, &table.name);
983
984    // Get the inner type (unwrap Option if nullable)
985    let inner_type = rust_type.inner_type().to_type_string();
986
987    let column_name_plural = pluralize(column_name);
988    format!(
989        r#"/// Find by list of {column_name_plural} (IN clause)
990pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
991    if {param_name}.is_empty() {{
992        return Ok(Vec::new());
993    }}
994    let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
995    let query = format!(
996        "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
997        placeholders
998    );
999    rdbi::DynamicQuery::new(query)
1000        .bind_all({param_name})
1001        .fetch_all(pool)
1002        .await
1003}}
1004
1005"#,
1006        column_name_plural = column_name_plural,
1007        column_name = column_name,
1008        method_name = method_name,
1009        param_name = param_name,
1010        inner_type = inner_type,
1011        struct_name = struct_name,
1012        select_columns = select_columns,
1013        table_name = table.name,
1014    )
1015}
1016
1017/// Generate composite enum list methods for multi-column indexes with enum columns
1018/// Example: find_by_user_id_and_device_types(user_id, &[DeviceType])
1019fn generate_composite_enum_list_methods(
1020    table: &TableMetadata,
1021    column_map: &HashMap<&str, &ColumnMetadata>,
1022    struct_name: &str,
1023    select_columns: &str,
1024) -> String {
1025    let mut code = String::new();
1026
1027    for index in &table.indexes {
1028        // Skip single-column indexes (handled by generate_find_by_list_methods)
1029        if index.columns.len() <= 1 {
1030            continue;
1031        }
1032
1033        // Identify which columns are enums
1034        let enum_columns: HashSet<&str> = index
1035            .columns
1036            .iter()
1037            .filter(|col_name| {
1038                column_map
1039                    .get(col_name.as_str())
1040                    .map(|col| col.is_enum())
1041                    .unwrap_or(false)
1042            })
1043            .map(|s| s.as_str())
1044            .collect();
1045
1046        // Skip if no enum columns
1047        if enum_columns.is_empty() {
1048            continue;
1049        }
1050
1051        // Skip if first column is enum (for optimal index usage, equality should be on leading column)
1052        let first_column = &index.columns[0];
1053        if enum_columns.contains(first_column.as_str()) {
1054            continue;
1055        }
1056
1057        code.push_str(&generate_composite_enum_list_method(
1058            table,
1059            &index.columns,
1060            &enum_columns,
1061            column_map,
1062            struct_name,
1063            select_columns,
1064        ));
1065    }
1066
1067    code
1068}
1069
1070/// Generate a single composite enum list method
1071fn generate_composite_enum_list_method(
1072    table: &TableMetadata,
1073    columns: &[String],
1074    enum_columns: &HashSet<&str>,
1075    column_map: &HashMap<&str, &ColumnMetadata>,
1076    struct_name: &str,
1077    select_columns: &str,
1078) -> String {
1079    // Build method name: pluralize enum column names
1080    let method_name = generate_composite_enum_method_name(columns, enum_columns);
1081
1082    // Build params and WHERE clause parts
1083    let mut params_parts = Vec::new();
1084
1085    for col_name in columns {
1086        let col = column_map.get(col_name.as_str()).unwrap();
1087        let rust_type = TypeResolver::resolve(col, &table.name);
1088        let is_enum = enum_columns.contains(col_name.as_str());
1089
1090        if is_enum {
1091            // Enum column uses list parameter
1092            let param_name = pluralize(&escape_field_name(col_name));
1093            let inner_type = rust_type.inner_type().to_type_string();
1094            params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1095        } else {
1096            // Non-enum column uses single value
1097            let param_name = escape_field_name(col_name);
1098            let param_type = rust_type.to_param_type_string();
1099            params_parts.push(format!("{}: {}", param_name, param_type));
1100        }
1101    }
1102
1103    let params = params_parts.join(", ");
1104
1105    // Build WHERE clause with proper placeholders
1106    let where_clause_static: Vec<String> = columns
1107        .iter()
1108        .map(|col_name| {
1109            if enum_columns.contains(col_name.as_str()) {
1110                format!("`{}` IN ({{}})", col_name) // placeholder for IN clause
1111            } else {
1112                format!("`{}` = ?", col_name)
1113            }
1114        })
1115        .collect();
1116
1117    // Build the bind section
1118    let mut bind_code = String::new();
1119
1120    // First bind non-enum (single value) columns
1121    for col_name in columns {
1122        if !enum_columns.contains(col_name.as_str()) {
1123            let param_name = escape_field_name(col_name);
1124            bind_code.push_str(&format!("        .bind({})\n", param_name));
1125        }
1126    }
1127
1128    // Then bind enum (list) columns
1129    for col_name in columns {
1130        if enum_columns.contains(col_name.as_str()) {
1131            let param_name = pluralize(&escape_field_name(col_name));
1132            bind_code.push_str(&format!("        .bind_all({})\n", param_name));
1133        }
1134    }
1135
1136    // Build the column name description for doc comment
1137    let column_desc: Vec<String> = columns
1138        .iter()
1139        .map(|col| {
1140            if enum_columns.contains(col.as_str()) {
1141                pluralize(col)
1142            } else {
1143                col.clone()
1144            }
1145        })
1146        .collect();
1147
1148    // Build dynamic WHERE clause construction
1149    let enum_col_names: Vec<&str> = columns
1150        .iter()
1151        .filter(|c| enum_columns.contains(c.as_str()))
1152        .map(|s| s.as_str())
1153        .collect();
1154
1155    // Generate the IN clause placeholders dynamically
1156    let in_clause_builders: Vec<String> = enum_col_names
1157        .iter()
1158        .map(|col| {
1159            let param_name = pluralize(&escape_field_name(col));
1160            format!(
1161                "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1162                param_name = param_name
1163            )
1164        })
1165        .collect();
1166
1167    // Build format args for the WHERE clause
1168    let format_args = in_clause_builders.join(", ");
1169
1170    format!(
1171        r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1172pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1173    // Check for empty enum lists
1174{empty_checks}
1175    // Build IN clause placeholders for enum columns
1176    let where_clause = format!("{where_template}", {format_args});
1177    let query = format!(
1178        "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1179        where_clause
1180    );
1181    rdbi::DynamicQuery::new(query)
1182{bind_code}        .fetch_all(pool)
1183        .await
1184}}
1185
1186"#,
1187        column_desc = column_desc.join(" and "),
1188        method_name = method_name,
1189        params = params,
1190        struct_name = struct_name,
1191        select_columns = select_columns,
1192        table_name = table.name,
1193        where_template = where_clause_static.join(" AND "),
1194        format_args = format_args,
1195        bind_code = bind_code,
1196        empty_checks = generate_empty_checks(columns, enum_columns),
1197    )
1198}
1199
1200/// Generate empty checks for enum list parameters
1201fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1202    let mut checks = String::new();
1203    for col_name in columns {
1204        if enum_columns.contains(col_name.as_str()) {
1205            let param_name = pluralize(&escape_field_name(col_name));
1206            checks.push_str(&format!(
1207                "    if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1208                param_name
1209            ));
1210        }
1211    }
1212    checks
1213}
1214
1215/// Generate method name for composite enum queries
1216/// Example: ["user_id", "device_type"] with enum_columns={"device_type"} -> "find_by_user_id_and_device_types"
1217fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1218    let mut parts = Vec::new();
1219    for col in columns {
1220        if enum_columns.contains(col.as_str()) {
1221            parts.push(pluralize(col));
1222        } else {
1223            parts.push(col.clone());
1224        }
1225    }
1226    generate_find_by_method_name(&parts)
1227}
1228
1229/// Generate pagination methods (find_all_paginated, get_paginated_result)
1230fn generate_pagination_methods(
1231    table: &TableMetadata,
1232    struct_name: &str,
1233    select_columns: &str,
1234    models_module: &str,
1235) -> String {
1236    let sort_by_enum = format!("{}SortBy", struct_name);
1237
1238    format!(
1239        r#"/// Find all records with pagination and sorting
1240pub async fn find_all_paginated<P: Pool>(
1241    pool: &P,
1242    limit: i32,
1243    offset: i32,
1244    sort_by: crate::{models_module}::{sort_by_enum},
1245    sort_dir: crate::{models_module}::SortDirection,
1246) -> Result<Vec<{struct_name}>> {{
1247    let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1248    let query = format!(
1249        "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1250        order_clause
1251    );
1252    rdbi::DynamicQuery::new(query)
1253        .bind(limit)
1254        .bind(offset)
1255        .fetch_all(pool)
1256        .await
1257}}
1258
1259/// Get paginated result with total count
1260pub async fn get_paginated_result<P: Pool>(
1261    pool: &P,
1262    page_size: i32,
1263    current_page: i32,
1264    sort_by: crate::{models_module}::{sort_by_enum},
1265    sort_dir: crate::{models_module}::SortDirection,
1266) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1267    let page_size = page_size.max(1);
1268    let current_page = current_page.max(1);
1269    let offset = (current_page - 1) * page_size;
1270
1271    let total_count = count_all(pool).await?;
1272    let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1273
1274    Ok(crate::{models_module}::PaginatedResult::new(
1275        items,
1276        total_count,
1277        current_page,
1278        page_size,
1279    ))
1280}}
1281
1282"#,
1283        struct_name = struct_name,
1284        select_columns = select_columns,
1285        table_name = table.name,
1286        models_module = models_module,
1287        sort_by_enum = sort_by_enum,
1288    )
1289}
1290
1291#[cfg(test)]
1292mod tests {
1293    use super::*;
1294    use crate::parser::{IndexMetadata, PrimaryKey};
1295
1296    fn make_table() -> TableMetadata {
1297        TableMetadata {
1298            name: "users".to_string(),
1299            comment: None,
1300            columns: vec![
1301                ColumnMetadata {
1302                    name: "id".to_string(),
1303                    data_type: "BIGINT".to_string(),
1304                    nullable: false,
1305                    default_value: None,
1306                    is_auto_increment: true,
1307                    is_unsigned: false,
1308                    enum_values: None,
1309                    comment: None,
1310                },
1311                ColumnMetadata {
1312                    name: "email".to_string(),
1313                    data_type: "VARCHAR(255)".to_string(),
1314                    nullable: false,
1315                    default_value: None,
1316                    is_auto_increment: false,
1317                    is_unsigned: false,
1318                    enum_values: None,
1319                    comment: None,
1320                },
1321                ColumnMetadata {
1322                    name: "status".to_string(),
1323                    data_type: "ENUM".to_string(),
1324                    nullable: false,
1325                    default_value: None,
1326                    is_auto_increment: false,
1327                    is_unsigned: false,
1328                    enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1329                    comment: None,
1330                },
1331            ],
1332            indexes: vec![
1333                IndexMetadata {
1334                    name: "email_unique".to_string(),
1335                    columns: vec!["email".to_string()],
1336                    unique: true,
1337                },
1338                IndexMetadata {
1339                    name: "idx_status".to_string(),
1340                    columns: vec!["status".to_string()],
1341                    unique: false,
1342                },
1343            ],
1344            foreign_keys: vec![],
1345            primary_key: Some(PrimaryKey {
1346                columns: vec!["id".to_string()],
1347            }),
1348        }
1349    }
1350
1351    #[test]
1352    fn test_collect_method_signatures() {
1353        let table = make_table();
1354        let sigs = collect_method_signatures(&table);
1355
1356        // Should have signatures for: id (PK), email (unique), status (non-unique)
1357        assert_eq!(sigs.len(), 3);
1358
1359        let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1360        assert!(id_sig.is_unique);
1361        assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1362
1363        let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1364        assert!(email_sig.is_unique);
1365        assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1366
1367        let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1368        assert!(!status_sig.is_unique);
1369        assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1370    }
1371
1372    #[test]
1373    fn test_build_select_columns() {
1374        let table = make_table();
1375        let cols = build_select_columns(&table);
1376        assert!(cols.contains("`id`"));
1377        assert!(cols.contains("`email`"));
1378        assert!(cols.contains("`status`"));
1379    }
1380
1381    #[test]
1382    fn test_build_where_clause() {
1383        let clause = build_where_clause(&["id".to_string()]);
1384        assert_eq!(clause, "`id` = ?");
1385
1386        let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1387        assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1388    }
1389
1390    #[test]
1391    fn test_generate_upsert_method() {
1392        let table = make_table();
1393        let code = generate_upsert_method(&table, "Users");
1394
1395        // Should contain upsert function
1396        assert!(code.contains("pub async fn upsert"));
1397        // Should contain ON DUPLICATE KEY UPDATE
1398        assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1399        // Should NOT update the PK column (id)
1400        assert!(!code.contains("`id` = VALUES(`id`)"));
1401        // Should update non-PK columns
1402        assert!(code.contains("`email` = VALUES(`email`)"));
1403        assert!(code.contains("`status` = VALUES(`status`)"));
1404    }
1405
1406    #[test]
1407    fn test_generate_upsert_method_no_pk() {
1408        let mut table = make_table();
1409        table.primary_key = None;
1410        table.indexes.clear();
1411
1412        let code = generate_upsert_method(&table, "Users");
1413        // Should not generate upsert without PK or unique index
1414        assert!(code.is_empty());
1415    }
1416
1417    #[test]
1418    fn test_generate_insert_all_method() {
1419        let table = make_table();
1420        let code = generate_insert_all_method(&table, "Users");
1421
1422        // Should contain insert_all function
1423        assert!(code.contains("pub async fn insert_all"));
1424        // Should use BatchInsert
1425        assert!(code.contains("rdbi::BatchInsert::new"));
1426    }
1427
1428    #[test]
1429    fn test_generate_pagination_methods() {
1430        let table = make_table();
1431        let select_columns = build_select_columns(&table);
1432        let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1433
1434        // Should contain find_all_paginated function
1435        assert!(code.contains("pub async fn find_all_paginated"));
1436        // Should have limit and offset params
1437        assert!(code.contains("limit: i32"));
1438        assert!(code.contains("offset: i32"));
1439        // Should use SortBy enum
1440        assert!(code.contains("UsersSortBy"));
1441        // Should use SortDirection
1442        assert!(code.contains("SortDirection"));
1443        // Should contain get_paginated_result
1444        assert!(code.contains("pub async fn get_paginated_result"));
1445        // Should use PaginatedResult
1446        assert!(code.contains("PaginatedResult<Users>"));
1447    }
1448}