Skip to main content

qail_core/
codegen.rs

1//! Type-safe schema code generation.
2//!
3//! Generates Rust code from schema.qail for compile-time type safety.
4//!
5//! # Usage from build.rs
6//! ```ignore
7//! qail_core::codegen::generate_to_file("schema.qail", "src/generated/schema.rs")?;
8//! ```
9//!
10//! # Generated code example
11//! ```ignore
12//! pub mod users {
13//!     use qail_core::typed::{Table, TypedColumn};
14//!     
15//!     pub struct Users;
16//!     impl Table for Users { fn table_name() -> &'static str { "users" } }
17//!     
18//!     pub fn id() -> TypedColumn<uuid::Uuid> { TypedColumn::new("users", "id") }
19//!     pub fn age() -> TypedColumn<i64> { TypedColumn::new("users", "age") }
20//! }
21//! ```
22
23use crate::build::Schema;
24use std::fs;
25
26/// Generate typed Rust code from a schema.qail file and write to output
27pub fn generate_to_file(schema_path: &str, output_path: &str) -> Result<(), String> {
28    let schema = Schema::parse_file(schema_path)?;
29    let code = generate_schema_code(&schema);
30    fs::write(output_path, &code)
31        .map_err(|e| format!("Failed to write output: {}", e))?;
32    Ok(())
33}
34
35/// Generate typed Rust code from a schema.qail file
36pub fn generate_from_file(schema_path: &str) -> Result<String, String> {
37    let schema = Schema::parse_file(schema_path)?;
38    Ok(generate_schema_code(&schema))
39}
40
41/// Generate Rust code for the schema
42pub fn generate_schema_code(schema: &Schema) -> String {
43    let mut code = String::new();
44    
45    // Header
46    code.push_str("//! Auto-generated by `qail types`\n");
47    code.push_str("//! Do not edit manually.\n\n");
48    code.push_str("#![allow(dead_code)]\n\n");
49    code.push_str("use qail_core::typed::{Table, TypedColumn, RequiresRls, DirectBuild};\n\n");
50    
51    // Generate table modules
52    let mut table_names: Vec<_> = schema.tables.keys().collect();
53    table_names.sort();
54    
55    for table_name in &table_names {
56        if let Some(table) = schema.tables.get(*table_name) {
57            code.push_str(&generate_table_module(table_name, table));
58            code.push('\n');
59        }
60    }
61    
62    // Generate tables re-export
63    code.push_str("/// Re-export all table types\n");
64    code.push_str("pub mod tables {\n");
65    
66    for table_name in &table_names {
67        let struct_name = to_pascal_case(table_name);
68        code.push_str(&format!(
69            "    pub use super::{}::{};\n",
70            table_name, struct_name
71        ));
72    }
73    code.push_str("}\n");
74    
75    code
76}
77
78/// Generate a module for a single table
79fn generate_table_module(table_name: &str, table: &crate::build::TableSchema) -> String {
80    let mut code = String::new();
81    let struct_name = to_pascal_case(table_name);
82    
83    code.push_str(&format!("/// Table: {}\n", table_name));
84    code.push_str(&format!("pub mod {} {{\n", table_name));
85    code.push_str("    use super::*;\n\n");
86    
87    // Table struct with Table trait
88    code.push_str(&format!("    /// Type-safe reference to `{}`\n", table_name));
89    code.push_str("    #[derive(Debug, Clone, Copy, Default)]\n");
90    code.push_str(&format!("    pub struct {};\n\n", struct_name));
91    
92    code.push_str(&format!("    impl Table for {} {{\n", struct_name));
93    code.push_str(&format!(
94        "        fn table_name() -> &'static str {{ \"{}\" }}\n",
95        table_name
96    ));
97    code.push_str("    }\n\n");
98    
99    // Implement From<Table> for String to work with Qail::get()
100    code.push_str(&format!("    impl From<{}> for String {{\n", struct_name));
101    code.push_str(&format!("        fn from(_: {}) -> String {{ \"{}\".to_string() }}\n", struct_name, table_name));
102    code.push_str("    }\n\n");
103    
104    // AsRef<str> for TypedQail compatibility
105    code.push_str(&format!("    impl AsRef<str> for {} {{\n", struct_name));
106    code.push_str(&format!("        fn as_ref(&self) -> &str {{ \"{}\" }}\n", table_name));
107    code.push_str("    }\n\n");
108    
109    // RLS trait: RequiresRls for tables with operator_id, DirectBuild for others
110    if table.rls_enabled {
111        code.push_str(&format!("    /// This table has `operator_id` — queries require `.with_rls()` proof\n"));
112        code.push_str(&format!("    impl RequiresRls for {} {{}}\n\n", struct_name));
113    } else {
114        code.push_str(&format!("    impl DirectBuild for {} {{}}\n\n", struct_name));
115    }
116    
117    // Typed column functions
118    let mut col_names: Vec<_> = table.columns.keys().collect();
119    col_names.sort();
120    
121    for col_name in &col_names {
122        if let Some(col_type) = table.columns.get(*col_name) {
123            let rust_type = sql_type_to_rust(col_type);
124            let fn_name = escape_keyword(col_name);
125            code.push_str(&format!(
126                "    /// Column `{}` ({})\n",
127                col_name, col_type
128            ));
129            code.push_str(&format!(
130                "    pub fn {}() -> TypedColumn<{}> {{ TypedColumn::new(\"{}\", \"{}\") }}\n\n",
131                fn_name, rust_type, table_name, col_name
132            ));
133        }
134    }
135    
136    code.push_str("}\n");
137    
138    code
139}
140
141/// Map SQL types to Rust types
142pub fn sql_type_to_rust(sql_type: &str) -> &'static str {
143    let upper = sql_type.to_uppercase();
144    
145    // Integer family
146    if upper.contains("BIGINT") || upper.contains("INT8") || upper.contains("BIGSERIAL") {
147        return "i64";
148    }
149    if upper.contains("INT") || upper.contains("SERIAL") {
150        return "i64";  // Use i64 for all ints for simplicity
151    }
152    
153    // Float family
154    if upper.contains("FLOAT") || upper.contains("DOUBLE") || 
155       upper.contains("DECIMAL") || upper.contains("NUMERIC") || upper.contains("REAL") {
156        return "f64";
157    }
158    
159    // Boolean
160    if upper.contains("BOOL") {
161        return "bool";
162    }
163    
164    // UUID
165    if upper.contains("UUID") {
166        return "uuid::Uuid";
167    }
168    
169    // Text family
170    if upper.contains("TEXT") || upper.contains("VARCHAR") || 
171       upper.contains("CHAR") || upper.contains("NAME") {
172        return "String";
173    }
174    
175    // JSON
176    if upper.contains("JSON") {
177        return "serde_json::Value";
178    }
179    
180    // Timestamp
181    if upper.contains("TIMESTAMP") || upper.contains("DATE") || upper.contains("TIME") {
182        return "chrono::DateTime<chrono::Utc>";
183    }
184    
185    // Bytea
186    if upper.contains("BYTEA") || upper.contains("BLOB") {
187        return "Vec<u8>";
188    }
189    
190    // Default to String for unknown types
191    "String"
192}
193
194/// Convert snake_case to PascalCase
195fn to_pascal_case(s: &str) -> String {
196    s.split('_')
197        .map(|word| {
198            let mut chars = word.chars();
199            match chars.next() {
200                None => String::new(),
201                Some(c) => c.to_uppercase().chain(chars).collect(),
202            }
203        })
204        .collect()
205}
206
207/// Escape Rust reserved keywords with r# prefix
208fn escape_keyword(name: &str) -> String {
209    const KEYWORDS: &[&str] = &[
210        "as", "break", "const", "continue", "crate", "else", "enum", "extern",
211        "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
212        "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
213        "super", "trait", "true", "type", "unsafe", "use", "where", "while",
214        "async", "await", "dyn", "abstract", "become", "box", "do", "final",
215        "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
216    ];
217    
218    if KEYWORDS.contains(&name) {
219        format!("r#{}", name)
220    } else {
221        name.to_string()
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    
229    #[test]
230    fn test_pascal_case() {
231        assert_eq!(to_pascal_case("users"), "Users");
232        assert_eq!(to_pascal_case("user_profiles"), "UserProfiles");
233    }
234    
235    #[test]
236    fn test_sql_type_mapping() {
237        assert_eq!(sql_type_to_rust("INT"), "i64");
238        assert_eq!(sql_type_to_rust("TEXT"), "String");
239        assert_eq!(sql_type_to_rust("UUID"), "uuid::Uuid");
240        assert_eq!(sql_type_to_rust("BOOLEAN"), "bool");
241        assert_eq!(sql_type_to_rust("JSONB"), "serde_json::Value");
242    }
243}