Skip to main content

rdbi_codegen/codegen/
struct_generator.rs

1//! Struct generator - generates Rust structs from table metadata
2
3use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15/// Generate struct files for all tables
16pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17    let output_dir = &config.output_structs_dir;
18    fs::create_dir_all(output_dir)?;
19
20    // Generate mod.rs with shared pagination types
21    let mut mod_content = String::new();
22    mod_content.push_str("// Generated model structs\n\n");
23
24    for table in tables {
25        let file_name = heck::AsSnakeCase(&table.name).to_string();
26        mod_content.push_str(&format!("mod {};\n", file_name));
27        mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28    }
29
30    // Add shared pagination types
31    mod_content.push('\n');
32    mod_content.push_str(&generate_shared_pagination_types());
33
34    fs::write(output_dir.join("mod.rs"), mod_content)?;
35
36    // Generate each struct file
37    for table in tables {
38        generate_struct_file(table, output_dir)?;
39    }
40
41    Ok(())
42}
43
44/// Generate shared pagination types (SortDirection, PaginatedResult)
45fn generate_shared_pagination_types() -> String {
46    r#"/// Sort direction for pagination
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SortDirection {
49    Asc,
50    Desc,
51}
52
53impl SortDirection {
54    pub fn as_sql(&self) -> &'static str {
55        match self {
56            Self::Asc => "ASC",
57            Self::Desc => "DESC",
58        }
59    }
60}
61
62/// Paginated result container
63#[derive(Debug, Clone)]
64pub struct PaginatedResult<T> {
65    pub items: Vec<T>,
66    pub total_count: i64,
67    pub current_page: i32,
68    pub total_pages: i32,
69    pub page_size: i32,
70    pub has_next: bool,
71}
72
73impl<T> PaginatedResult<T> {
74    pub fn new(
75        items: Vec<T>,
76        total_count: i64,
77        current_page: i32,
78        page_size: i32,
79    ) -> Self {
80        let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
81        let has_next = current_page < total_pages;
82        Self {
83            items,
84            total_count,
85            current_page,
86            total_pages,
87            page_size,
88            has_next,
89        }
90    }
91}
92"#
93    .to_string()
94}
95
96/// Generate a single struct file for a table
97fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
98    let struct_name = to_struct_name(&table.name);
99    let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
100    debug!("Generating struct {} -> {}", struct_name, file_name);
101
102    let mut code = String::new();
103
104    // Collect enum columns
105    let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
106    for col in &table.columns {
107        if col.is_enum() {
108            enum_columns.push(col);
109        }
110    }
111
112    // Generate imports (serde and rdbi)
113    code.push_str("use serde::{Deserialize, Serialize};\n");
114
115    code.push('\n');
116
117    // Generate enum types first
118    for col in &enum_columns {
119        if let Some(values) = &col.enum_values {
120            code.push_str(&generate_enum(&table.name, col, values));
121            code.push('\n');
122        }
123    }
124
125    // Generate struct documentation
126    code.push_str(&format!("/// Database table: `{}`\n", table.name));
127    if let Some(comment) = &table.comment {
128        if !comment.is_empty() {
129            code.push_str(&format!("///\n/// {}\n", comment));
130        }
131    }
132
133    // Generate struct with derives (rdbi::FromRow and rdbi::ToParams)
134    code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
135    code.push_str(&format!("pub struct {} {{\n", struct_name));
136
137    // Generate fields
138    for col in &table.columns {
139        let field_name = escape_field_name(&col.name);
140        let rust_type = TypeResolver::resolve(col, &table.name);
141
142        // Add field documentation
143        code.push_str(&format!("    /// Column: `{}`", col.name));
144
145        // Add index info
146        let index_info = get_index_info(table, &col.name);
147        if !index_info.is_empty() {
148            code.push_str(&format!(" ({})", index_info.join(", ")));
149        }
150
151        if let Some(comment) = &col.comment {
152            if !comment.is_empty() {
153                code.push_str(&format!(" - {}", comment));
154            }
155        }
156        code.push('\n');
157
158        // Add rdbi attributes
159        let mut attrs = Vec::new();
160
161        // Add rename attribute if field name differs from column name
162        if field_name != col.name {
163            attrs.push(format!("rename = \"{}\"", col.name));
164        }
165
166        // Add skip_insert for auto-increment columns
167        if col.is_auto_increment {
168            attrs.push("skip_insert".to_string());
169        }
170
171        if !attrs.is_empty() {
172            code.push_str(&format!("    #[rdbi({})]\n", attrs.join(", ")));
173        }
174
175        // Add serde rename attribute if field name differs from column name
176        // This is especially important for raw identifiers (r#type -> "type")
177        if field_name != col.name {
178            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", col.name));
179        }
180
181        code.push_str(&format!(
182            "    pub {}: {},\n",
183            field_name,
184            rust_type.to_type_string()
185        ));
186    }
187
188    code.push_str("}\n");
189
190    // Generate SortBy enum for pagination
191    code.push('\n');
192    code.push_str(&generate_sort_by_enum(table));
193
194    fs::write(output_dir.join(&file_name), code)?;
195    Ok(())
196}
197
198/// Generate SortBy enum for a table (used in pagination)
199fn generate_sort_by_enum(table: &TableMetadata) -> String {
200    let struct_name = to_struct_name(&table.name);
201    let enum_name = format!("{}SortBy", struct_name);
202
203    let mut code = String::new();
204
205    code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
206    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
207    code.push_str(&format!("pub enum {} {{\n", enum_name));
208
209    for col in &table.columns {
210        let variant = heck::AsPascalCase(&col.name).to_string();
211        code.push_str(&format!("    {},\n", variant));
212    }
213
214    code.push_str("}\n\n");
215
216    // Generate as_sql impl
217    code.push_str(&format!("impl {} {{\n", enum_name));
218    code.push_str("    pub fn as_sql(&self) -> &'static str {\n");
219    code.push_str("        match self {\n");
220
221    for col in &table.columns {
222        let variant = heck::AsPascalCase(&col.name).to_string();
223        code.push_str(&format!(
224            "            Self::{} => \"`{}`\",\n",
225            variant, col.name
226        ));
227    }
228
229    code.push_str("        }\n");
230    code.push_str("    }\n");
231    code.push_str("}\n");
232
233    code
234}
235
236/// Generate an enum type for a column
237fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
238    let enum_name = to_enum_name(table_name, &column.name);
239    let mut code = String::new();
240
241    // Add documentation
242    code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
243
244    // Generate enum with derives (for rdbi, we need to implement FromValue and ToValue)
245    code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
246    code.push_str(&format!("pub enum {} {{\n", enum_name));
247
248    // Track used variant names to avoid duplicates
249    let mut used_variants: HashSet<String> = HashSet::new();
250    let mut variant_mappings: Vec<(String, String)> = Vec::new();
251
252    for value in values {
253        let variant = to_enum_variant(value);
254
255        // Handle duplicate variants (shouldn't happen but be safe)
256        let final_variant = if used_variants.contains(&variant) {
257            let mut counter = 2;
258            loop {
259                let new_variant = format!("{}{}", variant, counter);
260                if !used_variants.contains(&new_variant) {
261                    break new_variant;
262                }
263                counter += 1;
264            }
265        } else {
266            variant
267        };
268
269        used_variants.insert(final_variant.clone());
270
271        // Clean the value (remove quotes)
272        let clean_value = value.trim_matches('\'').trim_matches('"');
273
274        // Add serde rename attribute if variant differs from original value
275        if final_variant != clean_value {
276            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", clean_value));
277        }
278
279        code.push_str(&format!("    {},\n", final_variant));
280        variant_mappings.push((final_variant, clean_value.to_string()));
281    }
282
283    code.push_str("}\n\n");
284
285    // Generate FromValue implementation for rdbi
286    code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
287    code.push_str("    fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
288    code.push_str("        match value {\n");
289    code.push_str("            rdbi::Value::String(s) => match s.as_str() {\n");
290    for (variant, db_value) in &variant_mappings {
291        code.push_str(&format!(
292            "                \"{}\" => Ok(Self::{}),\n",
293            db_value, variant
294        ));
295    }
296    code.push_str(&format!(
297        "                _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
298        enum_name
299    ));
300    code.push_str("            },\n");
301    code.push_str(&format!(
302        "            _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
303        enum_name
304    ));
305    code.push_str("        }\n");
306    code.push_str("    }\n");
307    code.push_str("}\n\n");
308
309    // Generate ToValue implementation for rdbi
310    code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
311    code.push_str("    fn to_value(&self) -> rdbi::Value {\n");
312    code.push_str("        rdbi::Value::String(match self {\n");
313    for (variant, db_value) in &variant_mappings {
314        code.push_str(&format!(
315            "            Self::{} => \"{}\".to_string(),\n",
316            variant, db_value
317        ));
318    }
319    code.push_str("        })\n");
320    code.push_str("    }\n");
321    code.push_str("}\n");
322
323    code
324}
325
326/// Get index information for a column
327fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
328    let mut info = Vec::new();
329
330    // Check primary key
331    if let Some(pk) = &table.primary_key {
332        if pk.columns.contains(&column_name.to_string()) {
333            info.push("PRIMARY KEY".to_string());
334        }
335    }
336
337    // Check indexes
338    for index in &table.indexes {
339        if index.columns.contains(&column_name.to_string()) {
340            let label = if index.unique {
341                format!("UNIQUE: {}", index.name)
342            } else {
343                format!("INDEX: {}", index.name)
344            };
345            info.push(label);
346        }
347    }
348
349    info
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::parser::{IndexMetadata, PrimaryKey};
356
357    fn make_table() -> TableMetadata {
358        TableMetadata {
359            name: "users".to_string(),
360            comment: None,
361            columns: vec![
362                ColumnMetadata {
363                    name: "id".to_string(),
364                    data_type: "BIGINT".to_string(),
365                    nullable: false,
366                    default_value: None,
367                    is_auto_increment: true,
368                    is_unsigned: false,
369                    enum_values: None,
370                    comment: None,
371                },
372                ColumnMetadata {
373                    name: "username".to_string(),
374                    data_type: "VARCHAR(255)".to_string(),
375                    nullable: false,
376                    default_value: None,
377                    is_auto_increment: false,
378                    is_unsigned: false,
379                    enum_values: None,
380                    comment: None,
381                },
382                ColumnMetadata {
383                    name: "status".to_string(),
384                    data_type: "ENUM".to_string(),
385                    nullable: false,
386                    default_value: None,
387                    is_auto_increment: false,
388                    is_unsigned: false,
389                    enum_values: Some(vec![
390                        "ACTIVE".to_string(),
391                        "INACTIVE".to_string(),
392                        "PENDING".to_string(),
393                    ]),
394                    comment: None,
395                },
396            ],
397            indexes: vec![IndexMetadata {
398                name: "idx_username".to_string(),
399                columns: vec!["username".to_string()],
400                unique: true,
401            }],
402            foreign_keys: vec![],
403            primary_key: Some(PrimaryKey {
404                columns: vec!["id".to_string()],
405            }),
406        }
407    }
408
409    #[test]
410    fn test_get_index_info() {
411        let table = make_table();
412        let info = get_index_info(&table, "id");
413        assert!(info.contains(&"PRIMARY KEY".to_string()));
414
415        let info = get_index_info(&table, "username");
416        assert!(info.iter().any(|i| i.contains("UNIQUE")));
417    }
418
419    #[test]
420    fn test_generate_enum() {
421        let col = ColumnMetadata {
422            name: "status".to_string(),
423            data_type: "ENUM".to_string(),
424            nullable: false,
425            default_value: None,
426            is_auto_increment: false,
427            is_unsigned: false,
428            enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
429            comment: None,
430        };
431
432        let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
433        assert!(code.contains("pub enum UsersStatus"));
434        assert!(code.contains("Active"));
435        assert!(code.contains("Inactive"));
436        // Check for rdbi trait implementations
437        assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
438        assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
439    }
440}