Skip to main content

qail_core/build/
codegen.rs

1//! Typed schema code generation.
2//!
3//! Generates Rust modules from `schema.qail` for compile-time type safety.
4
5use std::collections::{HashMap, HashSet};
6use std::fs;
7
8use crate::migrate::types::ColumnType;
9
10use super::schema::Schema;
11
12fn qail_type_to_rust(col_type: &ColumnType) -> &'static str {
13    match col_type {
14        ColumnType::Uuid => "uuid::Uuid",
15        ColumnType::Text | ColumnType::Varchar(_) => "String",
16        ColumnType::Int | ColumnType::Serial => "i32",
17        ColumnType::BigInt | ColumnType::BigSerial => "i64",
18        ColumnType::Bool => "bool",
19        ColumnType::Float => "f32",
20        ColumnType::Decimal(_) => "rust_decimal::Decimal",
21        ColumnType::Jsonb => "serde_json::Value",
22        ColumnType::Timestamp | ColumnType::Timestamptz => "chrono::DateTime<chrono::Utc>",
23        ColumnType::Date => "chrono::NaiveDate",
24        ColumnType::Time => "chrono::NaiveTime",
25        ColumnType::Bytea => "Vec<u8>",
26        ColumnType::Array(_) => "Vec<serde_json::Value>",
27        ColumnType::Enum { .. } => "String",
28        ColumnType::Range(_) => "String",
29        ColumnType::Interval => "String",
30        ColumnType::Cidr | ColumnType::Inet => "String",
31        ColumnType::MacAddr => "String",
32    }
33}
34
35/// Convert table/column names to valid Rust identifiers
36fn to_rust_ident(name: &str) -> String {
37    escape_keyword(&sanitize_rust_ident(name))
38}
39
40/// Convert table name to PascalCase struct name
41fn to_struct_name(name: &str) -> String {
42    let mut out = String::new();
43    for part in name
44        .split(|c: char| !c.is_ascii_alphanumeric())
45        .filter(|part| !part.is_empty())
46    {
47        let mut chars = part.chars();
48        if let Some(first) = chars.next() {
49            out.extend(first.to_uppercase());
50            out.push_str(chars.as_str());
51        }
52    }
53
54    if out.is_empty() {
55        out.push_str("QailGenerated");
56    }
57    if out
58        .chars()
59        .next()
60        .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
61    {
62        out.insert_str(0, "Qail");
63    }
64    if is_rust_keyword(&out) {
65        out.insert_str(0, "Qail");
66    }
67    out
68}
69
70fn sanitize_rust_ident(name: &str) -> String {
71    let mut ident: String = name
72        .chars()
73        .map(|c| {
74            if c.is_ascii_alphanumeric() || c == '_' {
75                c
76            } else {
77                '_'
78            }
79        })
80        .collect();
81
82    if ident.is_empty() {
83        ident.push('_');
84    }
85    if ident
86        .chars()
87        .next()
88        .is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
89    {
90        ident.insert(0, '_');
91    }
92
93    ident
94}
95
96fn escape_keyword(name: &str) -> String {
97    if is_rust_keyword(name) {
98        format!("r#{}", name)
99    } else {
100        name.to_string()
101    }
102}
103
104fn is_rust_keyword(name: &str) -> bool {
105    const KEYWORDS: &[&str] = &[
106        "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
107        "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
108        "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
109        "use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
110        "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
111    ];
112
113    KEYWORDS.contains(&name)
114}
115
116fn rust_string_literal(value: &str) -> String {
117    format!("{value:?}")
118}
119
120/// Generate typed Rust module from schema.
121///
122/// # Usage in consumer's build.rs:
123/// ```ignore
124/// fn main() {
125///     let out_dir = std::env::var("OUT_DIR").unwrap();
126///     qail_core::build::generate_typed_schema("schema.qail", &format!("{}/schema.rs", out_dir)).unwrap();
127///     println!("cargo:rerun-if-changed=schema.qail");
128/// }
129/// ```
130///
131/// Then in the consumer's lib.rs:
132/// ```ignore
133/// include!(concat!(env!("OUT_DIR"), "/schema.rs"));
134/// ```
135pub fn generate_typed_schema(schema_path: &str, output_path: &str) -> Result<(), String> {
136    let schema = Schema::parse_file(schema_path)?;
137    let code = generate_schema_code(&schema);
138
139    fs::write(output_path, code)
140        .map_err(|e| format!("Failed to write schema module to '{}': {}", output_path, e))?;
141
142    Ok(())
143}
144
145/// Generate typed Rust code from schema (does not write to file)
146pub fn generate_schema_code(schema: &Schema) -> String {
147    let mut code = String::new();
148
149    // Header
150    code.push_str("//! Auto-generated typed schema from schema.qail\n");
151    code.push_str("//! Do not edit manually - regenerate with `cargo build`\n\n");
152    code.push_str("#![allow(dead_code, non_upper_case_globals)]\n\n");
153    code.push_str("use qail_core::typed::{Table, TypedColumn, RelatedTo, Public, Protected};\n\n");
154
155    // Sort tables for deterministic output
156    let mut tables: Vec<_> = schema.tables.values().collect();
157    tables.sort_by(|a, b| a.name.cmp(&b.name));
158
159    for table in &tables {
160        let mod_name = to_rust_ident(&table.name);
161        let struct_name = to_struct_name(&table.name);
162
163        code.push_str(&format!("/// Typed schema for `{}` table\n", table.name));
164        code.push_str(&format!("pub mod {} {{\n", mod_name));
165        code.push_str("    use super::*;\n\n");
166
167        // Table struct implementing Table trait
168        code.push_str(&format!("    /// Table marker for `{}`\n", table.name));
169        code.push_str("    #[derive(Debug, Clone, Copy)]\n");
170        code.push_str(&format!("    pub struct {};\n\n", struct_name));
171
172        code.push_str(&format!("    impl Table for {} {{\n", struct_name));
173        code.push_str(&format!(
174            "        fn table_name() -> &'static str {{ {} }}\n",
175            rust_string_literal(&table.name)
176        ));
177        code.push_str("    }\n\n");
178
179        code.push_str(&format!("    impl From<{}> for String {{\n", struct_name));
180        code.push_str(&format!(
181            "        fn from(_: {}) -> String {{ {}.to_string() }}\n",
182            struct_name,
183            rust_string_literal(&table.name)
184        ));
185        code.push_str("    }\n\n");
186
187        code.push_str(&format!("    impl AsRef<str> for {} {{\n", struct_name));
188        code.push_str(&format!(
189            "        fn as_ref(&self) -> &str {{ {} }}\n",
190            rust_string_literal(&table.name)
191        ));
192        code.push_str("    }\n\n");
193
194        // Table constant for convenience
195        code.push_str(&format!("    /// The `{}` table\n", table.name));
196        code.push_str(&format!(
197            "    pub const table: {} = {};\n\n",
198            struct_name, struct_name
199        ));
200
201        // Sort columns for deterministic output
202        let mut columns: Vec<_> = table.columns.iter().collect();
203        columns.sort_by(|a, b| a.0.cmp(b.0));
204
205        // Column constants
206        for (col_name, col_type) in columns {
207            let rust_type = qail_type_to_rust(col_type);
208            let col_ident = to_rust_ident(col_name);
209            let policy = table
210                .policies
211                .get(col_name)
212                .map(|s| s.as_str())
213                .unwrap_or("Public");
214            let rust_policy = if policy == "Protected" {
215                "Protected"
216            } else {
217                "Public"
218            };
219
220            code.push_str(&format!(
221                "    /// Column `{}.{}` ({}) - {}\n",
222                table.name,
223                col_name,
224                col_type.to_pg_type(),
225                policy
226            ));
227            code.push_str(&format!(
228                "    pub const {}: TypedColumn<{}, {}> = TypedColumn::new({}, {});\n",
229                col_ident,
230                rust_type,
231                rust_policy,
232                rust_string_literal(&table.name),
233                rust_string_literal(col_name)
234            ));
235        }
236
237        code.push_str("}\n\n");
238    }
239
240    // ==========================================================================
241    // Generate RelatedTo impls for compile-time relationship checking
242    // ==========================================================================
243
244    code.push_str(
245        "// =============================================================================\n",
246    );
247    code.push_str("// Compile-Time Relationship Safety (RelatedTo impls)\n");
248    code.push_str(
249        "// =============================================================================\n\n",
250    );
251
252    let table_names: HashSet<&str> = tables.iter().map(|table| table.name.as_str()).collect();
253    let mut relation_impl_counts: HashMap<(&str, &str), usize> = HashMap::new();
254    for table in &tables {
255        for fk in &table.foreign_keys {
256            if !table_names.contains(fk.ref_table.as_str()) {
257                continue;
258            }
259            *relation_impl_counts
260                .entry((table.name.as_str(), fk.ref_table.as_str()))
261                .or_default() += 1;
262            *relation_impl_counts
263                .entry((fk.ref_table.as_str(), table.name.as_str()))
264                .or_default() += 1;
265        }
266    }
267
268    for table in &tables {
269        for fk in &table.foreign_keys {
270            if !table_names.contains(fk.ref_table.as_str()) {
271                continue;
272            }
273            // table.column refs ref_table.ref_column
274            // This means: table is related TO ref_table (forward)
275            // AND: ref_table is related FROM table (reverse - parent has many children)
276
277            let from_mod = to_rust_ident(&table.name);
278            let from_struct = to_struct_name(&table.name);
279            let to_mod = to_rust_ident(&fk.ref_table);
280            let to_struct = to_struct_name(&fk.ref_table);
281
282            // Forward: From table (child) -> Referenced table (parent)
283            // Example: posts -> users (posts.user_id -> users.id)
284            if relation_impl_counts
285                .get(&(table.name.as_str(), fk.ref_table.as_str()))
286                .copied()
287                .unwrap_or_default()
288                == 1
289            {
290                code.push_str(&format!(
291                    "/// {} has a foreign key to {} via {}.{}\n",
292                    table.name, fk.ref_table, table.name, fk.column
293                ));
294                code.push_str(&format!(
295                    "impl RelatedTo<{}::{}> for {}::{} {{\n",
296                    to_mod, to_struct, from_mod, from_struct
297                ));
298                code.push_str(&format!(
299                    "    fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
300                    rust_string_literal(&fk.column),
301                    rust_string_literal(&fk.ref_column)
302                ));
303                code.push_str("}\n\n");
304            }
305
306            // Reverse: Referenced table (parent) -> From table (child)
307            // Example: users -> posts (users.id -> posts.user_id)
308            // This allows: Qail::get(users::table).join_related(posts::table)
309            if relation_impl_counts
310                .get(&(fk.ref_table.as_str(), table.name.as_str()))
311                .copied()
312                .unwrap_or_default()
313                == 1
314            {
315                code.push_str(&format!(
316                    "/// {} is referenced by {} via {}.{}\n",
317                    fk.ref_table, table.name, table.name, fk.column
318                ));
319                code.push_str(&format!(
320                    "impl RelatedTo<{}::{}> for {}::{} {{\n",
321                    from_mod, from_struct, to_mod, to_struct
322                ));
323                code.push_str(&format!(
324                    "    fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
325                    rust_string_literal(&fk.ref_column),
326                    rust_string_literal(&fk.column)
327                ));
328                code.push_str("}\n\n");
329            }
330        }
331    }
332
333    code
334}
335
336#[cfg(test)]
337mod codegen_tests {
338    use super::*;
339
340    #[test]
341    fn test_generate_schema_code() {
342        let schema_content = r#"
343table users {
344    id UUID primary_key
345    email TEXT not_null
346    age INT
347}
348
349table posts {
350    id UUID primary_key
351    user_id UUID ref:users.id
352    title TEXT
353}
354"#;
355
356        let schema = Schema::parse(schema_content).unwrap();
357        let code = generate_schema_code(&schema);
358
359        // Verify module structure
360        assert!(code.contains("pub mod users {"));
361        assert!(code.contains("pub mod posts {"));
362
363        // Verify table structs
364        assert!(code.contains("pub struct Users;"));
365        assert!(code.contains("pub struct Posts;"));
366
367        // Verify columns
368        assert!(code.contains("pub const id: TypedColumn<uuid::Uuid, Public>"));
369        assert!(code.contains("pub const email: TypedColumn<String, Public>"));
370        assert!(code.contains("pub const age: TypedColumn<i32, Public>"));
371
372        // Verify RelatedTo impls for compile-time relationship checking
373        assert!(code.contains("impl RelatedTo<users::Users> for posts::Posts"));
374        assert!(code.contains("impl RelatedTo<posts::Posts> for users::Users"));
375    }
376
377    #[test]
378    fn test_generate_protected_column() {
379        let schema_content = r#"
380table secrets {
381    id UUID primary_key
382    token TEXT protected
383}
384"#;
385        let schema = Schema::parse(schema_content).unwrap();
386        let code = generate_schema_code(&schema);
387
388        // Verify Protected policy
389        assert!(code.contains("pub const token: TypedColumn<String, Protected>"));
390    }
391
392    #[test]
393    fn test_generate_schema_code_skips_ambiguous_related_to_impls() {
394        let schema_content = r#"
395table users {
396    id UUID primary_key
397}
398
399table invoices {
400    id UUID primary_key
401    buyer_id UUID ref:users.id
402    seller_id UUID ref:users.id
403}
404"#;
405
406        let schema = Schema::parse(schema_content).unwrap();
407        let code = generate_schema_code(&schema);
408
409        assert!(code.contains("pub const buyer_id: TypedColumn<uuid::Uuid, Public>"));
410        assert!(code.contains("pub const seller_id: TypedColumn<uuid::Uuid, Public>"));
411        assert!(!code.contains("impl RelatedTo<users::Users> for invoices::Invoices"));
412        assert!(!code.contains("impl RelatedTo<invoices::Invoices> for users::Users"));
413    }
414
415    #[test]
416    fn test_generate_schema_code_skips_missing_target_related_to_impls() {
417        let schema_content = r#"
418table posts {
419    id UUID primary_key
420    user_id UUID ref:users.id
421}
422"#;
423
424        let schema = Schema::parse(schema_content).unwrap();
425        let code = generate_schema_code(&schema);
426
427        assert!(code.contains("pub mod posts {"));
428        assert!(!code.contains("impl RelatedTo<users::Users> for posts::Posts"));
429        assert!(!code.contains("impl RelatedTo<posts::Posts> for users::Users"));
430    }
431
432    #[test]
433    fn test_generate_schema_code_sanitizes_rust_identifiers() {
434        let schema_content = r#"
435table type {
436    1st TEXT
437    match TEXT
438}
439"#;
440        let schema = Schema::parse(schema_content).unwrap();
441        let code = generate_schema_code(&schema);
442
443        assert!(code.contains("pub mod r#type {"));
444        assert!(code.contains("pub struct Type;"));
445        assert!(code.contains("pub const _1st: TypedColumn<String, Public>"));
446        assert!(code.contains("pub const r#match: TypedColumn<String, Public>"));
447        assert!(code.contains("TypedColumn::new(\"type\", \"1st\")"));
448    }
449}
450
451#[cfg(test)]
452mod migration_parser_tests {
453    use super::*;
454
455    #[test]
456    fn test_agent_contracts_migration_parses_all_columns() {
457        let sql = r#"
458CREATE TABLE agent_contracts (
459    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
460    agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
461    operator_id UUID NOT NULL REFERENCES operators(id) ON DELETE CASCADE,
462    pricing_model VARCHAR(20) NOT NULL CHECK (pricing_model IN ('commission', 'static_markup', 'net_rate')),
463    commission_percent DECIMAL(5,2),
464    static_markup DECIMAL(10,2),
465    is_active BOOLEAN DEFAULT true,
466    valid_from DATE,
467    valid_until DATE,
468    approved_by UUID REFERENCES users(id),
469    created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
470    updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
471    UNIQUE(agent_id, operator_id)
472);
473"#;
474
475        let mut schema = Schema::default();
476        schema.parse_sql_migration(sql);
477
478        let table = schema
479            .tables
480            .get("agent_contracts")
481            .expect("agent_contracts table should exist");
482
483        for col in &[
484            "id",
485            "agent_id",
486            "operator_id",
487            "pricing_model",
488            "commission_percent",
489            "static_markup",
490            "is_active",
491            "valid_from",
492            "valid_until",
493            "approved_by",
494            "created_at",
495            "updated_at",
496        ] {
497            assert!(
498                table.columns.contains_key(*col),
499                "Missing column: '{}'. Found: {:?}",
500                col,
501                table.columns.keys().collect::<Vec<_>>()
502            );
503        }
504    }
505
506    /// Regression test: column names that START with SQL keywords must parse correctly.
507    /// e.g., created_at starts with CREATE, primary_contact starts with PRIMARY, etc.
508    #[test]
509    fn test_keyword_prefixed_column_names_are_not_skipped() {
510        let sql = r#"
511CREATE TABLE edge_cases (
512    id UUID PRIMARY KEY,
513    created_at TIMESTAMPTZ NOT NULL,
514    created_by UUID,
515    primary_contact VARCHAR(255),
516    check_status VARCHAR(20),
517    unique_code VARCHAR(50),
518    foreign_ref UUID,
519    constraint_name VARCHAR(100),
520    PRIMARY KEY (id),
521    CHECK (check_status IN ('pending', 'active')),
522    UNIQUE (unique_code),
523    CONSTRAINT fk_ref FOREIGN KEY (foreign_ref) REFERENCES other(id)
524);
525"#;
526
527        let mut schema = Schema::default();
528        schema.parse_sql_migration(sql);
529
530        let table = schema
531            .tables
532            .get("edge_cases")
533            .expect("edge_cases table should exist");
534
535        // These column names start with SQL keywords — all must be found
536        for col in &[
537            "created_at",
538            "created_by",
539            "primary_contact",
540            "check_status",
541            "unique_code",
542            "foreign_ref",
543            "constraint_name",
544        ] {
545            assert!(
546                table.columns.contains_key(*col),
547                "Column '{}' should NOT be skipped just because it starts with a SQL keyword. Found: {:?}",
548                col,
549                table.columns.keys().collect::<Vec<_>>()
550            );
551        }
552
553        // These are constraint keywords, not columns — must NOT appear
554        // (PRIMARY KEY, CHECK, UNIQUE, CONSTRAINT lines should be skipped)
555        assert!(
556            !table.columns.contains_key("primary"),
557            "Constraint keyword 'PRIMARY' should not be treated as a column"
558        );
559    }
560}