Skip to main content

qail_core/parser/
schema.rs

1//! Schema file parser for `.qail` format.
2//!
3//! Parses schema definitions like:
4//! ```text
5//! table users (
6//!   id uuid primary_key,
7//!   email text not null,
8//!   name text,
9//!   created_at timestamp
10//! )
11//!
12//! policy users_isolation on users
13//!     for all
14//!     using (operator_id = current_setting('app.operator_id')::uuid)
15//! ```
16
17use nom::{
18    IResult, Parser,
19    branch::alt,
20    bytes::complete::{tag, tag_no_case, take_while1},
21    character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22    combinator::{map, opt},
23    multi::{many0, separated_list0},
24    sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{Expr, BinaryOp, Value as AstValue};
29use crate::migrate::policy::{RlsPolicy, PolicyTarget, PolicyPermissiveness};
30use crate::transpiler::policy::{create_policy_sql, alter_table_sql};
31use crate::migrate::alter::AlterTable;
32
33/// Schema containing all table definitions
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36    /// Schema format version (extracted from `-- qail: version=N` directive)
37    #[serde(default)]
38    pub version: Option<u32>,
39    pub tables: Vec<TableDef>,
40    /// RLS policies declared in the schema
41    #[serde(default)]
42    pub policies: Vec<RlsPolicy>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TableDef {
47    pub name: String,
48    pub columns: Vec<ColumnDef>,
49    /// Whether this table has RLS enabled
50    #[serde(default)]
51    pub enable_rls: bool,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ColumnDef {
56    pub name: String,
57    #[serde(rename = "type", alias = "typ")]
58    pub typ: String,
59    /// Type is an array (e.g., text[], uuid[])
60    #[serde(default)]
61    pub is_array: bool,
62    /// Type parameters (e.g., varchar(255) -> Some(vec!["255"]), decimal(10,2) -> Some(vec!["10", "2"]))
63    #[serde(default)]
64    pub type_params: Option<Vec<String>>,
65    #[serde(default)]
66    pub nullable: bool,
67    #[serde(default)]
68    pub primary_key: bool,
69    #[serde(default)]
70    pub unique: bool,
71    #[serde(default)]
72    pub references: Option<String>,
73    #[serde(default)]
74    pub default_value: Option<String>,
75    /// Check constraint expression
76    #[serde(default)]
77    pub check: Option<String>,
78    /// Is this a serial/auto-increment type
79    #[serde(default)]
80    pub is_serial: bool,
81}
82
83impl Default for ColumnDef {
84    fn default() -> Self {
85        Self {
86            name: String::new(),
87            typ: String::new(),
88            is_array: false,
89            type_params: None,
90            nullable: true,
91            primary_key: false,
92            unique: false,
93            references: None,
94            default_value: None,
95            check: None,
96            is_serial: false,
97        }
98    }
99}
100
101impl Schema {
102    /// Parse a schema from `.qail` format string
103    pub fn parse(input: &str) -> Result<Self, String> {
104        match parse_schema(input) {
105            Ok(("", schema)) => Ok(schema),
106            Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
107            Err(e) => Err(format!("Parse error: {:?}", e)),
108        }
109    }
110
111    /// Find a table by name
112    pub fn find_table(&self, name: &str) -> Option<&TableDef> {
113        self.tables
114            .iter()
115            .find(|t| t.name.eq_ignore_ascii_case(name))
116    }
117
118    /// Generate complete SQL for this schema: tables + RLS + policies.
119    pub fn to_sql(&self) -> String {
120        let mut parts = Vec::new();
121
122        for table in &self.tables {
123            parts.push(table.to_ddl());
124
125            if table.enable_rls {
126                let alter = AlterTable::new(&table.name).enable_rls().force_rls();
127                for stmt in alter_table_sql(&alter) {
128                    parts.push(stmt);
129                }
130            }
131        }
132
133        for policy in &self.policies {
134            parts.push(create_policy_sql(policy));
135        }
136
137        parts.join(";\n\n") + ";"
138    }
139
140    /// Export schema to JSON string (for qail-macros compatibility)
141    pub fn to_json(&self) -> Result<String, String> {
142        serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
143    }
144
145    /// Import schema from JSON string
146    pub fn from_json(json: &str) -> Result<Self, String> {
147        serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
148    }
149
150    /// Load schema from a .qail file
151    pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
152        let content =
153            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
154
155        if content.trim().starts_with('{') {
156            Self::from_json(&content)
157        } else {
158            Self::parse(&content)
159        }
160    }
161}
162
163impl TableDef {
164    /// Find a column by name
165    pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
166        self.columns
167            .iter()
168            .find(|c| c.name.eq_ignore_ascii_case(name))
169    }
170
171    /// Generate CREATE TABLE IF NOT EXISTS SQL (AST-native DDL).
172    pub fn to_ddl(&self) -> String {
173        let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
174
175        let mut col_defs = Vec::new();
176        for col in &self.columns {
177            let mut line = format!("    {}", col.name);
178
179            // Type with params
180            let mut typ = col.typ.to_uppercase();
181            if let Some(params) = &col.type_params {
182                typ = format!("{}({})", typ, params.join(", "));
183            }
184            if col.is_array {
185                typ.push_str("[]");
186            }
187            line.push_str(&format!(" {}", typ));
188
189            // Constraints
190            if col.primary_key {
191                line.push_str(" PRIMARY KEY");
192            }
193            if !col.nullable && !col.primary_key && !col.is_serial {
194                line.push_str(" NOT NULL");
195            }
196            if col.unique && !col.primary_key {
197                line.push_str(" UNIQUE");
198            }
199            if let Some(ref default) = col.default_value {
200                line.push_str(&format!(" DEFAULT {}", default));
201            }
202            if let Some(ref refs) = col.references {
203                line.push_str(&format!(" REFERENCES {}", refs));
204            }
205            if let Some(ref check) = col.check {
206                line.push_str(&format!(" CHECK({})", check));
207            }
208
209            col_defs.push(line);
210        }
211
212        sql.push_str(&col_defs.join(",\n"));
213        sql.push_str("\n)");
214        sql
215    }
216}
217
218// =============================================================================
219// Parsing Combinators
220// =============================================================================
221
222/// Parse identifier (table/column name)
223fn identifier(input: &str) -> IResult<&str, &str> {
224    take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
225}
226
227/// Skip whitespace and comments
228fn ws_and_comments(input: &str) -> IResult<&str, ()> {
229    let (input, _) = many0(alt((
230        map(multispace1, |_| ()),
231        map((tag("--"), not_line_ending), |_| ()),
232    )))
233    .parse(input)?;
234    Ok((input, ()))
235}
236
237struct TypeInfo {
238    name: String,
239    params: Option<Vec<String>>,
240    is_array: bool,
241    is_serial: bool,
242}
243
244/// Parse column type with optional params and array suffix
245/// Handles: varchar(255), decimal(10,2), text[], serial, bigserial
246fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
247    let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
248
249    let (input, params) = if input.starts_with('(') {
250        let paren_start = 1;
251        let mut paren_end = paren_start;
252        for (i, c) in input[paren_start..].char_indices() {
253            if c == ')' {
254                paren_end = paren_start + i;
255                break;
256            }
257        }
258        let param_str = &input[paren_start..paren_end];
259        let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
260        (&input[paren_end + 1..], Some(params))
261    } else {
262        (input, None)
263    };
264
265    let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
266        (stripped, true)
267    } else {
268        (input, false)
269    };
270
271    let lower = type_name.to_lowercase();
272    let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
273
274    Ok((
275        input,
276        TypeInfo {
277            name: lower,
278            params,
279            is_array,
280            is_serial,
281        },
282    ))
283}
284
285/// Parse constraint text until comma or closing paren (handling nested parens)
286fn constraint_text(input: &str) -> IResult<&str, &str> {
287    let mut paren_depth = 0;
288    let mut end = 0;
289
290    for (i, c) in input.char_indices() {
291        match c {
292            '(' => paren_depth += 1,
293            ')' => {
294                if paren_depth == 0 {
295                    break; // End at column-level closing paren
296                }
297                paren_depth -= 1;
298            }
299            ',' if paren_depth == 0 => break,
300            '\n' | '\r' if paren_depth == 0 => break,
301            _ => {}
302        }
303        end = i + c.len_utf8();
304    }
305
306    if end == 0 {
307        Err(nom::Err::Error(nom::error::Error::new(
308            input,
309            nom::error::ErrorKind::TakeWhile1,
310        )))
311    } else {
312        Ok((&input[end..], &input[..end]))
313    }
314}
315
316/// Parse a single column definition
317fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
318    let (input, _) = ws_and_comments(input)?;
319    let (input, name) = identifier(input)?;
320    let (input, _) = multispace1(input)?;
321    let (input, type_info) = parse_type_info(input)?;
322
323    let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
324
325    let mut col = ColumnDef {
326        name: name.to_string(),
327        typ: type_info.name,
328        is_array: type_info.is_array,
329        type_params: type_info.params,
330        is_serial: type_info.is_serial,
331        nullable: !type_info.is_serial, // Serial types are implicitly not null
332        ..Default::default()
333    };
334
335    if let Some(constraints) = constraint_str {
336        let lower = constraints.to_lowercase();
337
338        if lower.contains("primary_key") || lower.contains("primary key") {
339            col.primary_key = true;
340            col.nullable = false;
341        }
342        if lower.contains("not_null") || lower.contains("not null") {
343            col.nullable = false;
344        }
345        if lower.contains("unique") {
346            col.unique = true;
347        }
348
349        if let Some(idx) = lower.find("references ") {
350            let rest = &constraints[idx + 11..];
351            // Find end (space or end of string), but handle nested parens
352            let mut paren_depth = 0;
353            let mut end = rest.len();
354            for (i, c) in rest.char_indices() {
355                match c {
356                    '(' => paren_depth += 1,
357                    ')' => {
358                        if paren_depth == 0 {
359                            end = i;
360                            break;
361                        }
362                        paren_depth -= 1;
363                    }
364                    c if c.is_whitespace() && paren_depth == 0 => {
365                        end = i;
366                        break;
367                    }
368                    _ => {}
369                }
370            }
371            col.references = Some(rest[..end].to_string());
372        }
373
374        if let Some(idx) = lower.find("default ") {
375            let rest = &constraints[idx + 8..];
376            let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
377            col.default_value = Some(rest[..end].to_string());
378        }
379
380        if let Some(idx) = lower.find("check(") {
381            let rest = &constraints[idx + 6..];
382            // Find matching closing paren
383            let mut depth = 1;
384            let mut end = rest.len();
385            for (i, c) in rest.char_indices() {
386                match c {
387                    '(' => depth += 1,
388                    ')' => {
389                        depth -= 1;
390                        if depth == 0 {
391                            end = i;
392                            break;
393                        }
394                    }
395                    _ => {}
396                }
397            }
398            col.check = Some(rest[..end].to_string());
399        }
400    }
401
402    Ok((input, col))
403}
404
405/// Parse column list: (col1 type, col2 type, ...)
406fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
407    let (input, _) = ws_and_comments(input)?;
408    let (input, _) = char('(').parse(input)?;
409    let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
410    let (input, _) = ws_and_comments(input)?;
411    let (input, _) = char(')').parse(input)?;
412
413    Ok((input, columns))
414}
415
416/// Parse a table definition
417fn parse_table(input: &str) -> IResult<&str, TableDef> {
418    let (input, _) = ws_and_comments(input)?;
419    let (input, _) = tag_no_case("table").parse(input)?;
420    let (input, _) = multispace1(input)?;
421    let (input, name) = identifier(input)?;
422    let (input, columns) = parse_column_list(input)?;
423
424    // Optional enable_rls annotation after closing paren
425    let (input, _) = ws_and_comments(input)?;
426    let enable_rls = if let Ok((rest, _)) =
427        tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
428    {
429        return Ok((
430            rest,
431            TableDef {
432                name: name.to_string(),
433                columns,
434                enable_rls: true,
435            },
436        ));
437    } else {
438        false
439    };
440
441    Ok((
442        input,
443        TableDef {
444            name: name.to_string(),
445            columns,
446            enable_rls,
447        },
448    ))
449}
450
451// =============================================================================
452// Policy Parsing
453// =============================================================================
454
455/// A schema item is either a table or a policy.
456enum SchemaItem {
457    Table(TableDef),
458    Policy(RlsPolicy),
459}
460
461/// Parse a policy definition.
462///
463/// Syntax:
464/// ```text
465/// policy <name> on <table>
466///     [for all|select|insert|update|delete]
467///     [restrictive]
468///     [to <role>]
469///     [using (<expr>)]
470///     [with check (<expr>)]
471/// ```
472fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
473    let (input, _) = ws_and_comments(input)?;
474    let (input, _) = tag_no_case("policy").parse(input)?;
475    let (input, _) = multispace1(input)?;
476    let (input, name) = identifier(input)?;
477    let (input, _) = multispace1(input)?;
478    let (input, _) = tag_no_case("on").parse(input)?;
479    let (input, _) = multispace1(input)?;
480    let (input, table) = identifier(input)?;
481
482    let mut policy = RlsPolicy::create(name, table);
483
484    // Parse optional clauses in any order
485    let mut remaining = input;
486    loop {
487        let (input, _) = ws_and_comments(remaining)?;
488
489        // for all|select|insert|update|delete
490        if let Ok((rest, _)) =
491            tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input)
492        {
493            let (rest, _) = multispace1(rest)?;
494            let (rest, target) = alt((
495                map(tag_no_case("all"), |_| PolicyTarget::All),
496                map(tag_no_case("select"), |_| PolicyTarget::Select),
497                map(tag_no_case("insert"), |_| PolicyTarget::Insert),
498                map(tag_no_case("update"), |_| PolicyTarget::Update),
499                map(tag_no_case("delete"), |_| PolicyTarget::Delete),
500            ))
501            .parse(rest)?;
502            policy.target = target;
503            remaining = rest;
504            continue;
505        }
506
507        // restrictive
508        if let Ok((rest, _)) =
509            tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
510        {
511            policy.permissiveness = PolicyPermissiveness::Restrictive;
512            remaining = rest;
513            continue;
514        }
515
516        // to <role>
517        if let Ok((rest, _)) =
518            tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input)
519        {
520            // Make sure it's not "to_sql" or similar — needs whitespace after
521            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
522                let (rest, role) = identifier(rest)?;
523                policy.role = Some(role.to_string());
524                remaining = rest;
525                continue;
526            }
527        }
528
529        // with check (<expr>)
530        if let Ok((rest, _)) =
531            tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input)
532        {
533            let (rest, _) = multispace1(rest)?;
534            if let Ok((rest, _)) =
535                tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
536            {
537                let (rest, _) = nom_ws0(rest)?;
538                let (rest, _) = char('(').parse(rest)?;
539                let (rest, _) = nom_ws0(rest)?;
540                let (rest, expr) = parse_policy_expr(rest)?;
541                let (rest, _) = nom_ws0(rest)?;
542                let (rest, _) = char(')').parse(rest)?;
543                policy.with_check = Some(expr);
544                remaining = rest;
545                continue;
546            }
547        }
548
549        // using (<expr>)
550        if let Ok((rest, _)) =
551            tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input)
552        {
553            let (rest, _) = nom_ws0(rest)?;
554            let (rest, _) = char('(').parse(rest)?;
555            let (rest, _) = nom_ws0(rest)?;
556            let (rest, expr) = parse_policy_expr(rest)?;
557            let (rest, _) = nom_ws0(rest)?;
558            let (rest, _) = char(')').parse(rest)?;
559            policy.using = Some(expr);
560            remaining = rest;
561            continue;
562        }
563
564        // No more clauses matched
565        remaining = input;
566        break;
567    }
568
569    Ok((remaining, policy))
570}
571
572/// Parse a policy expression: `left op right [AND/OR left op right ...]`
573///
574/// Produces typed `Expr::Binary` AST nodes — no raw SQL.
575///
576/// Handles:
577/// - `column = value`
578/// - `column = function('arg')::type`   (function call + cast)
579/// - `expr AND expr`, `expr OR expr`
580fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
581    let (input, first) = parse_policy_comparison(input)?;
582
583    // Check for AND/OR chaining
584    let mut result = first;
585    let mut remaining = input;
586    loop {
587        let (input, _) = nom_ws0(remaining)?;
588
589        if let Ok((rest, _)) =
590            tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
591        {
592            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
593                let (rest, right) = parse_policy_comparison(rest)?;
594                result = Expr::Binary {
595                    left: Box::new(result),
596                    op: BinaryOp::Or,
597                    right: Box::new(right),
598                    alias: None,
599                };
600                remaining = rest;
601                continue;
602            }
603        }
604
605        if let Ok((rest, _)) =
606            tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
607        {
608            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
609                let (rest, right) = parse_policy_comparison(rest)?;
610                result = Expr::Binary {
611                    left: Box::new(result),
612                    op: BinaryOp::And,
613                    right: Box::new(right),
614                    alias: None,
615                };
616                remaining = rest;
617                continue;
618            }
619        }
620
621        remaining = input;
622        break;
623    }
624
625    Ok((remaining, result))
626}
627
628/// Parse a single comparison: `atom op atom`
629fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
630    let (input, left) = parse_policy_atom(input)?;
631    let (input, _) = nom_ws0(input)?;
632
633    // Try to parse comparison operator
634    if let Ok((rest, op)) = parse_cmp_op(input) {
635        let (rest, _) = nom_ws0(rest)?;
636        let (rest, right) = parse_policy_atom(rest)?;
637        return Ok((
638            rest,
639            Expr::Binary {
640                left: Box::new(left),
641                op,
642                right: Box::new(right),
643                alias: None,
644            },
645        ));
646    }
647
648    // No operator — just an atom
649    Ok((input, left))
650}
651
652/// Parse comparison operators: =, !=, <>, >=, <=, >, <
653fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
654    alt((
655        map(tag(">="), |_| BinaryOp::Gte),
656        map(tag("<="), |_| BinaryOp::Lte),
657        map(tag("<>"), |_| BinaryOp::Ne),
658        map(tag("!="), |_| BinaryOp::Ne),
659        map(tag("="), |_| BinaryOp::Eq),
660        map(tag(">"), |_| BinaryOp::Gt),
661        map(tag("<"), |_| BinaryOp::Lt),
662    ))
663    .parse(input)
664}
665
666/// Parse a policy expression atom:
667/// - identifier  (column name)
668/// - function_call(args)  with optional ::cast
669/// - 'string literal'
670/// - numeric literal
671/// - true/false
672/// - (sub_expr)  grouped
673fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
674    alt((
675        parse_policy_grouped,
676        parse_policy_bool,
677        parse_policy_string,
678        parse_policy_number,
679        parse_policy_func_or_ident, // function call or plain identifier, with optional ::cast
680    ))
681    .parse(input)
682}
683
684/// Parse grouped expression in parens
685fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
686    let (input, _) = char('(').parse(input)?;
687    let (input, _) = nom_ws0(input)?;
688    let (input, expr) = parse_policy_expr(input)?;
689    let (input, _) = nom_ws0(input)?;
690    let (input, _) = char(')').parse(input)?;
691    Ok((input, expr))
692}
693
694/// Parse true / false
695fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
696    alt((
697        map(tag_no_case("true"), |_| {
698            Expr::Literal(AstValue::Bool(true))
699        }),
700        map(tag_no_case("false"), |_| {
701            Expr::Literal(AstValue::Bool(false))
702        }),
703    ))
704    .parse(input)
705}
706
707/// Parse a 'string literal'
708fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
709    let (input, _) = char('\'').parse(input)?;
710    let mut end = 0;
711    for (i, c) in input.char_indices() {
712        if c == '\'' {
713            end = i;
714            break;
715        }
716    }
717    let content = &input[..end];
718    let rest = &input[end + 1..];
719    Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
720}
721
722/// Parse numeric literal
723fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
724    let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
725    // Make sure it starts with digit (not just '.')
726    if digits.starts_with('.') || digits.is_empty() {
727        return Err(nom::Err::Error(nom::error::Error::new(
728            input,
729            nom::error::ErrorKind::Digit,
730        )));
731    }
732    if let Ok(n) = digits.parse::<i64>() {
733        Ok((input, Expr::Literal(AstValue::Int(n))))
734    } else if let Ok(f) = digits.parse::<f64>() {
735        Ok((input, Expr::Literal(AstValue::Float(f))))
736    } else {
737        Ok((input, Expr::Named(digits.to_string())))
738    }
739}
740
741/// Parse function call or identifier, with optional ::cast
742fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
743    let (input, name) = identifier(input)?;
744
745    // Check for function call: name(
746    let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
747        // Parse args
748        let (rest, _) = nom_ws0(rest)?;
749        let (rest, args) = separated_list0(
750            (nom_ws0, char(','), nom_ws0),
751            parse_policy_atom,
752        )
753        .parse(rest)?;
754        let (rest, _) = nom_ws0(rest)?;
755        let (rest, _) = char(')').parse(rest)?;
756        let input = rest;
757        (input, Expr::FunctionCall {
758            name: name.to_string(),
759            args,
760            alias: None,
761        })
762    } else {
763        (input, Expr::Named(name.to_string()))
764    };
765
766    // Check for ::cast
767    if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
768        let (rest, cast_type) = identifier(rest)?;
769        expr = (
770            rest,
771            Expr::Cast {
772                expr: Box::new(expr.1),
773                target_type: cast_type.to_string(),
774                alias: None,
775            },
776        );
777    }
778
779    Ok(expr)
780}
781
782/// Parse a single schema item: either a table or a policy
783fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
784    let (input, _) = ws_and_comments(input)?;
785
786    // Try policy first (since "policy" is a distinct keyword)
787    if let Ok((rest, policy)) = parse_policy(input) {
788        return Ok((rest, SchemaItem::Policy(policy)));
789    }
790
791    // Otherwise parse table
792    let (rest, table) = parse_table(input)?;
793    Ok((rest, SchemaItem::Table(table)))
794}
795
796/// Parse complete schema file
797fn parse_schema(input: &str) -> IResult<&str, Schema> {
798    // Extract version directive before parsing
799    let version = extract_version_directive(input);
800
801    let (input, items) = many0(parse_schema_item).parse(input)?;
802    let (input, _) = ws_and_comments(input)?;
803
804    let mut tables = Vec::new();
805    let mut policies = Vec::new();
806    for item in items {
807        match item {
808            SchemaItem::Table(t) => tables.push(t),
809            SchemaItem::Policy(p) => policies.push(p),
810        }
811    }
812
813    Ok((input, Schema { version, tables, policies }))
814}
815
816/// Extract version from `-- qail: version=N` directive
817fn extract_version_directive(input: &str) -> Option<u32> {
818    for line in input.lines() {
819        let line = line.trim();
820        if let Some(rest) = line.strip_prefix("-- qail:") {
821            let rest = rest.trim();
822            if let Some(version_str) = rest.strip_prefix("version=") {
823                return version_str.trim().parse().ok();
824            }
825        }
826    }
827    None
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833
834    #[test]
835    fn test_parse_simple_table() {
836        let input = r#"
837            table users (
838                id uuid primary_key,
839                email text not null,
840                name text
841            )
842        "#;
843
844        let schema = Schema::parse(input).expect("parse failed");
845        assert_eq!(schema.tables.len(), 1);
846
847        let users = &schema.tables[0];
848        assert_eq!(users.name, "users");
849        assert_eq!(users.columns.len(), 3);
850
851        let id = &users.columns[0];
852        assert_eq!(id.name, "id");
853        assert_eq!(id.typ, "uuid");
854        assert!(id.primary_key);
855        assert!(!id.nullable);
856
857        let email = &users.columns[1];
858        assert_eq!(email.name, "email");
859        assert!(!email.nullable);
860
861        let name = &users.columns[2];
862        assert!(name.nullable);
863    }
864
865    #[test]
866    fn test_parse_multiple_tables() {
867        let input = r#"
868            -- Users table
869            table users (
870                id uuid primary_key,
871                email text not null unique
872            )
873            
874            -- Orders table
875            table orders (
876                id uuid primary_key,
877                user_id uuid references users(id),
878                total i64 not null default 0
879            )
880        "#;
881
882        let schema = Schema::parse(input).expect("parse failed");
883        assert_eq!(schema.tables.len(), 2);
884
885        let orders = schema.find_table("orders").expect("orders not found");
886        let user_id = orders.find_column("user_id").expect("user_id not found");
887        assert_eq!(user_id.references, Some("users(id)".to_string()));
888
889        let total = orders.find_column("total").expect("total not found");
890        assert_eq!(total.default_value, Some("0".to_string()));
891    }
892
893    #[test]
894    fn test_parse_comments() {
895        let input = r#"
896            -- This is a comment
897            table foo (
898                bar text
899            )
900        "#;
901
902        let schema = Schema::parse(input).expect("parse failed");
903        assert_eq!(schema.tables.len(), 1);
904    }
905
906    #[test]
907    fn test_array_types() {
908        let input = r#"
909            table products (
910                id uuid primary_key,
911                tags text[],
912                prices decimal[]
913            )
914        "#;
915
916        let schema = Schema::parse(input).expect("parse failed");
917        let products = &schema.tables[0];
918
919        let tags = products.find_column("tags").expect("tags not found");
920        assert_eq!(tags.typ, "text");
921        assert!(tags.is_array);
922
923        let prices = products.find_column("prices").expect("prices not found");
924        assert!(prices.is_array);
925    }
926
927    #[test]
928    fn test_type_params() {
929        let input = r#"
930            table items (
931                id serial primary_key,
932                name varchar(255) not null,
933                price decimal(10,2),
934                code varchar(50) unique
935            )
936        "#;
937
938        let schema = Schema::parse(input).expect("parse failed");
939        let items = &schema.tables[0];
940
941        let id = items.find_column("id").expect("id not found");
942        assert!(id.is_serial);
943        assert!(!id.nullable); // Serial is implicitly not null
944
945        let name = items.find_column("name").expect("name not found");
946        assert_eq!(name.typ, "varchar");
947        assert_eq!(name.type_params, Some(vec!["255".to_string()]));
948
949        let price = items.find_column("price").expect("price not found");
950        assert_eq!(
951            price.type_params,
952            Some(vec!["10".to_string(), "2".to_string()])
953        );
954
955        let code = items.find_column("code").expect("code not found");
956        assert!(code.unique);
957    }
958
959    #[test]
960    fn test_check_constraint() {
961        let input = r#"
962            table employees (
963                id uuid primary_key,
964                age i32 check(age >= 18),
965                salary decimal check(salary > 0)
966            )
967        "#;
968
969        let schema = Schema::parse(input).expect("parse failed");
970        let employees = &schema.tables[0];
971
972        let age = employees.find_column("age").expect("age not found");
973        assert_eq!(age.check, Some("age >= 18".to_string()));
974
975        let salary = employees.find_column("salary").expect("salary not found");
976        assert_eq!(salary.check, Some("salary > 0".to_string()));
977    }
978
979    #[test]
980    fn test_version_directive() {
981        let input = r#"
982            -- qail: version=1
983            table users (
984                id uuid primary_key
985            )
986        "#;
987
988        let schema = Schema::parse(input).expect("parse failed");
989        assert_eq!(schema.version, Some(1));
990        assert_eq!(schema.tables.len(), 1);
991
992        // Without version directive
993        let input_no_version = r#"
994            table items (
995                id uuid primary_key
996            )
997        "#;
998        let schema2 = Schema::parse(input_no_version).expect("parse failed");
999        assert_eq!(schema2.version, None);
1000    }
1001
1002    // =========================================================================
1003    // Policy + enable_rls tests
1004    // =========================================================================
1005
1006    #[test]
1007    fn test_enable_rls_table() {
1008        let input = r#"
1009            table orders (
1010                id uuid primary_key,
1011                operator_id uuid not null
1012            ) enable_rls
1013        "#;
1014
1015        let schema = Schema::parse(input).expect("parse failed");
1016        assert_eq!(schema.tables.len(), 1);
1017        assert!(schema.tables[0].enable_rls);
1018    }
1019
1020    #[test]
1021    fn test_parse_policy_basic() {
1022        let input = r#"
1023            table orders (
1024                id uuid primary_key,
1025                operator_id uuid not null
1026            ) enable_rls
1027
1028            policy orders_isolation on orders
1029                for all
1030                using (operator_id = current_setting('app.current_operator_id')::uuid)
1031        "#;
1032
1033        let schema = Schema::parse(input).expect("parse failed");
1034        assert_eq!(schema.tables.len(), 1);
1035        assert_eq!(schema.policies.len(), 1);
1036
1037        let policy = &schema.policies[0];
1038        assert_eq!(policy.name, "orders_isolation");
1039        assert_eq!(policy.table, "orders");
1040        assert_eq!(policy.target, PolicyTarget::All);
1041        assert!(policy.using.is_some());
1042
1043        // Verify the expression is a typed Binary, not raw SQL
1044        match policy.using.as_ref().unwrap() {
1045            Expr::Binary { left, op, right, .. } => {
1046                assert_eq!(*op, BinaryOp::Eq);
1047                match left.as_ref() {
1048                    Expr::Named(n) => assert_eq!(n, "operator_id"),
1049                    _ => panic!("Expected Named, got {:?}", left),
1050                }
1051                match right.as_ref() {
1052                    Expr::Cast { target_type, expr, .. } => {
1053                        assert_eq!(target_type, "uuid");
1054                        match expr.as_ref() {
1055                            Expr::FunctionCall { name, args, .. } => {
1056                                assert_eq!(name, "current_setting");
1057                                assert_eq!(args.len(), 1);
1058                            }
1059                            _ => panic!("Expected FunctionCall"),
1060                        }
1061                    }
1062                    _ => panic!("Expected Cast, got {:?}", right),
1063                }
1064            }
1065            _ => panic!("Expected Binary"),
1066        }
1067    }
1068
1069    #[test]
1070    fn test_parse_policy_with_check() {
1071        let input = r#"
1072            table orders (
1073                id uuid primary_key
1074            )
1075
1076            policy orders_write on orders
1077                for insert
1078                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1079        "#;
1080
1081        let schema = Schema::parse(input).expect("parse failed");
1082        let policy = &schema.policies[0];
1083        assert_eq!(policy.target, PolicyTarget::Insert);
1084        assert!(policy.with_check.is_some());
1085        assert!(policy.using.is_none());
1086    }
1087
1088    #[test]
1089    fn test_parse_policy_restrictive_with_role() {
1090        let input = r#"
1091            table secrets (
1092                id uuid primary_key
1093            )
1094
1095            policy admin_only on secrets
1096                for select
1097                restrictive
1098                to app_user
1099                using (current_setting('app.is_super_admin')::boolean = true)
1100        "#;
1101
1102        let schema = Schema::parse(input).expect("parse failed");
1103        let policy = &schema.policies[0];
1104        assert_eq!(policy.target, PolicyTarget::Select);
1105        assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1106        assert_eq!(policy.role.as_deref(), Some("app_user"));
1107        assert!(policy.using.is_some());
1108    }
1109
1110    #[test]
1111    fn test_parse_policy_or_expr() {
1112        let input = r#"
1113            table orders (
1114                id uuid primary_key
1115            )
1116
1117            policy tenant_or_admin on orders
1118                for all
1119                using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1120        "#;
1121
1122        let schema = Schema::parse(input).expect("parse failed");
1123        let policy = &schema.policies[0];
1124
1125        match policy.using.as_ref().unwrap() {
1126            Expr::Binary { op: BinaryOp::Or, .. } => {}
1127            e => panic!("Expected Binary OR, got {:?}", e),
1128        }
1129    }
1130
1131    #[test]
1132    fn test_schema_to_sql() {
1133        let input = r#"
1134            table orders (
1135                id uuid primary_key,
1136                operator_id uuid not null
1137            ) enable_rls
1138
1139            policy orders_isolation on orders
1140                for all
1141                using (operator_id = current_setting('app.current_operator_id')::uuid)
1142        "#;
1143
1144        let schema = Schema::parse(input).expect("parse failed");
1145        let sql = schema.to_sql();
1146        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1147        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1148        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1149        assert!(sql.contains("CREATE POLICY"));
1150        assert!(sql.contains("orders_isolation"));
1151        assert!(sql.contains("FOR ALL"));
1152    }
1153
1154    #[test]
1155    fn test_multiple_policies() {
1156        let input = r#"
1157            table orders (
1158                id uuid primary_key,
1159                operator_id uuid not null
1160            ) enable_rls
1161
1162            policy orders_read on orders
1163                for select
1164                using (operator_id = current_setting('app.current_operator_id')::uuid)
1165
1166            policy orders_write on orders
1167                for insert
1168                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1169        "#;
1170
1171        let schema = Schema::parse(input).expect("parse failed");
1172        assert_eq!(schema.policies.len(), 2);
1173        assert_eq!(schema.policies[0].name, "orders_read");
1174        assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1175        assert_eq!(schema.policies[1].name, "orders_write");
1176        assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1177    }
1178}