use std::collections::{HashMap, HashSet};
use std::fs;
use crate::migrate::types::ColumnType;
use super::schema::Schema;
fn qail_type_to_rust(col_type: &ColumnType) -> &'static str {
match col_type {
ColumnType::Uuid => "uuid::Uuid",
ColumnType::Text | ColumnType::Varchar(_) => "String",
ColumnType::Int | ColumnType::Serial => "i32",
ColumnType::BigInt | ColumnType::BigSerial => "i64",
ColumnType::Bool => "bool",
ColumnType::Float => "f32",
ColumnType::Decimal(_) => "rust_decimal::Decimal",
ColumnType::Jsonb => "serde_json::Value",
ColumnType::Timestamp | ColumnType::Timestamptz => "chrono::DateTime<chrono::Utc>",
ColumnType::Date => "chrono::NaiveDate",
ColumnType::Time => "chrono::NaiveTime",
ColumnType::Bytea => "Vec<u8>",
ColumnType::Array(_) => "Vec<serde_json::Value>",
ColumnType::Enum { .. } => "String",
ColumnType::Range(_) => "String",
ColumnType::Interval => "String",
ColumnType::Cidr | ColumnType::Inet => "String",
ColumnType::MacAddr => "String",
}
}
fn to_rust_ident(name: &str) -> String {
escape_keyword(&sanitize_rust_ident(name))
}
fn to_struct_name(name: &str) -> String {
let mut out = String::new();
for part in name
.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|part| !part.is_empty())
{
let mut chars = part.chars();
if let Some(first) = chars.next() {
out.extend(first.to_uppercase());
out.push_str(chars.as_str());
}
}
if out.is_empty() {
out.push_str("QailGenerated");
}
if out
.chars()
.next()
.is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
{
out.insert_str(0, "Qail");
}
if is_rust_keyword(&out) {
out.insert_str(0, "Qail");
}
out
}
fn sanitize_rust_ident(name: &str) -> String {
let mut ident: String = name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '_' {
c
} else {
'_'
}
})
.collect();
if ident.is_empty() {
ident.push('_');
}
if ident
.chars()
.next()
.is_none_or(|c| !c.is_ascii_alphabetic() && c != '_')
{
ident.insert(0, '_');
}
ident
}
fn escape_keyword(name: &str) -> String {
if is_rust_keyword(name) {
format!("r#{}", name)
} else {
name.to_string()
}
}
fn is_rust_keyword(name: &str) -> bool {
const KEYWORDS: &[&str] = &[
"as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
"for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
"return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
"use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
"final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", "yield",
];
KEYWORDS.contains(&name)
}
fn rust_string_literal(value: &str) -> String {
format!("{value:?}")
}
pub fn generate_typed_schema(schema_path: &str, output_path: &str) -> Result<(), String> {
let schema = Schema::parse_file(schema_path)?;
let code = generate_schema_code(&schema);
fs::write(output_path, code)
.map_err(|e| format!("Failed to write schema module to '{}': {}", output_path, e))?;
Ok(())
}
pub fn generate_schema_code(schema: &Schema) -> String {
let mut code = String::new();
code.push_str("//! Auto-generated typed schema from schema.qail\n");
code.push_str("//! Do not edit manually - regenerate with `cargo build`\n\n");
code.push_str("#![allow(dead_code, non_upper_case_globals)]\n\n");
code.push_str("use qail_core::typed::{Table, TypedColumn, RelatedTo, Public, Protected};\n\n");
let mut tables: Vec<_> = schema.tables.values().collect();
tables.sort_by(|a, b| a.name.cmp(&b.name));
for table in &tables {
let mod_name = to_rust_ident(&table.name);
let struct_name = to_struct_name(&table.name);
code.push_str(&format!("/// Typed schema for `{}` table\n", table.name));
code.push_str(&format!("pub mod {} {{\n", mod_name));
code.push_str(" use super::*;\n\n");
code.push_str(&format!(" /// Table marker for `{}`\n", table.name));
code.push_str(" #[derive(Debug, Clone, Copy)]\n");
code.push_str(&format!(" pub struct {};\n\n", struct_name));
code.push_str(&format!(" impl Table for {} {{\n", struct_name));
code.push_str(&format!(
" fn table_name() -> &'static str {{ {} }}\n",
rust_string_literal(&table.name)
));
code.push_str(" }\n\n");
code.push_str(&format!(" impl From<{}> for String {{\n", struct_name));
code.push_str(&format!(
" fn from(_: {}) -> String {{ {}.to_string() }}\n",
struct_name,
rust_string_literal(&table.name)
));
code.push_str(" }\n\n");
code.push_str(&format!(" impl AsRef<str> for {} {{\n", struct_name));
code.push_str(&format!(
" fn as_ref(&self) -> &str {{ {} }}\n",
rust_string_literal(&table.name)
));
code.push_str(" }\n\n");
code.push_str(&format!(" /// The `{}` table\n", table.name));
code.push_str(&format!(
" pub const table: {} = {};\n\n",
struct_name, struct_name
));
let mut columns: Vec<_> = table.columns.iter().collect();
columns.sort_by(|a, b| a.0.cmp(b.0));
for (col_name, col_type) in columns {
let rust_type = qail_type_to_rust(col_type);
let col_ident = to_rust_ident(col_name);
let policy = table
.policies
.get(col_name)
.map(|s| s.as_str())
.unwrap_or("Public");
let rust_policy = if policy == "Protected" {
"Protected"
} else {
"Public"
};
code.push_str(&format!(
" /// Column `{}.{}` ({}) - {}\n",
table.name,
col_name,
col_type.to_pg_type(),
policy
));
code.push_str(&format!(
" pub const {}: TypedColumn<{}, {}> = TypedColumn::new({}, {});\n",
col_ident,
rust_type,
rust_policy,
rust_string_literal(&table.name),
rust_string_literal(col_name)
));
}
code.push_str("}\n\n");
}
code.push_str(
"// =============================================================================\n",
);
code.push_str("// Compile-Time Relationship Safety (RelatedTo impls)\n");
code.push_str(
"// =============================================================================\n\n",
);
let table_names: HashSet<&str> = tables.iter().map(|table| table.name.as_str()).collect();
let mut relation_impl_counts: HashMap<(&str, &str), usize> = HashMap::new();
for table in &tables {
for fk in &table.foreign_keys {
if !table_names.contains(fk.ref_table.as_str()) {
continue;
}
*relation_impl_counts
.entry((table.name.as_str(), fk.ref_table.as_str()))
.or_default() += 1;
*relation_impl_counts
.entry((fk.ref_table.as_str(), table.name.as_str()))
.or_default() += 1;
}
}
for table in &tables {
for fk in &table.foreign_keys {
if !table_names.contains(fk.ref_table.as_str()) {
continue;
}
let from_mod = to_rust_ident(&table.name);
let from_struct = to_struct_name(&table.name);
let to_mod = to_rust_ident(&fk.ref_table);
let to_struct = to_struct_name(&fk.ref_table);
if relation_impl_counts
.get(&(table.name.as_str(), fk.ref_table.as_str()))
.copied()
.unwrap_or_default()
== 1
{
code.push_str(&format!(
"/// {} has a foreign key to {} via {}.{}\n",
table.name, fk.ref_table, table.name, fk.column
));
code.push_str(&format!(
"impl RelatedTo<{}::{}> for {}::{} {{\n",
to_mod, to_struct, from_mod, from_struct
));
code.push_str(&format!(
" fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
rust_string_literal(&fk.column),
rust_string_literal(&fk.ref_column)
));
code.push_str("}\n\n");
}
if relation_impl_counts
.get(&(fk.ref_table.as_str(), table.name.as_str()))
.copied()
.unwrap_or_default()
== 1
{
code.push_str(&format!(
"/// {} is referenced by {} via {}.{}\n",
fk.ref_table, table.name, table.name, fk.column
));
code.push_str(&format!(
"impl RelatedTo<{}::{}> for {}::{} {{\n",
from_mod, from_struct, to_mod, to_struct
));
code.push_str(&format!(
" fn join_columns() -> (&'static str, &'static str) {{ ({}, {}) }}\n",
rust_string_literal(&fk.ref_column),
rust_string_literal(&fk.column)
));
code.push_str("}\n\n");
}
}
}
code
}
#[cfg(test)]
mod codegen_tests {
use super::*;
#[test]
fn test_generate_schema_code() {
let schema_content = r#"
table users {
id UUID primary_key
email TEXT not_null
age INT
}
table posts {
id UUID primary_key
user_id UUID ref:users.id
title TEXT
}
"#;
let schema = Schema::parse(schema_content).unwrap();
let code = generate_schema_code(&schema);
assert!(code.contains("pub mod users {"));
assert!(code.contains("pub mod posts {"));
assert!(code.contains("pub struct Users;"));
assert!(code.contains("pub struct Posts;"));
assert!(code.contains("pub const id: TypedColumn<uuid::Uuid, Public>"));
assert!(code.contains("pub const email: TypedColumn<String, Public>"));
assert!(code.contains("pub const age: TypedColumn<i32, Public>"));
assert!(code.contains("impl RelatedTo<users::Users> for posts::Posts"));
assert!(code.contains("impl RelatedTo<posts::Posts> for users::Users"));
}
#[test]
fn test_generate_protected_column() {
let schema_content = r#"
table secrets {
id UUID primary_key
token TEXT protected
}
"#;
let schema = Schema::parse(schema_content).unwrap();
let code = generate_schema_code(&schema);
assert!(code.contains("pub const token: TypedColumn<String, Protected>"));
}
#[test]
fn test_generate_schema_code_skips_ambiguous_related_to_impls() {
let schema_content = r#"
table users {
id UUID primary_key
}
table invoices {
id UUID primary_key
buyer_id UUID ref:users.id
seller_id UUID ref:users.id
}
"#;
let schema = Schema::parse(schema_content).unwrap();
let code = generate_schema_code(&schema);
assert!(code.contains("pub const buyer_id: TypedColumn<uuid::Uuid, Public>"));
assert!(code.contains("pub const seller_id: TypedColumn<uuid::Uuid, Public>"));
assert!(!code.contains("impl RelatedTo<users::Users> for invoices::Invoices"));
assert!(!code.contains("impl RelatedTo<invoices::Invoices> for users::Users"));
}
#[test]
fn test_generate_schema_code_skips_missing_target_related_to_impls() {
let schema_content = r#"
table posts {
id UUID primary_key
user_id UUID ref:users.id
}
"#;
let schema = Schema::parse(schema_content).unwrap();
let code = generate_schema_code(&schema);
assert!(code.contains("pub mod posts {"));
assert!(!code.contains("impl RelatedTo<users::Users> for posts::Posts"));
assert!(!code.contains("impl RelatedTo<posts::Posts> for users::Users"));
}
#[test]
fn test_generate_schema_code_sanitizes_rust_identifiers() {
let schema_content = r#"
table type {
1st TEXT
match TEXT
}
"#;
let schema = Schema::parse(schema_content).unwrap();
let code = generate_schema_code(&schema);
assert!(code.contains("pub mod r#type {"));
assert!(code.contains("pub struct Type;"));
assert!(code.contains("pub const _1st: TypedColumn<String, Public>"));
assert!(code.contains("pub const r#match: TypedColumn<String, Public>"));
assert!(code.contains("TypedColumn::new(\"type\", \"1st\")"));
}
}
#[cfg(test)]
mod migration_parser_tests {
use super::*;
#[test]
fn test_agent_contracts_migration_parses_all_columns() {
let sql = r#"
CREATE TABLE agent_contracts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
operator_id UUID NOT NULL REFERENCES operators(id) ON DELETE CASCADE,
pricing_model VARCHAR(20) NOT NULL CHECK (pricing_model IN ('commission', 'static_markup', 'net_rate')),
commission_percent DECIMAL(5,2),
static_markup DECIMAL(10,2),
is_active BOOLEAN DEFAULT true,
valid_from DATE,
valid_until DATE,
approved_by UUID REFERENCES users(id),
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
UNIQUE(agent_id, operator_id)
);
"#;
let mut schema = Schema::default();
schema.parse_sql_migration(sql);
let table = schema
.tables
.get("agent_contracts")
.expect("agent_contracts table should exist");
for col in &[
"id",
"agent_id",
"operator_id",
"pricing_model",
"commission_percent",
"static_markup",
"is_active",
"valid_from",
"valid_until",
"approved_by",
"created_at",
"updated_at",
] {
assert!(
table.columns.contains_key(*col),
"Missing column: '{}'. Found: {:?}",
col,
table.columns.keys().collect::<Vec<_>>()
);
}
}
#[test]
fn test_keyword_prefixed_column_names_are_not_skipped() {
let sql = r#"
CREATE TABLE edge_cases (
id UUID PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL,
created_by UUID,
primary_contact VARCHAR(255),
check_status VARCHAR(20),
unique_code VARCHAR(50),
foreign_ref UUID,
constraint_name VARCHAR(100),
PRIMARY KEY (id),
CHECK (check_status IN ('pending', 'active')),
UNIQUE (unique_code),
CONSTRAINT fk_ref FOREIGN KEY (foreign_ref) REFERENCES other(id)
);
"#;
let mut schema = Schema::default();
schema.parse_sql_migration(sql);
let table = schema
.tables
.get("edge_cases")
.expect("edge_cases table should exist");
for col in &[
"created_at",
"created_by",
"primary_contact",
"check_status",
"unique_code",
"foreign_ref",
"constraint_name",
] {
assert!(
table.columns.contains_key(*col),
"Column '{}' should NOT be skipped just because it starts with a SQL keyword. Found: {:?}",
col,
table.columns.keys().collect::<Vec<_>>()
);
}
assert!(
!table.columns.contains_key("primary"),
"Constraint keyword 'PRIMARY' should not be treated as a column"
);
}
}