Skip to main content

qail_core/parser/
schema.rs

1//! Schema file parser for `.qail` format.
2//!
3//! Parses schema definitions like:
4//! ```text
5//! table users (
6//!   id uuid primary_key,
7//!   email text not null,
8//!   name text,
9//!   created_at timestamp
10//! )
11//!
12//! policy users_isolation on users
13//!     for all
14//!     using (operator_id = current_setting('app.operator_id')::uuid)
15//! ```
16
17use nom::{
18    IResult, Parser,
19    branch::alt,
20    bytes::complete::{tag, tag_no_case, take_while1},
21    character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22    combinator::{map, opt},
23    multi::{many0, separated_list0},
24    sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{BinaryOp, Expr, Value as AstValue};
29use crate::migrate::alter::AlterTable;
30use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
31use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
32
33/// Schema containing all table definitions
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36    /// Schema format version (extracted from `-- qail: version=N` directive)
37    #[serde(default)]
38    pub version: Option<u32>,
39    /// Table definitions.
40    pub tables: Vec<TableDef>,
41    /// RLS policies declared in the schema
42    #[serde(default)]
43    pub policies: Vec<RlsPolicy>,
44    /// Indexes declared in the schema
45    #[serde(default)]
46    pub indexes: Vec<IndexDef>,
47}
48
49/// Index definition parsed from `index <name> on <table> (<columns>) [unique]`
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct IndexDef {
52    /// Index name.
53    pub name: String,
54    /// Table this index belongs to.
55    pub table: String,
56    /// Columns included in the index.
57    pub columns: Vec<String>,
58    /// Whether this is a UNIQUE index.
59    #[serde(default)]
60    pub unique: bool,
61}
62
63impl IndexDef {
64    /// Generate `CREATE INDEX IF NOT EXISTS` SQL.
65    pub fn to_sql(&self) -> String {
66        let unique = if self.unique { " UNIQUE" } else { "" };
67        format!(
68            "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
69            unique,
70            self.name,
71            self.table,
72            self.columns.join(", ")
73        )
74    }
75}
76
77/// Table definition parsed from a `.qail` schema file.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TableDef {
80    /// Table name.
81    pub name: String,
82    /// Column definitions.
83    pub columns: Vec<ColumnDef>,
84    /// Whether this table has RLS enabled.
85    #[serde(default)]
86    pub enable_rls: bool,
87}
88
89/// Column definition parsed from a `.qail` schema file.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ColumnDef {
92    /// Column name.
93    pub name: String,
94    /// SQL data type (lowercased).
95    #[serde(rename = "type", alias = "typ")]
96    pub typ: String,
97    /// Type is an array (e.g., text[], uuid[]).
98    #[serde(default)]
99    pub is_array: bool,
100    /// Type parameters (e.g., varchar(255) → Some(vec!["255"]), decimal(10,2) → Some(vec!["10", "2"])).
101    #[serde(default)]
102    pub type_params: Option<Vec<String>>,
103    /// Whether the column accepts NULL.
104    #[serde(default)]
105    pub nullable: bool,
106    /// Whether the column is a primary key.
107    #[serde(default)]
108    pub primary_key: bool,
109    /// Whether the column has a UNIQUE constraint.
110    #[serde(default)]
111    pub unique: bool,
112    #[serde(default)]
113    /// Foreign key reference (e.g. "users(id)").
114    pub references: Option<String>,
115    /// Default value expression.
116    #[serde(default)]
117    pub default_value: Option<String>,
118    /// Check constraint expression
119    #[serde(default)]
120    pub check: Option<String>,
121    /// Is this a serial/auto-increment type
122    #[serde(default)]
123    pub is_serial: bool,
124}
125
126impl Default for ColumnDef {
127    fn default() -> Self {
128        Self {
129            name: String::new(),
130            typ: String::new(),
131            is_array: false,
132            type_params: None,
133            nullable: true,
134            primary_key: false,
135            unique: false,
136            references: None,
137            default_value: None,
138            check: None,
139            is_serial: false,
140        }
141    }
142}
143
144impl Schema {
145    /// Parse a schema from `.qail` format string
146    pub fn parse(input: &str) -> Result<Self, String> {
147        match parse_schema(input) {
148            Ok(("", schema)) => Ok(schema),
149            Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
150            Err(e) => Err(format!("Parse error: {:?}", e)),
151        }
152    }
153
154    /// Find a table by name
155    pub fn find_table(&self, name: &str) -> Option<&TableDef> {
156        self.tables
157            .iter()
158            .find(|t| t.name.eq_ignore_ascii_case(name))
159    }
160
161    /// Generate complete SQL for this schema: tables + RLS + policies + indexes.
162    pub fn to_sql(&self) -> String {
163        let mut parts = Vec::new();
164
165        for table in &self.tables {
166            parts.push(table.to_ddl());
167
168            if table.enable_rls {
169                let alter = AlterTable::new(&table.name).enable_rls().force_rls();
170                for stmt in alter_table_sql(&alter) {
171                    parts.push(stmt);
172                }
173            }
174        }
175
176        for idx in &self.indexes {
177            parts.push(idx.to_sql());
178        }
179
180        for policy in &self.policies {
181            parts.push(create_policy_sql(policy));
182        }
183
184        parts.join(";\n\n") + ";"
185    }
186
187    /// Export schema to JSON string (for qail-macros compatibility)
188    pub fn to_json(&self) -> Result<String, String> {
189        serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
190    }
191
192    /// Import schema from JSON string
193    pub fn from_json(json: &str) -> Result<Self, String> {
194        serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
195    }
196
197    /// Load schema from a .qail file
198    pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
199        let content =
200            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
201
202        if content.trim().starts_with('{') {
203            Self::from_json(&content)
204        } else {
205            Self::parse(&content)
206        }
207    }
208}
209
210impl TableDef {
211    /// Find a column by name
212    pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
213        self.columns
214            .iter()
215            .find(|c| c.name.eq_ignore_ascii_case(name))
216    }
217
218    /// Generate CREATE TABLE IF NOT EXISTS SQL (AST-native DDL).
219    pub fn to_ddl(&self) -> String {
220        let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
221
222        let mut col_defs = Vec::new();
223        for col in &self.columns {
224            let mut line = format!("    {}", col.name);
225
226            // Type with params
227            let mut typ = col.typ.to_uppercase();
228            if let Some(params) = &col.type_params {
229                typ = format!("{}({})", typ, params.join(", "));
230            }
231            if col.is_array {
232                typ.push_str("[]");
233            }
234            line.push_str(&format!(" {}", typ));
235
236            // Constraints
237            if col.primary_key {
238                line.push_str(" PRIMARY KEY");
239            }
240            if !col.nullable && !col.primary_key && !col.is_serial {
241                line.push_str(" NOT NULL");
242            }
243            if col.unique && !col.primary_key {
244                line.push_str(" UNIQUE");
245            }
246            if let Some(ref default) = col.default_value {
247                line.push_str(&format!(" DEFAULT {}", default));
248            }
249            if let Some(ref refs) = col.references {
250                line.push_str(&format!(" REFERENCES {}", refs));
251            }
252            if let Some(ref check) = col.check {
253                line.push_str(&format!(" CHECK({})", check));
254            }
255
256            col_defs.push(line);
257        }
258
259        sql.push_str(&col_defs.join(",\n"));
260        sql.push_str("\n)");
261        sql
262    }
263}
264
265// =============================================================================
266// Parsing Combinators
267// =============================================================================
268
269/// Parse identifier (table/column name)
270fn identifier(input: &str) -> IResult<&str, &str> {
271    take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
272}
273
274/// Skip whitespace and comments (both `--` and `#` styles)
275fn ws_and_comments(input: &str) -> IResult<&str, ()> {
276    let (input, _) = many0(alt((
277        map(multispace1, |_| ()),
278        map((tag("--"), not_line_ending), |_| ()),
279        map((tag("#"), not_line_ending), |_| ()),
280    )))
281    .parse(input)?;
282    Ok((input, ()))
283}
284
285struct TypeInfo {
286    name: String,
287    params: Option<Vec<String>>,
288    is_array: bool,
289    is_serial: bool,
290}
291
292/// Parse column type with optional params and array suffix
293/// Handles: varchar(255), decimal(10,2), text[], serial, bigserial
294fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
295    let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
296
297    let (input, params) = if input.starts_with('(') {
298        let paren_start = 1;
299        let mut paren_end = paren_start;
300        for (i, c) in input[paren_start..].char_indices() {
301            if c == ')' {
302                paren_end = paren_start + i;
303                break;
304            }
305        }
306        let param_str = &input[paren_start..paren_end];
307        let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
308        (&input[paren_end + 1..], Some(params))
309    } else {
310        (input, None)
311    };
312
313    let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
314        (stripped, true)
315    } else {
316        (input, false)
317    };
318
319    let lower = type_name.to_lowercase();
320    let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
321
322    Ok((
323        input,
324        TypeInfo {
325            name: lower,
326            params,
327            is_array,
328            is_serial,
329        },
330    ))
331}
332
333/// Parse constraint text until comma or closing paren (handling nested parens)
334fn constraint_text(input: &str) -> IResult<&str, &str> {
335    let mut paren_depth = 0;
336    let mut end = 0;
337
338    for (i, c) in input.char_indices() {
339        match c {
340            '(' => paren_depth += 1,
341            ')' => {
342                if paren_depth == 0 {
343                    break; // End at column-level closing paren
344                }
345                paren_depth -= 1;
346            }
347            ',' if paren_depth == 0 => break,
348            '\n' | '\r' if paren_depth == 0 => break,
349            _ => {}
350        }
351        end = i + c.len_utf8();
352    }
353
354    if end == 0 {
355        Err(nom::Err::Error(nom::error::Error::new(
356            input,
357            nom::error::ErrorKind::TakeWhile1,
358        )))
359    } else {
360        Ok((&input[end..], &input[..end]))
361    }
362}
363
364/// Parse a single column definition
365fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
366    let (input, _) = ws_and_comments(input)?;
367    let (input, name) = identifier(input)?;
368    let (input, _) = multispace1(input)?;
369    let (input, type_info) = parse_type_info(input)?;
370
371    let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
372
373    let mut col = ColumnDef {
374        name: name.to_string(),
375        typ: type_info.name,
376        is_array: type_info.is_array,
377        type_params: type_info.params,
378        is_serial: type_info.is_serial,
379        nullable: !type_info.is_serial, // Serial types are implicitly not null
380        ..Default::default()
381    };
382
383    if let Some(constraints) = constraint_str {
384        let lower = constraints.to_lowercase();
385
386        if lower.contains("primary_key") || lower.contains("primary key") {
387            col.primary_key = true;
388            col.nullable = false;
389        }
390        if lower.contains("not_null") || lower.contains("not null") {
391            col.nullable = false;
392        }
393        if lower.contains("unique") {
394            col.unique = true;
395        }
396
397        if let Some(idx) = lower.find("references ") {
398            let rest = &constraints[idx + 11..];
399            // Find end (space or end of string), but handle nested parens
400            let mut paren_depth = 0;
401            let mut end = rest.len();
402            for (i, c) in rest.char_indices() {
403                match c {
404                    '(' => paren_depth += 1,
405                    ')' => {
406                        if paren_depth == 0 {
407                            end = i;
408                            break;
409                        }
410                        paren_depth -= 1;
411                    }
412                    c if c.is_whitespace() && paren_depth == 0 => {
413                        end = i;
414                        break;
415                    }
416                    _ => {}
417                }
418            }
419            col.references = Some(rest[..end].to_string());
420        }
421
422        if let Some(idx) = lower.find("default ") {
423            let rest = &constraints[idx + 8..];
424            let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
425            col.default_value = Some(rest[..end].to_string());
426        }
427
428        if let Some(idx) = lower.find("check(") {
429            let rest = &constraints[idx + 6..];
430            // Find matching closing paren
431            let mut depth = 1;
432            let mut end = rest.len();
433            for (i, c) in rest.char_indices() {
434                match c {
435                    '(' => depth += 1,
436                    ')' => {
437                        depth -= 1;
438                        if depth == 0 {
439                            end = i;
440                            break;
441                        }
442                    }
443                    _ => {}
444                }
445            }
446            col.check = Some(rest[..end].to_string());
447        }
448    }
449
450    Ok((input, col))
451}
452
453/// Parse column list: (col1 type, col2 type, ...)
454fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
455    let (input, _) = ws_and_comments(input)?;
456    let (input, _) = char('(').parse(input)?;
457    let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
458    let (input, _) = ws_and_comments(input)?;
459    let (input, _) = char(')').parse(input)?;
460
461    Ok((input, columns))
462}
463
464/// Parse a table definition
465fn parse_table(input: &str) -> IResult<&str, TableDef> {
466    let (input, _) = ws_and_comments(input)?;
467    let (input, _) = tag_no_case("table").parse(input)?;
468    let (input, _) = multispace1(input)?;
469    let (input, name) = identifier(input)?;
470    let (input, columns) = parse_column_list(input)?;
471
472    // Optional enable_rls annotation after closing paren
473    let (input, _) = ws_and_comments(input)?;
474    let enable_rls = if let Ok((rest, _)) =
475        tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
476    {
477        return Ok((
478            rest,
479            TableDef {
480                name: name.to_string(),
481                columns,
482                enable_rls: true,
483            },
484        ));
485    } else {
486        false
487    };
488
489    Ok((
490        input,
491        TableDef {
492            name: name.to_string(),
493            columns,
494            enable_rls,
495        },
496    ))
497}
498
499// =============================================================================
500// Policy Parsing
501// =============================================================================
502
503/// A schema item is either a table, policy, or index.
504enum SchemaItem {
505    Table(TableDef),
506    Policy(Box<RlsPolicy>),
507    Index(IndexDef),
508}
509
510/// Parse a policy definition.
511///
512/// Syntax:
513/// ```text
514/// policy <name> on <table>
515///     [for all|select|insert|update|delete]
516///     [restrictive]
517///     [to <role>]
518///     [using (<expr>)]
519///     [with check (<expr>)]
520/// ```
521fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
522    let (input, _) = ws_and_comments(input)?;
523    let (input, _) = tag_no_case("policy").parse(input)?;
524    let (input, _) = multispace1(input)?;
525    let (input, name) = identifier(input)?;
526    let (input, _) = multispace1(input)?;
527    let (input, _) = tag_no_case("on").parse(input)?;
528    let (input, _) = multispace1(input)?;
529    let (input, table) = identifier(input)?;
530
531    let mut policy = RlsPolicy::create(name, table);
532
533    // Parse optional clauses in any order
534    let mut remaining = input;
535    loop {
536        let (input, _) = ws_and_comments(remaining)?;
537
538        // for all|select|insert|update|delete
539        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
540            let (rest, _) = multispace1(rest)?;
541            let (rest, target) = alt((
542                map(tag_no_case("all"), |_| PolicyTarget::All),
543                map(tag_no_case("select"), |_| PolicyTarget::Select),
544                map(tag_no_case("insert"), |_| PolicyTarget::Insert),
545                map(tag_no_case("update"), |_| PolicyTarget::Update),
546                map(tag_no_case("delete"), |_| PolicyTarget::Delete),
547            ))
548            .parse(rest)?;
549            policy.target = target;
550            remaining = rest;
551            continue;
552        }
553
554        // restrictive
555        if let Ok((rest, _)) =
556            tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
557        {
558            policy.permissiveness = PolicyPermissiveness::Restrictive;
559            remaining = rest;
560            continue;
561        }
562
563        // to <role>
564        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
565            // Make sure it's not "to_sql" or similar — needs whitespace after
566            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
567                let (rest, role) = identifier(rest)?;
568                policy.role = Some(role.to_string());
569                remaining = rest;
570                continue;
571            }
572        }
573
574        // with check (<expr>)
575        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
576            let (rest, _) = multispace1(rest)?;
577            if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
578            {
579                let (rest, _) = nom_ws0(rest)?;
580                let (rest, _) = char('(').parse(rest)?;
581                let (rest, _) = nom_ws0(rest)?;
582                let (rest, expr) = parse_policy_expr(rest)?;
583                let (rest, _) = nom_ws0(rest)?;
584                let (rest, _) = char(')').parse(rest)?;
585                policy.with_check = Some(expr);
586                remaining = rest;
587                continue;
588            }
589        }
590
591        // using (<expr>)
592        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
593            let (rest, _) = nom_ws0(rest)?;
594            let (rest, _) = char('(').parse(rest)?;
595            let (rest, _) = nom_ws0(rest)?;
596            let (rest, expr) = parse_policy_expr(rest)?;
597            let (rest, _) = nom_ws0(rest)?;
598            let (rest, _) = char(')').parse(rest)?;
599            policy.using = Some(expr);
600            remaining = rest;
601            continue;
602        }
603
604        // No more clauses matched
605        remaining = input;
606        break;
607    }
608
609    Ok((remaining, policy))
610}
611
612/// Parse a policy expression: `left op right [AND/OR left op right ...]`
613///
614/// Produces typed `Expr::Binary` AST nodes — no raw SQL.
615///
616/// Handles:
617/// - `column = value`
618/// - `column = function('arg')::type`   (function call + cast)
619/// - `expr AND expr`, `expr OR expr`
620fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
621    let (input, first) = parse_policy_comparison(input)?;
622
623    // Check for AND/OR chaining
624    let mut result = first;
625    let mut remaining = input;
626    loop {
627        let (input, _) = nom_ws0(remaining)?;
628
629        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
630            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
631        {
632            let (rest, right) = parse_policy_comparison(rest)?;
633            result = Expr::Binary {
634                left: Box::new(result),
635                op: BinaryOp::Or,
636                right: Box::new(right),
637                alias: None,
638            };
639            remaining = rest;
640            continue;
641        }
642
643        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
644            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
645        {
646            let (rest, right) = parse_policy_comparison(rest)?;
647            result = Expr::Binary {
648                left: Box::new(result),
649                op: BinaryOp::And,
650                right: Box::new(right),
651                alias: None,
652            };
653            remaining = rest;
654            continue;
655        }
656
657        remaining = input;
658        break;
659    }
660
661    Ok((remaining, result))
662}
663
664/// Parse a single comparison: `atom op atom`
665fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
666    let (input, left) = parse_policy_atom(input)?;
667    let (input, _) = nom_ws0(input)?;
668
669    // Try to parse comparison operator
670    if let Ok((rest, op)) = parse_cmp_op(input) {
671        let (rest, _) = nom_ws0(rest)?;
672        let (rest, right) = parse_policy_atom(rest)?;
673        return Ok((
674            rest,
675            Expr::Binary {
676                left: Box::new(left),
677                op,
678                right: Box::new(right),
679                alias: None,
680            },
681        ));
682    }
683
684    // No operator — just an atom
685    Ok((input, left))
686}
687
688/// Parse comparison operators: =, !=, <>, >=, <=, >, <
689fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
690    alt((
691        map(tag(">="), |_| BinaryOp::Gte),
692        map(tag("<="), |_| BinaryOp::Lte),
693        map(tag("<>"), |_| BinaryOp::Ne),
694        map(tag("!="), |_| BinaryOp::Ne),
695        map(tag("="), |_| BinaryOp::Eq),
696        map(tag(">"), |_| BinaryOp::Gt),
697        map(tag("<"), |_| BinaryOp::Lt),
698    ))
699    .parse(input)
700}
701
702/// Parse a policy expression atom:
703/// - identifier  (column name)
704/// - function_call(args)  with optional ::cast
705/// - 'string literal'
706/// - numeric literal
707/// - true/false
708/// - (sub_expr)  grouped
709fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
710    alt((
711        parse_policy_grouped,
712        parse_policy_bool,
713        parse_policy_string,
714        parse_policy_number,
715        parse_policy_func_or_ident, // function call or plain identifier, with optional ::cast
716    ))
717    .parse(input)
718}
719
720/// Parse grouped expression in parens
721fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
722    let (input, _) = char('(').parse(input)?;
723    let (input, _) = nom_ws0(input)?;
724    let (input, expr) = parse_policy_expr(input)?;
725    let (input, _) = nom_ws0(input)?;
726    let (input, _) = char(')').parse(input)?;
727    Ok((input, expr))
728}
729
730/// Parse true / false
731fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
732    alt((
733        map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
734        map(tag_no_case("false"), |_| {
735            Expr::Literal(AstValue::Bool(false))
736        }),
737    ))
738    .parse(input)
739}
740
741/// Parse a 'string literal'
742fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
743    let (input, _) = char('\'').parse(input)?;
744    let mut end = 0;
745    for (i, c) in input.char_indices() {
746        if c == '\'' {
747            end = i;
748            break;
749        }
750    }
751    let content = &input[..end];
752    let rest = &input[end + 1..];
753    Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
754}
755
756/// Parse numeric literal
757fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
758    let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
759    // Make sure it starts with digit (not just '.')
760    if digits.starts_with('.') || digits.is_empty() {
761        return Err(nom::Err::Error(nom::error::Error::new(
762            input,
763            nom::error::ErrorKind::Digit,
764        )));
765    }
766    if let Ok(n) = digits.parse::<i64>() {
767        Ok((input, Expr::Literal(AstValue::Int(n))))
768    } else if let Ok(f) = digits.parse::<f64>() {
769        Ok((input, Expr::Literal(AstValue::Float(f))))
770    } else {
771        Ok((input, Expr::Named(digits.to_string())))
772    }
773}
774
775/// Parse function call or identifier, with optional ::cast
776fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
777    let (input, name) = identifier(input)?;
778
779    // Check for function call: name(
780    let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
781        // Parse args
782        let (rest, _) = nom_ws0(rest)?;
783        let (rest, args) =
784            separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
785        let (rest, _) = nom_ws0(rest)?;
786        let (rest, _) = char(')').parse(rest)?;
787        let input = rest;
788        (
789            input,
790            Expr::FunctionCall {
791                name: name.to_string(),
792                args,
793                alias: None,
794            },
795        )
796    } else {
797        (input, Expr::Named(name.to_string()))
798    };
799
800    // Check for ::cast
801    if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
802        let (rest, cast_type) = identifier(rest)?;
803        expr = (
804            rest,
805            Expr::Cast {
806                expr: Box::new(expr.1),
807                target_type: cast_type.to_string(),
808                alias: None,
809            },
810        );
811    }
812
813    Ok(expr)
814}
815
816/// Parse a single schema item: table, policy, or index
817fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
818    let (input, _) = ws_and_comments(input)?;
819
820    // Try policy first (since "policy" is a distinct keyword)
821    if let Ok((rest, policy)) = parse_policy(input) {
822        return Ok((rest, SchemaItem::Policy(Box::new(policy))));
823    }
824
825    // Try index
826    if let Ok((rest, idx)) = parse_index(input) {
827        return Ok((rest, SchemaItem::Index(idx)));
828    }
829
830    // Otherwise parse table
831    let (rest, table) = parse_table(input)?;
832    Ok((rest, SchemaItem::Table(table)))
833}
834
835/// Parse an index line: `index <name> on <table> (<col1>, <col2>) [unique]`
836fn parse_index(input: &str) -> IResult<&str, IndexDef> {
837    let (input, _) = tag_no_case("index")(input)?;
838    let (input, _) = multispace1(input)?;
839    let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
840    let (input, _) = multispace1(input)?;
841    let (input, _) = tag_no_case("on")(input)?;
842    let (input, _) = multispace1(input)?;
843    let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
844    let (input, _) = nom_ws0(input)?;
845    let (input, _) = char('(')(input)?;
846    let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
847    let (input, _) = char(')')(input)?;
848    let (input, _) = nom_ws0(input)?;
849    let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
850
851    let columns: Vec<String> = cols_str
852        .split(',')
853        .map(|s| s.trim().to_string())
854        .filter(|s| !s.is_empty())
855        .collect();
856
857    let is_unique = unique_tag.is_some();
858
859    Ok((
860        input,
861        IndexDef {
862            name: name.to_string(),
863            table: table.to_string(),
864            columns,
865            unique: is_unique,
866        },
867    ))
868}
869
870/// Parse complete schema file
871fn parse_schema(input: &str) -> IResult<&str, Schema> {
872    // Extract version directive before parsing
873    let version = extract_version_directive(input);
874
875    let (input, items) = many0(parse_schema_item).parse(input)?;
876    let (input, _) = ws_and_comments(input)?;
877
878    let mut tables = Vec::new();
879    let mut policies = Vec::new();
880    let mut indexes = Vec::new();
881    for item in items {
882        match item {
883            SchemaItem::Table(t) => tables.push(t),
884            SchemaItem::Policy(p) => policies.push(*p),
885            SchemaItem::Index(i) => indexes.push(i),
886        }
887    }
888
889    Ok((
890        input,
891        Schema {
892            version,
893            tables,
894            policies,
895            indexes,
896        },
897    ))
898}
899
900/// Extract version from `-- qail: version=N` directive
901fn extract_version_directive(input: &str) -> Option<u32> {
902    for line in input.lines() {
903        let line = line.trim();
904        if let Some(rest) = line.strip_prefix("-- qail:") {
905            let rest = rest.trim();
906            if let Some(version_str) = rest.strip_prefix("version=") {
907                return version_str.trim().parse().ok();
908            }
909        }
910    }
911    None
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917
918    #[test]
919    fn test_parse_simple_table() {
920        let input = r#"
921            table users (
922                id uuid primary_key,
923                email text not null,
924                name text
925            )
926        "#;
927
928        let schema = Schema::parse(input).expect("parse failed");
929        assert_eq!(schema.tables.len(), 1);
930
931        let users = &schema.tables[0];
932        assert_eq!(users.name, "users");
933        assert_eq!(users.columns.len(), 3);
934
935        let id = &users.columns[0];
936        assert_eq!(id.name, "id");
937        assert_eq!(id.typ, "uuid");
938        assert!(id.primary_key);
939        assert!(!id.nullable);
940
941        let email = &users.columns[1];
942        assert_eq!(email.name, "email");
943        assert!(!email.nullable);
944
945        let name = &users.columns[2];
946        assert!(name.nullable);
947    }
948
949    #[test]
950    fn test_parse_multiple_tables() {
951        let input = r#"
952            -- Users table
953            table users (
954                id uuid primary_key,
955                email text not null unique
956            )
957            
958            -- Orders table
959            table orders (
960                id uuid primary_key,
961                user_id uuid references users(id),
962                total i64 not null default 0
963            )
964        "#;
965
966        let schema = Schema::parse(input).expect("parse failed");
967        assert_eq!(schema.tables.len(), 2);
968
969        let orders = schema.find_table("orders").expect("orders not found");
970        let user_id = orders.find_column("user_id").expect("user_id not found");
971        assert_eq!(user_id.references, Some("users(id)".to_string()));
972
973        let total = orders.find_column("total").expect("total not found");
974        assert_eq!(total.default_value, Some("0".to_string()));
975    }
976
977    #[test]
978    fn test_parse_comments() {
979        let input = r#"
980            -- This is a comment
981            table foo (
982                bar text
983            )
984        "#;
985
986        let schema = Schema::parse(input).expect("parse failed");
987        assert_eq!(schema.tables.len(), 1);
988    }
989
990    #[test]
991    fn test_array_types() {
992        let input = r#"
993            table products (
994                id uuid primary_key,
995                tags text[],
996                prices decimal[]
997            )
998        "#;
999
1000        let schema = Schema::parse(input).expect("parse failed");
1001        let products = &schema.tables[0];
1002
1003        let tags = products.find_column("tags").expect("tags not found");
1004        assert_eq!(tags.typ, "text");
1005        assert!(tags.is_array);
1006
1007        let prices = products.find_column("prices").expect("prices not found");
1008        assert!(prices.is_array);
1009    }
1010
1011    #[test]
1012    fn test_type_params() {
1013        let input = r#"
1014            table items (
1015                id serial primary_key,
1016                name varchar(255) not null,
1017                price decimal(10,2),
1018                code varchar(50) unique
1019            )
1020        "#;
1021
1022        let schema = Schema::parse(input).expect("parse failed");
1023        let items = &schema.tables[0];
1024
1025        let id = items.find_column("id").expect("id not found");
1026        assert!(id.is_serial);
1027        assert!(!id.nullable); // Serial is implicitly not null
1028
1029        let name = items.find_column("name").expect("name not found");
1030        assert_eq!(name.typ, "varchar");
1031        assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1032
1033        let price = items.find_column("price").expect("price not found");
1034        assert_eq!(
1035            price.type_params,
1036            Some(vec!["10".to_string(), "2".to_string()])
1037        );
1038
1039        let code = items.find_column("code").expect("code not found");
1040        assert!(code.unique);
1041    }
1042
1043    #[test]
1044    fn test_check_constraint() {
1045        let input = r#"
1046            table employees (
1047                id uuid primary_key,
1048                age i32 check(age >= 18),
1049                salary decimal check(salary > 0)
1050            )
1051        "#;
1052
1053        let schema = Schema::parse(input).expect("parse failed");
1054        let employees = &schema.tables[0];
1055
1056        let age = employees.find_column("age").expect("age not found");
1057        assert_eq!(age.check, Some("age >= 18".to_string()));
1058
1059        let salary = employees.find_column("salary").expect("salary not found");
1060        assert_eq!(salary.check, Some("salary > 0".to_string()));
1061    }
1062
1063    #[test]
1064    fn test_version_directive() {
1065        let input = r#"
1066            -- qail: version=1
1067            table users (
1068                id uuid primary_key
1069            )
1070        "#;
1071
1072        let schema = Schema::parse(input).expect("parse failed");
1073        assert_eq!(schema.version, Some(1));
1074        assert_eq!(schema.tables.len(), 1);
1075
1076        // Without version directive
1077        let input_no_version = r#"
1078            table items (
1079                id uuid primary_key
1080            )
1081        "#;
1082        let schema2 = Schema::parse(input_no_version).expect("parse failed");
1083        assert_eq!(schema2.version, None);
1084    }
1085
1086    // =========================================================================
1087    // Policy + enable_rls tests
1088    // =========================================================================
1089
1090    #[test]
1091    fn test_enable_rls_table() {
1092        let input = r#"
1093            table orders (
1094                id uuid primary_key,
1095                operator_id uuid not null
1096            ) enable_rls
1097        "#;
1098
1099        let schema = Schema::parse(input).expect("parse failed");
1100        assert_eq!(schema.tables.len(), 1);
1101        assert!(schema.tables[0].enable_rls);
1102    }
1103
1104    #[test]
1105    fn test_parse_policy_basic() {
1106        let input = r#"
1107            table orders (
1108                id uuid primary_key,
1109                operator_id uuid not null
1110            ) enable_rls
1111
1112            policy orders_isolation on orders
1113                for all
1114                using (operator_id = current_setting('app.current_operator_id')::uuid)
1115        "#;
1116
1117        let schema = Schema::parse(input).expect("parse failed");
1118        assert_eq!(schema.tables.len(), 1);
1119        assert_eq!(schema.policies.len(), 1);
1120
1121        let policy = &schema.policies[0];
1122        assert_eq!(policy.name, "orders_isolation");
1123        assert_eq!(policy.table, "orders");
1124        assert_eq!(policy.target, PolicyTarget::All);
1125        assert!(policy.using.is_some());
1126
1127        // Verify the expression is a typed Binary, not raw SQL
1128        let using = policy.using.as_ref().unwrap();
1129        let Expr::Binary {
1130            left, op, right, ..
1131        } = using
1132        else {
1133            panic!("Expected Binary, got {using:?}");
1134        };
1135        assert_eq!(*op, BinaryOp::Eq);
1136
1137        let Expr::Named(n) = left.as_ref() else {
1138            panic!("Expected Named, got {left:?}");
1139        };
1140        assert_eq!(n, "operator_id");
1141
1142        let Expr::Cast {
1143            target_type,
1144            expr: cast_expr,
1145            ..
1146        } = right.as_ref()
1147        else {
1148            panic!("Expected Cast, got {right:?}");
1149        };
1150        assert_eq!(target_type, "uuid");
1151
1152        let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1153            panic!("Expected FunctionCall, got {cast_expr:?}");
1154        };
1155        assert_eq!(name, "current_setting");
1156        assert_eq!(args.len(), 1);
1157    }
1158
1159    #[test]
1160    fn test_parse_policy_with_check() {
1161        let input = r#"
1162            table orders (
1163                id uuid primary_key
1164            )
1165
1166            policy orders_write on orders
1167                for insert
1168                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1169        "#;
1170
1171        let schema = Schema::parse(input).expect("parse failed");
1172        let policy = &schema.policies[0];
1173        assert_eq!(policy.target, PolicyTarget::Insert);
1174        assert!(policy.with_check.is_some());
1175        assert!(policy.using.is_none());
1176    }
1177
1178    #[test]
1179    fn test_parse_policy_restrictive_with_role() {
1180        let input = r#"
1181            table secrets (
1182                id uuid primary_key
1183            )
1184
1185            policy admin_only on secrets
1186                for select
1187                restrictive
1188                to app_user
1189                using (current_setting('app.is_super_admin')::boolean = true)
1190        "#;
1191
1192        let schema = Schema::parse(input).expect("parse failed");
1193        let policy = &schema.policies[0];
1194        assert_eq!(policy.target, PolicyTarget::Select);
1195        assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1196        assert_eq!(policy.role.as_deref(), Some("app_user"));
1197        assert!(policy.using.is_some());
1198    }
1199
1200    #[test]
1201    fn test_parse_policy_or_expr() {
1202        let input = r#"
1203            table orders (
1204                id uuid primary_key
1205            )
1206
1207            policy tenant_or_admin on orders
1208                for all
1209                using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1210        "#;
1211
1212        let schema = Schema::parse(input).expect("parse failed");
1213        let policy = &schema.policies[0];
1214
1215        assert!(
1216            matches!(
1217                policy.using.as_ref().unwrap(),
1218                Expr::Binary {
1219                    op: BinaryOp::Or,
1220                    ..
1221                }
1222            ),
1223            "Expected Binary OR, got {:?}",
1224            policy.using
1225        );
1226    }
1227
1228    #[test]
1229    fn test_schema_to_sql() {
1230        let input = r#"
1231            table orders (
1232                id uuid primary_key,
1233                operator_id uuid not null
1234            ) enable_rls
1235
1236            policy orders_isolation on orders
1237                for all
1238                using (operator_id = current_setting('app.current_operator_id')::uuid)
1239        "#;
1240
1241        let schema = Schema::parse(input).expect("parse failed");
1242        let sql = schema.to_sql();
1243        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1244        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1245        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1246        assert!(sql.contains("CREATE POLICY"));
1247        assert!(sql.contains("orders_isolation"));
1248        assert!(sql.contains("FOR ALL"));
1249    }
1250
1251    #[test]
1252    fn test_multiple_policies() {
1253        let input = r#"
1254            table orders (
1255                id uuid primary_key,
1256                operator_id uuid not null
1257            ) enable_rls
1258
1259            policy orders_read on orders
1260                for select
1261                using (operator_id = current_setting('app.current_operator_id')::uuid)
1262
1263            policy orders_write on orders
1264                for insert
1265                with check (operator_id = current_setting('app.current_operator_id')::uuid)
1266        "#;
1267
1268        let schema = Schema::parse(input).expect("parse failed");
1269        assert_eq!(schema.policies.len(), 2);
1270        assert_eq!(schema.policies[0].name, "orders_read");
1271        assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1272        assert_eq!(schema.policies[1].name, "orders_write");
1273        assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1274    }
1275}