Skip to main content

qail_core/parser/
schema.rs

1//! Schema file parser for `.qail` format.
2//!
3//! Parses schema definitions like:
4//! ```text
5//! table users (
6//!   id uuid primary_key,
7//!   email text not null,
8//!   name text,
9//!   created_at timestamp
10//! )
11//!
12//! policy users_isolation on users
13//!     for all
14//!     using (operator_id = current_setting('app.operator_id')::uuid)
15//! ```
16
17use nom::{
18    IResult, Parser,
19    branch::alt,
20    bytes::complete::{tag, tag_no_case, take_while1},
21    character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22    combinator::{map, opt},
23    multi::{many0, separated_list0},
24    sequence::preceded,
25};
26
27use crate::ast::{BinaryOp, Expr, Value as AstValue};
28use crate::migrate::alter::AlterTable;
29use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
30use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
31
32/// Schema containing all table definitions
33#[derive(Debug, Clone, Default)]
34pub struct Schema {
35    /// Schema format version (extracted from `-- qail: version=N` directive)
36    pub version: Option<u32>,
37    /// Table definitions.
38    pub tables: Vec<TableDef>,
39    /// RLS policies declared in the schema
40    pub policies: Vec<RlsPolicy>,
41    /// Indexes declared in the schema
42    pub indexes: Vec<IndexDef>,
43}
44
45/// Index definition parsed from `index <name> on <table> (<columns>) [unique]`
46#[derive(Debug, Clone)]
47pub struct IndexDef {
48    /// Index name.
49    pub name: String,
50    /// Table this index belongs to.
51    pub table: String,
52    /// Columns included in the index.
53    pub columns: Vec<String>,
54    /// Whether this is a UNIQUE index.
55    pub unique: bool,
56}
57
58impl IndexDef {
59    /// Generate `CREATE INDEX IF NOT EXISTS` SQL.
60    pub fn to_sql(&self) -> String {
61        let unique = if self.unique { " UNIQUE" } else { "" };
62        format!(
63            "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64            unique,
65            self.name,
66            self.table,
67            self.columns.join(", ")
68        )
69    }
70}
71
72/// Table definition parsed from a `.qail` schema file.
73#[derive(Debug, Clone)]
74pub struct TableDef {
75    /// Table name.
76    pub name: String,
77    /// Column definitions.
78    pub columns: Vec<ColumnDef>,
79    /// Whether this table has RLS enabled.
80    pub enable_rls: bool,
81}
82
83/// Column definition parsed from a `.qail` schema file.
84#[derive(Debug, Clone)]
85pub struct ColumnDef {
86    /// Column name.
87    pub name: String,
88    /// SQL data type (lowercased).
89    pub typ: String,
90    /// Type is an array (e.g., text[], uuid[]).
91    pub is_array: bool,
92    /// Type parameters (e.g., varchar(255) → Some(vec!["255"]), decimal(10,2) → Some(vec!["10", "2"])).
93    pub type_params: Option<Vec<String>>,
94    /// Whether the column accepts NULL.
95    pub nullable: bool,
96    /// Whether the column is a primary key.
97    pub primary_key: bool,
98    /// Whether the column has a UNIQUE constraint.
99    pub unique: bool,
100    /// Foreign key reference (e.g. "users(id)").
101    pub references: Option<String>,
102    /// Default value expression.
103    pub default_value: Option<String>,
104    /// Check constraint expression
105    pub check: Option<String>,
106    /// Is this a serial/auto-increment type
107    pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111    fn default() -> Self {
112        Self {
113            name: String::new(),
114            typ: String::new(),
115            is_array: false,
116            type_params: None,
117            nullable: true,
118            primary_key: false,
119            unique: false,
120            references: None,
121            default_value: None,
122            check: None,
123            is_serial: false,
124        }
125    }
126}
127
128impl Schema {
129    /// Parse a schema from `.qail` format string
130    pub fn parse(input: &str) -> Result<Self, String> {
131        match parse_schema(input) {
132            Ok(("", schema)) => Ok(schema),
133            Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134            Err(e) => Err(format!("Parse error: {:?}", e)),
135        }
136    }
137
138    /// Find a table by name
139    pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140        self.tables
141            .iter()
142            .find(|t| t.name.eq_ignore_ascii_case(name))
143    }
144
145    /// Generate complete SQL for this schema: tables + RLS + policies + indexes.
146    pub fn to_sql(&self) -> String {
147        let mut parts = Vec::new();
148
149        for table in &self.tables {
150            parts.push(table.to_ddl());
151
152            if table.enable_rls {
153                let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154                for stmt in alter_table_sql(&alter) {
155                    parts.push(stmt);
156                }
157            }
158        }
159
160        for idx in &self.indexes {
161            parts.push(idx.to_sql());
162        }
163
164        for policy in &self.policies {
165            parts.push(create_policy_sql(policy));
166        }
167
168        parts.join(";\n\n") + ";"
169    }
170
171    /// Load schema from a .qail file
172    pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
173        let content =
174            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
175        Self::parse(&content)
176    }
177}
178
179impl TableDef {
180    /// Find a column by name
181    pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
182        self.columns
183            .iter()
184            .find(|c| c.name.eq_ignore_ascii_case(name))
185    }
186
187    /// Generate CREATE TABLE IF NOT EXISTS SQL (AST-native DDL).
188    pub fn to_ddl(&self) -> String {
189        let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
190
191        let mut col_defs = Vec::new();
192        for col in &self.columns {
193            let mut line = format!("    {}", col.name);
194
195            // Type with params
196            let mut typ = col.typ.to_uppercase();
197            if let Some(params) = &col.type_params {
198                typ = format!("{}({})", typ, params.join(", "));
199            }
200            if col.is_array {
201                typ.push_str("[]");
202            }
203            line.push_str(&format!(" {}", typ));
204
205            // Constraints
206            if col.primary_key {
207                line.push_str(" PRIMARY KEY");
208            }
209            if !col.nullable && !col.primary_key && !col.is_serial {
210                line.push_str(" NOT NULL");
211            }
212            if col.unique && !col.primary_key {
213                line.push_str(" UNIQUE");
214            }
215            if let Some(ref default) = col.default_value {
216                line.push_str(&format!(" DEFAULT {}", default));
217            }
218            if let Some(ref refs) = col.references {
219                line.push_str(&format!(" REFERENCES {}", refs));
220            }
221            if let Some(ref check) = col.check {
222                line.push_str(&format!(" CHECK({})", check));
223            }
224
225            col_defs.push(line);
226        }
227
228        sql.push_str(&col_defs.join(",\n"));
229        sql.push_str("\n)");
230        sql
231    }
232}
233
234// =============================================================================
235// Parsing Combinators
236// =============================================================================
237
238/// Parse identifier (table/column name)
239fn identifier(input: &str) -> IResult<&str, &str> {
240    take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
241}
242
243/// Skip whitespace and comments (both `--` and `#` styles)
244fn ws_and_comments(input: &str) -> IResult<&str, ()> {
245    let (input, _) = many0(alt((
246        map(multispace1, |_| ()),
247        map((tag("--"), not_line_ending), |_| ()),
248        map((tag("#"), not_line_ending), |_| ()),
249    )))
250    .parse(input)?;
251    Ok((input, ()))
252}
253
254struct TypeInfo {
255    name: String,
256    params: Option<Vec<String>>,
257    is_array: bool,
258    is_serial: bool,
259}
260
261/// Parse column type with optional params and array suffix
262/// Handles: varchar(255), decimal(10,2), text[], serial, bigserial
263fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
264    let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
265
266    let (input, params) = if input.starts_with('(') {
267        let paren_start = 1;
268        let mut paren_end = paren_start;
269        for (i, c) in input[paren_start..].char_indices() {
270            if c == ')' {
271                paren_end = paren_start + i;
272                break;
273            }
274        }
275        let param_str = &input[paren_start..paren_end];
276        let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
277        (&input[paren_end + 1..], Some(params))
278    } else {
279        (input, None)
280    };
281
282    let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
283        (stripped, true)
284    } else {
285        (input, false)
286    };
287
288    let lower = type_name.to_lowercase();
289    let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
290
291    Ok((
292        input,
293        TypeInfo {
294            name: lower,
295            params,
296            is_array,
297            is_serial,
298        },
299    ))
300}
301
302/// Parse constraint text until comma or closing paren (handling nested parens)
303fn constraint_text(input: &str) -> IResult<&str, &str> {
304    let mut paren_depth = 0;
305    let mut end = 0;
306
307    for (i, c) in input.char_indices() {
308        match c {
309            '(' => paren_depth += 1,
310            ')' => {
311                if paren_depth == 0 {
312                    break; // End at column-level closing paren
313                }
314                paren_depth -= 1;
315            }
316            ',' if paren_depth == 0 => break,
317            '\n' | '\r' if paren_depth == 0 => break,
318            _ => {}
319        }
320        end = i + c.len_utf8();
321    }
322
323    if end == 0 {
324        Err(nom::Err::Error(nom::error::Error::new(
325            input,
326            nom::error::ErrorKind::TakeWhile1,
327        )))
328    } else {
329        Ok((&input[end..], &input[..end]))
330    }
331}
332
333/// Parse a single column definition
334fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
335    let (input, _) = ws_and_comments(input)?;
336    let (input, name) = identifier(input)?;
337    let (input, _) = multispace1(input)?;
338    let (input, type_info) = parse_type_info(input)?;
339
340    let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
341
342    let mut col = ColumnDef {
343        name: name.to_string(),
344        typ: type_info.name,
345        is_array: type_info.is_array,
346        type_params: type_info.params,
347        is_serial: type_info.is_serial,
348        nullable: !type_info.is_serial, // Serial types are implicitly not null
349        ..Default::default()
350    };
351
352    if let Some(constraints) = constraint_str {
353        let lower = constraints.to_lowercase();
354
355        if lower.contains("primary_key") || lower.contains("primary key") {
356            col.primary_key = true;
357            col.nullable = false;
358        }
359        if lower.contains("not_null") || lower.contains("not null") {
360            col.nullable = false;
361        }
362        if lower.contains("unique") {
363            col.unique = true;
364        }
365
366        if let Some(idx) = lower.find("references ") {
367            let rest = &constraints[idx + 11..];
368            // Find end (space or end of string), but handle nested parens
369            let mut paren_depth = 0;
370            let mut end = rest.len();
371            for (i, c) in rest.char_indices() {
372                match c {
373                    '(' => paren_depth += 1,
374                    ')' => {
375                        if paren_depth == 0 {
376                            end = i;
377                            break;
378                        }
379                        paren_depth -= 1;
380                    }
381                    c if c.is_whitespace() && paren_depth == 0 => {
382                        end = i;
383                        break;
384                    }
385                    _ => {}
386                }
387            }
388            col.references = Some(rest[..end].to_string());
389        }
390
391        if let Some(idx) = lower.find("default ") {
392            let rest = &constraints[idx + 8..];
393            let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
394            col.default_value = Some(rest[..end].to_string());
395        }
396
397        if let Some(idx) = lower.find("check(") {
398            let rest = &constraints[idx + 6..];
399            // Find matching closing paren
400            let mut depth = 1;
401            let mut end = rest.len();
402            for (i, c) in rest.char_indices() {
403                match c {
404                    '(' => depth += 1,
405                    ')' => {
406                        depth -= 1;
407                        if depth == 0 {
408                            end = i;
409                            break;
410                        }
411                    }
412                    _ => {}
413                }
414            }
415            col.check = Some(rest[..end].to_string());
416        }
417    }
418
419    Ok((input, col))
420}
421
422/// Parse column list: (col1 type, col2 type, ...)
423fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
424    let (input, _) = ws_and_comments(input)?;
425    let (input, _) = char('(').parse(input)?;
426    let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
427    let (input, _) = ws_and_comments(input)?;
428    let (input, _) = char(')').parse(input)?;
429
430    Ok((input, columns))
431}
432
433/// Parse a table definition
434fn parse_table(input: &str) -> IResult<&str, TableDef> {
435    let (input, _) = ws_and_comments(input)?;
436    let (input, _) = tag_no_case("table").parse(input)?;
437    let (input, _) = multispace1(input)?;
438    let (input, name) = identifier(input)?;
439    let (input, columns) = parse_column_list(input)?;
440
441    // Optional enable_rls annotation after closing paren
442    let (input, _) = ws_and_comments(input)?;
443    let enable_rls = if let Ok((rest, _)) =
444        tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
445    {
446        return Ok((
447            rest,
448            TableDef {
449                name: name.to_string(),
450                columns,
451                enable_rls: true,
452            },
453        ));
454    } else {
455        false
456    };
457
458    Ok((
459        input,
460        TableDef {
461            name: name.to_string(),
462            columns,
463            enable_rls,
464        },
465    ))
466}
467
468// =============================================================================
469// Policy Parsing
470// =============================================================================
471
472/// A schema item is either a table, policy, or index.
473enum SchemaItem {
474    Table(TableDef),
475    Policy(Box<RlsPolicy>),
476    Index(IndexDef),
477}
478
479/// Parse a policy definition.
480///
481/// Syntax:
482/// ```text
483/// policy <name> on <table>
484///     [for all|select|insert|update|delete]
485///     [restrictive]
486///     [to <role>]
487///     [using (<expr>)]
488///     [with check (<expr>)]
489/// ```
490fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
491    let (input, _) = ws_and_comments(input)?;
492    let (input, _) = tag_no_case("policy").parse(input)?;
493    let (input, _) = multispace1(input)?;
494    let (input, name) = identifier(input)?;
495    let (input, _) = multispace1(input)?;
496    let (input, _) = tag_no_case("on").parse(input)?;
497    let (input, _) = multispace1(input)?;
498    let (input, table) = identifier(input)?;
499
500    let mut policy = RlsPolicy::create(name, table);
501
502    // Parse optional clauses in any order
503    let mut remaining = input;
504    loop {
505        let (input, _) = ws_and_comments(remaining)?;
506
507        // for all|select|insert|update|delete
508        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
509            let (rest, _) = multispace1(rest)?;
510            let (rest, target) = alt((
511                map(tag_no_case("all"), |_| PolicyTarget::All),
512                map(tag_no_case("select"), |_| PolicyTarget::Select),
513                map(tag_no_case("insert"), |_| PolicyTarget::Insert),
514                map(tag_no_case("update"), |_| PolicyTarget::Update),
515                map(tag_no_case("delete"), |_| PolicyTarget::Delete),
516            ))
517            .parse(rest)?;
518            policy.target = target;
519            remaining = rest;
520            continue;
521        }
522
523        // restrictive
524        if let Ok((rest, _)) =
525            tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
526        {
527            policy.permissiveness = PolicyPermissiveness::Restrictive;
528            remaining = rest;
529            continue;
530        }
531
532        // to <role>
533        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
534            // Make sure it's not "to_sql" or similar — needs whitespace after
535            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
536                let (rest, role) = identifier(rest)?;
537                policy.role = Some(role.to_string());
538                remaining = rest;
539                continue;
540            }
541        }
542
543        // with check (<expr>)
544        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
545            let (rest, _) = multispace1(rest)?;
546            if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
547            {
548                let (rest, _) = nom_ws0(rest)?;
549                let (rest, _) = char('(').parse(rest)?;
550                let (rest, _) = nom_ws0(rest)?;
551                let (rest, expr) = parse_policy_expr(rest)?;
552                let (rest, _) = nom_ws0(rest)?;
553                let (rest, _) = char(')').parse(rest)?;
554                policy.with_check = Some(expr);
555                remaining = rest;
556                continue;
557            }
558        }
559
560        // using (<expr>)
561        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
562            let (rest, _) = nom_ws0(rest)?;
563            let (rest, _) = char('(').parse(rest)?;
564            let (rest, _) = nom_ws0(rest)?;
565            let (rest, expr) = parse_policy_expr(rest)?;
566            let (rest, _) = nom_ws0(rest)?;
567            let (rest, _) = char(')').parse(rest)?;
568            policy.using = Some(expr);
569            remaining = rest;
570            continue;
571        }
572
573        // No more clauses matched
574        remaining = input;
575        break;
576    }
577
578    Ok((remaining, policy))
579}
580
581/// Parse a policy expression: `left op right [AND/OR left op right ...]`
582///
583/// Produces typed `Expr::Binary` AST nodes — no raw SQL.
584///
585/// Handles:
586/// - `column = value`
587/// - `column = function('arg')::type`   (function call + cast)
588/// - `expr AND expr`, `expr OR expr`
589fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
590    let (input, first) = parse_policy_comparison(input)?;
591
592    // Check for AND/OR chaining
593    let mut result = first;
594    let mut remaining = input;
595    loop {
596        let (input, _) = nom_ws0(remaining)?;
597
598        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
599            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
600        {
601            let (rest, right) = parse_policy_comparison(rest)?;
602            result = Expr::Binary {
603                left: Box::new(result),
604                op: BinaryOp::Or,
605                right: Box::new(right),
606                alias: None,
607            };
608            remaining = rest;
609            continue;
610        }
611
612        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
613            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
614        {
615            let (rest, right) = parse_policy_comparison(rest)?;
616            result = Expr::Binary {
617                left: Box::new(result),
618                op: BinaryOp::And,
619                right: Box::new(right),
620                alias: None,
621            };
622            remaining = rest;
623            continue;
624        }
625
626        remaining = input;
627        break;
628    }
629
630    Ok((remaining, result))
631}
632
633/// Parse a single comparison: `atom op atom`
634fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
635    let (input, left) = parse_policy_atom(input)?;
636    let (input, _) = nom_ws0(input)?;
637
638    // Try to parse comparison operator
639    if let Ok((rest, op)) = parse_cmp_op(input) {
640        let (rest, _) = nom_ws0(rest)?;
641        let (rest, right) = parse_policy_atom(rest)?;
642        return Ok((
643            rest,
644            Expr::Binary {
645                left: Box::new(left),
646                op,
647                right: Box::new(right),
648                alias: None,
649            },
650        ));
651    }
652
653    // No operator — just an atom
654    Ok((input, left))
655}
656
657/// Parse comparison operators: =, !=, <>, >=, <=, >, <
658fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
659    alt((
660        map(tag(">="), |_| BinaryOp::Gte),
661        map(tag("<="), |_| BinaryOp::Lte),
662        map(tag("<>"), |_| BinaryOp::Ne),
663        map(tag("!="), |_| BinaryOp::Ne),
664        map(tag("="), |_| BinaryOp::Eq),
665        map(tag(">"), |_| BinaryOp::Gt),
666        map(tag("<"), |_| BinaryOp::Lt),
667    ))
668    .parse(input)
669}
670
671/// Parse a policy expression atom:
672/// - identifier  (column name)
673/// - function_call(args)  with optional ::cast
674/// - 'string literal'
675/// - numeric literal
676/// - true/false
677/// - (sub_expr)  grouped
678fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
679    alt((
680        parse_policy_grouped,
681        parse_policy_bool,
682        parse_policy_string,
683        parse_policy_number,
684        parse_policy_func_or_ident, // function call or plain identifier, with optional ::cast
685    ))
686    .parse(input)
687}
688
689/// Parse grouped expression in parens
690fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
691    let (input, _) = char('(').parse(input)?;
692    let (input, _) = nom_ws0(input)?;
693    let (input, expr) = parse_policy_expr(input)?;
694    let (input, _) = nom_ws0(input)?;
695    let (input, _) = char(')').parse(input)?;
696    Ok((input, expr))
697}
698
699/// Parse true / false
700fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
701    alt((
702        map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
703        map(tag_no_case("false"), |_| {
704            Expr::Literal(AstValue::Bool(false))
705        }),
706    ))
707    .parse(input)
708}
709
710/// Parse a 'string literal'
711fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
712    let (input, _) = char('\'').parse(input)?;
713    let mut end = 0;
714    for (i, c) in input.char_indices() {
715        if c == '\'' {
716            end = i;
717            break;
718        }
719    }
720    let content = &input[..end];
721    let rest = &input[end + 1..];
722    Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
723}
724
725/// Parse numeric literal
726fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
727    let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
728    // Make sure it starts with digit (not just '.')
729    if digits.starts_with('.') || digits.is_empty() {
730        return Err(nom::Err::Error(nom::error::Error::new(
731            input,
732            nom::error::ErrorKind::Digit,
733        )));
734    }
735    if let Ok(n) = digits.parse::<i64>() {
736        Ok((input, Expr::Literal(AstValue::Int(n))))
737    } else if let Ok(f) = digits.parse::<f64>() {
738        Ok((input, Expr::Literal(AstValue::Float(f))))
739    } else {
740        Ok((input, Expr::Named(digits.to_string())))
741    }
742}
743
744/// Parse function call or identifier, with optional ::cast
745fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
746    let (input, name) = identifier(input)?;
747
748    // Check for function call: name(
749    let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
750        // Parse args
751        let (rest, _) = nom_ws0(rest)?;
752        let (rest, args) =
753            separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
754        let (rest, _) = nom_ws0(rest)?;
755        let (rest, _) = char(')').parse(rest)?;
756        let input = rest;
757        (
758            input,
759            Expr::FunctionCall {
760                name: name.to_string(),
761                args,
762                alias: None,
763            },
764        )
765    } else {
766        (input, Expr::Named(name.to_string()))
767    };
768
769    // Check for ::cast
770    if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
771        let (rest, cast_type) = identifier(rest)?;
772        expr = (
773            rest,
774            Expr::Cast {
775                expr: Box::new(expr.1),
776                target_type: cast_type.to_string(),
777                alias: None,
778            },
779        );
780    }
781
782    Ok(expr)
783}
784
785/// Parse a single schema item: table, policy, or index
786fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
787    let (input, _) = ws_and_comments(input)?;
788
789    // Try policy first (since "policy" is a distinct keyword)
790    if let Ok((rest, policy)) = parse_policy(input) {
791        return Ok((rest, SchemaItem::Policy(Box::new(policy))));
792    }
793
794    // Try index
795    if let Ok((rest, idx)) = parse_index(input) {
796        return Ok((rest, SchemaItem::Index(idx)));
797    }
798
799    // Otherwise parse table
800    let (rest, table) = parse_table(input)?;
801    Ok((rest, SchemaItem::Table(table)))
802}
803
804/// Parse an index line: `index <name> on <table> (<col1>, <col2>) [unique]`
805fn parse_index(input: &str) -> IResult<&str, IndexDef> {
806    let (input, _) = tag_no_case("index")(input)?;
807    let (input, _) = multispace1(input)?;
808    let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
809    let (input, _) = multispace1(input)?;
810    let (input, _) = tag_no_case("on")(input)?;
811    let (input, _) = multispace1(input)?;
812    let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
813    let (input, _) = nom_ws0(input)?;
814    let (input, _) = char('(')(input)?;
815    let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
816    let (input, _) = char(')')(input)?;
817    let (input, _) = nom_ws0(input)?;
818    let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
819
820    let columns: Vec<String> = cols_str
821        .split(',')
822        .map(|s| s.trim().to_string())
823        .filter(|s| !s.is_empty())
824        .collect();
825
826    let is_unique = unique_tag.is_some();
827
828    Ok((
829        input,
830        IndexDef {
831            name: name.to_string(),
832            table: table.to_string(),
833            columns,
834            unique: is_unique,
835        },
836    ))
837}
838
839/// Parse complete schema file
840fn parse_schema(input: &str) -> IResult<&str, Schema> {
841    // Extract version directive before parsing
842    let version = extract_version_directive(input);
843
844    let (input, items) = many0(parse_schema_item).parse(input)?;
845    let (input, _) = ws_and_comments(input)?;
846
847    let mut tables = Vec::new();
848    let mut policies = Vec::new();
849    let mut indexes = Vec::new();
850    for item in items {
851        match item {
852            SchemaItem::Table(t) => tables.push(t),
853            SchemaItem::Policy(p) => policies.push(*p),
854            SchemaItem::Index(i) => indexes.push(i),
855        }
856    }
857
858    Ok((
859        input,
860        Schema {
861            version,
862            tables,
863            policies,
864            indexes,
865        },
866    ))
867}
868
869/// Extract version from `-- qail: version=N` directive
870fn extract_version_directive(input: &str) -> Option<u32> {
871    for line in input.lines() {
872        let line = line.trim();
873        if let Some(rest) = line.strip_prefix("-- qail:") {
874            let rest = rest.trim();
875            if let Some(version_str) = rest.strip_prefix("version=") {
876                return version_str.trim().parse().ok();
877            }
878        }
879    }
880    None
881}
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886
887    #[test]
888    fn test_parse_simple_table() {
889        let input = r#"
890            table users (
891                id uuid primary_key,
892                email text not null,
893                name text
894            )
895        "#;
896
897        let schema = Schema::parse(input).expect("parse failed");
898        assert_eq!(schema.tables.len(), 1);
899
900        let users = &schema.tables[0];
901        assert_eq!(users.name, "users");
902        assert_eq!(users.columns.len(), 3);
903
904        let id = &users.columns[0];
905        assert_eq!(id.name, "id");
906        assert_eq!(id.typ, "uuid");
907        assert!(id.primary_key);
908        assert!(!id.nullable);
909
910        let email = &users.columns[1];
911        assert_eq!(email.name, "email");
912        assert!(!email.nullable);
913
914        let name = &users.columns[2];
915        assert!(name.nullable);
916    }
917
918    #[test]
919    fn test_parse_multiple_tables() {
920        let input = r#"
921            -- Users table
922            table users (
923                id uuid primary_key,
924                email text not null unique
925            )
926            
927            -- Orders table
928            table orders (
929                id uuid primary_key,
930                user_id uuid references users(id),
931                total i64 not null default 0
932            )
933        "#;
934
935        let schema = Schema::parse(input).expect("parse failed");
936        assert_eq!(schema.tables.len(), 2);
937
938        let orders = schema.find_table("orders").expect("orders not found");
939        let user_id = orders.find_column("user_id").expect("user_id not found");
940        assert_eq!(user_id.references, Some("users(id)".to_string()));
941
942        let total = orders.find_column("total").expect("total not found");
943        assert_eq!(total.default_value, Some("0".to_string()));
944    }
945
946    #[test]
947    fn test_parse_comments() {
948        let input = r#"
949            -- This is a comment
950            table foo (
951                bar text
952            )
953        "#;
954
955        let schema = Schema::parse(input).expect("parse failed");
956        assert_eq!(schema.tables.len(), 1);
957    }
958
959    #[test]
960    fn test_array_types() {
961        let input = r#"
962            table products (
963                id uuid primary_key,
964                tags text[],
965                prices decimal[]
966            )
967        "#;
968
969        let schema = Schema::parse(input).expect("parse failed");
970        let products = &schema.tables[0];
971
972        let tags = products.find_column("tags").expect("tags not found");
973        assert_eq!(tags.typ, "text");
974        assert!(tags.is_array);
975
976        let prices = products.find_column("prices").expect("prices not found");
977        assert!(prices.is_array);
978    }
979
980    #[test]
981    fn test_type_params() {
982        let input = r#"
983            table items (
984                id serial primary_key,
985                name varchar(255) not null,
986                price decimal(10,2),
987                code varchar(50) unique
988            )
989        "#;
990
991        let schema = Schema::parse(input).expect("parse failed");
992        let items = &schema.tables[0];
993
994        let id = items.find_column("id").expect("id not found");
995        assert!(id.is_serial);
996        assert!(!id.nullable); // Serial is implicitly not null
997
998        let name = items.find_column("name").expect("name not found");
999        assert_eq!(name.typ, "varchar");
1000        assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1001
1002        let price = items.find_column("price").expect("price not found");
1003        assert_eq!(
1004            price.type_params,
1005            Some(vec!["10".to_string(), "2".to_string()])
1006        );
1007
1008        let code = items.find_column("code").expect("code not found");
1009        assert!(code.unique);
1010    }
1011
1012    #[test]
1013    fn test_check_constraint() {
1014        let input = r#"
1015            table employees (
1016                id uuid primary_key,
1017                age i32 check(age >= 18),
1018                salary decimal check(salary > 0)
1019            )
1020        "#;
1021
1022        let schema = Schema::parse(input).expect("parse failed");
1023        let employees = &schema.tables[0];
1024
1025        let age = employees.find_column("age").expect("age not found");
1026        assert_eq!(age.check, Some("age >= 18".to_string()));
1027
1028        let salary = employees.find_column("salary").expect("salary not found");
1029        assert_eq!(salary.check, Some("salary > 0".to_string()));
1030    }
1031
1032    #[test]
1033    fn test_version_directive() {
1034        let input = r#"
1035            -- qail: version=1
1036            table users (
1037                id uuid primary_key
1038            )
1039        "#;
1040
1041        let schema = Schema::parse(input).expect("parse failed");
1042        assert_eq!(schema.version, Some(1));
1043        assert_eq!(schema.tables.len(), 1);
1044
1045        // Without version directive
1046        let input_no_version = r#"
1047            table items (
1048                id uuid primary_key
1049            )
1050        "#;
1051        let schema2 = Schema::parse(input_no_version).expect("parse failed");
1052        assert_eq!(schema2.version, None);
1053    }
1054
1055    // =========================================================================
1056    // Policy + enable_rls tests
1057    // =========================================================================
1058
1059    #[test]
1060    fn test_enable_rls_table() {
1061        let input = r#"
1062            table orders (
1063                id uuid primary_key,
1064                operator_id uuid not null
1065            ) enable_rls
1066        "#;
1067
1068        let schema = Schema::parse(input).expect("parse failed");
1069        assert_eq!(schema.tables.len(), 1);
1070        assert!(schema.tables[0].enable_rls);
1071    }
1072
1073    #[test]
1074    fn test_parse_policy_basic() {
1075        let input = r#"
1076            table orders (
1077                id uuid primary_key,
1078                operator_id uuid not null
1079            ) enable_rls
1080
1081            policy orders_isolation on orders
1082                for all
1083                using (operator_id = current_setting('app.current_operator_id')::uuid)
1084        "#;
1085
1086        let schema = Schema::parse(input).expect("parse failed");
1087        assert_eq!(schema.tables.len(), 1);
1088        assert_eq!(schema.policies.len(), 1);
1089
1090        let policy = &schema.policies[0];
1091        assert_eq!(policy.name, "orders_isolation");
1092        assert_eq!(policy.table, "orders");
1093        assert_eq!(policy.target, PolicyTarget::All);
1094        assert!(policy.using.is_some());
1095
1096        // Verify the expression is a typed Binary, not raw SQL
1097        let using = policy.using.as_ref().unwrap();
1098        let Expr::Binary {
1099            left, op, right, ..
1100        } = using
1101        else {
1102            panic!("Expected Binary, got {using:?}");
1103        };
1104        assert_eq!(*op, BinaryOp::Eq);
1105
1106        let Expr::Named(n) = left.as_ref() else {
1107            panic!("Expected Named, got {left:?}");
1108        };
1109        assert_eq!(n, "operator_id");
1110
1111        let Expr::Cast {
1112            target_type,
1113            expr: cast_expr,
1114            ..
1115        } = right.as_ref()
1116        else {
1117            panic!("Expected Cast, got {right:?}");
1118        };
1119        assert_eq!(target_type, "uuid");
1120
1121        let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1122            panic!("Expected FunctionCall, got {cast_expr:?}");
1123        };
1124        assert_eq!(name, "current_setting");
1125        assert_eq!(args.len(), 1);
1126    }
1127
1128    #[test]
1129    fn test_parse_policy_with_check() {
1130        let input = r#"
1131            table orders (
1132                id uuid primary_key
1133            )
1134
1135            policy orders_write on orders
1136                for insert
1137                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1138        "#;
1139
1140        let schema = Schema::parse(input).expect("parse failed");
1141        let policy = &schema.policies[0];
1142        assert_eq!(policy.target, PolicyTarget::Insert);
1143        assert!(policy.with_check.is_some());
1144        assert!(policy.using.is_none());
1145    }
1146
1147    #[test]
1148    fn test_parse_policy_restrictive_with_role() {
1149        let input = r#"
1150            table secrets (
1151                id uuid primary_key
1152            )
1153
1154            policy admin_only on secrets
1155                for select
1156                restrictive
1157                to app_user
1158                using (current_setting('app.is_super_admin')::boolean = true)
1159        "#;
1160
1161        let schema = Schema::parse(input).expect("parse failed");
1162        let policy = &schema.policies[0];
1163        assert_eq!(policy.target, PolicyTarget::Select);
1164        assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1165        assert_eq!(policy.role.as_deref(), Some("app_user"));
1166        assert!(policy.using.is_some());
1167    }
1168
1169    #[test]
1170    fn test_parse_policy_or_expr() {
1171        let input = r#"
1172            table orders (
1173                id uuid primary_key
1174            )
1175
1176            policy tenant_or_admin on orders
1177                for all
1178                using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1179        "#;
1180
1181        let schema = Schema::parse(input).expect("parse failed");
1182        let policy = &schema.policies[0];
1183
1184        assert!(
1185            matches!(
1186                policy.using.as_ref().unwrap(),
1187                Expr::Binary {
1188                    op: BinaryOp::Or,
1189                    ..
1190                }
1191            ),
1192            "Expected Binary OR, got {:?}",
1193            policy.using
1194        );
1195    }
1196
1197    #[test]
1198    fn test_schema_to_sql() {
1199        let input = r#"
1200            table orders (
1201                id uuid primary_key,
1202                operator_id uuid not null
1203            ) enable_rls
1204
1205            policy orders_isolation on orders
1206                for all
1207                using (operator_id = current_setting('app.current_operator_id')::uuid)
1208        "#;
1209
1210        let schema = Schema::parse(input).expect("parse failed");
1211        let sql = schema.to_sql();
1212        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1213        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1214        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1215        assert!(sql.contains("CREATE POLICY"));
1216        assert!(sql.contains("orders_isolation"));
1217        assert!(sql.contains("FOR ALL"));
1218    }
1219
1220    #[test]
1221    fn test_multiple_policies() {
1222        let input = r#"
1223            table orders (
1224                id uuid primary_key,
1225                operator_id uuid not null
1226            ) enable_rls
1227
1228            policy orders_read on orders
1229                for select
1230                using (operator_id = current_setting('app.current_operator_id')::uuid)
1231
1232            policy orders_write on orders
1233                for insert
1234                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1235        "#;
1236
1237        let schema = Schema::parse(input).expect("parse failed");
1238        assert_eq!(schema.policies.len(), 2);
1239        assert_eq!(schema.policies[0].name, "orders_read");
1240        assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1241        assert_eq!(schema.policies[1].name, "orders_write");
1242        assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1243    }
1244}