Skip to main content

rdbi_codegen/codegen/
dao_generator.rs

1//! DAO generator - generates async rdbi query functions from table metadata
2
3use crate::error::Result;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{
13    escape_field_name, generate_delete_by_method_name, generate_find_by_list_method_name,
14    generate_find_by_method_name, generate_update_by_method_name, pluralize, to_struct_name,
15};
16use super::type_resolver::TypeResolver;
17
18/// Priority levels for method signature deduplication
19const PRIORITY_PRIMARY_KEY: u8 = 1;
20const PRIORITY_UNIQUE_INDEX: u8 = 2;
21const PRIORITY_NON_UNIQUE_INDEX: u8 = 3;
22const PRIORITY_FOREIGN_KEY: u8 = 4;
23
24/// Represents a method signature for deduplication
25#[derive(Debug, Clone)]
26struct MethodSignature {
27    columns: Vec<String>,
28    method_name: String,
29    priority: u8,
30    is_unique: bool,
31    source: String,
32}
33
34impl MethodSignature {
35    fn new(columns: Vec<String>, priority: u8, is_unique: bool, source: &str) -> Self {
36        let method_name = generate_find_by_method_name(&columns);
37        Self {
38            columns,
39            method_name,
40            priority,
41            is_unique,
42            source: source.to_string(),
43        }
44    }
45}
46
47/// Generate DAO files for all tables
48pub fn generate_daos(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
49    let output_dir = &config.output_dao_dir;
50    fs::create_dir_all(output_dir)?;
51
52    // Generate mod.rs
53    let mut mod_content = String::new();
54    mod_content.push_str("// Generated DAO functions\n\n");
55
56    for table in tables {
57        let file_name = heck::AsSnakeCase(&table.name).to_string();
58        mod_content.push_str(&format!("pub mod {};\n", file_name));
59    }
60
61    let mod_path = output_dir.join("mod.rs");
62    fs::write(&mod_path, mod_content)?;
63    super::format_file(&mod_path);
64
65    // Generate each DAO file
66    for table in tables {
67        generate_dao_file(table, output_dir, &config.models_module)?;
68    }
69
70    Ok(())
71}
72
73/// Generate a single DAO file for a table
74fn generate_dao_file(table: &TableMetadata, output_dir: &Path, models_module: &str) -> Result<()> {
75    let struct_name = to_struct_name(&table.name);
76    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
77    debug!("Generating DAO for {} -> {}", struct_name, file_name);
78
79    let mut code = String::new();
80
81    // Build column map for quick lookup
82    let column_map: HashMap<&str, &ColumnMetadata> =
83        table.columns.iter().map(|c| (c.name.as_str(), c)).collect();
84
85    // Collect import requirements
86    let mut needs_chrono = false;
87    let mut needs_decimal = false;
88
89    for col in &table.columns {
90        let rust_type = TypeResolver::resolve(col, &table.name);
91        if rust_type.needs_chrono() {
92            needs_chrono = true;
93        }
94        if rust_type.needs_decimal() {
95            needs_decimal = true;
96        }
97    }
98
99    // Check if we have enum columns for imports
100    let has_enums = table.columns.iter().any(|c| c.is_enum());
101
102    // Generate imports
103    code.push_str("use rdbi::{Pool, Query, Result};\n");
104    code.push_str(&format!("use crate::{}::{};\n", models_module, struct_name));
105
106    if has_enums {
107        // Import enum types from the same struct module
108        for col in &table.columns {
109            if col.is_enum() {
110                let enum_name = super::naming::to_enum_name(&table.name, &col.name);
111                code.push_str(&format!("use crate::{}::{};\n", models_module, enum_name));
112            }
113        }
114    }
115
116    if needs_chrono {
117        code.push_str("#[allow(unused_imports)]\n");
118        code.push_str("use chrono::{NaiveDate, NaiveDateTime, NaiveTime};\n");
119    }
120    if needs_decimal {
121        code.push_str("#[allow(unused_imports)]\n");
122        code.push_str("use rust_decimal::Decimal;\n");
123    }
124
125    code.push('\n');
126
127    // Generate select columns list
128    let select_columns = build_select_columns(table);
129
130    // Generate find_all
131    code.push_str(&generate_find_all(table, &struct_name, &select_columns));
132
133    // Generate count_all
134    code.push_str(&generate_count_all(table));
135
136    // Generate primary key methods
137    if let Some(pk) = &table.primary_key {
138        code.push_str(&generate_pk_methods(
139            table,
140            pk,
141            &column_map,
142            &struct_name,
143            &select_columns,
144        ));
145    }
146
147    // Generate insert methods
148    code.push_str(&generate_insert_methods(table, &struct_name, &column_map));
149
150    // Generate insert_plain method (individual params)
151    code.push_str(&generate_insert_plain_method(table, &column_map));
152
153    // Generate batch insert method
154    code.push_str(&generate_insert_all_method(table, &struct_name));
155
156    // Generate upsert method
157    code.push_str(&generate_upsert_method(table, &struct_name));
158
159    // Generate update methods
160    if table.primary_key.is_some() {
161        code.push_str(&generate_update_methods(table, &struct_name, &column_map));
162        // Generate update_plain method (individual params)
163        code.push_str(&generate_update_plain_method(table, &column_map));
164    }
165
166    // Generate index-aware findBy methods
167    let signatures = collect_method_signatures(table);
168    for sig in signatures.values() {
169        // Skip if this is the primary key (already generated separately)
170        if sig.source == "PRIMARY_KEY" {
171            continue;
172        }
173        code.push_str(&generate_find_by_method(
174            table,
175            sig,
176            &column_map,
177            &struct_name,
178            &select_columns,
179        ));
180    }
181
182    // Generate list-based findBy methods for single-column indexes
183    code.push_str(&generate_find_by_list_methods(
184        table,
185        &column_map,
186        &struct_name,
187        &select_columns,
188    ));
189
190    // Generate composite enum list methods (e.g., find_by_user_id_and_device_types)
191    code.push_str(&generate_composite_enum_list_methods(
192        table,
193        &column_map,
194        &struct_name,
195        &select_columns,
196    ));
197
198    // Generate pagination methods
199    code.push_str(&generate_pagination_methods(
200        table,
201        &struct_name,
202        &select_columns,
203        models_module,
204    ));
205
206    let file_path = output_dir.join(&file_name);
207    fs::write(&file_path, code)?;
208    super::format_file(&file_path);
209    Ok(())
210}
211
212/// Build the SELECT columns list
213fn build_select_columns(table: &TableMetadata) -> String {
214    table
215        .columns
216        .iter()
217        .map(|c| format!("`{}`", c.name))
218        .collect::<Vec<_>>()
219        .join(", ")
220}
221
222/// Build the WHERE clause for columns
223fn build_where_clause(columns: &[String]) -> String {
224    columns
225        .iter()
226        .map(|c| format!("`{}` = ?", c))
227        .collect::<Vec<_>>()
228        .join(" AND ")
229}
230
231/// Build parameter list for a function signature
232fn build_params(
233    columns: &[String],
234    column_map: &HashMap<&str, &ColumnMetadata>,
235    table_name: &str,
236) -> String {
237    columns
238        .iter()
239        .map(|c| {
240            let col = column_map.get(c.as_str()).unwrap();
241            let rust_type = TypeResolver::resolve(col, table_name);
242            let param_type = rust_type.to_param_type_string();
243            format!("{}: {}", escape_field_name(c), param_type)
244        })
245        .collect::<Vec<_>>()
246        .join(", ")
247}
248
249/// Generate bind calls for query
250fn generate_bind_section(columns: &[String]) -> String {
251    columns
252        .iter()
253        .map(|c| format!("        .bind({})", escape_field_name(c)))
254        .collect::<Vec<_>>()
255        .join("\n")
256}
257
258/// Generate find_all function
259fn generate_find_all(table: &TableMetadata, struct_name: &str, select_columns: &str) -> String {
260    format!(
261        r#"/// Find all records
262pub async fn find_all<P: Pool>(pool: &P) -> Result<Vec<{struct_name}>> {{
263    Query::new("SELECT {select_columns} FROM `{table_name}`")
264        .fetch_all(pool)
265        .await
266}}
267
268"#,
269        struct_name = struct_name,
270        select_columns = select_columns,
271        table_name = table.name,
272    )
273}
274
275/// Generate count_all function
276fn generate_count_all(table: &TableMetadata) -> String {
277    format!(
278        r#"/// Count all records
279pub async fn count_all<P: Pool>(pool: &P) -> Result<i64> {{
280    Query::new("SELECT COUNT(*) FROM `{table_name}`")
281        .fetch_scalar(pool)
282        .await
283}}
284
285"#,
286        table_name = table.name,
287    )
288}
289
290/// Generate primary key methods (find, delete)
291fn generate_pk_methods(
292    table: &TableMetadata,
293    pk: &crate::parser::PrimaryKey,
294    column_map: &HashMap<&str, &ColumnMetadata>,
295    struct_name: &str,
296    select_columns: &str,
297) -> String {
298    let mut code = String::new();
299
300    let method_name = generate_find_by_method_name(&pk.columns);
301    let params = build_params(&pk.columns, column_map, &table.name);
302    let where_clause = build_where_clause(&pk.columns);
303    let bind_section = generate_bind_section(&pk.columns);
304
305    // find_by_pk
306    code.push_str(&format!(
307        r#"/// Find by primary key
308pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Option<{struct_name}>> {{
309    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
310{bind_section}
311        .fetch_optional(pool)
312        .await
313}}
314
315"#,
316        method_name = method_name,
317        params = params,
318        struct_name = struct_name,
319        select_columns = select_columns,
320        table_name = table.name,
321        where_clause = where_clause,
322        bind_section = bind_section,
323    ));
324
325    // delete_by_pk
326    let delete_method = generate_delete_by_method_name(&pk.columns);
327    code.push_str(&format!(
328        r#"/// Delete by primary key
329pub async fn {delete_method}<P: Pool>(pool: &P, {params}) -> Result<u64> {{
330    Query::new("DELETE FROM `{table_name}` WHERE {where_clause}")
331{bind_section}
332        .execute(pool)
333        .await
334        .map(|r| r.rows_affected)
335}}
336
337"#,
338        delete_method = delete_method,
339        params = params,
340        table_name = table.name,
341        where_clause = where_clause,
342        bind_section = bind_section,
343    ));
344
345    code
346}
347
348/// Generate insert methods
349fn generate_insert_methods(
350    table: &TableMetadata,
351    struct_name: &str,
352    _column_map: &HashMap<&str, &ColumnMetadata>,
353) -> String {
354    let mut code = String::new();
355
356    // Get non-auto-increment columns for insert
357    let insert_columns: Vec<&ColumnMetadata> = table
358        .columns
359        .iter()
360        .filter(|c| !c.is_auto_increment)
361        .collect();
362
363    if insert_columns.is_empty() {
364        return code;
365    }
366
367    let column_list = insert_columns
368        .iter()
369        .map(|c| format!("`{}`", c.name))
370        .collect::<Vec<_>>()
371        .join(", ");
372
373    let placeholders = insert_columns
374        .iter()
375        .map(|_| "?")
376        .collect::<Vec<_>>()
377        .join(", ");
378
379    let bind_fields = insert_columns
380        .iter()
381        .map(|c| {
382            let field = escape_field_name(&c.name);
383            let rust_type = TypeResolver::resolve(c, &table.name);
384            if rust_type.is_copy() {
385                format!("        .bind(entity.{})", field)
386            } else {
387                format!("        .bind(&entity.{})", field)
388            }
389        })
390        .collect::<Vec<_>>()
391        .join("\n");
392
393    // insert (with entity)
394    code.push_str(&format!(
395        r#"/// Insert a new record
396pub async fn insert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
397    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
398{bind_fields}
399        .execute(pool)
400        .await
401        .map(|r| r.last_insert_id.unwrap_or(0))
402}}
403
404"#,
405        struct_name = struct_name,
406        table_name = table.name,
407        column_list = column_list,
408        placeholders = placeholders,
409        bind_fields = bind_fields,
410    ));
411
412    code
413}
414
415/// Generate insert_plain method with individual parameters
416fn generate_insert_plain_method(
417    table: &TableMetadata,
418    column_map: &HashMap<&str, &ColumnMetadata>,
419) -> String {
420    // Get non-auto-increment columns for insert
421    let insert_columns: Vec<&ColumnMetadata> = table
422        .columns
423        .iter()
424        .filter(|c| !c.is_auto_increment)
425        .collect();
426
427    if insert_columns.is_empty() {
428        return String::new();
429    }
430
431    let column_names: Vec<String> = insert_columns.iter().map(|c| c.name.clone()).collect();
432    let params = build_params(&column_names, column_map, &table.name);
433
434    let column_list = insert_columns
435        .iter()
436        .map(|c| format!("`{}`", c.name))
437        .collect::<Vec<_>>()
438        .join(", ");
439
440    let placeholders = insert_columns
441        .iter()
442        .map(|_| "?")
443        .collect::<Vec<_>>()
444        .join(", ");
445
446    let bind_section = generate_bind_section(&column_names);
447
448    format!(
449        r#"/// Insert a new record with individual parameters
450#[allow(clippy::too_many_arguments)]
451pub async fn insert_plain<P: Pool>(pool: &P, {params}) -> Result<u64> {{
452    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders})")
453{bind_section}
454        .execute(pool)
455        .await
456        .map(|r| r.last_insert_id.unwrap_or(0))
457}}
458
459"#,
460        params = params,
461        table_name = table.name,
462        column_list = column_list,
463        placeholders = placeholders,
464        bind_section = bind_section,
465    )
466}
467
468/// Generate batch insert method (insert_all) using BatchInsert
469fn generate_insert_all_method(table: &TableMetadata, struct_name: &str) -> String {
470    // Get non-auto-increment columns for insert
471    let insert_columns: Vec<&ColumnMetadata> = table
472        .columns
473        .iter()
474        .filter(|c| !c.is_auto_increment)
475        .collect();
476
477    if insert_columns.is_empty() {
478        return String::new();
479    }
480
481    format!(
482        r#"/// Insert multiple records in a single batch
483pub async fn insert_all<P: Pool>(pool: &P, entities: &[{struct_name}]) -> Result<u64> {{
484    rdbi::BatchInsert::new("{table_name}", entities)
485        .execute(pool)
486        .await
487        .map(|r| r.rows_affected)
488}}
489
490"#,
491        struct_name = struct_name,
492        table_name = table.name,
493    )
494}
495
496/// Check if table has a unique index (excluding primary key)
497fn has_unique_index(table: &TableMetadata) -> bool {
498    table.indexes.iter().any(|idx| idx.unique)
499}
500
501/// Generate upsert method (INSERT ... ON DUPLICATE KEY UPDATE)
502fn generate_upsert_method(table: &TableMetadata, struct_name: &str) -> String {
503    // Only generate if table has primary key or unique index
504    if table.primary_key.is_none() && !has_unique_index(table) {
505        return String::new();
506    }
507
508    // Get non-auto-increment columns for insert
509    let insert_columns: Vec<&ColumnMetadata> = table
510        .columns
511        .iter()
512        .filter(|c| !c.is_auto_increment)
513        .collect();
514
515    if insert_columns.is_empty() {
516        return String::new();
517    }
518
519    // Get primary key columns for exclusion from UPDATE clause
520    let pk_columns: HashSet<&str> = table
521        .primary_key
522        .as_ref()
523        .map(|pk| pk.columns.iter().map(|s| s.as_str()).collect())
524        .unwrap_or_default();
525
526    // Columns to update on duplicate key (all non-PK, non-auto-increment columns)
527    let update_columns: Vec<&ColumnMetadata> = insert_columns
528        .iter()
529        .filter(|c| !pk_columns.contains(c.name.as_str()))
530        .copied()
531        .collect();
532
533    // If no columns to update, skip upsert generation
534    if update_columns.is_empty() {
535        return String::new();
536    }
537
538    let column_list = insert_columns
539        .iter()
540        .map(|c| format!("`{}`", c.name))
541        .collect::<Vec<_>>()
542        .join(", ");
543
544    let placeholders = insert_columns
545        .iter()
546        .map(|_| "?")
547        .collect::<Vec<_>>()
548        .join(", ");
549
550    let update_clause = update_columns
551        .iter()
552        .map(|c| format!("`{name}` = VALUES(`{name}`)", name = c.name))
553        .collect::<Vec<_>>()
554        .join(", ");
555
556    let bind_fields = insert_columns
557        .iter()
558        .map(|c| {
559            let field = escape_field_name(&c.name);
560            let rust_type = TypeResolver::resolve(c, &table.name);
561            if rust_type.is_copy() {
562                format!("        .bind(entity.{})", field)
563            } else {
564                format!("        .bind(&entity.{})", field)
565            }
566        })
567        .collect::<Vec<_>>()
568        .join("\n");
569
570    format!(
571        r#"/// Upsert a record (insert or update on duplicate key)
572/// Returns rows_affected: 1 if inserted, 2 if updated
573pub async fn upsert<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
574    Query::new("INSERT INTO `{table_name}` ({column_list}) VALUES ({placeholders}) \
575         ON DUPLICATE KEY UPDATE {update_clause}")
576{bind_fields}
577        .execute(pool)
578        .await
579        .map(|r| r.rows_affected)
580}}
581
582"#,
583        struct_name = struct_name,
584        table_name = table.name,
585        column_list = column_list,
586        placeholders = placeholders,
587        update_clause = update_clause,
588        bind_fields = bind_fields,
589    )
590}
591
592/// Generate update methods
593fn generate_update_methods(
594    table: &TableMetadata,
595    struct_name: &str,
596    column_map: &HashMap<&str, &ColumnMetadata>,
597) -> String {
598    let mut code = String::new();
599
600    let pk = table.primary_key.as_ref().unwrap();
601
602    // Get non-PK columns for SET clause
603    let update_columns: Vec<&ColumnMetadata> = table
604        .columns
605        .iter()
606        .filter(|c| !pk.columns.contains(&c.name))
607        .collect();
608
609    if update_columns.is_empty() {
610        return code;
611    }
612
613    let set_clause = update_columns
614        .iter()
615        .map(|c| format!("`{}` = ?", c.name))
616        .collect::<Vec<_>>()
617        .join(", ");
618
619    let where_clause = build_where_clause(&pk.columns);
620
621    // Bind update columns first, then PK columns
622    let bind_fields: Vec<String> = update_columns
623        .iter()
624        .map(|c| {
625            let field = escape_field_name(&c.name);
626            let rust_type = TypeResolver::resolve(c, &table.name);
627            if rust_type.is_copy() {
628                format!("        .bind(entity.{})", field)
629            } else {
630                format!("        .bind(&entity.{})", field)
631            }
632        })
633        .chain(pk.columns.iter().map(|c| {
634            let field = escape_field_name(c);
635            let col = column_map.get(c.as_str()).unwrap();
636            let rust_type = TypeResolver::resolve(col, &table.name);
637            if rust_type.is_copy() {
638                format!("        .bind(entity.{})", field)
639            } else {
640                format!("        .bind(&entity.{})", field)
641            }
642        }))
643        .collect();
644
645    // update_by_bean (using entity)
646    code.push_str(&format!(
647        r#"/// Update a record by primary key
648pub async fn update<P: Pool>(pool: &P, entity: &{struct_name}) -> Result<u64> {{
649    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
650{bind_fields}
651        .execute(pool)
652        .await
653        .map(|r| r.rows_affected)
654}}
655
656"#,
657        struct_name = struct_name,
658        table_name = table.name,
659        set_clause = set_clause,
660        where_clause = where_clause,
661        bind_fields = bind_fields.join("\n"),
662    ));
663
664    code
665}
666
667/// Generate update_plain method with individual parameters (update_by_<pk>)
668fn generate_update_plain_method(
669    table: &TableMetadata,
670    column_map: &HashMap<&str, &ColumnMetadata>,
671) -> String {
672    let pk = table.primary_key.as_ref().unwrap();
673
674    // Get non-PK columns for SET clause
675    let update_columns: Vec<&ColumnMetadata> = table
676        .columns
677        .iter()
678        .filter(|c| !pk.columns.contains(&c.name))
679        .collect();
680
681    if update_columns.is_empty() {
682        return String::new();
683    }
684
685    // Build method name based on PK columns
686    let method_name = generate_update_by_method_name(&pk.columns);
687
688    // Build params: PK columns first, then update columns
689    let pk_params = build_params(&pk.columns, column_map, &table.name);
690    let update_column_names: Vec<String> = update_columns.iter().map(|c| c.name.clone()).collect();
691    let update_params = build_params(&update_column_names, column_map, &table.name);
692    let all_params = format!("{}, {}", pk_params, update_params);
693
694    let set_clause = update_columns
695        .iter()
696        .map(|c| format!("`{}` = ?", c.name))
697        .collect::<Vec<_>>()
698        .join(", ");
699
700    let where_clause = build_where_clause(&pk.columns);
701
702    // Bind update columns first (for SET), then PK columns (for WHERE)
703    let bind_section_update = generate_bind_section(&update_column_names);
704    let bind_section_pk = generate_bind_section(&pk.columns);
705
706    format!(
707        r#"/// Update a record by primary key with individual parameters
708#[allow(clippy::too_many_arguments)]
709pub async fn {method_name}<P: Pool>(pool: &P, {all_params}) -> Result<u64> {{
710    Query::new("UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}")
711{bind_section_update}
712{bind_section_pk}
713        .execute(pool)
714        .await
715        .map(|r| r.rows_affected)
716}}
717
718"#,
719        method_name = method_name,
720        all_params = all_params,
721        table_name = table.name,
722        set_clause = set_clause,
723        where_clause = where_clause,
724        bind_section_update = bind_section_update,
725        bind_section_pk = bind_section_pk,
726    )
727}
728
729/// Collect method signatures with priority-based deduplication
730fn collect_method_signatures(table: &TableMetadata) -> HashMap<Vec<String>, MethodSignature> {
731    let mut signatures: HashMap<Vec<String>, MethodSignature> = HashMap::new();
732
733    // Priority 1: Primary key
734    if let Some(pk) = &table.primary_key {
735        let sig = MethodSignature::new(
736            pk.columns.clone(),
737            PRIORITY_PRIMARY_KEY,
738            true,
739            "PRIMARY_KEY",
740        );
741        signatures.insert(pk.columns.clone(), sig);
742    }
743
744    // Priority 2 & 3: Indexes
745    for index in &table.indexes {
746        let priority = if index.unique {
747            PRIORITY_UNIQUE_INDEX
748        } else {
749            PRIORITY_NON_UNIQUE_INDEX
750        };
751        let source = if index.unique {
752            "UNIQUE_INDEX"
753        } else {
754            "NON_UNIQUE_INDEX"
755        };
756        let sig = MethodSignature::new(index.columns.clone(), priority, index.unique, source);
757
758        // Only add if no higher priority signature exists
759        if let Some(existing) = signatures.get(&index.columns) {
760            if sig.priority < existing.priority {
761                signatures.insert(index.columns.clone(), sig);
762            }
763        } else {
764            signatures.insert(index.columns.clone(), sig);
765        }
766    }
767
768    // Priority 4: Foreign keys
769    for fk in &table.foreign_keys {
770        let columns = vec![fk.column_name.clone()];
771        let sig = MethodSignature::new(columns.clone(), PRIORITY_FOREIGN_KEY, false, "FOREIGN_KEY");
772
773        signatures.entry(columns).or_insert(sig);
774    }
775
776    signatures
777}
778
779/// Generate a find_by method for an index/FK
780fn generate_find_by_method(
781    table: &TableMetadata,
782    sig: &MethodSignature,
783    column_map: &HashMap<&str, &ColumnMetadata>,
784    struct_name: &str,
785    select_columns: &str,
786) -> String {
787    let params = build_params(&sig.columns, column_map, &table.name);
788
789    let (return_type, fetch_method) = if sig.is_unique {
790        (format!("Option<{}>", struct_name), "fetch_optional")
791    } else {
792        (format!("Vec<{}>", struct_name), "fetch_all")
793    };
794
795    let return_desc = if sig.is_unique {
796        "Option (unique)"
797    } else {
798        "Vec (non-unique)"
799    };
800
801    // Check if any column is nullable - if so, we need dynamic SQL for IS NULL handling
802    let has_nullable = sig.columns.iter().any(|c| {
803        column_map
804            .get(c.as_str())
805            .map(|col| col.nullable)
806            .unwrap_or(false)
807    });
808
809    if has_nullable {
810        // Generate dynamic query that handles NULL values properly
811        generate_find_by_method_nullable(
812            table,
813            sig,
814            column_map,
815            select_columns,
816            &params,
817            &return_type,
818            fetch_method,
819            return_desc,
820        )
821    } else {
822        // Use static query for non-nullable columns
823        let where_clause = build_where_clause(&sig.columns);
824        let bind_section = generate_bind_section(&sig.columns);
825
826        format!(
827            r#"/// Find by {source}: returns {return_desc}
828pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
829    Query::new("SELECT {select_columns} FROM `{table_name}` WHERE {where_clause}")
830{bind_section}
831        .{fetch_method}(pool)
832        .await
833}}
834
835"#,
836            source = sig.source.to_lowercase().replace('_', " "),
837            return_desc = return_desc,
838            method_name = sig.method_name,
839            params = params,
840            return_type = return_type,
841            select_columns = select_columns,
842            table_name = table.name,
843            where_clause = where_clause,
844            bind_section = bind_section,
845            fetch_method = fetch_method,
846        )
847    }
848}
849
850/// Generate a find_by method that handles nullable columns with IS NULL
851#[allow(clippy::too_many_arguments)]
852fn generate_find_by_method_nullable(
853    table: &TableMetadata,
854    sig: &MethodSignature,
855    column_map: &HashMap<&str, &ColumnMetadata>,
856    select_columns: &str,
857    params: &str,
858    return_type: &str,
859    fetch_method: &str,
860    return_desc: &str,
861) -> String {
862    // Build the where clause conditions and bind logic
863    let mut where_parts = Vec::new();
864    let mut bind_parts = Vec::new();
865
866    for col in &sig.columns {
867        let col_meta = column_map.get(col.as_str()).unwrap();
868        let field_name = escape_field_name(col);
869
870        if col_meta.nullable {
871            // For nullable columns, check if value is None and use IS NULL
872            where_parts.push(format!(
873                r#"if {field}.is_some() {{ "`{col}` = ?" }} else {{ "`{col}` IS NULL" }}"#,
874                field = field_name,
875                col = col,
876            ));
877            bind_parts.push(format!(
878                r#"if let Some(v) = {field}.as_ref() {{ query = query.bind(v); }}"#,
879                field = field_name,
880            ));
881        } else {
882            where_parts.push(format!(r#""`{}` = ?""#, col));
883            bind_parts.push(format!(r#"query = query.bind({});"#, field_name,));
884        }
885    }
886
887    let where_expr = if where_parts.len() == 1 {
888        where_parts[0].clone()
889    } else {
890        // Join with " AND "
891        let parts = where_parts
892            .iter()
893            .map(|p| format!("({})", p))
894            .collect::<Vec<_>>()
895            .join(", ");
896        format!("vec![{}].join(\" AND \")", parts)
897    };
898
899    let bind_code = bind_parts.join("\n        ");
900
901    format!(
902        r#"/// Find by {source}: returns {return_desc}
903pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<{return_type}> {{
904    let where_clause = {where_expr};
905    let sql = format!("SELECT {select_columns} FROM `{table_name}` WHERE {{}}", where_clause);
906    let mut query = rdbi::DynamicQuery::new(sql);
907    {bind_code}
908    query.{fetch_method}(pool).await
909}}
910
911"#,
912        source = sig.source.to_lowercase().replace('_', " "),
913        return_desc = return_desc,
914        method_name = sig.method_name,
915        params = params,
916        return_type = return_type,
917        select_columns = select_columns,
918        table_name = table.name,
919        where_expr = where_expr,
920        bind_code = bind_code,
921        fetch_method = fetch_method,
922    )
923}
924
925/// Generate list-based findBy methods for single-column indexes
926fn generate_find_by_list_methods(
927    table: &TableMetadata,
928    column_map: &HashMap<&str, &ColumnMetadata>,
929    struct_name: &str,
930    select_columns: &str,
931) -> String {
932    let mut code = String::new();
933    let mut processed: HashSet<String> = HashSet::new();
934
935    // Primary key (if single column)
936    if let Some(pk) = &table.primary_key {
937        if pk.columns.len() == 1 {
938            let col = &pk.columns[0];
939            code.push_str(&generate_single_find_by_list(
940                table,
941                col,
942                column_map,
943                struct_name,
944                select_columns,
945            ));
946            processed.insert(col.clone());
947        }
948    }
949
950    // Single-column indexes
951    for index in &table.indexes {
952        if index.columns.len() == 1 {
953            let col = &index.columns[0];
954            if !processed.contains(col) {
955                code.push_str(&generate_single_find_by_list(
956                    table,
957                    col,
958                    column_map,
959                    struct_name,
960                    select_columns,
961                ));
962                processed.insert(col.clone());
963            }
964        }
965    }
966
967    code
968}
969
970/// Generate a single find_by_<column>s method (using IN clause)
971fn generate_single_find_by_list(
972    table: &TableMetadata,
973    column_name: &str,
974    column_map: &HashMap<&str, &ColumnMetadata>,
975    struct_name: &str,
976    select_columns: &str,
977) -> String {
978    let method_name = generate_find_by_list_method_name(column_name);
979    let param_name = pluralize(&escape_field_name(column_name));
980    let column = column_map.get(column_name).unwrap();
981    let rust_type = TypeResolver::resolve(column, &table.name);
982
983    // Get the inner type (unwrap Option if nullable)
984    let inner_type = rust_type.inner_type().to_type_string();
985
986    let column_name_plural = pluralize(column_name);
987    format!(
988        r#"/// Find by list of {column_name_plural} (IN clause)
989pub async fn {method_name}<P: Pool>(pool: &P, {param_name}: &[{inner_type}]) -> Result<Vec<{struct_name}>> {{
990    if {param_name}.is_empty() {{
991        return Ok(Vec::new());
992    }}
993    let placeholders = {param_name}.iter().map(|_| "?").collect::<Vec<_>>().join(",");
994    let query = format!(
995        "SELECT {select_columns} FROM `{table_name}` WHERE `{column_name}` IN ({{}})",
996        placeholders
997    );
998    rdbi::DynamicQuery::new(query)
999        .bind_all({param_name})
1000        .fetch_all(pool)
1001        .await
1002}}
1003
1004"#,
1005        column_name_plural = column_name_plural,
1006        column_name = column_name,
1007        method_name = method_name,
1008        param_name = param_name,
1009        inner_type = inner_type,
1010        struct_name = struct_name,
1011        select_columns = select_columns,
1012        table_name = table.name,
1013    )
1014}
1015
1016/// Generate composite enum list methods for multi-column indexes with enum columns
1017/// Example: find_by_user_id_and_device_types(user_id, &[DeviceType])
1018fn generate_composite_enum_list_methods(
1019    table: &TableMetadata,
1020    column_map: &HashMap<&str, &ColumnMetadata>,
1021    struct_name: &str,
1022    select_columns: &str,
1023) -> String {
1024    let mut code = String::new();
1025
1026    for index in &table.indexes {
1027        // Skip single-column indexes (handled by generate_find_by_list_methods)
1028        if index.columns.len() <= 1 {
1029            continue;
1030        }
1031
1032        // Identify which columns are enums
1033        let enum_columns: HashSet<&str> = index
1034            .columns
1035            .iter()
1036            .filter(|col_name| {
1037                column_map
1038                    .get(col_name.as_str())
1039                    .map(|col| col.is_enum())
1040                    .unwrap_or(false)
1041            })
1042            .map(|s| s.as_str())
1043            .collect();
1044
1045        // Skip if no enum columns
1046        if enum_columns.is_empty() {
1047            continue;
1048        }
1049
1050        // Skip if first column is enum (for optimal index usage, equality should be on leading column)
1051        let first_column = &index.columns[0];
1052        if enum_columns.contains(first_column.as_str()) {
1053            continue;
1054        }
1055
1056        code.push_str(&generate_composite_enum_list_method(
1057            table,
1058            &index.columns,
1059            &enum_columns,
1060            column_map,
1061            struct_name,
1062            select_columns,
1063        ));
1064    }
1065
1066    code
1067}
1068
1069/// Generate a single composite enum list method
1070fn generate_composite_enum_list_method(
1071    table: &TableMetadata,
1072    columns: &[String],
1073    enum_columns: &HashSet<&str>,
1074    column_map: &HashMap<&str, &ColumnMetadata>,
1075    struct_name: &str,
1076    select_columns: &str,
1077) -> String {
1078    // Build method name: pluralize enum column names
1079    let method_name = generate_composite_enum_method_name(columns, enum_columns);
1080
1081    // Build params and WHERE clause parts
1082    let mut params_parts = Vec::new();
1083
1084    for col_name in columns {
1085        let col = column_map.get(col_name.as_str()).unwrap();
1086        let rust_type = TypeResolver::resolve(col, &table.name);
1087        let is_enum = enum_columns.contains(col_name.as_str());
1088
1089        if is_enum {
1090            // Enum column uses list parameter
1091            let param_name = pluralize(&escape_field_name(col_name));
1092            let inner_type = rust_type.inner_type().to_type_string();
1093            params_parts.push(format!("{}: &[{}]", param_name, inner_type));
1094        } else {
1095            // Non-enum column uses single value
1096            let param_name = escape_field_name(col_name);
1097            let param_type = rust_type.to_param_type_string();
1098            params_parts.push(format!("{}: {}", param_name, param_type));
1099        }
1100    }
1101
1102    let params = params_parts.join(", ");
1103
1104    // Build WHERE clause with proper placeholders
1105    let where_clause_static: Vec<String> = columns
1106        .iter()
1107        .map(|col_name| {
1108            if enum_columns.contains(col_name.as_str()) {
1109                format!("`{}` IN ({{}})", col_name) // placeholder for IN clause
1110            } else {
1111                format!("`{}` = ?", col_name)
1112            }
1113        })
1114        .collect();
1115
1116    // Build the bind section
1117    let mut bind_code = String::new();
1118
1119    // First bind non-enum (single value) columns
1120    for col_name in columns {
1121        if !enum_columns.contains(col_name.as_str()) {
1122            let param_name = escape_field_name(col_name);
1123            bind_code.push_str(&format!("        .bind({})\n", param_name));
1124        }
1125    }
1126
1127    // Then bind enum (list) columns
1128    for col_name in columns {
1129        if enum_columns.contains(col_name.as_str()) {
1130            let param_name = pluralize(&escape_field_name(col_name));
1131            bind_code.push_str(&format!("        .bind_all({})\n", param_name));
1132        }
1133    }
1134
1135    // Build the column name description for doc comment
1136    let column_desc: Vec<String> = columns
1137        .iter()
1138        .map(|col| {
1139            if enum_columns.contains(col.as_str()) {
1140                pluralize(col)
1141            } else {
1142                col.clone()
1143            }
1144        })
1145        .collect();
1146
1147    // Build dynamic WHERE clause construction
1148    let enum_col_names: Vec<&str> = columns
1149        .iter()
1150        .filter(|c| enum_columns.contains(c.as_str()))
1151        .map(|s| s.as_str())
1152        .collect();
1153
1154    // Generate the IN clause placeholders dynamically
1155    let in_clause_builders: Vec<String> = enum_col_names
1156        .iter()
1157        .map(|col| {
1158            let param_name = pluralize(&escape_field_name(col));
1159            format!(
1160                "{param_name}.iter().map(|_| \"?\").collect::<Vec<_>>().join(\",\")",
1161                param_name = param_name
1162            )
1163        })
1164        .collect();
1165
1166    // Build format args for the WHERE clause
1167    let format_args = in_clause_builders.join(", ");
1168
1169    format!(
1170        r#"/// Find by {column_desc} (composite index with IN clause for enum columns)
1171pub async fn {method_name}<P: Pool>(pool: &P, {params}) -> Result<Vec<{struct_name}>> {{
1172    // Check for empty enum lists
1173{empty_checks}
1174    // Build IN clause placeholders for enum columns
1175    let where_clause = format!("{where_template}", {format_args});
1176    let query = format!(
1177        "SELECT {select_columns} FROM `{table_name}` WHERE {{}}",
1178        where_clause
1179    );
1180    rdbi::DynamicQuery::new(query)
1181{bind_code}        .fetch_all(pool)
1182        .await
1183}}
1184
1185"#,
1186        column_desc = column_desc.join(" and "),
1187        method_name = method_name,
1188        params = params,
1189        struct_name = struct_name,
1190        select_columns = select_columns,
1191        table_name = table.name,
1192        where_template = where_clause_static.join(" AND "),
1193        format_args = format_args,
1194        bind_code = bind_code,
1195        empty_checks = generate_empty_checks(columns, enum_columns),
1196    )
1197}
1198
1199/// Generate empty checks for enum list parameters
1200fn generate_empty_checks(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1201    let mut checks = String::new();
1202    for col_name in columns {
1203        if enum_columns.contains(col_name.as_str()) {
1204            let param_name = pluralize(&escape_field_name(col_name));
1205            checks.push_str(&format!(
1206                "    if {}.is_empty() {{ return Ok(Vec::new()); }}\n",
1207                param_name
1208            ));
1209        }
1210    }
1211    checks
1212}
1213
1214/// Generate method name for composite enum queries
1215/// Example: ["user_id", "device_type"] with enum_columns={"device_type"} -> "find_by_user_id_and_device_types"
1216fn generate_composite_enum_method_name(columns: &[String], enum_columns: &HashSet<&str>) -> String {
1217    let mut parts = Vec::new();
1218    for col in columns {
1219        if enum_columns.contains(col.as_str()) {
1220            parts.push(pluralize(col));
1221        } else {
1222            parts.push(col.clone());
1223        }
1224    }
1225    generate_find_by_method_name(&parts)
1226}
1227
1228/// Generate pagination methods (find_all_paginated, get_paginated_result)
1229fn generate_pagination_methods(
1230    table: &TableMetadata,
1231    struct_name: &str,
1232    select_columns: &str,
1233    models_module: &str,
1234) -> String {
1235    let sort_by_enum = format!("{}SortBy", struct_name);
1236
1237    format!(
1238        r#"/// Find all records with pagination and sorting
1239pub async fn find_all_paginated<P: Pool>(
1240    pool: &P,
1241    limit: i32,
1242    offset: i32,
1243    sort_by: crate::{models_module}::{sort_by_enum},
1244    sort_dir: crate::{models_module}::SortDirection,
1245) -> Result<Vec<{struct_name}>> {{
1246    let order_clause = format!("{{}} {{}}", sort_by.as_sql(), sort_dir.as_sql());
1247    let query = format!(
1248        "SELECT {select_columns} FROM `{table_name}` ORDER BY {{}} LIMIT ? OFFSET ?",
1249        order_clause
1250    );
1251    rdbi::DynamicQuery::new(query)
1252        .bind(limit)
1253        .bind(offset)
1254        .fetch_all(pool)
1255        .await
1256}}
1257
1258/// Get paginated result with total count
1259pub async fn get_paginated_result<P: Pool>(
1260    pool: &P,
1261    page_size: i32,
1262    current_page: i32,
1263    sort_by: crate::{models_module}::{sort_by_enum},
1264    sort_dir: crate::{models_module}::SortDirection,
1265) -> Result<crate::{models_module}::PaginatedResult<{struct_name}>> {{
1266    let page_size = page_size.max(1);
1267    let current_page = current_page.max(1);
1268    let offset = (current_page - 1) * page_size;
1269
1270    let total_count = count_all(pool).await?;
1271    let items = find_all_paginated(pool, page_size, offset, sort_by, sort_dir).await?;
1272
1273    Ok(crate::{models_module}::PaginatedResult::new(
1274        items,
1275        total_count,
1276        current_page,
1277        page_size,
1278    ))
1279}}
1280
1281"#,
1282        struct_name = struct_name,
1283        select_columns = select_columns,
1284        table_name = table.name,
1285        models_module = models_module,
1286        sort_by_enum = sort_by_enum,
1287    )
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292    use super::*;
1293    use crate::parser::{IndexMetadata, PrimaryKey};
1294
1295    fn make_table() -> TableMetadata {
1296        TableMetadata {
1297            name: "users".to_string(),
1298            comment: None,
1299            columns: vec![
1300                ColumnMetadata {
1301                    name: "id".to_string(),
1302                    data_type: "BIGINT".to_string(),
1303                    nullable: false,
1304                    default_value: None,
1305                    is_auto_increment: true,
1306                    is_unsigned: false,
1307                    enum_values: None,
1308                    comment: None,
1309                },
1310                ColumnMetadata {
1311                    name: "email".to_string(),
1312                    data_type: "VARCHAR(255)".to_string(),
1313                    nullable: false,
1314                    default_value: None,
1315                    is_auto_increment: false,
1316                    is_unsigned: false,
1317                    enum_values: None,
1318                    comment: None,
1319                },
1320                ColumnMetadata {
1321                    name: "status".to_string(),
1322                    data_type: "ENUM".to_string(),
1323                    nullable: false,
1324                    default_value: None,
1325                    is_auto_increment: false,
1326                    is_unsigned: false,
1327                    enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
1328                    comment: None,
1329                },
1330            ],
1331            indexes: vec![
1332                IndexMetadata {
1333                    name: "email_unique".to_string(),
1334                    columns: vec!["email".to_string()],
1335                    unique: true,
1336                },
1337                IndexMetadata {
1338                    name: "idx_status".to_string(),
1339                    columns: vec!["status".to_string()],
1340                    unique: false,
1341                },
1342            ],
1343            foreign_keys: vec![],
1344            primary_key: Some(PrimaryKey {
1345                columns: vec!["id".to_string()],
1346            }),
1347        }
1348    }
1349
1350    #[test]
1351    fn test_collect_method_signatures() {
1352        let table = make_table();
1353        let sigs = collect_method_signatures(&table);
1354
1355        // Should have signatures for: id (PK), email (unique), status (non-unique)
1356        assert_eq!(sigs.len(), 3);
1357
1358        let id_sig = sigs.get(&vec!["id".to_string()]).unwrap();
1359        assert!(id_sig.is_unique);
1360        assert_eq!(id_sig.priority, PRIORITY_PRIMARY_KEY);
1361
1362        let email_sig = sigs.get(&vec!["email".to_string()]).unwrap();
1363        assert!(email_sig.is_unique);
1364        assert_eq!(email_sig.priority, PRIORITY_UNIQUE_INDEX);
1365
1366        let status_sig = sigs.get(&vec!["status".to_string()]).unwrap();
1367        assert!(!status_sig.is_unique);
1368        assert_eq!(status_sig.priority, PRIORITY_NON_UNIQUE_INDEX);
1369    }
1370
1371    #[test]
1372    fn test_build_select_columns() {
1373        let table = make_table();
1374        let cols = build_select_columns(&table);
1375        assert!(cols.contains("`id`"));
1376        assert!(cols.contains("`email`"));
1377        assert!(cols.contains("`status`"));
1378    }
1379
1380    #[test]
1381    fn test_build_where_clause() {
1382        let clause = build_where_clause(&["id".to_string()]);
1383        assert_eq!(clause, "`id` = ?");
1384
1385        let clause = build_where_clause(&["user_id".to_string(), "role_id".to_string()]);
1386        assert_eq!(clause, "`user_id` = ? AND `role_id` = ?");
1387    }
1388
1389    #[test]
1390    fn test_generate_upsert_method() {
1391        let table = make_table();
1392        let code = generate_upsert_method(&table, "Users");
1393
1394        // Should contain upsert function
1395        assert!(code.contains("pub async fn upsert"));
1396        // Should contain ON DUPLICATE KEY UPDATE
1397        assert!(code.contains("ON DUPLICATE KEY UPDATE"));
1398        // Should NOT update the PK column (id)
1399        assert!(!code.contains("`id` = VALUES(`id`)"));
1400        // Should update non-PK columns
1401        assert!(code.contains("`email` = VALUES(`email`)"));
1402        assert!(code.contains("`status` = VALUES(`status`)"));
1403    }
1404
1405    #[test]
1406    fn test_generate_upsert_method_no_pk() {
1407        let mut table = make_table();
1408        table.primary_key = None;
1409        table.indexes.clear();
1410
1411        let code = generate_upsert_method(&table, "Users");
1412        // Should not generate upsert without PK or unique index
1413        assert!(code.is_empty());
1414    }
1415
1416    #[test]
1417    fn test_generate_insert_all_method() {
1418        let table = make_table();
1419        let code = generate_insert_all_method(&table, "Users");
1420
1421        // Should contain insert_all function
1422        assert!(code.contains("pub async fn insert_all"));
1423        // Should use BatchInsert
1424        assert!(code.contains("rdbi::BatchInsert::new"));
1425    }
1426
1427    #[test]
1428    fn test_generate_pagination_methods() {
1429        let table = make_table();
1430        let select_columns = build_select_columns(&table);
1431        let code = generate_pagination_methods(&table, "Users", &select_columns, "models");
1432
1433        // Should contain find_all_paginated function
1434        assert!(code.contains("pub async fn find_all_paginated"));
1435        // Should have limit and offset params
1436        assert!(code.contains("limit: i32"));
1437        assert!(code.contains("offset: i32"));
1438        // Should use SortBy enum
1439        assert!(code.contains("UsersSortBy"));
1440        // Should use SortDirection
1441        assert!(code.contains("SortDirection"));
1442        // Should contain get_paginated_result
1443        assert!(code.contains("pub async fn get_paginated_result"));
1444        // Should use PaginatedResult
1445        assert!(code.contains("PaginatedResult<Users>"));
1446    }
1447}