Skip to main content

qail_core/build/
codegen.rs

1//! Typed schema code generation.
2//!
3//! Generates Rust modules from `schema.qail` for compile-time type safety.
4
5use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use crate::migrate::types::ColumnType;
9
10use super::schema::Schema;
11
12fn qail_type_to_rust(col_type: &ColumnType) -> &'static str {
13    match col_type {
14        ColumnType::Uuid => "uuid::Uuid",
15        ColumnType::Text | ColumnType::Varchar(_) => "String",
16        ColumnType::Int | ColumnType::Serial => "i32",
17        ColumnType::BigInt | ColumnType::BigSerial => "i64",
18        ColumnType::Bool => "bool",
19        ColumnType::Float => "f64",
20        ColumnType::Decimal(_) => "rust_decimal::Decimal",
21        ColumnType::Jsonb => "serde_json::Value",
22        ColumnType::Timestamp => "chrono::NaiveDateTime",
23        ColumnType::Timestamptz => "chrono::DateTime<chrono::Utc>",
24        ColumnType::Date => "chrono::NaiveDate",
25        ColumnType::Time => "chrono::NaiveTime",
26        ColumnType::Bytea => "Vec<u8>",
27        ColumnType::Array(_) => "Vec<serde_json::Value>",
28        ColumnType::Enum { .. } => "String",
29        ColumnType::Range(_) => "String",
30        ColumnType::Interval => "String",
31        ColumnType::Cidr | ColumnType::Inet => "String",
32        ColumnType::MacAddr => "String",
33    }
34}
35
36/// Convert table/column names to valid Rust identifiers
37fn to_rust_ident(name: &str) -> String {
38    escape_keyword(&sanitize_rust_ident(name))
39}
40
41/// Convert table name to PascalCase struct name
42fn to_struct_name(name: &str) -> String {
43    let mut out = String::new();
44    for part in name
45        .split(|c: char| !c.is_ascii_alphanumeric())
46        .filter(|part| !part.is_empty())
47    {
48        let mut chars = part.chars();
49        if let Some(first) = chars.next() {
50            out.extend(first.to_uppercase());
51            out.push_str(chars.as_str());
52        }
53    }
54
55    if out.is_empty() {
56        out.push_str("QailGenerated");
57    }
58    if out
59        .chars()
60        .next()
61        .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
62    {
63        out.insert_str(0, "Qail");
64    }
65    if is_rust_keyword(&out) {
66        out.insert_str(0, "Qail");
67    }
68    out
69}
70
71fn sanitize_rust_ident(name: &str) -> String {
72    let mut ident: String = name
73        .chars()
74        .map(|c| {
75            if c.is_ascii_alphanumeric() || c == '_' {
76                c
77            } else {
78                '_'
79            }
80        })
81        .collect();
82
83    if ident.is_empty() {
84        ident.push('_');
85    }
86    if ident
87        .chars()
88        .next()
89        .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
90    {
91        ident.insert(0, '_');
92    }
93
94    ident
95}
96
97fn escape_keyword(name: &str) -> String {
98    if is_rust_keyword(name) {
99        format!("r#{}", name)
100    } else {
101        name.to_string()
102    }
103}
104
105fn is_rust_keyword(name: &str) -> bool {
106    const KEYWORDS: &[&str] = &[
107        "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
108        "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
109        "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
110        "use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
111        "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
112    ];
113
114    KEYWORDS.contains(&name)
115}
116
117fn rust_string_literal(value: &str) -> String {
118    format!("{value:?}")
119}
120
121/// Generate typed Rust module from schema.
122///
123/// # Usage in consumer's build.rs:
124/// ```ignore
125/// fn main() {
126///     let out_dir = std::env::var("OUT_DIR").unwrap();
127///     qail_core::build::generate_typed_schema("schema.qail", &format!("{}/schema.rs", out_dir)).unwrap();
128///     println!("cargo:rerun-if-changed=schema.qail");
129/// }
130/// ```
131///
132/// Then in the consumer's lib.rs:
133/// ```ignore
134/// include!(concat!(env!("OUT_DIR"), "/schema.rs"));
135/// ```
136pub fn generate_typed_schema(schema_path: &str, output_path: &str) -> Result<(), String> {
137    let schema = Schema::parse_file(schema_path)?;
138    let code = generate_schema_code(&schema);
139
140    fs::write(output_path, code)
141        .map_err(|e| format!("Failed to write schema module to '{}': {}", output_path, e))?;
142
143    Ok(())
144}
145
146/// Generate typed Rust code from schema (does not write to file)
147pub fn generate_schema_code(schema: &Schema) -> String {
148    let mut code = String::new();
149
150    // Header
151    code.push_str("//! Auto-generated typed schema from schema.qail\n");
152    code.push_str("//! Do not edit manually - regenerate with `cargo build`\n\n");
153    code.push_str("#![allow(dead_code, non_upper_case_globals)]\n\n");
154    code.push_str("use qail_core::typed::{Table, TypedColumn, RelatedTo, Public, Protected};\n\n");
155
156    // Sort tables for deterministic output
157    let mut tables: Vec<_> = schema.tables.values().collect();
158    tables.sort_by(|a, b| a.name.cmp(&b.name));
159
160    for table in &tables {
161        let mod_name = to_rust_ident(&table.name);
162        let struct_name = to_struct_name(&table.name);
163
164        code.push_str(&format!("/// Typed schema for `{}` table\n", table.name));
165        code.push_str(&format!("pub mod {} {{\n", mod_name));
166        code.push_str("    use super::*;\n\n");
167
168        // Table struct implementing Table trait
169        code.push_str(&format!("    /// Table marker for `{}`\n", table.name));
170        code.push_str("    #[derive(Debug, Clone, Copy)]\n");
171        code.push_str(&format!("    pub struct {};\n\n", struct_name));
172
173        code.push_str(&format!("    impl Table for {} {{\n", struct_name));
174        code.push_str(&format!(
175            "        fn table_name() -> &'static str {{ {} }}\n",
176            rust_string_literal(&table.name)
177        ));
178        code.push_str("    }\n\n");
179
180        code.push_str(&format!("    impl From<{}> for String {{\n", struct_name));
181        code.push_str(&format!(
182            "        fn from(_: {}) -> String {{ {}.to_string() }}\n",
183            struct_name,
184            rust_string_literal(&table.name)
185        ));
186        code.push_str("    }\n\n");
187
188        code.push_str(&format!("    impl AsRef<str> for {} {{\n", struct_name));
189        code.push_str(&format!(
190            "        fn as_ref(&self) -> &str {{ {} }}\n",
191            rust_string_literal(&table.name)
192        ));
193        code.push_str("    }\n\n");
194
195        // Table constant for convenience
196        code.push_str(&format!("    /// The `{}` table\n", table.name));
197        code.push_str(&format!(
198            "    pub const table: {} = {};\n\n",
199            struct_name, struct_name
200        ));
201
202        // Sort columns for deterministic output
203        let mut columns: Vec<_> = table.columns.iter().collect();
204        columns.sort_by(|a, b| a.0.cmp(b.0));
205
206        // Column constants
207        for (col_name, col_type) in columns {
208            let rust_type = qail_type_to_rust(col_type);
209            let col_ident = to_rust_ident(col_name);
210            let policy = table
211                .policies
212                .get(col_name)
213                .map(|s| s.as_str())
214                .unwrap_or("Public");
215            let rust_policy = if policy == "Protected" {
216                "Protected"
217            } else {
218                "Public"
219            };
220
221            code.push_str(&format!(
222                "    /// Column `{}.{}` ({}) - {}\n",
223                table.name,
224                col_name,
225                col_type.to_pg_type(),
226                policy
227            ));
228            code.push_str(&format!(
229                "    pub const {}: TypedColumn<{}, {}> = TypedColumn::new({}, {});\n",
230                col_ident,
231                rust_type,
232                rust_policy,
233                rust_string_literal(&table.name),
234                rust_string_literal(col_name)
235            ));
236        }
237
238        code.push_str("}\n\n");
239    }
240
241    // ==========================================================================
242    // Generate RelatedTo impls for compile-time relationship checking
243    // ==========================================================================
244
245    code.push_str(
246        "// =============================================================================\n",
247    );
248    code.push_str("// Compile-Time Relationship Safety (RelatedTo impls)\n");
249    code.push_str(
250        "// =============================================================================\n\n",
251    );
252
253    let table_names: HashSet<&str> = tables.iter().map(|table| table.name.as_str()).collect();
254    let mut relation_impl_counts: HashMap<(&str, &str), usize> = HashMap::new();
255    for table in &tables {
256        for fk in &table.foreign_keys {
257            if !table_names.contains(fk.ref_table.as_str()) {
258                continue;
259            }
260            *relation_impl_counts
261                .entry((table.name.as_str(), fk.ref_table.as_str()))
262                .or_default() += 1;
263            *relation_impl_counts
264                .entry((fk.ref_table.as_str(), table.name.as_str()))
265                .or_default() += 1;
266        }
267    }
268
269    for table in &tables {
270        for fk in &table.foreign_keys {
271            if !table_names.contains(fk.ref_table.as_str()) {
272                continue;
273            }
274            // table.column refs ref_table.ref_column
275            // This means: table is related TO ref_table (forward)
276            // AND: ref_table is related FROM table (reverse - parent has many children)
277
278            let from_mod = to_rust_ident(&table.name);
279            let from_struct = to_struct_name(&table.name);
280            let to_mod = to_rust_ident(&fk.ref_table);
281            let to_struct = to_struct_name(&fk.ref_table);
282
283            // Forward: From table (child) -> Referenced table (parent)
284            // Example: posts -> users (posts.user_id -> users.id)
285            if relation_impl_counts
286                .get(&(table.name.as_str(), fk.ref_table.as_str()))
287                .copied()
288                .unwrap_or_default()
289                == 1
290            {
291                code.push_str(&format!(
292                    "/// {} has a foreign key to {} via {}.{}\n",
293                    table.name, fk.ref_table, table.name, fk.column
294                ));
295                code.push_str(&format!(
296                    "impl RelatedTo<{}::{}> for {}::{} {{\n",
297                    to_mod, to_struct, from_mod, from_struct
298                ));
299                code.push_str(&format!(
300                    "    fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
301                    rust_string_literal(&fk.column),
302                    rust_string_literal(&fk.ref_column)
303                ));
304                code.push_str("}\n\n");
305            }
306
307            // Reverse: Referenced table (parent) -> From table (child)
308            // Example: users -> posts (users.id -> posts.user_id)
309            // This allows: Qail::get(users::table).join_related(posts::table)
310            if relation_impl_counts
311                .get(&(fk.ref_table.as_str(), table.name.as_str()))
312                .copied()
313                .unwrap_or_default()
314                == 1
315            {
316                code.push_str(&format!(
317                    "/// {} is referenced by {} via {}.{}\n",
318                    fk.ref_table, table.name, table.name, fk.column
319                ));
320                code.push_str(&format!(
321                    "impl RelatedTo<{}::{}> for {}::{} {{\n",
322                    from_mod, from_struct, to_mod, to_struct
323                ));
324                code.push_str(&format!(
325                    "    fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
326                    rust_string_literal(&fk.ref_column),
327                    rust_string_literal(&fk.column)
328                ));
329                code.push_str("}\n\n");
330            }
331        }
332    }
333
334    code
335}
336
337#[cfg(test)]
338mod codegen_tests {
339    use super::*;
340
341    #[test]
342    fn test_generate_schema_code() {
343        let schema_content = r#"
344table users {
345    id UUID primary_key
346    email TEXT not_null
347    age INT
348}
349
350table posts {
351    id UUID primary_key
352    user_id UUID ref:users.id
353    title TEXT
354}
355"#;
356
357        let schema = Schema::parse(schema_content).unwrap();
358        let code = generate_schema_code(&schema);
359
360        // Verify module structure
361        assert!(code.contains("pub mod users {"));
362        assert!(code.contains("pub mod posts {"));
363
364        // Verify table structs
365        assert!(code.contains("pub struct Users;"));
366        assert!(code.contains("pub struct Posts;"));
367
368        // Verify columns
369        assert!(code.contains("pub const id: TypedColumn<uuid::Uuid, Public>"));
370        assert!(code.contains("pub const email: TypedColumn<String, Public>"));
371        assert!(code.contains("pub const age: TypedColumn<i32, Public>"));
372
373        // Verify RelatedTo impls for compile-time relationship checking
374        assert!(code.contains("impl RelatedTo<users::Users> for posts::Posts"));
375        assert!(code.contains("impl RelatedTo<posts::Posts> for users::Users"));
376    }
377
378    #[test]
379    fn test_generate_protected_column() {
380        let schema_content = r#"
381table secrets {
382    id UUID primary_key
383    token TEXT protected
384}
385"#;
386        let schema = Schema::parse(schema_content).unwrap();
387        let code = generate_schema_code(&schema);
388
389        // Verify Protected policy
390        assert!(code.contains("pub const token: TypedColumn<String, Protected>"));
391    }
392
393    #[test]
394    fn test_generate_float_column_uses_double_precision_rust_type() {
395        let schema_content = r#"
396table metrics {
397    score FLOAT
398}
399"#;
400        let schema = Schema::parse(schema_content).unwrap();
401        let code = generate_schema_code(&schema);
402
403        assert!(code.contains("pub const score: TypedColumn<f64, Public>"));
404    }
405
406    #[test]
407    fn test_generate_timestamp_columns_preserve_timezone_semantics() {
408        let schema_content = r#"
409table events {
410    happened_at TIMESTAMP
411    recorded_at TIMESTAMPTZ
412}
413"#;
414        let schema = Schema::parse(schema_content).unwrap();
415        let code = generate_schema_code(&schema);
416
417        assert!(code.contains("pub const happened_at: TypedColumn<chrono::NaiveDateTime, Public>"));
418        assert!(
419            code.contains(
420                "pub const recorded_at: TypedColumn<chrono::DateTime<chrono::Utc>, Public>"
421            )
422        );
423    }
424
425    #[test]
426    fn test_generate_schema_code_skips_ambiguous_related_to_impls() {
427        let schema_content = r#"
428table users {
429    id UUID primary_key
430}
431
432table invoices {
433    id UUID primary_key
434    buyer_id UUID ref:users.id
435    seller_id UUID ref:users.id
436}
437"#;
438
439        let schema = Schema::parse(schema_content).unwrap();
440        let code = generate_schema_code(&schema);
441
442        assert!(code.contains("pub const buyer_id: TypedColumn<uuid::Uuid, Public>"));
443        assert!(code.contains("pub const seller_id: TypedColumn<uuid::Uuid, Public>"));
444        assert!(!code.contains("impl RelatedTo<users::Users> for invoices::Invoices"));
445        assert!(!code.contains("impl RelatedTo<invoices::Invoices> for users::Users"));
446    }
447
448    #[test]
449    fn test_generate_schema_code_skips_missing_target_related_to_impls() {
450        let schema_content = r#"
451table posts {
452    id UUID primary_key
453    user_id UUID ref:users.id
454}
455"#;
456
457        let schema = Schema::parse(schema_content).unwrap();
458        let code = generate_schema_code(&schema);
459
460        assert!(code.contains("pub mod posts {"));
461        assert!(!code.contains("impl RelatedTo<users::Users> for posts::Posts"));
462        assert!(!code.contains("impl RelatedTo<posts::Posts> for users::Users"));
463    }
464
465    #[test]
466    fn test_generate_schema_code_sanitizes_rust_identifiers() {
467        let schema_content = r#"
468table type {
469    1st TEXT
470    match TEXT
471}
472"#;
473        let schema = Schema::parse(schema_content).unwrap();
474        let code = generate_schema_code(&schema);
475
476        assert!(code.contains("pub mod r#type {"));
477        assert!(code.contains("pub struct Type;"));
478        assert!(code.contains("pub const _1st: TypedColumn<String, Public>"));
479        assert!(code.contains("pub const r#match: TypedColumn<String, Public>"));
480        assert!(code.contains("TypedColumn::new(\"type\", \"1st\")"));
481    }
482}
483
484#[cfg(test)]
485mod migration_parser_tests {
486    use super::*;
487
488    #[test]
489    fn test_agent_contracts_migration_parses_all_columns() {
490        let sql = r#"
491CREATE TABLE agent_contracts (
492    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
493    agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
494    operator_id UUID NOT NULL REFERENCES operators(id) ON DELETE CASCADE,
495    pricing_model VARCHAR(20) NOT NULL CHECK (pricing_model IN ('commission', 'static_markup', 'net_rate')),
496    commission_percent DECIMAL(5,2),
497    static_markup DECIMAL(10,2),
498    is_active BOOLEAN DEFAULT true,
499    valid_from DATE,
500    valid_until DATE,
501    approved_by UUID REFERENCES users(id),
502    created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
503    updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
504    UNIQUE(agent_id, operator_id)
505);
506"#;
507
508        let mut schema = Schema::default();
509        schema.parse_sql_migration(sql);
510
511        let table = schema
512            .tables
513            .get("agent_contracts")
514            .expect("agent_contracts table should exist");
515
516        for col in &[
517            "id",
518            "agent_id",
519            "operator_id",
520            "pricing_model",
521            "commission_percent",
522            "static_markup",
523            "is_active",
524            "valid_from",
525            "valid_until",
526            "approved_by",
527            "created_at",
528            "updated_at",
529        ] {
530            assert!(
531                table.columns.contains_key(*col),
532                "Missing column: '{}'. Found: {:?}",
533                col,
534                table.columns.keys().collect::<Vec<_>>()
535            );
536        }
537    }
538
539    /// Regression test: column names that START with SQL keywords must parse correctly.
540    /// e.g., created_at starts with CREATE, primary_contact starts with PRIMARY, etc.
541    #[test]
542    fn test_keyword_prefixed_column_names_are_not_skipped() {
543        let sql = r#"
544CREATE TABLE edge_cases (
545    id UUID PRIMARY KEY,
546    created_at TIMESTAMPTZ NOT NULL,
547    created_by UUID,
548    primary_contact VARCHAR(255),
549    check_status VARCHAR(20),
550    unique_code VARCHAR(50),
551    foreign_ref UUID,
552    constraint_name VARCHAR(100),
553    PRIMARY KEY (id),
554    CHECK (check_status IN ('pending', 'active')),
555    UNIQUE (unique_code),
556    CONSTRAINT fk_ref FOREIGN KEY (foreign_ref) REFERENCES other(id)
557);
558"#;
559
560        let mut schema = Schema::default();
561        schema.parse_sql_migration(sql);
562
563        let table = schema
564            .tables
565            .get("edge_cases")
566            .expect("edge_cases table should exist");
567
568        // These column names start with SQL keywords — all must be found
569        for col in &[
570            "created_at",
571            "created_by",
572            "primary_contact",
573            "check_status",
574            "unique_code",
575            "foreign_ref",
576            "constraint_name",
577        ] {
578            assert!(
579                table.columns.contains_key(*col),
580                "Column '{}' should NOT be skipped just because it starts with a SQL keyword. Found: {:?}",
581                col,
582                table.columns.keys().collect::<Vec<_>>()
583            );
584        }
585
586        // These are constraint keywords, not columns — must NOT appear
587        // (PRIMARY KEY, CHECK, UNIQUE, CONSTRAINT lines should be skipped)
588        assert!(
589            !table.columns.contains_key("primary"),
590            "Constraint keyword 'PRIMARY' should not be treated as a column"
591        );
592    }
593}