1use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use crate::migrate::types::ColumnType;
9
10use super::schema::Schema;
11
12fn qail_type_to_rust(col_type: &ColumnType) -> &'static str {
13 match col_type {
14 ColumnType::Uuid => "uuid::Uuid",
15 ColumnType::Text | ColumnType::Varchar(_) => "String",
16 ColumnType::Int | ColumnType::Serial => "i32",
17 ColumnType::BigInt | ColumnType::BigSerial => "i64",
18 ColumnType::Bool => "bool",
19 ColumnType::Float => "f64",
20 ColumnType::Decimal(_) => "rust_decimal::Decimal",
21 ColumnType::Jsonb => "serde_json::Value",
22 ColumnType::Timestamp => "chrono::NaiveDateTime",
23 ColumnType::Timestamptz => "chrono::DateTime<chrono::Utc>",
24 ColumnType::Date => "chrono::NaiveDate",
25 ColumnType::Time => "chrono::NaiveTime",
26 ColumnType::Bytea => "Vec<u8>",
27 ColumnType::Array(_) => "Vec<serde_json::Value>",
28 ColumnType::Enum { .. } => "String",
29 ColumnType::Range(_) => "String",
30 ColumnType::Interval => "String",
31 ColumnType::Cidr | ColumnType::Inet => "String",
32 ColumnType::MacAddr => "String",
33 }
34}
35
36fn to_rust_ident(name: &str) -> String {
38 escape_keyword(&sanitize_rust_ident(name))
39}
40
41fn to_struct_name(name: &str) -> String {
43 let mut out = String::new();
44 for part in name
45 .split(|c: char| !c.is_ascii_alphanumeric())
46 .filter(|part| !part.is_empty())
47 {
48 let mut chars = part.chars();
49 if let Some(first) = chars.next() {
50 out.extend(first.to_uppercase());
51 out.push_str(chars.as_str());
52 }
53 }
54
55 if out.is_empty() {
56 out.push_str("QailGenerated");
57 }
58 if out
59 .chars()
60 .next()
61 .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
62 {
63 out.insert_str(0, "Qail");
64 }
65 if is_rust_keyword(&out) {
66 out.insert_str(0, "Qail");
67 }
68 out
69}
70
71fn sanitize_rust_ident(name: &str) -> String {
72 let mut ident: String = name
73 .chars()
74 .map(|c| {
75 if c.is_ascii_alphanumeric() || c == '_' {
76 c
77 } else {
78 '_'
79 }
80 })
81 .collect();
82
83 if ident.is_empty() {
84 ident.push('_');
85 }
86 if ident
87 .chars()
88 .next()
89 .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
90 {
91 ident.insert(0, '_');
92 }
93
94 ident
95}
96
97fn escape_keyword(name: &str) -> String {
98 if is_rust_keyword(name) {
99 format!("r#{}", name)
100 } else {
101 name.to_string()
102 }
103}
104
105fn is_rust_keyword(name: &str) -> bool {
106 const KEYWORDS: &[&str] = &[
107 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
108 "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
109 "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
110 "use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
111 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
112 ];
113
114 KEYWORDS.contains(&name)
115}
116
117fn rust_string_literal(value: &str) -> String {
118 format!("{value:?}")
119}
120
121pub fn generate_typed_schema(schema_path: &str, output_path: &str) -> Result<(), String> {
137 let schema = Schema::parse_file(schema_path)?;
138 let code = generate_schema_code(&schema);
139
140 fs::write(output_path, code)
141 .map_err(|e| format!("Failed to write schema module to '{}': {}", output_path, e))?;
142
143 Ok(())
144}
145
146pub fn generate_schema_code(schema: &Schema) -> String {
148 let mut code = String::new();
149
150 code.push_str("//! Auto-generated typed schema from schema.qail\n");
152 code.push_str("//! Do not edit manually - regenerate with `cargo build`\n\n");
153 code.push_str("#![allow(dead_code, non_upper_case_globals)]\n\n");
154 code.push_str("use qail_core::typed::{Table, TypedColumn, RelatedTo, Public, Protected};\n\n");
155
156 let mut tables: Vec<_> = schema.tables.values().collect();
158 tables.sort_by(|a, b| a.name.cmp(&b.name));
159
160 for table in &tables {
161 let mod_name = to_rust_ident(&table.name);
162 let struct_name = to_struct_name(&table.name);
163
164 code.push_str(&format!("/// Typed schema for `{}` table\n", table.name));
165 code.push_str(&format!("pub mod {} {{\n", mod_name));
166 code.push_str(" use super::*;\n\n");
167
168 code.push_str(&format!(" /// Table marker for `{}`\n", table.name));
170 code.push_str(" #[derive(Debug, Clone, Copy)]\n");
171 code.push_str(&format!(" pub struct {};\n\n", struct_name));
172
173 code.push_str(&format!(" impl Table for {} {{\n", struct_name));
174 code.push_str(&format!(
175 " fn table_name() -> &'static str {{ {} }}\n",
176 rust_string_literal(&table.name)
177 ));
178 code.push_str(" }\n\n");
179
180 code.push_str(&format!(" impl From<{}> for String {{\n", struct_name));
181 code.push_str(&format!(
182 " fn from(_: {}) -> String {{ {}.to_string() }}\n",
183 struct_name,
184 rust_string_literal(&table.name)
185 ));
186 code.push_str(" }\n\n");
187
188 code.push_str(&format!(" impl AsRef<str> for {} {{\n", struct_name));
189 code.push_str(&format!(
190 " fn as_ref(&self) -> &str {{ {} }}\n",
191 rust_string_literal(&table.name)
192 ));
193 code.push_str(" }\n\n");
194
195 code.push_str(&format!(" /// The `{}` table\n", table.name));
197 code.push_str(&format!(
198 " pub const table: {} = {};\n\n",
199 struct_name, struct_name
200 ));
201
202 let mut columns: Vec<_> = table.columns.iter().collect();
204 columns.sort_by(|a, b| a.0.cmp(b.0));
205
206 for (col_name, col_type) in columns {
208 let rust_type = qail_type_to_rust(col_type);
209 let col_ident = to_rust_ident(col_name);
210 let policy = table
211 .policies
212 .get(col_name)
213 .map(|s| s.as_str())
214 .unwrap_or("Public");
215 let rust_policy = if policy == "Protected" {
216 "Protected"
217 } else {
218 "Public"
219 };
220
221 code.push_str(&format!(
222 " /// Column `{}.{}` ({}) - {}\n",
223 table.name,
224 col_name,
225 col_type.to_pg_type(),
226 policy
227 ));
228 code.push_str(&format!(
229 " pub const {}: TypedColumn<{}, {}> = TypedColumn::new({}, {});\n",
230 col_ident,
231 rust_type,
232 rust_policy,
233 rust_string_literal(&table.name),
234 rust_string_literal(col_name)
235 ));
236 }
237
238 code.push_str("}\n\n");
239 }
240
241 code.push_str(
246 "// =============================================================================\n",
247 );
248 code.push_str("// Compile-Time Relationship Safety (RelatedTo impls)\n");
249 code.push_str(
250 "// =============================================================================\n\n",
251 );
252
253 let table_names: HashSet<&str> = tables.iter().map(|table| table.name.as_str()).collect();
254 let mut relation_impl_counts: HashMap<(&str, &str), usize> = HashMap::new();
255 for table in &tables {
256 for fk in &table.foreign_keys {
257 if !table_names.contains(fk.ref_table.as_str()) {
258 continue;
259 }
260 *relation_impl_counts
261 .entry((table.name.as_str(), fk.ref_table.as_str()))
262 .or_default() += 1;
263 *relation_impl_counts
264 .entry((fk.ref_table.as_str(), table.name.as_str()))
265 .or_default() += 1;
266 }
267 }
268
269 for table in &tables {
270 for fk in &table.foreign_keys {
271 if !table_names.contains(fk.ref_table.as_str()) {
272 continue;
273 }
274 let from_mod = to_rust_ident(&table.name);
279 let from_struct = to_struct_name(&table.name);
280 let to_mod = to_rust_ident(&fk.ref_table);
281 let to_struct = to_struct_name(&fk.ref_table);
282
283 if relation_impl_counts
286 .get(&(table.name.as_str(), fk.ref_table.as_str()))
287 .copied()
288 .unwrap_or_default()
289 == 1
290 {
291 code.push_str(&format!(
292 "/// {} has a foreign key to {} via {}.{}\n",
293 table.name, fk.ref_table, table.name, fk.column
294 ));
295 code.push_str(&format!(
296 "impl RelatedTo<{}::{}> for {}::{} {{\n",
297 to_mod, to_struct, from_mod, from_struct
298 ));
299 code.push_str(&format!(
300 " fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
301 rust_string_literal(&fk.column),
302 rust_string_literal(&fk.ref_column)
303 ));
304 code.push_str("}\n\n");
305 }
306
307 if relation_impl_counts
311 .get(&(fk.ref_table.as_str(), table.name.as_str()))
312 .copied()
313 .unwrap_or_default()
314 == 1
315 {
316 code.push_str(&format!(
317 "/// {} is referenced by {} via {}.{}\n",
318 fk.ref_table, table.name, table.name, fk.column
319 ));
320 code.push_str(&format!(
321 "impl RelatedTo<{}::{}> for {}::{} {{\n",
322 from_mod, from_struct, to_mod, to_struct
323 ));
324 code.push_str(&format!(
325 " fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
326 rust_string_literal(&fk.ref_column),
327 rust_string_literal(&fk.column)
328 ));
329 code.push_str("}\n\n");
330 }
331 }
332 }
333
334 code
335}
336
337#[cfg(test)]
338mod codegen_tests {
339 use super::*;
340
341 #[test]
342 fn test_generate_schema_code() {
343 let schema_content = r#"
344table users {
345 id UUID primary_key
346 email TEXT not_null
347 age INT
348}
349
350table posts {
351 id UUID primary_key
352 user_id UUID ref:users.id
353 title TEXT
354}
355"#;
356
357 let schema = Schema::parse(schema_content).unwrap();
358 let code = generate_schema_code(&schema);
359
360 assert!(code.contains("pub mod users {"));
362 assert!(code.contains("pub mod posts {"));
363
364 assert!(code.contains("pub struct Users;"));
366 assert!(code.contains("pub struct Posts;"));
367
368 assert!(code.contains("pub const id: TypedColumn<uuid::Uuid, Public>"));
370 assert!(code.contains("pub const email: TypedColumn<String, Public>"));
371 assert!(code.contains("pub const age: TypedColumn<i32, Public>"));
372
373 assert!(code.contains("impl RelatedTo<users::Users> for posts::Posts"));
375 assert!(code.contains("impl RelatedTo<posts::Posts> for users::Users"));
376 }
377
378 #[test]
379 fn test_generate_protected_column() {
380 let schema_content = r#"
381table secrets {
382 id UUID primary_key
383 token TEXT protected
384}
385"#;
386 let schema = Schema::parse(schema_content).unwrap();
387 let code = generate_schema_code(&schema);
388
389 assert!(code.contains("pub const token: TypedColumn<String, Protected>"));
391 }
392
393 #[test]
394 fn test_generate_float_column_uses_double_precision_rust_type() {
395 let schema_content = r#"
396table metrics {
397 score FLOAT
398}
399"#;
400 let schema = Schema::parse(schema_content).unwrap();
401 let code = generate_schema_code(&schema);
402
403 assert!(code.contains("pub const score: TypedColumn<f64, Public>"));
404 }
405
406 #[test]
407 fn test_generate_timestamp_columns_preserve_timezone_semantics() {
408 let schema_content = r#"
409table events {
410 happened_at TIMESTAMP
411 recorded_at TIMESTAMPTZ
412}
413"#;
414 let schema = Schema::parse(schema_content).unwrap();
415 let code = generate_schema_code(&schema);
416
417 assert!(code.contains("pub const happened_at: TypedColumn<chrono::NaiveDateTime, Public>"));
418 assert!(
419 code.contains(
420 "pub const recorded_at: TypedColumn<chrono::DateTime<chrono::Utc>, Public>"
421 )
422 );
423 }
424
425 #[test]
426 fn test_generate_schema_code_skips_ambiguous_related_to_impls() {
427 let schema_content = r#"
428table users {
429 id UUID primary_key
430}
431
432table invoices {
433 id UUID primary_key
434 buyer_id UUID ref:users.id
435 seller_id UUID ref:users.id
436}
437"#;
438
439 let schema = Schema::parse(schema_content).unwrap();
440 let code = generate_schema_code(&schema);
441
442 assert!(code.contains("pub const buyer_id: TypedColumn<uuid::Uuid, Public>"));
443 assert!(code.contains("pub const seller_id: TypedColumn<uuid::Uuid, Public>"));
444 assert!(!code.contains("impl RelatedTo<users::Users> for invoices::Invoices"));
445 assert!(!code.contains("impl RelatedTo<invoices::Invoices> for users::Users"));
446 }
447
448 #[test]
449 fn test_generate_schema_code_skips_missing_target_related_to_impls() {
450 let schema_content = r#"
451table posts {
452 id UUID primary_key
453 user_id UUID ref:users.id
454}
455"#;
456
457 let schema = Schema::parse(schema_content).unwrap();
458 let code = generate_schema_code(&schema);
459
460 assert!(code.contains("pub mod posts {"));
461 assert!(!code.contains("impl RelatedTo<users::Users> for posts::Posts"));
462 assert!(!code.contains("impl RelatedTo<posts::Posts> for users::Users"));
463 }
464
465 #[test]
466 fn test_generate_schema_code_sanitizes_rust_identifiers() {
467 let schema_content = r#"
468table type {
469 1st TEXT
470 match TEXT
471}
472"#;
473 let schema = Schema::parse(schema_content).unwrap();
474 let code = generate_schema_code(&schema);
475
476 assert!(code.contains("pub mod r#type {"));
477 assert!(code.contains("pub struct Type;"));
478 assert!(code.contains("pub const _1st: TypedColumn<String, Public>"));
479 assert!(code.contains("pub const r#match: TypedColumn<String, Public>"));
480 assert!(code.contains("TypedColumn::new(\"type\", \"1st\")"));
481 }
482}
483
484#[cfg(test)]
485mod migration_parser_tests {
486 use super::*;
487
488 #[test]
489 fn test_agent_contracts_migration_parses_all_columns() {
490 let sql = r#"
491CREATE TABLE agent_contracts (
492 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
493 agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
494 operator_id UUID NOT NULL REFERENCES operators(id) ON DELETE CASCADE,
495 pricing_model VARCHAR(20) NOT NULL CHECK (pricing_model IN ('commission', 'static_markup', 'net_rate')),
496 commission_percent DECIMAL(5,2),
497 static_markup DECIMAL(10,2),
498 is_active BOOLEAN DEFAULT true,
499 valid_from DATE,
500 valid_until DATE,
501 approved_by UUID REFERENCES users(id),
502 created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
503 updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
504 UNIQUE(agent_id, operator_id)
505);
506"#;
507
508 let mut schema = Schema::default();
509 schema.parse_sql_migration(sql);
510
511 let table = schema
512 .tables
513 .get("agent_contracts")
514 .expect("agent_contracts table should exist");
515
516 for col in &[
517 "id",
518 "agent_id",
519 "operator_id",
520 "pricing_model",
521 "commission_percent",
522 "static_markup",
523 "is_active",
524 "valid_from",
525 "valid_until",
526 "approved_by",
527 "created_at",
528 "updated_at",
529 ] {
530 assert!(
531 table.columns.contains_key(*col),
532 "Missing column: '{}'. Found: {:?}",
533 col,
534 table.columns.keys().collect::<Vec<_>>()
535 );
536 }
537 }
538
539 #[test]
542 fn test_keyword_prefixed_column_names_are_not_skipped() {
543 let sql = r#"
544CREATE TABLE edge_cases (
545 id UUID PRIMARY KEY,
546 created_at TIMESTAMPTZ NOT NULL,
547 created_by UUID,
548 primary_contact VARCHAR(255),
549 check_status VARCHAR(20),
550 unique_code VARCHAR(50),
551 foreign_ref UUID,
552 constraint_name VARCHAR(100),
553 PRIMARY KEY (id),
554 CHECK (check_status IN ('pending', 'active')),
555 UNIQUE (unique_code),
556 CONSTRAINT fk_ref FOREIGN KEY (foreign_ref) REFERENCES other(id)
557);
558"#;
559
560 let mut schema = Schema::default();
561 schema.parse_sql_migration(sql);
562
563 let table = schema
564 .tables
565 .get("edge_cases")
566 .expect("edge_cases table should exist");
567
568 for col in &[
570 "created_at",
571 "created_by",
572 "primary_contact",
573 "check_status",
574 "unique_code",
575 "foreign_ref",
576 "constraint_name",
577 ] {
578 assert!(
579 table.columns.contains_key(*col),
580 "Column '{}' should NOT be skipped just because it starts with a SQL keyword. Found: {:?}",
581 col,
582 table.columns.keys().collect::<Vec<_>>()
583 );
584 }
585
586 assert!(
589 !table.columns.contains_key("primary"),
590 "Constraint keyword 'PRIMARY' should not be treated as a column"
591 );
592 }
593}