Skip to main content

qail_core/
codegen.rs

1//! Type-safe schema code generation.
2//!
3//! Generates Rust code from schema.qail for compile-time type safety.
4//!
5//! # Usage from build.rs
6//! ```ignore
7//! qail_core::codegen::generate_to_file("schema.qail", "src/generated/schema.rs")?;
8//! ```
9//!
10//! # Generated code example
11//! ```ignore
12//! pub mod users {
13//!     use qail_core::typed::{Table, TypedColumn};
14//!     
15//!     pub struct Users;
16//!     impl Table for Users { fn table_name() -> &'static str { "users" } }
17//!     
18//!     pub fn id() -> TypedColumn<uuid::Uuid> { TypedColumn::new("users", "id") }
19//!     pub fn age() -> TypedColumn<i64> { TypedColumn::new("users", "age") }
20//! }
21//! ```
22
23use crate::build::Schema;
24use crate::migrate::types::ColumnType;
25use std::fs;
26
27/// Generate typed Rust code from a schema.qail file and write to output
28pub fn generate_to_file(schema_path: &str, output_path: &str) -> Result<(), String> {
29    let schema = Schema::parse_file(schema_path)?;
30    let code = generate_schema_code(&schema);
31    fs::write(output_path, &code).map_err(|e| format!("Failed to write output: {}", e))?;
32    Ok(())
33}
34
35/// Generate typed Rust code from a schema.qail file
36pub fn generate_from_file(schema_path: &str) -> Result<String, String> {
37    let schema = Schema::parse_file(schema_path)?;
38    Ok(generate_schema_code(&schema))
39}
40
41/// Generate Rust code for the schema
42pub fn generate_schema_code(schema: &Schema) -> String {
43    let mut code = String::new();
44
45    // Header
46    code.push_str("//! Auto-generated by `qail types`\n");
47    code.push_str("//! Do not edit manually.\n\n");
48    code.push_str("#![allow(dead_code)]\n\n");
49    code.push_str("use qail_core::typed::{Table, TypedColumn, RequiresRls, DirectBuild, Bucket, Queue, Topic};\n\n");
50
51    // Generate table modules
52    let mut table_names: Vec<_> = schema.tables.keys().collect();
53    table_names.sort();
54
55    for table_name in &table_names {
56        if let Some(table) = schema.tables.get(*table_name) {
57            code.push_str(&generate_table_module(table_name, table));
58            code.push('\n');
59        }
60    }
61
62    // Generate tables re-export
63    code.push_str("/// Re-export all table types\n");
64    code.push_str("pub mod tables {\n");
65
66    for table_name in &table_names {
67        let struct_name = to_pascal_case(table_name);
68        code.push_str(&format!(
69            "    pub use super::{}::{};\n",
70            table_name, struct_name
71        ));
72    }
73    code.push_str("}\n\n");
74
75    // Generate resource modules
76    let mut resource_names: Vec<_> = schema.resources.keys().collect();
77    resource_names.sort();
78
79    for res_name in &resource_names {
80        if let Some(resource) = schema.resources.get(*res_name) {
81            code.push_str(&generate_resource_module(res_name, resource));
82            code.push('\n');
83        }
84    }
85
86    // Generate resources re-export
87    if !resource_names.is_empty() {
88        code.push_str("/// Re-export all resource types\n");
89        code.push_str("pub mod resources {\n");
90        for res_name in &resource_names {
91            let struct_name = to_pascal_case(res_name);
92            code.push_str(&format!(
93                "    pub use super::{}::{};\n",
94                res_name, struct_name
95            ));
96        }
97        code.push_str("}\n");
98    }
99
100    code
101}
102
103/// Generate a module for an infrastructure resource
104fn generate_resource_module(
105    resource_name: &str,
106    resource: &crate::build::ResourceSchema,
107) -> String {
108    let mut code = String::new();
109    let struct_name = to_pascal_case(resource_name);
110    let kind = &resource.kind;
111
112    code.push_str(&format!("/// {} resource: {}\n", kind, resource_name));
113    code.push_str(&format!("pub mod {} {{\n", resource_name));
114    code.push_str("    use super::*;\n\n");
115
116    // Struct
117    code.push_str(&format!(
118        "    /// Type-safe reference to {} `{}`\n",
119        kind, resource_name
120    ));
121    code.push_str("    #[derive(Debug, Clone, Copy, Default)]\n");
122    code.push_str(&format!("    pub struct {};\n\n", struct_name));
123
124    // Implement the appropriate trait
125    let (trait_name, method_name) = match kind.as_str() {
126        "bucket" => ("Bucket", "bucket_name"),
127        "queue" => ("Queue", "queue_name"),
128        "topic" => ("Topic", "topic_name"),
129        _ => ("Bucket", "bucket_name"), // fallback
130    };
131
132    code.push_str(&format!("    impl {} for {} {{\n", trait_name, struct_name));
133    code.push_str(&format!(
134        "        fn {}() -> &'static str {{ \"{}\" }}\n",
135        method_name, resource_name
136    ));
137    code.push_str("    }\n");
138
139    // Add provider constant if specified
140    if let Some(ref provider) = resource.provider {
141        code.push_str(&format!(
142            "\n    pub const PROVIDER: &str = \"{}\";\n",
143            provider
144        ));
145    }
146
147    // Add property constants
148    for (key, value) in &resource.properties {
149        let const_name = key.to_uppercase();
150        code.push_str(&format!(
151            "    pub const {}: &str = \"{}\";\n",
152            const_name, value
153        ));
154    }
155
156    code.push_str("}\n");
157    code
158}
159
160fn generate_table_module(table_name: &str, table: &crate::build::TableSchema) -> String {
161    let mut code = String::new();
162    let struct_name = to_pascal_case(table_name);
163
164    code.push_str(&format!("/// Table: {}\n", table_name));
165    code.push_str(&format!("pub mod {} {{\n", table_name));
166    code.push_str("    use super::*;\n\n");
167
168    // Table struct with Table trait
169    code.push_str(&format!(
170        "    /// Type-safe reference to `{}`\n",
171        table_name
172    ));
173    code.push_str("    #[derive(Debug, Clone, Copy, Default)]\n");
174    code.push_str(&format!("    pub struct {};\n\n", struct_name));
175
176    code.push_str(&format!("    impl Table for {} {{\n", struct_name));
177    code.push_str(&format!(
178        "        fn table_name() -> &'static str {{ \"{}\" }}\n",
179        table_name
180    ));
181    code.push_str("    }\n\n");
182
183    // Implement From<Table> for String to work with Qail::get()
184    code.push_str(&format!("    impl From<{}> for String {{\n", struct_name));
185    code.push_str(&format!(
186        "        fn from(_: {}) -> String {{ \"{}\".to_string() }}\n",
187        struct_name, table_name
188    ));
189    code.push_str("    }\n\n");
190
191    // AsRef<str> for TypedQail compatibility
192    code.push_str(&format!("    impl AsRef<str> for {} {{\n", struct_name));
193    code.push_str(&format!(
194        "        fn as_ref(&self) -> &str {{ \"{}\" }}\n",
195        table_name
196    ));
197    code.push_str("    }\n\n");
198
199    // RLS trait: RequiresRls for tables with tenant_id, DirectBuild for others
200    if table.rls_enabled {
201        code.push_str("    /// This table has `tenant_id` — queries require `.with_rls()` proof\n");
202        code.push_str(&format!(
203            "    impl RequiresRls for {} {{}}\n\n",
204            struct_name
205        ));
206    } else {
207        code.push_str(&format!(
208            "    impl DirectBuild for {} {{}}\n\n",
209            struct_name
210        ));
211    }
212
213    // Typed column functions
214    let mut col_names: Vec<_> = table.columns.keys().collect();
215    col_names.sort();
216
217    for col_name in &col_names {
218        if let Some(col_type) = table.columns.get(*col_name) {
219            let rust_type = column_type_to_rust(col_type);
220            let fn_name = escape_keyword(col_name);
221            code.push_str(&format!(
222                "    /// Column `{}` ({})\n",
223                col_name,
224                col_type.to_pg_type()
225            ));
226            code.push_str(&format!(
227                "    pub fn {}() -> TypedColumn<{}> {{ TypedColumn::new(\"{}\", \"{}\") }}\n\n",
228                fn_name, rust_type, table_name, col_name
229            ));
230        }
231    }
232
233    code.push_str("}\n");
234
235    code
236}
237
238/// Map ColumnType AST enum to Rust types (for codegen).
239/// This is the ONLY place where we map SQL types to Rust types.
240fn column_type_to_rust(col_type: &ColumnType) -> &'static str {
241    match col_type {
242        ColumnType::Uuid => "uuid::Uuid",
243        ColumnType::Text | ColumnType::Varchar(_) => "String",
244        ColumnType::Int | ColumnType::BigInt | ColumnType::Serial | ColumnType::BigSerial => "i64",
245        ColumnType::Bool => "bool",
246        ColumnType::Float | ColumnType::Decimal(_) => "f64",
247        ColumnType::Jsonb => "serde_json::Value",
248        ColumnType::Timestamp | ColumnType::Timestamptz | ColumnType::Date | ColumnType::Time => {
249            "chrono::DateTime<chrono::Utc>"
250        }
251        ColumnType::Bytea => "Vec<u8>",
252        ColumnType::Array(_) => "Vec<serde_json::Value>",
253        ColumnType::Enum { .. } => "String",
254        ColumnType::Range(_) => "String",
255        ColumnType::Interval => "String",
256        ColumnType::Cidr | ColumnType::Inet => "String",
257        ColumnType::MacAddr => "String",
258    }
259}
260
261/// Convert snake_case to PascalCase
262fn to_pascal_case(s: &str) -> String {
263    s.split('_')
264        .map(|word| {
265            let mut chars = word.chars();
266            match chars.next() {
267                None => String::new(),
268                Some(c) => c.to_uppercase().chain(chars).collect(),
269            }
270        })
271        .collect()
272}
273
274/// Escape Rust reserved keywords with r# prefix
275fn escape_keyword(name: &str) -> String {
276    const KEYWORDS: &[&str] = &[
277        "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
278        "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
279        "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
280        "use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
281        "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
282    ];
283
284    if KEYWORDS.contains(&name) {
285        format!("r#{}", name)
286    } else {
287        name.to_string()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_pascal_case() {
297        assert_eq!(to_pascal_case("users"), "Users");
298        assert_eq!(to_pascal_case("user_profiles"), "UserProfiles");
299    }
300
301    #[test]
302    fn test_column_type_mapping() {
303        assert_eq!(column_type_to_rust(&ColumnType::Int), "i64");
304        assert_eq!(column_type_to_rust(&ColumnType::Text), "String");
305        assert_eq!(column_type_to_rust(&ColumnType::Uuid), "uuid::Uuid");
306        assert_eq!(column_type_to_rust(&ColumnType::Bool), "bool");
307        assert_eq!(column_type_to_rust(&ColumnType::Jsonb), "serde_json::Value");
308        assert_eq!(column_type_to_rust(&ColumnType::BigInt), "i64");
309        assert_eq!(column_type_to_rust(&ColumnType::Float), "f64");
310        assert_eq!(
311            column_type_to_rust(&ColumnType::Timestamp),
312            "chrono::DateTime<chrono::Utc>"
313        );
314        assert_eq!(column_type_to_rust(&ColumnType::Bytea), "Vec<u8>");
315    }
316}