Skip to main content

rdbi_codegen/codegen/
struct_generator.rs

1//! Struct generator - generates Rust structs from table metadata
2
3use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15/// Generate struct files for all tables
16pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17    let output_dir = &config.output_structs_dir;
18    fs::create_dir_all(output_dir)?;
19
20    // Generate mod.rs with shared pagination types
21    let mut mod_content = String::new();
22    mod_content.push_str("// Generated model structs\n\n");
23
24    for table in tables {
25        let file_name = heck::AsSnakeCase(&table.name).to_string();
26        mod_content.push_str(&format!("mod {};\n", file_name));
27        mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28    }
29
30    // Add shared pagination types
31    mod_content.push('\n');
32    mod_content.push_str(&generate_shared_pagination_types());
33
34    let mod_path = output_dir.join("mod.rs");
35    fs::write(&mod_path, mod_content)?;
36
37    // Generate each struct file
38    for table in tables {
39        generate_struct_file(table, output_dir)?;
40    }
41
42    Ok(())
43}
44
45/// Generate shared pagination types (SortDirection, PaginatedResult)
46fn generate_shared_pagination_types() -> String {
47    r#"/// Sort direction for pagination
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum SortDirection {
50    Asc,
51    Desc,
52}
53
54impl SortDirection {
55    pub fn as_sql(&self) -> &'static str {
56        match self {
57            Self::Asc => "ASC",
58            Self::Desc => "DESC",
59        }
60    }
61}
62
63/// Paginated result container
64#[derive(Debug, Clone)]
65pub struct PaginatedResult<T> {
66    pub items: Vec<T>,
67    pub total_count: i64,
68    pub current_page: i32,
69    pub total_pages: i32,
70    pub page_size: i32,
71    pub has_next: bool,
72}
73
74impl<T> PaginatedResult<T> {
75    pub fn new(
76        items: Vec<T>,
77        total_count: i64,
78        current_page: i32,
79        page_size: i32,
80    ) -> Self {
81        let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
82        let has_next = current_page < total_pages;
83        Self {
84            items,
85            total_count,
86            current_page,
87            total_pages,
88            page_size,
89            has_next,
90        }
91    }
92}
93"#
94    .to_string()
95}
96
97/// Generate a single struct file for a table
98fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
99    let struct_name = to_struct_name(&table.name);
100    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
101    debug!("Generating struct {} -> {}", struct_name, file_name);
102
103    let mut code = String::new();
104
105    // Collect enum columns
106    let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
107    for col in &table.columns {
108        if col.is_enum() {
109            enum_columns.push(col);
110        }
111    }
112
113    // Generate imports (serde and rdbi)
114    code.push_str("use serde::{Deserialize, Serialize};\n");
115
116    code.push('\n');
117
118    // Generate enum types first
119    for col in &enum_columns {
120        if let Some(values) = &col.enum_values {
121            code.push_str(&generate_enum(&table.name, col, values));
122            code.push('\n');
123        }
124    }
125
126    // Generate struct documentation
127    code.push_str(&format!("/// Database table: `{}`\n", table.name));
128    if let Some(comment) = &table.comment {
129        if !comment.is_empty() {
130            code.push_str(&format!("///\n/// {}\n", comment));
131        }
132    }
133
134    // Generate struct with derives (rdbi::FromRow and rdbi::ToParams)
135    code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
136    code.push_str(&format!("pub struct {} {{\n", struct_name));
137
138    // Generate fields
139    for col in &table.columns {
140        let field_name = escape_field_name(&col.name);
141        let rust_type = TypeResolver::resolve(col, &table.name);
142
143        // Add field documentation
144        code.push_str(&format!("    /// Column: `{}`", col.name));
145
146        // Add index info
147        let index_info = get_index_info(table, &col.name);
148        if !index_info.is_empty() {
149            code.push_str(&format!(" ({})", index_info.join(", ")));
150        }
151
152        if let Some(comment) = &col.comment {
153            if !comment.is_empty() {
154                code.push_str(&format!(" - {}", comment));
155            }
156        }
157        code.push('\n');
158
159        // Add rdbi attributes
160        let mut attrs = Vec::new();
161
162        // Add rename attribute if field name differs from column name
163        if field_name != col.name {
164            attrs.push(format!("rename = \"{}\"", col.name));
165        }
166
167        // Add skip_insert for auto-increment columns
168        if col.is_auto_increment {
169            attrs.push("skip_insert".to_string());
170        }
171
172        if !attrs.is_empty() {
173            code.push_str(&format!("    #[rdbi({})]\n", attrs.join(", ")));
174        }
175
176        // Add serde rename attribute if field name differs from column name
177        // This is especially important for raw identifiers (r#type -> "type")
178        if field_name != col.name {
179            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", col.name));
180        }
181
182        code.push_str(&format!(
183            "    pub {}: {},\n",
184            field_name,
185            rust_type.to_type_string()
186        ));
187    }
188
189    code.push_str("}\n");
190
191    // Generate SortBy enum for pagination
192    code.push('\n');
193    code.push_str(&generate_sort_by_enum(table));
194
195    let file_path = output_dir.join(&file_name);
196    fs::write(&file_path, code)?;
197    Ok(())
198}
199
200/// Generate SortBy enum for a table (used in pagination)
201fn generate_sort_by_enum(table: &TableMetadata) -> String {
202    let struct_name = to_struct_name(&table.name);
203    let enum_name = format!("{}SortBy", struct_name);
204
205    let mut code = String::new();
206
207    code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
208    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
209    code.push_str(&format!("pub enum {} {{\n", enum_name));
210
211    for col in &table.columns {
212        let variant = heck::AsPascalCase(&col.name).to_string();
213        code.push_str(&format!("    {},\n", variant));
214    }
215
216    code.push_str("}\n\n");
217
218    // Generate as_sql impl
219    code.push_str(&format!("impl {} {{\n", enum_name));
220    code.push_str("    pub fn as_sql(&self) -> &'static str {\n");
221    code.push_str("        match self {\n");
222
223    for col in &table.columns {
224        let variant = heck::AsPascalCase(&col.name).to_string();
225        code.push_str(&format!(
226            "            Self::{} => \"`{}`\",\n",
227            variant, col.name
228        ));
229    }
230
231    code.push_str("        }\n");
232    code.push_str("    }\n");
233    code.push_str("}\n");
234
235    code
236}
237
238/// Generate an enum type for a column
239fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
240    let enum_name = to_enum_name(table_name, &column.name);
241    let mut code = String::new();
242
243    // Add documentation
244    code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
245
246    // Generate enum with derives (for rdbi, we need to implement FromValue and ToValue)
247    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
248    code.push_str(&format!("pub enum {} {{\n", enum_name));
249
250    // Track used variant names to avoid duplicates
251    let mut used_variants: HashSet<String> = HashSet::new();
252    let mut variant_mappings: Vec<(String, String)> = Vec::new();
253
254    for value in values {
255        let variant = to_enum_variant(value);
256
257        // Handle duplicate variants (shouldn't happen but be safe)
258        let final_variant = if used_variants.contains(&variant) {
259            let mut counter = 2;
260            loop {
261                let new_variant = format!("{}{}", variant, counter);
262                if !used_variants.contains(&new_variant) {
263                    break new_variant;
264                }
265                counter += 1;
266            }
267        } else {
268            variant
269        };
270
271        used_variants.insert(final_variant.clone());
272
273        // Clean the value (remove quotes)
274        let clean_value = value.trim_matches('\'').trim_matches('"');
275
276        // Add serde rename attribute if variant differs from original value
277        if final_variant != clean_value {
278            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", clean_value));
279        }
280
281        code.push_str(&format!("    {},\n", final_variant));
282        variant_mappings.push((final_variant, clean_value.to_string()));
283    }
284
285    code.push_str("}\n\n");
286
287    // Generate FromValue implementation for rdbi
288    code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
289    code.push_str("    fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
290    code.push_str("        match value {\n");
291    code.push_str("            rdbi::Value::String(s) => match s.as_str() {\n");
292    for (variant, db_value) in &variant_mappings {
293        code.push_str(&format!(
294            "                \"{}\" => Ok(Self::{}),\n",
295            db_value, variant
296        ));
297    }
298    code.push_str(&format!(
299        "                _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
300        enum_name
301    ));
302    code.push_str("            },\n");
303    code.push_str(&format!(
304        "            _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
305        enum_name
306    ));
307    code.push_str("        }\n");
308    code.push_str("    }\n");
309    code.push_str("}\n\n");
310
311    // Generate ToValue implementation for rdbi
312    code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
313    code.push_str("    fn to_value(&self) -> rdbi::Value {\n");
314    code.push_str("        rdbi::Value::String(match self {\n");
315    for (variant, db_value) in &variant_mappings {
316        code.push_str(&format!(
317            "            Self::{} => \"{}\".to_string(),\n",
318            variant, db_value
319        ));
320    }
321    code.push_str("        })\n");
322    code.push_str("    }\n");
323    code.push_str("}\n");
324
325    code
326}
327
328/// Get index information for a column
329fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
330    let mut info = Vec::new();
331
332    // Check primary key
333    if let Some(pk) = &table.primary_key {
334        if pk.columns.contains(&column_name.to_string()) {
335            info.push("PRIMARY KEY".to_string());
336        }
337    }
338
339    // Check indexes
340    for index in &table.indexes {
341        if index.columns.contains(&column_name.to_string()) {
342            let label = if index.unique {
343                format!("UNIQUE: {}", index.name)
344            } else {
345                format!("INDEX: {}", index.name)
346            };
347            info.push(label);
348        }
349    }
350
351    info
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::parser::{IndexMetadata, PrimaryKey};
358
359    fn make_table() -> TableMetadata {
360        TableMetadata {
361            name: "users".to_string(),
362            comment: None,
363            columns: vec![
364                ColumnMetadata {
365                    name: "id".to_string(),
366                    data_type: "BIGINT".to_string(),
367                    nullable: false,
368                    default_value: None,
369                    is_auto_increment: true,
370                    is_unsigned: false,
371                    enum_values: None,
372                    comment: None,
373                },
374                ColumnMetadata {
375                    name: "username".to_string(),
376                    data_type: "VARCHAR(255)".to_string(),
377                    nullable: false,
378                    default_value: None,
379                    is_auto_increment: false,
380                    is_unsigned: false,
381                    enum_values: None,
382                    comment: None,
383                },
384                ColumnMetadata {
385                    name: "status".to_string(),
386                    data_type: "ENUM".to_string(),
387                    nullable: false,
388                    default_value: None,
389                    is_auto_increment: false,
390                    is_unsigned: false,
391                    enum_values: Some(vec![
392                        "ACTIVE".to_string(),
393                        "INACTIVE".to_string(),
394                        "PENDING".to_string(),
395                    ]),
396                    comment: None,
397                },
398            ],
399            indexes: vec![IndexMetadata {
400                name: "idx_username".to_string(),
401                columns: vec!["username".to_string()],
402                unique: true,
403            }],
404            foreign_keys: vec![],
405            primary_key: Some(PrimaryKey {
406                columns: vec!["id".to_string()],
407            }),
408        }
409    }
410
411    #[test]
412    fn test_get_index_info() {
413        let table = make_table();
414        let info = get_index_info(&table, "id");
415        assert!(info.contains(&"PRIMARY KEY".to_string()));
416
417        let info = get_index_info(&table, "username");
418        assert!(info.iter().any(|i| i.contains("UNIQUE")));
419    }
420
421    #[test]
422    fn test_generate_enum() {
423        let col = ColumnMetadata {
424            name: "status".to_string(),
425            data_type: "ENUM".to_string(),
426            nullable: false,
427            default_value: None,
428            is_auto_increment: false,
429            is_unsigned: false,
430            enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
431            comment: None,
432        };
433
434        let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
435        assert!(code.contains("pub enum UsersStatus"));
436        assert!(code.contains("Active"));
437        assert!(code.contains("Inactive"));
438        // Check for rdbi trait implementations
439        assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
440        assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
441    }
442}