1use crate::build::Schema;
24use std::fs;
25
26pub fn generate_to_file(schema_path: &str, output_path: &str) -> Result<(), String> {
28 let schema = Schema::parse_file(schema_path)?;
29 let code = generate_schema_code(&schema);
30 fs::write(output_path, &code)
31 .map_err(|e| format!("Failed to write output: {}", e))?;
32 Ok(())
33}
34
35pub fn generate_from_file(schema_path: &str) -> Result<String, String> {
37 let schema = Schema::parse_file(schema_path)?;
38 Ok(generate_schema_code(&schema))
39}
40
41pub fn generate_schema_code(schema: &Schema) -> String {
43 let mut code = String::new();
44
45 code.push_str("//! Auto-generated by `qail types`\n");
47 code.push_str("//! Do not edit manually.\n\n");
48 code.push_str("#![allow(dead_code)]\n\n");
49 code.push_str("use qail_core::typed::{Table, TypedColumn};\n\n");
50
51 let mut table_names: Vec<_> = schema.tables.keys().collect();
53 table_names.sort();
54
55 for table_name in &table_names {
56 if let Some(table) = schema.tables.get(*table_name) {
57 code.push_str(&generate_table_module(table_name, table));
58 code.push('\n');
59 }
60 }
61
62 code.push_str("/// Re-export all table types\n");
64 code.push_str("pub mod tables {\n");
65
66 for table_name in &table_names {
67 let struct_name = to_pascal_case(table_name);
68 code.push_str(&format!(
69 " pub use super::{}::{};\n",
70 table_name, struct_name
71 ));
72 }
73 code.push_str("}\n");
74
75 code
76}
77
78fn generate_table_module(table_name: &str, table: &crate::build::TableSchema) -> String {
80 let mut code = String::new();
81 let struct_name = to_pascal_case(table_name);
82
83 code.push_str(&format!("/// Table: {}\n", table_name));
84 code.push_str(&format!("pub mod {} {{\n", table_name));
85 code.push_str(" use super::*;\n\n");
86
87 code.push_str(&format!(" /// Type-safe reference to `{}`\n", table_name));
89 code.push_str(" #[derive(Debug, Clone, Copy, Default)]\n");
90 code.push_str(&format!(" pub struct {};\n\n", struct_name));
91
92 code.push_str(&format!(" impl Table for {} {{\n", struct_name));
93 code.push_str(&format!(
94 " fn table_name() -> &'static str {{ \"{}\" }}\n",
95 table_name
96 ));
97 code.push_str(" }\n\n");
98
99 code.push_str(&format!(" impl From<{}> for String {{\n", struct_name));
101 code.push_str(&format!(" fn from(_: {}) -> String {{ \"{}\".to_string() }}\n", struct_name, table_name));
102 code.push_str(" }\n\n");
103
104 let mut col_names: Vec<_> = table.columns.keys().collect();
106 col_names.sort();
107
108 for col_name in &col_names {
109 if let Some(col_type) = table.columns.get(*col_name) {
110 let rust_type = sql_type_to_rust(col_type);
111 let fn_name = escape_keyword(col_name);
112 code.push_str(&format!(
113 " /// Column `{}` ({})\n",
114 col_name, col_type
115 ));
116 code.push_str(&format!(
117 " pub fn {}() -> TypedColumn<{}> {{ TypedColumn::new(\"{}\", \"{}\") }}\n\n",
118 fn_name, rust_type, table_name, col_name
119 ));
120 }
121 }
122
123 code.push_str("}\n");
124
125 code
126}
127
128pub fn sql_type_to_rust(sql_type: &str) -> &'static str {
130 let upper = sql_type.to_uppercase();
131
132 if upper.contains("BIGINT") || upper.contains("INT8") || upper.contains("BIGSERIAL") {
134 return "i64";
135 }
136 if upper.contains("INT") || upper.contains("SERIAL") {
137 return "i64"; }
139
140 if upper.contains("FLOAT") || upper.contains("DOUBLE") ||
142 upper.contains("DECIMAL") || upper.contains("NUMERIC") || upper.contains("REAL") {
143 return "f64";
144 }
145
146 if upper.contains("BOOL") {
148 return "bool";
149 }
150
151 if upper.contains("UUID") {
153 return "uuid::Uuid";
154 }
155
156 if upper.contains("TEXT") || upper.contains("VARCHAR") ||
158 upper.contains("CHAR") || upper.contains("NAME") {
159 return "String";
160 }
161
162 if upper.contains("JSON") {
164 return "serde_json::Value";
165 }
166
167 if upper.contains("TIMESTAMP") || upper.contains("DATE") || upper.contains("TIME") {
169 return "chrono::DateTime<chrono::Utc>";
170 }
171
172 if upper.contains("BYTEA") || upper.contains("BLOB") {
174 return "Vec<u8>";
175 }
176
177 "String"
179}
180
181fn to_pascal_case(s: &str) -> String {
183 s.split('_')
184 .map(|word| {
185 let mut chars = word.chars();
186 match chars.next() {
187 None => String::new(),
188 Some(c) => c.to_uppercase().chain(chars).collect(),
189 }
190 })
191 .collect()
192}
193
194fn escape_keyword(name: &str) -> String {
196 const KEYWORDS: &[&str] = &[
197 "as", "break", "const", "continue", "crate", "else", "enum", "extern",
198 "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
199 "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
200 "super", "trait", "true", "type", "unsafe", "use", "where", "while",
201 "async", "await", "dyn", "abstract", "become", "box", "do", "final",
202 "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
203 ];
204
205 if KEYWORDS.contains(&name) {
206 format!("r#{}", name)
207 } else {
208 name.to_string()
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_pascal_case() {
218 assert_eq!(to_pascal_case("users"), "Users");
219 assert_eq!(to_pascal_case("user_profiles"), "UserProfiles");
220 }
221
222 #[test]
223 fn test_sql_type_mapping() {
224 assert_eq!(sql_type_to_rust("INT"), "i64");
225 assert_eq!(sql_type_to_rust("TEXT"), "String");
226 assert_eq!(sql_type_to_rust("UUID"), "uuid::Uuid");
227 assert_eq!(sql_type_to_rust("BOOLEAN"), "bool");
228 assert_eq!(sql_type_to_rust("JSONB"), "serde_json::Value");
229 }
230}