Skip to main content

qail_core/
codegen.rs

1//! Type-safe schema code generation.
2//!
3//! Generates Rust code from schema.qail for compile-time type safety.
4//!
5//! # Usage from build.rs
6//! ```ignore
7//! qail_core::codegen::generate_to_file("schema.qail", "src/generated/schema.rs")?;
8//! ```
9//!
10//! # Generated code example
11//! ```ignore
12//! pub mod users {
13//!     use qail_core::typed::{Table, TypedColumn};
14//!     
15//!     pub struct Users;
16//!     impl Table for Users { fn table_name() -> &'static str { "users" } }
17//!     
18//!     pub fn id() -> TypedColumn<uuid::Uuid> { TypedColumn::new("users", "id") }
19//!     pub fn age() -> TypedColumn<i64> { TypedColumn::new("users", "age") }
20//! }
21//! ```
22
23use crate::build::Schema;
24use std::fs;
25
26/// Generate typed Rust code from a schema.qail file and write to output
27pub fn generate_to_file(schema_path: &str, output_path: &str) -> Result<(), String> {
28    let schema = Schema::parse_file(schema_path)?;
29    let code = generate_schema_code(&schema);
30    fs::write(output_path, &code)
31        .map_err(|e| format!("Failed to write output: {}", e))?;
32    Ok(())
33}
34
35/// Generate typed Rust code from a schema.qail file
36pub fn generate_from_file(schema_path: &str) -> Result<String, String> {
37    let schema = Schema::parse_file(schema_path)?;
38    Ok(generate_schema_code(&schema))
39}
40
41/// Generate Rust code for the schema
42pub fn generate_schema_code(schema: &Schema) -> String {
43    let mut code = String::new();
44    
45    // Header
46    code.push_str("//! Auto-generated by `qail types`\n");
47    code.push_str("//! Do not edit manually.\n\n");
48    code.push_str("#![allow(dead_code)]\n\n");
49    code.push_str("use qail_core::typed::{Table, TypedColumn, RequiresRls, DirectBuild, Bucket, Queue, Topic};\n\n");
50    
51    // Generate table modules
52    let mut table_names: Vec<_> = schema.tables.keys().collect();
53    table_names.sort();
54    
55    for table_name in &table_names {
56        if let Some(table) = schema.tables.get(*table_name) {
57            code.push_str(&generate_table_module(table_name, table));
58            code.push('\n');
59        }
60    }
61    
62    // Generate tables re-export
63    code.push_str("/// Re-export all table types\n");
64    code.push_str("pub mod tables {\n");
65    
66    for table_name in &table_names {
67        let struct_name = to_pascal_case(table_name);
68        code.push_str(&format!(
69            "    pub use super::{}::{};\n",
70            table_name, struct_name
71        ));
72    }
73    code.push_str("}\n\n");
74    
75    // Generate resource modules
76    let mut resource_names: Vec<_> = schema.resources.keys().collect();
77    resource_names.sort();
78    
79    for res_name in &resource_names {
80        if let Some(resource) = schema.resources.get(*res_name) {
81            code.push_str(&generate_resource_module(res_name, resource));
82            code.push('\n');
83        }
84    }
85    
86    // Generate resources re-export
87    if !resource_names.is_empty() {
88        code.push_str("/// Re-export all resource types\n");
89        code.push_str("pub mod resources {\n");
90        for res_name in &resource_names {
91            let struct_name = to_pascal_case(res_name);
92            code.push_str(&format!(
93                "    pub use super::{}::{};\n",
94                res_name, struct_name
95            ));
96        }
97        code.push_str("}\n");
98    }
99    
100    code
101}
102
103/// Generate a module for an infrastructure resource
104fn generate_resource_module(resource_name: &str, resource: &crate::build::ResourceSchema) -> String {
105    let mut code = String::new();
106    let struct_name = to_pascal_case(resource_name);
107    let kind = &resource.kind;
108    
109    code.push_str(&format!("/// {} resource: {}\n", kind, resource_name));
110    code.push_str(&format!("pub mod {} {{\n", resource_name));
111    code.push_str("    use super::*;\n\n");
112    
113    // Struct
114    code.push_str(&format!("    /// Type-safe reference to {} `{}`\n", kind, resource_name));
115    code.push_str("    #[derive(Debug, Clone, Copy, Default)]\n");
116    code.push_str(&format!("    pub struct {};\n\n", struct_name));
117    
118    // Implement the appropriate trait
119    let (trait_name, method_name) = match kind.as_str() {
120        "bucket" => ("Bucket", "bucket_name"),
121        "queue" => ("Queue", "queue_name"),
122        "topic" => ("Topic", "topic_name"),
123        _ => ("Bucket", "bucket_name"), // fallback
124    };
125    
126    code.push_str(&format!("    impl {} for {} {{\n", trait_name, struct_name));
127    code.push_str(&format!(
128        "        fn {}() -> &'static str {{ \"{}\" }}\n",
129        method_name, resource_name
130    ));
131    code.push_str("    }\n");
132    
133    // Add provider constant if specified
134    if let Some(ref provider) = resource.provider {
135        code.push_str(&format!("\n    pub const PROVIDER: &str = \"{}\";\n", provider));
136    }
137    
138    // Add property constants
139    for (key, value) in &resource.properties {
140        let const_name = key.to_uppercase();
141        code.push_str(&format!("    pub const {}: &str = \"{}\";\n", const_name, value));
142    }
143    
144    code.push_str("}\n");
145    code
146}
147
148fn generate_table_module(table_name: &str, table: &crate::build::TableSchema) -> String {
149    let mut code = String::new();
150    let struct_name = to_pascal_case(table_name);
151    
152    code.push_str(&format!("/// Table: {}\n", table_name));
153    code.push_str(&format!("pub mod {} {{\n", table_name));
154    code.push_str("    use super::*;\n\n");
155    
156    // Table struct with Table trait
157    code.push_str(&format!("    /// Type-safe reference to `{}`\n", table_name));
158    code.push_str("    #[derive(Debug, Clone, Copy, Default)]\n");
159    code.push_str(&format!("    pub struct {};\n\n", struct_name));
160    
161    code.push_str(&format!("    impl Table for {} {{\n", struct_name));
162    code.push_str(&format!(
163        "        fn table_name() -> &'static str {{ \"{}\" }}\n",
164        table_name
165    ));
166    code.push_str("    }\n\n");
167    
168    // Implement From<Table> for String to work with Qail::get()
169    code.push_str(&format!("    impl From<{}> for String {{\n", struct_name));
170    code.push_str(&format!("        fn from(_: {}) -> String {{ \"{}\".to_string() }}\n", struct_name, table_name));
171    code.push_str("    }\n\n");
172    
173    // AsRef<str> for TypedQail compatibility
174    code.push_str(&format!("    impl AsRef<str> for {} {{\n", struct_name));
175    code.push_str(&format!("        fn as_ref(&self) -> &str {{ \"{}\" }}\n", table_name));
176    code.push_str("    }\n\n");
177    
178    // RLS trait: RequiresRls for tables with operator_id, DirectBuild for others
179    if table.rls_enabled {
180        code.push_str("    /// This table has `operator_id` — queries require `.with_rls()` proof\n");
181        code.push_str(&format!("    impl RequiresRls for {} {{}}\n\n", struct_name));
182    } else {
183        code.push_str(&format!("    impl DirectBuild for {} {{}}\n\n", struct_name));
184    }
185    
186    // Typed column functions
187    let mut col_names: Vec<_> = table.columns.keys().collect();
188    col_names.sort();
189    
190    for col_name in &col_names {
191        if let Some(col_type) = table.columns.get(*col_name) {
192            let rust_type = sql_type_to_rust(col_type);
193            let fn_name = escape_keyword(col_name);
194            code.push_str(&format!(
195                "    /// Column `{}` ({})\n",
196                col_name, col_type
197            ));
198            code.push_str(&format!(
199                "    pub fn {}() -> TypedColumn<{}> {{ TypedColumn::new(\"{}\", \"{}\") }}\n\n",
200                fn_name, rust_type, table_name, col_name
201            ));
202        }
203    }
204    
205    code.push_str("}\n");
206    
207    code
208}
209
210/// Map SQL types to Rust types
211pub fn sql_type_to_rust(sql_type: &str) -> &'static str {
212    let upper = sql_type.to_uppercase();
213    
214    // Integer family
215    if upper.contains("BIGINT") || upper.contains("INT8") || upper.contains("BIGSERIAL") {
216        return "i64";
217    }
218    if upper.contains("INT") || upper.contains("SERIAL") {
219        return "i64";  // Use i64 for all ints for simplicity
220    }
221    
222    // Float family
223    if upper.contains("FLOAT") || upper.contains("DOUBLE") || 
224       upper.contains("DECIMAL") || upper.contains("NUMERIC") || upper.contains("REAL") {
225        return "f64";
226    }
227    
228    // Boolean
229    if upper.contains("BOOL") {
230        return "bool";
231    }
232    
233    // UUID
234    if upper.contains("UUID") {
235        return "uuid::Uuid";
236    }
237    
238    // Text family
239    if upper.contains("TEXT") || upper.contains("VARCHAR") || 
240       upper.contains("CHAR") || upper.contains("NAME") {
241        return "String";
242    }
243    
244    // JSON
245    if upper.contains("JSON") {
246        return "serde_json::Value";
247    }
248    
249    // Timestamp
250    if upper.contains("TIMESTAMP") || upper.contains("DATE") || upper.contains("TIME") {
251        return "chrono::DateTime<chrono::Utc>";
252    }
253    
254    // Bytea
255    if upper.contains("BYTEA") || upper.contains("BLOB") {
256        return "Vec<u8>";
257    }
258    
259    // Default to String for unknown types
260    "String"
261}
262
263/// Convert snake_case to PascalCase
264fn to_pascal_case(s: &str) -> String {
265    s.split('_')
266        .map(|word| {
267            let mut chars = word.chars();
268            match chars.next() {
269                None => String::new(),
270                Some(c) => c.to_uppercase().chain(chars).collect(),
271            }
272        })
273        .collect()
274}
275
276/// Escape Rust reserved keywords with r# prefix
277fn escape_keyword(name: &str) -> String {
278    const KEYWORDS: &[&str] = &[
279        "as", "break", "const", "continue", "crate", "else", "enum", "extern",
280        "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
281        "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
282        "super", "trait", "true", "type", "unsafe", "use", "where", "while",
283        "async", "await", "dyn", "abstract", "become", "box", "do", "final",
284        "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
285    ];
286    
287    if KEYWORDS.contains(&name) {
288        format!("r#{}", name)
289    } else {
290        name.to_string()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    
298    #[test]
299    fn test_pascal_case() {
300        assert_eq!(to_pascal_case("users"), "Users");
301        assert_eq!(to_pascal_case("user_profiles"), "UserProfiles");
302    }
303    
304    #[test]
305    fn test_sql_type_mapping() {
306        assert_eq!(sql_type_to_rust("INT"), "i64");
307        assert_eq!(sql_type_to_rust("TEXT"), "String");
308        assert_eq!(sql_type_to_rust("UUID"), "uuid::Uuid");
309        assert_eq!(sql_type_to_rust("BOOLEAN"), "bool");
310        assert_eq!(sql_type_to_rust("JSONB"), "serde_json::Value");
311    }
312}