1use crate::build::Schema;
24use std::fs;
25
26pub fn generate_to_file(schema_path: &str, output_path: &str) -> Result<(), String> {
28 let schema = Schema::parse_file(schema_path)?;
29 let code = generate_schema_code(&schema);
30 fs::write(output_path, &code)
31 .map_err(|e| format!("Failed to write output: {}", e))?;
32 Ok(())
33}
34
35pub fn generate_from_file(schema_path: &str) -> Result<String, String> {
37 let schema = Schema::parse_file(schema_path)?;
38 Ok(generate_schema_code(&schema))
39}
40
41pub fn generate_schema_code(schema: &Schema) -> String {
43 let mut code = String::new();
44
45 code.push_str("//! Auto-generated by `qail types`\n");
47 code.push_str("//! Do not edit manually.\n\n");
48 code.push_str("#![allow(dead_code)]\n\n");
49 code.push_str("use qail_core::typed::{Table, TypedColumn, RequiresRls, DirectBuild};\n\n");
50
51 let mut table_names: Vec<_> = schema.tables.keys().collect();
53 table_names.sort();
54
55 for table_name in &table_names {
56 if let Some(table) = schema.tables.get(*table_name) {
57 code.push_str(&generate_table_module(table_name, table));
58 code.push('\n');
59 }
60 }
61
62 code.push_str("/// Re-export all table types\n");
64 code.push_str("pub mod tables {\n");
65
66 for table_name in &table_names {
67 let struct_name = to_pascal_case(table_name);
68 code.push_str(&format!(
69 " pub use super::{}::{};\n",
70 table_name, struct_name
71 ));
72 }
73 code.push_str("}\n");
74
75 code
76}
77
78fn generate_table_module(table_name: &str, table: &crate::build::TableSchema) -> String {
80 let mut code = String::new();
81 let struct_name = to_pascal_case(table_name);
82
83 code.push_str(&format!("/// Table: {}\n", table_name));
84 code.push_str(&format!("pub mod {} {{\n", table_name));
85 code.push_str(" use super::*;\n\n");
86
87 code.push_str(&format!(" /// Type-safe reference to `{}`\n", table_name));
89 code.push_str(" #[derive(Debug, Clone, Copy, Default)]\n");
90 code.push_str(&format!(" pub struct {};\n\n", struct_name));
91
92 code.push_str(&format!(" impl Table for {} {{\n", struct_name));
93 code.push_str(&format!(
94 " fn table_name() -> &'static str {{ \"{}\" }}\n",
95 table_name
96 ));
97 code.push_str(" }\n\n");
98
99 code.push_str(&format!(" impl From<{}> for String {{\n", struct_name));
101 code.push_str(&format!(" fn from(_: {}) -> String {{ \"{}\".to_string() }}\n", struct_name, table_name));
102 code.push_str(" }\n\n");
103
104 code.push_str(&format!(" impl AsRef<str> for {} {{\n", struct_name));
106 code.push_str(&format!(" fn as_ref(&self) -> &str {{ \"{}\" }}\n", table_name));
107 code.push_str(" }\n\n");
108
109 if table.rls_enabled {
111 code.push_str(&format!(" /// This table has `operator_id` — queries require `.with_rls()` proof\n"));
112 code.push_str(&format!(" impl RequiresRls for {} {{}}\n\n", struct_name));
113 } else {
114 code.push_str(&format!(" impl DirectBuild for {} {{}}\n\n", struct_name));
115 }
116
117 let mut col_names: Vec<_> = table.columns.keys().collect();
119 col_names.sort();
120
121 for col_name in &col_names {
122 if let Some(col_type) = table.columns.get(*col_name) {
123 let rust_type = sql_type_to_rust(col_type);
124 let fn_name = escape_keyword(col_name);
125 code.push_str(&format!(
126 " /// Column `{}` ({})\n",
127 col_name, col_type
128 ));
129 code.push_str(&format!(
130 " pub fn {}() -> TypedColumn<{}> {{ TypedColumn::new(\"{}\", \"{}\") }}\n\n",
131 fn_name, rust_type, table_name, col_name
132 ));
133 }
134 }
135
136 code.push_str("}\n");
137
138 code
139}
140
141pub fn sql_type_to_rust(sql_type: &str) -> &'static str {
143 let upper = sql_type.to_uppercase();
144
145 if upper.contains("BIGINT") || upper.contains("INT8") || upper.contains("BIGSERIAL") {
147 return "i64";
148 }
149 if upper.contains("INT") || upper.contains("SERIAL") {
150 return "i64"; }
152
153 if upper.contains("FLOAT") || upper.contains("DOUBLE") ||
155 upper.contains("DECIMAL") || upper.contains("NUMERIC") || upper.contains("REAL") {
156 return "f64";
157 }
158
159 if upper.contains("BOOL") {
161 return "bool";
162 }
163
164 if upper.contains("UUID") {
166 return "uuid::Uuid";
167 }
168
169 if upper.contains("TEXT") || upper.contains("VARCHAR") ||
171 upper.contains("CHAR") || upper.contains("NAME") {
172 return "String";
173 }
174
175 if upper.contains("JSON") {
177 return "serde_json::Value";
178 }
179
180 if upper.contains("TIMESTAMP") || upper.contains("DATE") || upper.contains("TIME") {
182 return "chrono::DateTime<chrono::Utc>";
183 }
184
185 if upper.contains("BYTEA") || upper.contains("BLOB") {
187 return "Vec<u8>";
188 }
189
190 "String"
192}
193
194fn to_pascal_case(s: &str) -> String {
196 s.split('_')
197 .map(|word| {
198 let mut chars = word.chars();
199 match chars.next() {
200 None => String::new(),
201 Some(c) => c.to_uppercase().chain(chars).collect(),
202 }
203 })
204 .collect()
205}
206
207fn escape_keyword(name: &str) -> String {
209 const KEYWORDS: &[&str] = &[
210 "as", "break", "const", "continue", "crate", "else", "enum", "extern",
211 "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
212 "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
213 "super", "trait", "true", "type", "unsafe", "use", "where", "while",
214 "async", "await", "dyn", "abstract", "become", "box", "do", "final",
215 "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
216 ];
217
218 if KEYWORDS.contains(&name) {
219 format!("r#{}", name)
220 } else {
221 name.to_string()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_pascal_case() {
231 assert_eq!(to_pascal_case("users"), "Users");
232 assert_eq!(to_pascal_case("user_profiles"), "UserProfiles");
233 }
234
235 #[test]
236 fn test_sql_type_mapping() {
237 assert_eq!(sql_type_to_rust("INT"), "i64");
238 assert_eq!(sql_type_to_rust("TEXT"), "String");
239 assert_eq!(sql_type_to_rust("UUID"), "uuid::Uuid");
240 assert_eq!(sql_type_to_rust("BOOLEAN"), "bool");
241 assert_eq!(sql_type_to_rust("JSONB"), "serde_json::Value");
242 }
243}