Skip to main content

rdbi_codegen/codegen/
struct_generator.rs

1//! Struct generator - generates Rust structs from table metadata
2
3use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15/// Generate struct files for all tables
16pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17    let output_dir = &config.output_structs_dir;
18    fs::create_dir_all(output_dir)?;
19
20    // Generate mod.rs with shared pagination types
21    let mut mod_content = String::new();
22    mod_content.push_str("// Generated model structs\n\n");
23
24    for table in tables {
25        let file_name = heck::AsSnakeCase(&table.name).to_string();
26        mod_content.push_str(&format!("mod {};\n", file_name));
27        mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28    }
29
30    // Add shared pagination types
31    mod_content.push('\n');
32    mod_content.push_str(&generate_shared_pagination_types());
33
34    let mod_path = output_dir.join("mod.rs");
35    fs::write(&mod_path, mod_content)?;
36    super::format_file(&mod_path);
37
38    // Generate each struct file
39    for table in tables {
40        generate_struct_file(table, output_dir)?;
41    }
42
43    Ok(())
44}
45
46/// Generate shared pagination types (SortDirection, PaginatedResult)
47fn generate_shared_pagination_types() -> String {
48    r#"/// Sort direction for pagination
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SortDirection {
51    Asc,
52    Desc,
53}
54
55impl SortDirection {
56    pub fn as_sql(&self) -> &'static str {
57        match self {
58            Self::Asc => "ASC",
59            Self::Desc => "DESC",
60        }
61    }
62}
63
64/// Paginated result container
65#[derive(Debug, Clone)]
66pub struct PaginatedResult<T> {
67    pub items: Vec<T>,
68    pub total_count: i64,
69    pub current_page: i32,
70    pub total_pages: i32,
71    pub page_size: i32,
72    pub has_next: bool,
73}
74
75impl<T> PaginatedResult<T> {
76    pub fn new(
77        items: Vec<T>,
78        total_count: i64,
79        current_page: i32,
80        page_size: i32,
81    ) -> Self {
82        let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
83        let has_next = current_page < total_pages;
84        Self {
85            items,
86            total_count,
87            current_page,
88            total_pages,
89            page_size,
90            has_next,
91        }
92    }
93}
94"#
95    .to_string()
96}
97
98/// Generate a single struct file for a table
99fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
100    let struct_name = to_struct_name(&table.name);
101    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
102    debug!("Generating struct {} -> {}", struct_name, file_name);
103
104    let mut code = String::new();
105
106    // Collect enum columns
107    let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
108    for col in &table.columns {
109        if col.is_enum() {
110            enum_columns.push(col);
111        }
112    }
113
114    // Generate imports (serde and rdbi)
115    code.push_str("use serde::{Deserialize, Serialize};\n");
116
117    code.push('\n');
118
119    // Generate enum types first
120    for col in &enum_columns {
121        if let Some(values) = &col.enum_values {
122            code.push_str(&generate_enum(&table.name, col, values));
123            code.push('\n');
124        }
125    }
126
127    // Generate struct documentation
128    code.push_str(&format!("/// Database table: `{}`\n", table.name));
129    if let Some(comment) = &table.comment {
130        if !comment.is_empty() {
131            code.push_str(&format!("///\n/// {}\n", comment));
132        }
133    }
134
135    // Generate struct with derives (rdbi::FromRow and rdbi::ToParams)
136    code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
137    code.push_str(&format!("pub struct {} {{\n", struct_name));
138
139    // Generate fields
140    for col in &table.columns {
141        let field_name = escape_field_name(&col.name);
142        let rust_type = TypeResolver::resolve(col, &table.name);
143
144        // Add field documentation
145        code.push_str(&format!("    /// Column: `{}`", col.name));
146
147        // Add index info
148        let index_info = get_index_info(table, &col.name);
149        if !index_info.is_empty() {
150            code.push_str(&format!(" ({})", index_info.join(", ")));
151        }
152
153        if let Some(comment) = &col.comment {
154            if !comment.is_empty() {
155                code.push_str(&format!(" - {}", comment));
156            }
157        }
158        code.push('\n');
159
160        // Add rdbi attributes
161        let mut attrs = Vec::new();
162
163        // Add rename attribute if field name differs from column name
164        if field_name != col.name {
165            attrs.push(format!("rename = \"{}\"", col.name));
166        }
167
168        // Add skip_insert for auto-increment columns
169        if col.is_auto_increment {
170            attrs.push("skip_insert".to_string());
171        }
172
173        if !attrs.is_empty() {
174            code.push_str(&format!("    #[rdbi({})]\n", attrs.join(", ")));
175        }
176
177        // Add serde rename attribute if field name differs from column name
178        // This is especially important for raw identifiers (r#type -> "type")
179        if field_name != col.name {
180            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", col.name));
181        }
182
183        code.push_str(&format!(
184            "    pub {}: {},\n",
185            field_name,
186            rust_type.to_type_string()
187        ));
188    }
189
190    code.push_str("}\n");
191
192    // Generate SortBy enum for pagination
193    code.push('\n');
194    code.push_str(&generate_sort_by_enum(table));
195
196    let file_path = output_dir.join(&file_name);
197    fs::write(&file_path, code)?;
198    super::format_file(&file_path);
199    Ok(())
200}
201
202/// Generate SortBy enum for a table (used in pagination)
203fn generate_sort_by_enum(table: &TableMetadata) -> String {
204    let struct_name = to_struct_name(&table.name);
205    let enum_name = format!("{}SortBy", struct_name);
206
207    let mut code = String::new();
208
209    code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
210    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
211    code.push_str(&format!("pub enum {} {{\n", enum_name));
212
213    for col in &table.columns {
214        let variant = heck::AsPascalCase(&col.name).to_string();
215        code.push_str(&format!("    {},\n", variant));
216    }
217
218    code.push_str("}\n\n");
219
220    // Generate as_sql impl
221    code.push_str(&format!("impl {} {{\n", enum_name));
222    code.push_str("    pub fn as_sql(&self) -> &'static str {\n");
223    code.push_str("        match self {\n");
224
225    for col in &table.columns {
226        let variant = heck::AsPascalCase(&col.name).to_string();
227        code.push_str(&format!(
228            "            Self::{} => \"`{}`\",\n",
229            variant, col.name
230        ));
231    }
232
233    code.push_str("        }\n");
234    code.push_str("    }\n");
235    code.push_str("}\n");
236
237    code
238}
239
240/// Generate an enum type for a column
241fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
242    let enum_name = to_enum_name(table_name, &column.name);
243    let mut code = String::new();
244
245    // Add documentation
246    code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
247
248    // Generate enum with derives (for rdbi, we need to implement FromValue and ToValue)
249    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
250    code.push_str(&format!("pub enum {} {{\n", enum_name));
251
252    // Track used variant names to avoid duplicates
253    let mut used_variants: HashSet<String> = HashSet::new();
254    let mut variant_mappings: Vec<(String, String)> = Vec::new();
255
256    for value in values {
257        let variant = to_enum_variant(value);
258
259        // Handle duplicate variants (shouldn't happen but be safe)
260        let final_variant = if used_variants.contains(&variant) {
261            let mut counter = 2;
262            loop {
263                let new_variant = format!("{}{}", variant, counter);
264                if !used_variants.contains(&new_variant) {
265                    break new_variant;
266                }
267                counter += 1;
268            }
269        } else {
270            variant
271        };
272
273        used_variants.insert(final_variant.clone());
274
275        // Clean the value (remove quotes)
276        let clean_value = value.trim_matches('\'').trim_matches('"');
277
278        // Add serde rename attribute if variant differs from original value
279        if final_variant != clean_value {
280            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", clean_value));
281        }
282
283        code.push_str(&format!("    {},\n", final_variant));
284        variant_mappings.push((final_variant, clean_value.to_string()));
285    }
286
287    code.push_str("}\n\n");
288
289    // Generate FromValue implementation for rdbi
290    code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
291    code.push_str("    fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
292    code.push_str("        match value {\n");
293    code.push_str("            rdbi::Value::String(s) => match s.as_str() {\n");
294    for (variant, db_value) in &variant_mappings {
295        code.push_str(&format!(
296            "                \"{}\" => Ok(Self::{}),\n",
297            db_value, variant
298        ));
299    }
300    code.push_str(&format!(
301        "                _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
302        enum_name
303    ));
304    code.push_str("            },\n");
305    code.push_str(&format!(
306        "            _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
307        enum_name
308    ));
309    code.push_str("        }\n");
310    code.push_str("    }\n");
311    code.push_str("}\n\n");
312
313    // Generate ToValue implementation for rdbi
314    code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
315    code.push_str("    fn to_value(&self) -> rdbi::Value {\n");
316    code.push_str("        rdbi::Value::String(match self {\n");
317    for (variant, db_value) in &variant_mappings {
318        code.push_str(&format!(
319            "            Self::{} => \"{}\".to_string(),\n",
320            variant, db_value
321        ));
322    }
323    code.push_str("        })\n");
324    code.push_str("    }\n");
325    code.push_str("}\n");
326
327    code
328}
329
330/// Get index information for a column
331fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
332    let mut info = Vec::new();
333
334    // Check primary key
335    if let Some(pk) = &table.primary_key {
336        if pk.columns.contains(&column_name.to_string()) {
337            info.push("PRIMARY KEY".to_string());
338        }
339    }
340
341    // Check indexes
342    for index in &table.indexes {
343        if index.columns.contains(&column_name.to_string()) {
344            let label = if index.unique {
345                format!("UNIQUE: {}", index.name)
346            } else {
347                format!("INDEX: {}", index.name)
348            };
349            info.push(label);
350        }
351    }
352
353    info
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::parser::{IndexMetadata, PrimaryKey};
360
361    fn make_table() -> TableMetadata {
362        TableMetadata {
363            name: "users".to_string(),
364            comment: None,
365            columns: vec![
366                ColumnMetadata {
367                    name: "id".to_string(),
368                    data_type: "BIGINT".to_string(),
369                    nullable: false,
370                    default_value: None,
371                    is_auto_increment: true,
372                    is_unsigned: false,
373                    enum_values: None,
374                    comment: None,
375                },
376                ColumnMetadata {
377                    name: "username".to_string(),
378                    data_type: "VARCHAR(255)".to_string(),
379                    nullable: false,
380                    default_value: None,
381                    is_auto_increment: false,
382                    is_unsigned: false,
383                    enum_values: None,
384                    comment: None,
385                },
386                ColumnMetadata {
387                    name: "status".to_string(),
388                    data_type: "ENUM".to_string(),
389                    nullable: false,
390                    default_value: None,
391                    is_auto_increment: false,
392                    is_unsigned: false,
393                    enum_values: Some(vec![
394                        "ACTIVE".to_string(),
395                        "INACTIVE".to_string(),
396                        "PENDING".to_string(),
397                    ]),
398                    comment: None,
399                },
400            ],
401            indexes: vec![IndexMetadata {
402                name: "idx_username".to_string(),
403                columns: vec!["username".to_string()],
404                unique: true,
405            }],
406            foreign_keys: vec![],
407            primary_key: Some(PrimaryKey {
408                columns: vec!["id".to_string()],
409            }),
410        }
411    }
412
413    #[test]
414    fn test_get_index_info() {
415        let table = make_table();
416        let info = get_index_info(&table, "id");
417        assert!(info.contains(&"PRIMARY KEY".to_string()));
418
419        let info = get_index_info(&table, "username");
420        assert!(info.iter().any(|i| i.contains("UNIQUE")));
421    }
422
423    #[test]
424    fn test_generate_enum() {
425        let col = ColumnMetadata {
426            name: "status".to_string(),
427            data_type: "ENUM".to_string(),
428            nullable: false,
429            default_value: None,
430            is_auto_increment: false,
431            is_unsigned: false,
432            enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
433            comment: None,
434        };
435
436        let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
437        assert!(code.contains("pub enum UsersStatus"));
438        assert!(code.contains("Active"));
439        assert!(code.contains("Inactive"));
440        // Check for rdbi trait implementations
441        assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
442        assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
443    }
444}