1use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use crate::migrate::types::ColumnType;
9
10use super::schema::Schema;
11
12fn qail_type_to_rust(col_type: &ColumnType) -> &'static str {
13 match col_type {
14 ColumnType::Uuid => "uuid::Uuid",
15 ColumnType::Text | ColumnType::Varchar(_) => "String",
16 ColumnType::Int | ColumnType::Serial => "i32",
17 ColumnType::BigInt | ColumnType::BigSerial => "i64",
18 ColumnType::Bool => "bool",
19 ColumnType::Float => "f32",
20 ColumnType::Decimal(_) => "rust_decimal::Decimal",
21 ColumnType::Jsonb => "serde_json::Value",
22 ColumnType::Timestamp | ColumnType::Timestamptz => "chrono::DateTime<chrono::Utc>",
23 ColumnType::Date => "chrono::NaiveDate",
24 ColumnType::Time => "chrono::NaiveTime",
25 ColumnType::Bytea => "Vec<u8>",
26 ColumnType::Array(_) => "Vec<serde_json::Value>",
27 ColumnType::Enum { .. } => "String",
28 ColumnType::Range(_) => "String",
29 ColumnType::Interval => "String",
30 ColumnType::Cidr | ColumnType::Inet => "String",
31 ColumnType::MacAddr => "String",
32 }
33}
34
35fn to_rust_ident(name: &str) -> String {
37 escape_keyword(&sanitize_rust_ident(name))
38}
39
40fn to_struct_name(name: &str) -> String {
42 let mut out = String::new();
43 for part in name
44 .split(|c: char| !c.is_ascii_alphanumeric())
45 .filter(|part| !part.is_empty())
46 {
47 let mut chars = part.chars();
48 if let Some(first) = chars.next() {
49 out.extend(first.to_uppercase());
50 out.push_str(chars.as_str());
51 }
52 }
53
54 if out.is_empty() {
55 out.push_str("QailGenerated");
56 }
57 if out
58 .chars()
59 .next()
60 .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
61 {
62 out.insert_str(0, "Qail");
63 }
64 if is_rust_keyword(&out) {
65 out.insert_str(0, "Qail");
66 }
67 out
68}
69
70fn sanitize_rust_ident(name: &str) -> String {
71 let mut ident: String = name
72 .chars()
73 .map(|c| {
74 if c.is_ascii_alphanumeric() || c == '_' {
75 c
76 } else {
77 '_'
78 }
79 })
80 .collect();
81
82 if ident.is_empty() {
83 ident.push('_');
84 }
85 if ident
86 .chars()
87 .next()
88 .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
89 {
90 ident.insert(0, '_');
91 }
92
93 ident
94}
95
96fn escape_keyword(name: &str) -> String {
97 if is_rust_keyword(name) {
98 format!("r#{}", name)
99 } else {
100 name.to_string()
101 }
102}
103
104fn is_rust_keyword(name: &str) -> bool {
105 const KEYWORDS: &[&str] = &[
106 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
107 "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
108 "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
109 "use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
110 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
111 ];
112
113 KEYWORDS.contains(&name)
114}
115
116fn rust_string_literal(value: &str) -> String {
117 format!("{value:?}")
118}
119
120pub fn generate_typed_schema(schema_path: &str, output_path: &str) -> Result<(), String> {
136 let schema = Schema::parse_file(schema_path)?;
137 let code = generate_schema_code(&schema);
138
139 fs::write(output_path, code)
140 .map_err(|e| format!("Failed to write schema module to '{}': {}", output_path, e))?;
141
142 Ok(())
143}
144
145pub fn generate_schema_code(schema: &Schema) -> String {
147 let mut code = String::new();
148
149 code.push_str("//! Auto-generated typed schema from schema.qail\n");
151 code.push_str("//! Do not edit manually - regenerate with `cargo build`\n\n");
152 code.push_str("#![allow(dead_code, non_upper_case_globals)]\n\n");
153 code.push_str("use qail_core::typed::{Table, TypedColumn, RelatedTo, Public, Protected};\n\n");
154
155 let mut tables: Vec<_> = schema.tables.values().collect();
157 tables.sort_by(|a, b| a.name.cmp(&b.name));
158
159 for table in &tables {
160 let mod_name = to_rust_ident(&table.name);
161 let struct_name = to_struct_name(&table.name);
162
163 code.push_str(&format!("/// Typed schema for `{}` table\n", table.name));
164 code.push_str(&format!("pub mod {} {{\n", mod_name));
165 code.push_str(" use super::*;\n\n");
166
167 code.push_str(&format!(" /// Table marker for `{}`\n", table.name));
169 code.push_str(" #[derive(Debug, Clone, Copy)]\n");
170 code.push_str(&format!(" pub struct {};\n\n", struct_name));
171
172 code.push_str(&format!(" impl Table for {} {{\n", struct_name));
173 code.push_str(&format!(
174 " fn table_name() -> &'static str {{ {} }}\n",
175 rust_string_literal(&table.name)
176 ));
177 code.push_str(" }\n\n");
178
179 code.push_str(&format!(" impl From<{}> for String {{\n", struct_name));
180 code.push_str(&format!(
181 " fn from(_: {}) -> String {{ {}.to_string() }}\n",
182 struct_name,
183 rust_string_literal(&table.name)
184 ));
185 code.push_str(" }\n\n");
186
187 code.push_str(&format!(" impl AsRef<str> for {} {{\n", struct_name));
188 code.push_str(&format!(
189 " fn as_ref(&self) -> &str {{ {} }}\n",
190 rust_string_literal(&table.name)
191 ));
192 code.push_str(" }\n\n");
193
194 code.push_str(&format!(" /// The `{}` table\n", table.name));
196 code.push_str(&format!(
197 " pub const table: {} = {};\n\n",
198 struct_name, struct_name
199 ));
200
201 let mut columns: Vec<_> = table.columns.iter().collect();
203 columns.sort_by(|a, b| a.0.cmp(b.0));
204
205 for (col_name, col_type) in columns {
207 let rust_type = qail_type_to_rust(col_type);
208 let col_ident = to_rust_ident(col_name);
209 let policy = table
210 .policies
211 .get(col_name)
212 .map(|s| s.as_str())
213 .unwrap_or("Public");
214 let rust_policy = if policy == "Protected" {
215 "Protected"
216 } else {
217 "Public"
218 };
219
220 code.push_str(&format!(
221 " /// Column `{}.{}` ({}) - {}\n",
222 table.name,
223 col_name,
224 col_type.to_pg_type(),
225 policy
226 ));
227 code.push_str(&format!(
228 " pub const {}: TypedColumn<{}, {}> = TypedColumn::new({}, {});\n",
229 col_ident,
230 rust_type,
231 rust_policy,
232 rust_string_literal(&table.name),
233 rust_string_literal(col_name)
234 ));
235 }
236
237 code.push_str("}\n\n");
238 }
239
240 code.push_str(
245 "// =============================================================================\n",
246 );
247 code.push_str("// Compile-Time Relationship Safety (RelatedTo impls)\n");
248 code.push_str(
249 "// =============================================================================\n\n",
250 );
251
252 let table_names: HashSet<&str> = tables.iter().map(|table| table.name.as_str()).collect();
253 let mut relation_impl_counts: HashMap<(&str, &str), usize> = HashMap::new();
254 for table in &tables {
255 for fk in &table.foreign_keys {
256 if !table_names.contains(fk.ref_table.as_str()) {
257 continue;
258 }
259 *relation_impl_counts
260 .entry((table.name.as_str(), fk.ref_table.as_str()))
261 .or_default() += 1;
262 *relation_impl_counts
263 .entry((fk.ref_table.as_str(), table.name.as_str()))
264 .or_default() += 1;
265 }
266 }
267
268 for table in &tables {
269 for fk in &table.foreign_keys {
270 if !table_names.contains(fk.ref_table.as_str()) {
271 continue;
272 }
273 let from_mod = to_rust_ident(&table.name);
278 let from_struct = to_struct_name(&table.name);
279 let to_mod = to_rust_ident(&fk.ref_table);
280 let to_struct = to_struct_name(&fk.ref_table);
281
282 if relation_impl_counts
285 .get(&(table.name.as_str(), fk.ref_table.as_str()))
286 .copied()
287 .unwrap_or_default()
288 == 1
289 {
290 code.push_str(&format!(
291 "/// {} has a foreign key to {} via {}.{}\n",
292 table.name, fk.ref_table, table.name, fk.column
293 ));
294 code.push_str(&format!(
295 "impl RelatedTo<{}::{}> for {}::{} {{\n",
296 to_mod, to_struct, from_mod, from_struct
297 ));
298 code.push_str(&format!(
299 " fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
300 rust_string_literal(&fk.column),
301 rust_string_literal(&fk.ref_column)
302 ));
303 code.push_str("}\n\n");
304 }
305
306 if relation_impl_counts
310 .get(&(fk.ref_table.as_str(), table.name.as_str()))
311 .copied()
312 .unwrap_or_default()
313 == 1
314 {
315 code.push_str(&format!(
316 "/// {} is referenced by {} via {}.{}\n",
317 fk.ref_table, table.name, table.name, fk.column
318 ));
319 code.push_str(&format!(
320 "impl RelatedTo<{}::{}> for {}::{} {{\n",
321 from_mod, from_struct, to_mod, to_struct
322 ));
323 code.push_str(&format!(
324 " fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
325 rust_string_literal(&fk.ref_column),
326 rust_string_literal(&fk.column)
327 ));
328 code.push_str("}\n\n");
329 }
330 }
331 }
332
333 code
334}
335
336#[cfg(test)]
337mod codegen_tests {
338 use super::*;
339
340 #[test]
341 fn test_generate_schema_code() {
342 let schema_content = r#"
343table users {
344 id UUID primary_key
345 email TEXT not_null
346 age INT
347}
348
349table posts {
350 id UUID primary_key
351 user_id UUID ref:users.id
352 title TEXT
353}
354"#;
355
356 let schema = Schema::parse(schema_content).unwrap();
357 let code = generate_schema_code(&schema);
358
359 assert!(code.contains("pub mod users {"));
361 assert!(code.contains("pub mod posts {"));
362
363 assert!(code.contains("pub struct Users;"));
365 assert!(code.contains("pub struct Posts;"));
366
367 assert!(code.contains("pub const id: TypedColumn<uuid::Uuid, Public>"));
369 assert!(code.contains("pub const email: TypedColumn<String, Public>"));
370 assert!(code.contains("pub const age: TypedColumn<i32, Public>"));
371
372 assert!(code.contains("impl RelatedTo<users::Users> for posts::Posts"));
374 assert!(code.contains("impl RelatedTo<posts::Posts> for users::Users"));
375 }
376
377 #[test]
378 fn test_generate_protected_column() {
379 let schema_content = r#"
380table secrets {
381 id UUID primary_key
382 token TEXT protected
383}
384"#;
385 let schema = Schema::parse(schema_content).unwrap();
386 let code = generate_schema_code(&schema);
387
388 assert!(code.contains("pub const token: TypedColumn<String, Protected>"));
390 }
391
392 #[test]
393 fn test_generate_schema_code_skips_ambiguous_related_to_impls() {
394 let schema_content = r#"
395table users {
396 id UUID primary_key
397}
398
399table invoices {
400 id UUID primary_key
401 buyer_id UUID ref:users.id
402 seller_id UUID ref:users.id
403}
404"#;
405
406 let schema = Schema::parse(schema_content).unwrap();
407 let code = generate_schema_code(&schema);
408
409 assert!(code.contains("pub const buyer_id: TypedColumn<uuid::Uuid, Public>"));
410 assert!(code.contains("pub const seller_id: TypedColumn<uuid::Uuid, Public>"));
411 assert!(!code.contains("impl RelatedTo<users::Users> for invoices::Invoices"));
412 assert!(!code.contains("impl RelatedTo<invoices::Invoices> for users::Users"));
413 }
414
415 #[test]
416 fn test_generate_schema_code_skips_missing_target_related_to_impls() {
417 let schema_content = r#"
418table posts {
419 id UUID primary_key
420 user_id UUID ref:users.id
421}
422"#;
423
424 let schema = Schema::parse(schema_content).unwrap();
425 let code = generate_schema_code(&schema);
426
427 assert!(code.contains("pub mod posts {"));
428 assert!(!code.contains("impl RelatedTo<users::Users> for posts::Posts"));
429 assert!(!code.contains("impl RelatedTo<posts::Posts> for users::Users"));
430 }
431
432 #[test]
433 fn test_generate_schema_code_sanitizes_rust_identifiers() {
434 let schema_content = r#"
435table type {
436 1st TEXT
437 match TEXT
438}
439"#;
440 let schema = Schema::parse(schema_content).unwrap();
441 let code = generate_schema_code(&schema);
442
443 assert!(code.contains("pub mod r#type {"));
444 assert!(code.contains("pub struct Type;"));
445 assert!(code.contains("pub const _1st: TypedColumn<String, Public>"));
446 assert!(code.contains("pub const r#match: TypedColumn<String, Public>"));
447 assert!(code.contains("TypedColumn::new(\"type\", \"1st\")"));
448 }
449}
450
451#[cfg(test)]
452mod migration_parser_tests {
453 use super::*;
454
455 #[test]
456 fn test_agent_contracts_migration_parses_all_columns() {
457 let sql = r#"
458CREATE TABLE agent_contracts (
459 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
460 agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
461 operator_id UUID NOT NULL REFERENCES operators(id) ON DELETE CASCADE,
462 pricing_model VARCHAR(20) NOT NULL CHECK (pricing_model IN ('commission', 'static_markup', 'net_rate')),
463 commission_percent DECIMAL(5,2),
464 static_markup DECIMAL(10,2),
465 is_active BOOLEAN DEFAULT true,
466 valid_from DATE,
467 valid_until DATE,
468 approved_by UUID REFERENCES users(id),
469 created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
470 updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
471 UNIQUE(agent_id, operator_id)
472);
473"#;
474
475 let mut schema = Schema::default();
476 schema.parse_sql_migration(sql);
477
478 let table = schema
479 .tables
480 .get("agent_contracts")
481 .expect("agent_contracts table should exist");
482
483 for col in &[
484 "id",
485 "agent_id",
486 "operator_id",
487 "pricing_model",
488 "commission_percent",
489 "static_markup",
490 "is_active",
491 "valid_from",
492 "valid_until",
493 "approved_by",
494 "created_at",
495 "updated_at",
496 ] {
497 assert!(
498 table.columns.contains_key(*col),
499 "Missing column: '{}'. Found: {:?}",
500 col,
501 table.columns.keys().collect::<Vec<_>>()
502 );
503 }
504 }
505
506 #[test]
509 fn test_keyword_prefixed_column_names_are_not_skipped() {
510 let sql = r#"
511CREATE TABLE edge_cases (
512 id UUID PRIMARY KEY,
513 created_at TIMESTAMPTZ NOT NULL,
514 created_by UUID,
515 primary_contact VARCHAR(255),
516 check_status VARCHAR(20),
517 unique_code VARCHAR(50),
518 foreign_ref UUID,
519 constraint_name VARCHAR(100),
520 PRIMARY KEY (id),
521 CHECK (check_status IN ('pending', 'active')),
522 UNIQUE (unique_code),
523 CONSTRAINT fk_ref FOREIGN KEY (foreign_ref) REFERENCES other(id)
524);
525"#;
526
527 let mut schema = Schema::default();
528 schema.parse_sql_migration(sql);
529
530 let table = schema
531 .tables
532 .get("edge_cases")
533 .expect("edge_cases table should exist");
534
535 for col in &[
537 "created_at",
538 "created_by",
539 "primary_contact",
540 "check_status",
541 "unique_code",
542 "foreign_ref",
543 "constraint_name",
544 ] {
545 assert!(
546 table.columns.contains_key(*col),
547 "Column '{}' should NOT be skipped just because it starts with a SQL keyword. Found: {:?}",
548 col,
549 table.columns.keys().collect::<Vec<_>>()
550 );
551 }
552
553 assert!(
556 !table.columns.contains_key("primary"),
557 "Constraint keyword 'PRIMARY' should not be treated as a column"
558 );
559 }
560}