Skip to main content

qail_core/parser/
schema.rs

1//! Schema file parser for `.qail` format.
2//!
3//! Parses schema definitions like:
4//! ```text
5//! table users (
6//!   id uuid primary_key,
7//!   email text not null,
8//!   name text,
9//!   created_at timestamp
10//! )
11//!
12//! policy users_isolation on users
13//!     for all
14//!     using (operator_id = current_setting('app.operator_id')::uuid)
15//! ```
16
17use nom::{
18    IResult, Parser,
19    branch::alt,
20    bytes::complete::{tag, tag_no_case, take_while1},
21    character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22    combinator::{map, opt},
23    multi::{many0, separated_list0},
24    sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{Expr, BinaryOp, Value as AstValue};
29use crate::migrate::policy::{RlsPolicy, PolicyTarget, PolicyPermissiveness};
30use crate::transpiler::policy::{create_policy_sql, alter_table_sql};
31use crate::migrate::alter::AlterTable;
32
33/// Schema containing all table definitions
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36    /// Schema format version (extracted from `-- qail: version=N` directive)
37    #[serde(default)]
38    pub version: Option<u32>,
39    pub tables: Vec<TableDef>,
40    /// RLS policies declared in the schema
41    #[serde(default)]
42    pub policies: Vec<RlsPolicy>,
43    /// Indexes declared in the schema
44    #[serde(default)]
45    pub indexes: Vec<IndexDef>,
46}
47
48/// Index definition parsed from `index <name> on <table> (<columns>) [unique]`
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct IndexDef {
51    pub name: String,
52    pub table: String,
53    pub columns: Vec<String>,
54    #[serde(default)]
55    pub unique: bool,
56}
57
58impl IndexDef {
59    /// Generate `CREATE INDEX IF NOT EXISTS` SQL.
60    pub fn to_sql(&self) -> String {
61        let unique = if self.unique { " UNIQUE" } else { "" };
62        format!(
63            "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64            unique,
65            self.name,
66            self.table,
67            self.columns.join(", ")
68        )
69    }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TableDef {
74    pub name: String,
75    pub columns: Vec<ColumnDef>,
76    /// Whether this table has RLS enabled
77    #[serde(default)]
78    pub enable_rls: bool,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ColumnDef {
83    pub name: String,
84    #[serde(rename = "type", alias = "typ")]
85    pub typ: String,
86    /// Type is an array (e.g., text[], uuid[])
87    #[serde(default)]
88    pub is_array: bool,
89    /// Type parameters (e.g., varchar(255) -> Some(vec!["255"]), decimal(10,2) -> Some(vec!["10", "2"]))
90    #[serde(default)]
91    pub type_params: Option<Vec<String>>,
92    #[serde(default)]
93    pub nullable: bool,
94    #[serde(default)]
95    pub primary_key: bool,
96    #[serde(default)]
97    pub unique: bool,
98    #[serde(default)]
99    pub references: Option<String>,
100    #[serde(default)]
101    pub default_value: Option<String>,
102    /// Check constraint expression
103    #[serde(default)]
104    pub check: Option<String>,
105    /// Is this a serial/auto-increment type
106    #[serde(default)]
107    pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111    fn default() -> Self {
112        Self {
113            name: String::new(),
114            typ: String::new(),
115            is_array: false,
116            type_params: None,
117            nullable: true,
118            primary_key: false,
119            unique: false,
120            references: None,
121            default_value: None,
122            check: None,
123            is_serial: false,
124        }
125    }
126}
127
128impl Schema {
129    /// Parse a schema from `.qail` format string
130    pub fn parse(input: &str) -> Result<Self, String> {
131        match parse_schema(input) {
132            Ok(("", schema)) => Ok(schema),
133            Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134            Err(e) => Err(format!("Parse error: {:?}", e)),
135        }
136    }
137
138    /// Find a table by name
139    pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140        self.tables
141            .iter()
142            .find(|t| t.name.eq_ignore_ascii_case(name))
143    }
144
145    /// Generate complete SQL for this schema: tables + RLS + policies + indexes.
146    pub fn to_sql(&self) -> String {
147        let mut parts = Vec::new();
148
149        for table in &self.tables {
150            parts.push(table.to_ddl());
151
152            if table.enable_rls {
153                let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154                for stmt in alter_table_sql(&alter) {
155                    parts.push(stmt);
156                }
157            }
158        }
159
160        for idx in &self.indexes {
161            parts.push(idx.to_sql());
162        }
163
164        for policy in &self.policies {
165            parts.push(create_policy_sql(policy));
166        }
167
168        parts.join(";\n\n") + ";"
169    }
170
171    /// Export schema to JSON string (for qail-macros compatibility)
172    pub fn to_json(&self) -> Result<String, String> {
173        serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
174    }
175
176    /// Import schema from JSON string
177    pub fn from_json(json: &str) -> Result<Self, String> {
178        serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
179    }
180
181    /// Load schema from a .qail file
182    pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
183        let content =
184            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
185
186        if content.trim().starts_with('{') {
187            Self::from_json(&content)
188        } else {
189            Self::parse(&content)
190        }
191    }
192}
193
194impl TableDef {
195    /// Find a column by name
196    pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
197        self.columns
198            .iter()
199            .find(|c| c.name.eq_ignore_ascii_case(name))
200    }
201
202    /// Generate CREATE TABLE IF NOT EXISTS SQL (AST-native DDL).
203    pub fn to_ddl(&self) -> String {
204        let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
205
206        let mut col_defs = Vec::new();
207        for col in &self.columns {
208            let mut line = format!("    {}", col.name);
209
210            // Type with params
211            let mut typ = col.typ.to_uppercase();
212            if let Some(params) = &col.type_params {
213                typ = format!("{}({})", typ, params.join(", "));
214            }
215            if col.is_array {
216                typ.push_str("[]");
217            }
218            line.push_str(&format!(" {}", typ));
219
220            // Constraints
221            if col.primary_key {
222                line.push_str(" PRIMARY KEY");
223            }
224            if !col.nullable && !col.primary_key && !col.is_serial {
225                line.push_str(" NOT NULL");
226            }
227            if col.unique && !col.primary_key {
228                line.push_str(" UNIQUE");
229            }
230            if let Some(ref default) = col.default_value {
231                line.push_str(&format!(" DEFAULT {}", default));
232            }
233            if let Some(ref refs) = col.references {
234                line.push_str(&format!(" REFERENCES {}", refs));
235            }
236            if let Some(ref check) = col.check {
237                line.push_str(&format!(" CHECK({})", check));
238            }
239
240            col_defs.push(line);
241        }
242
243        sql.push_str(&col_defs.join(",\n"));
244        sql.push_str("\n)");
245        sql
246    }
247}
248
249// =============================================================================
250// Parsing Combinators
251// =============================================================================
252
253/// Parse identifier (table/column name)
254fn identifier(input: &str) -> IResult<&str, &str> {
255    take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
256}
257
258/// Skip whitespace and comments (both `--` and `#` styles)
259fn ws_and_comments(input: &str) -> IResult<&str, ()> {
260    let (input, _) = many0(alt((
261        map(multispace1, |_| ()),
262        map((tag("--"), not_line_ending), |_| ()),
263        map((tag("#"), not_line_ending), |_| ()),
264    )))
265    .parse(input)?;
266    Ok((input, ()))
267}
268
269struct TypeInfo {
270    name: String,
271    params: Option<Vec<String>>,
272    is_array: bool,
273    is_serial: bool,
274}
275
276/// Parse column type with optional params and array suffix
277/// Handles: varchar(255), decimal(10,2), text[], serial, bigserial
278fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
279    let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
280
281    let (input, params) = if input.starts_with('(') {
282        let paren_start = 1;
283        let mut paren_end = paren_start;
284        for (i, c) in input[paren_start..].char_indices() {
285            if c == ')' {
286                paren_end = paren_start + i;
287                break;
288            }
289        }
290        let param_str = &input[paren_start..paren_end];
291        let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
292        (&input[paren_end + 1..], Some(params))
293    } else {
294        (input, None)
295    };
296
297    let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
298        (stripped, true)
299    } else {
300        (input, false)
301    };
302
303    let lower = type_name.to_lowercase();
304    let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
305
306    Ok((
307        input,
308        TypeInfo {
309            name: lower,
310            params,
311            is_array,
312            is_serial,
313        },
314    ))
315}
316
317/// Parse constraint text until comma or closing paren (handling nested parens)
318fn constraint_text(input: &str) -> IResult<&str, &str> {
319    let mut paren_depth = 0;
320    let mut end = 0;
321
322    for (i, c) in input.char_indices() {
323        match c {
324            '(' => paren_depth += 1,
325            ')' => {
326                if paren_depth == 0 {
327                    break; // End at column-level closing paren
328                }
329                paren_depth -= 1;
330            }
331            ',' if paren_depth == 0 => break,
332            '\n' | '\r' if paren_depth == 0 => break,
333            _ => {}
334        }
335        end = i + c.len_utf8();
336    }
337
338    if end == 0 {
339        Err(nom::Err::Error(nom::error::Error::new(
340            input,
341            nom::error::ErrorKind::TakeWhile1,
342        )))
343    } else {
344        Ok((&input[end..], &input[..end]))
345    }
346}
347
348/// Parse a single column definition
349fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
350    let (input, _) = ws_and_comments(input)?;
351    let (input, name) = identifier(input)?;
352    let (input, _) = multispace1(input)?;
353    let (input, type_info) = parse_type_info(input)?;
354
355    let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
356
357    let mut col = ColumnDef {
358        name: name.to_string(),
359        typ: type_info.name,
360        is_array: type_info.is_array,
361        type_params: type_info.params,
362        is_serial: type_info.is_serial,
363        nullable: !type_info.is_serial, // Serial types are implicitly not null
364        ..Default::default()
365    };
366
367    if let Some(constraints) = constraint_str {
368        let lower = constraints.to_lowercase();
369
370        if lower.contains("primary_key") || lower.contains("primary key") {
371            col.primary_key = true;
372            col.nullable = false;
373        }
374        if lower.contains("not_null") || lower.contains("not null") {
375            col.nullable = false;
376        }
377        if lower.contains("unique") {
378            col.unique = true;
379        }
380
381        if let Some(idx) = lower.find("references ") {
382            let rest = &constraints[idx + 11..];
383            // Find end (space or end of string), but handle nested parens
384            let mut paren_depth = 0;
385            let mut end = rest.len();
386            for (i, c) in rest.char_indices() {
387                match c {
388                    '(' => paren_depth += 1,
389                    ')' => {
390                        if paren_depth == 0 {
391                            end = i;
392                            break;
393                        }
394                        paren_depth -= 1;
395                    }
396                    c if c.is_whitespace() && paren_depth == 0 => {
397                        end = i;
398                        break;
399                    }
400                    _ => {}
401                }
402            }
403            col.references = Some(rest[..end].to_string());
404        }
405
406        if let Some(idx) = lower.find("default ") {
407            let rest = &constraints[idx + 8..];
408            let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
409            col.default_value = Some(rest[..end].to_string());
410        }
411
412        if let Some(idx) = lower.find("check(") {
413            let rest = &constraints[idx + 6..];
414            // Find matching closing paren
415            let mut depth = 1;
416            let mut end = rest.len();
417            for (i, c) in rest.char_indices() {
418                match c {
419                    '(' => depth += 1,
420                    ')' => {
421                        depth -= 1;
422                        if depth == 0 {
423                            end = i;
424                            break;
425                        }
426                    }
427                    _ => {}
428                }
429            }
430            col.check = Some(rest[..end].to_string());
431        }
432    }
433
434    Ok((input, col))
435}
436
437/// Parse column list: (col1 type, col2 type, ...)
438fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
439    let (input, _) = ws_and_comments(input)?;
440    let (input, _) = char('(').parse(input)?;
441    let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
442    let (input, _) = ws_and_comments(input)?;
443    let (input, _) = char(')').parse(input)?;
444
445    Ok((input, columns))
446}
447
448/// Parse a table definition
449fn parse_table(input: &str) -> IResult<&str, TableDef> {
450    let (input, _) = ws_and_comments(input)?;
451    let (input, _) = tag_no_case("table").parse(input)?;
452    let (input, _) = multispace1(input)?;
453    let (input, name) = identifier(input)?;
454    let (input, columns) = parse_column_list(input)?;
455
456    // Optional enable_rls annotation after closing paren
457    let (input, _) = ws_and_comments(input)?;
458    let enable_rls = if let Ok((rest, _)) =
459        tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
460    {
461        return Ok((
462            rest,
463            TableDef {
464                name: name.to_string(),
465                columns,
466                enable_rls: true,
467            },
468        ));
469    } else {
470        false
471    };
472
473    Ok((
474        input,
475        TableDef {
476            name: name.to_string(),
477            columns,
478            enable_rls,
479        },
480    ))
481}
482
483// =============================================================================
484// Policy Parsing
485// =============================================================================
486
487/// A schema item is either a table, policy, or index.
488enum SchemaItem {
489    Table(TableDef),
490    Policy(RlsPolicy),
491    Index(IndexDef),
492}
493
494/// Parse a policy definition.
495///
496/// Syntax:
497/// ```text
498/// policy <name> on <table>
499///     [for all|select|insert|update|delete]
500///     [restrictive]
501///     [to <role>]
502///     [using (<expr>)]
503///     [with check (<expr>)]
504/// ```
505fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
506    let (input, _) = ws_and_comments(input)?;
507    let (input, _) = tag_no_case("policy").parse(input)?;
508    let (input, _) = multispace1(input)?;
509    let (input, name) = identifier(input)?;
510    let (input, _) = multispace1(input)?;
511    let (input, _) = tag_no_case("on").parse(input)?;
512    let (input, _) = multispace1(input)?;
513    let (input, table) = identifier(input)?;
514
515    let mut policy = RlsPolicy::create(name, table);
516
517    // Parse optional clauses in any order
518    let mut remaining = input;
519    loop {
520        let (input, _) = ws_and_comments(remaining)?;
521
522        // for all|select|insert|update|delete
523        if let Ok((rest, _)) =
524            tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input)
525        {
526            let (rest, _) = multispace1(rest)?;
527            let (rest, target) = alt((
528                map(tag_no_case("all"), |_| PolicyTarget::All),
529                map(tag_no_case("select"), |_| PolicyTarget::Select),
530                map(tag_no_case("insert"), |_| PolicyTarget::Insert),
531                map(tag_no_case("update"), |_| PolicyTarget::Update),
532                map(tag_no_case("delete"), |_| PolicyTarget::Delete),
533            ))
534            .parse(rest)?;
535            policy.target = target;
536            remaining = rest;
537            continue;
538        }
539
540        // restrictive
541        if let Ok((rest, _)) =
542            tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
543        {
544            policy.permissiveness = PolicyPermissiveness::Restrictive;
545            remaining = rest;
546            continue;
547        }
548
549        // to <role>
550        if let Ok((rest, _)) =
551            tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input)
552        {
553            // Make sure it's not "to_sql" or similar — needs whitespace after
554            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
555                let (rest, role) = identifier(rest)?;
556                policy.role = Some(role.to_string());
557                remaining = rest;
558                continue;
559            }
560        }
561
562        // with check (<expr>)
563        if let Ok((rest, _)) =
564            tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input)
565        {
566            let (rest, _) = multispace1(rest)?;
567            if let Ok((rest, _)) =
568                tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
569            {
570                let (rest, _) = nom_ws0(rest)?;
571                let (rest, _) = char('(').parse(rest)?;
572                let (rest, _) = nom_ws0(rest)?;
573                let (rest, expr) = parse_policy_expr(rest)?;
574                let (rest, _) = nom_ws0(rest)?;
575                let (rest, _) = char(')').parse(rest)?;
576                policy.with_check = Some(expr);
577                remaining = rest;
578                continue;
579            }
580        }
581
582        // using (<expr>)
583        if let Ok((rest, _)) =
584            tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input)
585        {
586            let (rest, _) = nom_ws0(rest)?;
587            let (rest, _) = char('(').parse(rest)?;
588            let (rest, _) = nom_ws0(rest)?;
589            let (rest, expr) = parse_policy_expr(rest)?;
590            let (rest, _) = nom_ws0(rest)?;
591            let (rest, _) = char(')').parse(rest)?;
592            policy.using = Some(expr);
593            remaining = rest;
594            continue;
595        }
596
597        // No more clauses matched
598        remaining = input;
599        break;
600    }
601
602    Ok((remaining, policy))
603}
604
605/// Parse a policy expression: `left op right [AND/OR left op right ...]`
606///
607/// Produces typed `Expr::Binary` AST nodes — no raw SQL.
608///
609/// Handles:
610/// - `column = value`
611/// - `column = function('arg')::type`   (function call + cast)
612/// - `expr AND expr`, `expr OR expr`
613fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
614    let (input, first) = parse_policy_comparison(input)?;
615
616    // Check for AND/OR chaining
617    let mut result = first;
618    let mut remaining = input;
619    loop {
620        let (input, _) = nom_ws0(remaining)?;
621
622        if let Ok((rest, _)) =
623            tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
624        {
625            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
626                let (rest, right) = parse_policy_comparison(rest)?;
627                result = Expr::Binary {
628                    left: Box::new(result),
629                    op: BinaryOp::Or,
630                    right: Box::new(right),
631                    alias: None,
632                };
633                remaining = rest;
634                continue;
635            }
636        }
637
638        if let Ok((rest, _)) =
639            tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
640        {
641            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
642                let (rest, right) = parse_policy_comparison(rest)?;
643                result = Expr::Binary {
644                    left: Box::new(result),
645                    op: BinaryOp::And,
646                    right: Box::new(right),
647                    alias: None,
648                };
649                remaining = rest;
650                continue;
651            }
652        }
653
654        remaining = input;
655        break;
656    }
657
658    Ok((remaining, result))
659}
660
661/// Parse a single comparison: `atom op atom`
662fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
663    let (input, left) = parse_policy_atom(input)?;
664    let (input, _) = nom_ws0(input)?;
665
666    // Try to parse comparison operator
667    if let Ok((rest, op)) = parse_cmp_op(input) {
668        let (rest, _) = nom_ws0(rest)?;
669        let (rest, right) = parse_policy_atom(rest)?;
670        return Ok((
671            rest,
672            Expr::Binary {
673                left: Box::new(left),
674                op,
675                right: Box::new(right),
676                alias: None,
677            },
678        ));
679    }
680
681    // No operator — just an atom
682    Ok((input, left))
683}
684
685/// Parse comparison operators: =, !=, <>, >=, <=, >, <
686fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
687    alt((
688        map(tag(">="), |_| BinaryOp::Gte),
689        map(tag("<="), |_| BinaryOp::Lte),
690        map(tag("<>"), |_| BinaryOp::Ne),
691        map(tag("!="), |_| BinaryOp::Ne),
692        map(tag("="), |_| BinaryOp::Eq),
693        map(tag(">"), |_| BinaryOp::Gt),
694        map(tag("<"), |_| BinaryOp::Lt),
695    ))
696    .parse(input)
697}
698
699/// Parse a policy expression atom:
700/// - identifier  (column name)
701/// - function_call(args)  with optional ::cast
702/// - 'string literal'
703/// - numeric literal
704/// - true/false
705/// - (sub_expr)  grouped
706fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
707    alt((
708        parse_policy_grouped,
709        parse_policy_bool,
710        parse_policy_string,
711        parse_policy_number,
712        parse_policy_func_or_ident, // function call or plain identifier, with optional ::cast
713    ))
714    .parse(input)
715}
716
717/// Parse grouped expression in parens
718fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
719    let (input, _) = char('(').parse(input)?;
720    let (input, _) = nom_ws0(input)?;
721    let (input, expr) = parse_policy_expr(input)?;
722    let (input, _) = nom_ws0(input)?;
723    let (input, _) = char(')').parse(input)?;
724    Ok((input, expr))
725}
726
727/// Parse true / false
728fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
729    alt((
730        map(tag_no_case("true"), |_| {
731            Expr::Literal(AstValue::Bool(true))
732        }),
733        map(tag_no_case("false"), |_| {
734            Expr::Literal(AstValue::Bool(false))
735        }),
736    ))
737    .parse(input)
738}
739
740/// Parse a 'string literal'
741fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
742    let (input, _) = char('\'').parse(input)?;
743    let mut end = 0;
744    for (i, c) in input.char_indices() {
745        if c == '\'' {
746            end = i;
747            break;
748        }
749    }
750    let content = &input[..end];
751    let rest = &input[end + 1..];
752    Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
753}
754
755/// Parse numeric literal
756fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
757    let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
758    // Make sure it starts with digit (not just '.')
759    if digits.starts_with('.') || digits.is_empty() {
760        return Err(nom::Err::Error(nom::error::Error::new(
761            input,
762            nom::error::ErrorKind::Digit,
763        )));
764    }
765    if let Ok(n) = digits.parse::<i64>() {
766        Ok((input, Expr::Literal(AstValue::Int(n))))
767    } else if let Ok(f) = digits.parse::<f64>() {
768        Ok((input, Expr::Literal(AstValue::Float(f))))
769    } else {
770        Ok((input, Expr::Named(digits.to_string())))
771    }
772}
773
774/// Parse function call or identifier, with optional ::cast
775fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
776    let (input, name) = identifier(input)?;
777
778    // Check for function call: name(
779    let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
780        // Parse args
781        let (rest, _) = nom_ws0(rest)?;
782        let (rest, args) = separated_list0(
783            (nom_ws0, char(','), nom_ws0),
784            parse_policy_atom,
785        )
786        .parse(rest)?;
787        let (rest, _) = nom_ws0(rest)?;
788        let (rest, _) = char(')').parse(rest)?;
789        let input = rest;
790        (input, Expr::FunctionCall {
791            name: name.to_string(),
792            args,
793            alias: None,
794        })
795    } else {
796        (input, Expr::Named(name.to_string()))
797    };
798
799    // Check for ::cast
800    if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
801        let (rest, cast_type) = identifier(rest)?;
802        expr = (
803            rest,
804            Expr::Cast {
805                expr: Box::new(expr.1),
806                target_type: cast_type.to_string(),
807                alias: None,
808            },
809        );
810    }
811
812    Ok(expr)
813}
814
815/// Parse a single schema item: table, policy, or index
816fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
817    let (input, _) = ws_and_comments(input)?;
818
819    // Try policy first (since "policy" is a distinct keyword)
820    if let Ok((rest, policy)) = parse_policy(input) {
821        return Ok((rest, SchemaItem::Policy(policy)));
822    }
823
824    // Try index
825    if let Ok((rest, idx)) = parse_index(input) {
826        return Ok((rest, SchemaItem::Index(idx)));
827    }
828
829    // Otherwise parse table
830    let (rest, table) = parse_table(input)?;
831    Ok((rest, SchemaItem::Table(table)))
832}
833
834/// Parse an index line: `index <name> on <table> (<col1>, <col2>) [unique]`
835fn parse_index(input: &str) -> IResult<&str, IndexDef> {
836    let (input, _) = tag_no_case("index")(input)?;
837    let (input, _) = multispace1(input)?;
838    let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
839    let (input, _) = multispace1(input)?;
840    let (input, _) = tag_no_case("on")(input)?;
841    let (input, _) = multispace1(input)?;
842    let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
843    let (input, _) = nom_ws0(input)?;
844    let (input, _) = char('(')(input)?;
845    let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
846    let (input, _) = char(')')(input)?;
847    let (input, _) = nom_ws0(input)?;
848    let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
849
850    let columns: Vec<String> = cols_str
851        .split(',')
852        .map(|s| s.trim().to_string())
853        .filter(|s| !s.is_empty())
854        .collect();
855
856    let is_unique = unique_tag.is_some();
857
858    Ok((input, IndexDef {
859        name: name.to_string(),
860        table: table.to_string(),
861        columns,
862        unique: is_unique,
863    }))
864}
865
866/// Parse complete schema file
867fn parse_schema(input: &str) -> IResult<&str, Schema> {
868    // Extract version directive before parsing
869    let version = extract_version_directive(input);
870
871    let (input, items) = many0(parse_schema_item).parse(input)?;
872    let (input, _) = ws_and_comments(input)?;
873
874    let mut tables = Vec::new();
875    let mut policies = Vec::new();
876    let mut indexes = Vec::new();
877    for item in items {
878        match item {
879            SchemaItem::Table(t) => tables.push(t),
880            SchemaItem::Policy(p) => policies.push(p),
881            SchemaItem::Index(i) => indexes.push(i),
882        }
883    }
884
885    Ok((input, Schema { version, tables, policies, indexes }))
886}
887
888/// Extract version from `-- qail: version=N` directive
889fn extract_version_directive(input: &str) -> Option<u32> {
890    for line in input.lines() {
891        let line = line.trim();
892        if let Some(rest) = line.strip_prefix("-- qail:") {
893            let rest = rest.trim();
894            if let Some(version_str) = rest.strip_prefix("version=") {
895                return version_str.trim().parse().ok();
896            }
897        }
898    }
899    None
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    #[test]
907    fn test_parse_simple_table() {
908        let input = r#"
909            table users (
910                id uuid primary_key,
911                email text not null,
912                name text
913            )
914        "#;
915
916        let schema = Schema::parse(input).expect("parse failed");
917        assert_eq!(schema.tables.len(), 1);
918
919        let users = &schema.tables[0];
920        assert_eq!(users.name, "users");
921        assert_eq!(users.columns.len(), 3);
922
923        let id = &users.columns[0];
924        assert_eq!(id.name, "id");
925        assert_eq!(id.typ, "uuid");
926        assert!(id.primary_key);
927        assert!(!id.nullable);
928
929        let email = &users.columns[1];
930        assert_eq!(email.name, "email");
931        assert!(!email.nullable);
932
933        let name = &users.columns[2];
934        assert!(name.nullable);
935    }
936
937    #[test]
938    fn test_parse_multiple_tables() {
939        let input = r#"
940            -- Users table
941            table users (
942                id uuid primary_key,
943                email text not null unique
944            )
945            
946            -- Orders table
947            table orders (
948                id uuid primary_key,
949                user_id uuid references users(id),
950                total i64 not null default 0
951            )
952        "#;
953
954        let schema = Schema::parse(input).expect("parse failed");
955        assert_eq!(schema.tables.len(), 2);
956
957        let orders = schema.find_table("orders").expect("orders not found");
958        let user_id = orders.find_column("user_id").expect("user_id not found");
959        assert_eq!(user_id.references, Some("users(id)".to_string()));
960
961        let total = orders.find_column("total").expect("total not found");
962        assert_eq!(total.default_value, Some("0".to_string()));
963    }
964
965    #[test]
966    fn test_parse_comments() {
967        let input = r#"
968            -- This is a comment
969            table foo (
970                bar text
971            )
972        "#;
973
974        let schema = Schema::parse(input).expect("parse failed");
975        assert_eq!(schema.tables.len(), 1);
976    }
977
978    #[test]
979    fn test_array_types() {
980        let input = r#"
981            table products (
982                id uuid primary_key,
983                tags text[],
984                prices decimal[]
985            )
986        "#;
987
988        let schema = Schema::parse(input).expect("parse failed");
989        let products = &schema.tables[0];
990
991        let tags = products.find_column("tags").expect("tags not found");
992        assert_eq!(tags.typ, "text");
993        assert!(tags.is_array);
994
995        let prices = products.find_column("prices").expect("prices not found");
996        assert!(prices.is_array);
997    }
998
999    #[test]
1000    fn test_type_params() {
1001        let input = r#"
1002            table items (
1003                id serial primary_key,
1004                name varchar(255) not null,
1005                price decimal(10,2),
1006                code varchar(50) unique
1007            )
1008        "#;
1009
1010        let schema = Schema::parse(input).expect("parse failed");
1011        let items = &schema.tables[0];
1012
1013        let id = items.find_column("id").expect("id not found");
1014        assert!(id.is_serial);
1015        assert!(!id.nullable); // Serial is implicitly not null
1016
1017        let name = items.find_column("name").expect("name not found");
1018        assert_eq!(name.typ, "varchar");
1019        assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1020
1021        let price = items.find_column("price").expect("price not found");
1022        assert_eq!(
1023            price.type_params,
1024            Some(vec!["10".to_string(), "2".to_string()])
1025        );
1026
1027        let code = items.find_column("code").expect("code not found");
1028        assert!(code.unique);
1029    }
1030
1031    #[test]
1032    fn test_check_constraint() {
1033        let input = r#"
1034            table employees (
1035                id uuid primary_key,
1036                age i32 check(age >= 18),
1037                salary decimal check(salary > 0)
1038            )
1039        "#;
1040
1041        let schema = Schema::parse(input).expect("parse failed");
1042        let employees = &schema.tables[0];
1043
1044        let age = employees.find_column("age").expect("age not found");
1045        assert_eq!(age.check, Some("age >= 18".to_string()));
1046
1047        let salary = employees.find_column("salary").expect("salary not found");
1048        assert_eq!(salary.check, Some("salary > 0".to_string()));
1049    }
1050
1051    #[test]
1052    fn test_version_directive() {
1053        let input = r#"
1054            -- qail: version=1
1055            table users (
1056                id uuid primary_key
1057            )
1058        "#;
1059
1060        let schema = Schema::parse(input).expect("parse failed");
1061        assert_eq!(schema.version, Some(1));
1062        assert_eq!(schema.tables.len(), 1);
1063
1064        // Without version directive
1065        let input_no_version = r#"
1066            table items (
1067                id uuid primary_key
1068            )
1069        "#;
1070        let schema2 = Schema::parse(input_no_version).expect("parse failed");
1071        assert_eq!(schema2.version, None);
1072    }
1073
1074    // =========================================================================
1075    // Policy + enable_rls tests
1076    // =========================================================================
1077
1078    #[test]
1079    fn test_enable_rls_table() {
1080        let input = r#"
1081            table orders (
1082                id uuid primary_key,
1083                operator_id uuid not null
1084            ) enable_rls
1085        "#;
1086
1087        let schema = Schema::parse(input).expect("parse failed");
1088        assert_eq!(schema.tables.len(), 1);
1089        assert!(schema.tables[0].enable_rls);
1090    }
1091
1092    #[test]
1093    fn test_parse_policy_basic() {
1094        let input = r#"
1095            table orders (
1096                id uuid primary_key,
1097                operator_id uuid not null
1098            ) enable_rls
1099
1100            policy orders_isolation on orders
1101                for all
1102                using (operator_id = current_setting('app.current_operator_id')::uuid)
1103        "#;
1104
1105        let schema = Schema::parse(input).expect("parse failed");
1106        assert_eq!(schema.tables.len(), 1);
1107        assert_eq!(schema.policies.len(), 1);
1108
1109        let policy = &schema.policies[0];
1110        assert_eq!(policy.name, "orders_isolation");
1111        assert_eq!(policy.table, "orders");
1112        assert_eq!(policy.target, PolicyTarget::All);
1113        assert!(policy.using.is_some());
1114
1115        // Verify the expression is a typed Binary, not raw SQL
1116        match policy.using.as_ref().unwrap() {
1117            Expr::Binary { left, op, right, .. } => {
1118                assert_eq!(*op, BinaryOp::Eq);
1119                match left.as_ref() {
1120                    Expr::Named(n) => assert_eq!(n, "operator_id"),
1121                    _ => panic!("Expected Named, got {:?}", left),
1122                }
1123                match right.as_ref() {
1124                    Expr::Cast { target_type, expr, .. } => {
1125                        assert_eq!(target_type, "uuid");
1126                        match expr.as_ref() {
1127                            Expr::FunctionCall { name, args, .. } => {
1128                                assert_eq!(name, "current_setting");
1129                                assert_eq!(args.len(), 1);
1130                            }
1131                            _ => panic!("Expected FunctionCall"),
1132                        }
1133                    }
1134                    _ => panic!("Expected Cast, got {:?}", right),
1135                }
1136            }
1137            _ => panic!("Expected Binary"),
1138        }
1139    }
1140
1141    #[test]
1142    fn test_parse_policy_with_check() {
1143        let input = r#"
1144            table orders (
1145                id uuid primary_key
1146            )
1147
1148            policy orders_write on orders
1149                for insert
1150                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1151        "#;
1152
1153        let schema = Schema::parse(input).expect("parse failed");
1154        let policy = &schema.policies[0];
1155        assert_eq!(policy.target, PolicyTarget::Insert);
1156        assert!(policy.with_check.is_some());
1157        assert!(policy.using.is_none());
1158    }
1159
1160    #[test]
1161    fn test_parse_policy_restrictive_with_role() {
1162        let input = r#"
1163            table secrets (
1164                id uuid primary_key
1165            )
1166
1167            policy admin_only on secrets
1168                for select
1169                restrictive
1170                to app_user
1171                using (current_setting('app.is_super_admin')::boolean = true)
1172        "#;
1173
1174        let schema = Schema::parse(input).expect("parse failed");
1175        let policy = &schema.policies[0];
1176        assert_eq!(policy.target, PolicyTarget::Select);
1177        assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1178        assert_eq!(policy.role.as_deref(), Some("app_user"));
1179        assert!(policy.using.is_some());
1180    }
1181
1182    #[test]
1183    fn test_parse_policy_or_expr() {
1184        let input = r#"
1185            table orders (
1186                id uuid primary_key
1187            )
1188
1189            policy tenant_or_admin on orders
1190                for all
1191                using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1192        "#;
1193
1194        let schema = Schema::parse(input).expect("parse failed");
1195        let policy = &schema.policies[0];
1196
1197        match policy.using.as_ref().unwrap() {
1198            Expr::Binary { op: BinaryOp::Or, .. } => {}
1199            e => panic!("Expected Binary OR, got {:?}", e),
1200        }
1201    }
1202
1203    #[test]
1204    fn test_schema_to_sql() {
1205        let input = r#"
1206            table orders (
1207                id uuid primary_key,
1208                operator_id uuid not null
1209            ) enable_rls
1210
1211            policy orders_isolation on orders
1212                for all
1213                using (operator_id = current_setting('app.current_operator_id')::uuid)
1214        "#;
1215
1216        let schema = Schema::parse(input).expect("parse failed");
1217        let sql = schema.to_sql();
1218        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1219        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1220        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1221        assert!(sql.contains("CREATE POLICY"));
1222        assert!(sql.contains("orders_isolation"));
1223        assert!(sql.contains("FOR ALL"));
1224    }
1225
1226    #[test]
1227    fn test_multiple_policies() {
1228        let input = r#"
1229            table orders (
1230                id uuid primary_key,
1231                operator_id uuid not null
1232            ) enable_rls
1233
1234            policy orders_read on orders
1235                for select
1236                using (operator_id = current_setting('app.current_operator_id')::uuid)
1237
1238            policy orders_write on orders
1239                for insert
1240                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1241        "#;
1242
1243        let schema = Schema::parse(input).expect("parse failed");
1244        assert_eq!(schema.policies.len(), 2);
1245        assert_eq!(schema.policies[0].name, "orders_read");
1246        assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1247        assert_eq!(schema.policies[1].name, "orders_write");
1248        assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1249    }
1250}