Skip to main content

qail_core/parser/
schema.rs

1//! Schema file parser for `.qail` format.
2//!
3//! Parses schema definitions like:
4//! ```text
5//! table users (
6//!   id uuid primary_key,
7//!   email text not null,
8//!   name text,
9//!   created_at timestamp
10//! )
11//!
12//! policy users_isolation on users
13//!     for all
14//!     using (tenant_id = current_setting('app.tenant_id')::uuid)
15//! ```
16
17use nom::{
18    IResult, Parser,
19    branch::alt,
20    bytes::complete::{tag, tag_no_case, take_while1},
21    character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22    combinator::{map, opt},
23    multi::{many0, separated_list0},
24    sequence::preceded,
25};
26use std::collections::HashSet;
27
28use crate::ast::{BinaryOp, Expr, Value as AstValue};
29use crate::migrate::alter::AlterTable;
30use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
31use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
32
33/// Schema containing all table definitions
34#[derive(Debug, Clone, Default)]
35pub struct Schema {
36    /// Schema format version (extracted from `-- qail: version=N` directive)
37    pub version: Option<u32>,
38    /// Table definitions.
39    pub tables: Vec<TableDef>,
40    /// RLS policies declared in the schema
41    pub policies: Vec<RlsPolicy>,
42    /// Indexes declared in the schema
43    pub indexes: Vec<IndexDef>,
44}
45
46/// Index definition parsed from `index <name> on <table> (<columns>) [unique]`
47#[derive(Debug, Clone)]
48pub struct IndexDef {
49    /// Index name.
50    pub name: String,
51    /// Table this index belongs to.
52    pub table: String,
53    /// Columns included in the index.
54    pub columns: Vec<String>,
55    /// Whether this is a UNIQUE index.
56    pub unique: bool,
57}
58
59impl IndexDef {
60    /// Generate `CREATE INDEX IF NOT EXISTS` SQL.
61    pub fn to_sql(&self) -> String {
62        let unique = if self.unique { " UNIQUE" } else { "" };
63        format!(
64            "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
65            unique,
66            self.name,
67            self.table,
68            self.columns.join(", ")
69        )
70    }
71}
72
73/// Table definition parsed from a `.qail` schema file.
74#[derive(Debug, Clone)]
75pub struct TableDef {
76    /// Table name.
77    pub name: String,
78    /// Column definitions.
79    pub columns: Vec<ColumnDef>,
80    /// Whether this table has RLS enabled.
81    pub enable_rls: bool,
82}
83
84/// Column definition parsed from a `.qail` schema file.
85#[derive(Debug, Clone)]
86pub struct ColumnDef {
87    /// Column name.
88    pub name: String,
89    /// SQL data type (lowercased).
90    pub typ: String,
91    /// Type is an array (e.g., text[], uuid[]).
92    pub is_array: bool,
93    /// Type parameters (e.g., varchar(255) → Some(vec!["255"]), decimal(10,2) → Some(vec!["10", "2"])).
94    pub type_params: Option<Vec<String>>,
95    /// Whether the column accepts NULL.
96    pub nullable: bool,
97    /// Whether the column is a primary key.
98    pub primary_key: bool,
99    /// Whether the column has a UNIQUE constraint.
100    pub unique: bool,
101    /// Foreign key reference (e.g. "users(id)").
102    pub references: Option<String>,
103    /// Default value expression.
104    pub default_value: Option<String>,
105    /// Check constraint expression
106    pub check: Option<String>,
107    /// Is this a serial/auto-increment type
108    pub is_serial: bool,
109}
110
111impl Default for ColumnDef {
112    fn default() -> Self {
113        Self {
114            name: String::new(),
115            typ: String::new(),
116            is_array: false,
117            type_params: None,
118            nullable: true,
119            primary_key: false,
120            unique: false,
121            references: None,
122            default_value: None,
123            check: None,
124            is_serial: false,
125        }
126    }
127}
128
129impl Schema {
130    /// Parse a schema from `.qail` format string
131    pub fn parse(input: &str) -> Result<Self, String> {
132        match parse_schema(input) {
133            Ok(("", schema)) => Ok(schema),
134            Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
135            Err(e) => Err(format!("Parse error: {:?}", e)),
136        }
137    }
138
139    /// Find a table by name
140    pub fn find_table(&self, name: &str) -> Option<&TableDef> {
141        self.tables
142            .iter()
143            .find(|t| t.name.eq_ignore_ascii_case(name))
144    }
145
146    /// Generate complete SQL for this schema: tables + RLS + policies + indexes.
147    pub fn to_sql(&self) -> String {
148        let mut parts = Vec::new();
149
150        for table in &self.tables {
151            parts.push(table.to_ddl());
152
153            if table.enable_rls {
154                let alter = AlterTable::new(&table.name).enable_rls().force_rls();
155                for stmt in alter_table_sql(&alter) {
156                    parts.push(stmt);
157                }
158            }
159        }
160
161        for idx in &self.indexes {
162            parts.push(idx.to_sql());
163        }
164
165        for policy in &self.policies {
166            parts.push(create_policy_sql(policy));
167        }
168
169        parts.join(";\n\n") + ";"
170    }
171
172    /// Load schema from a .qail file
173    pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
174        let content =
175            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
176        Self::parse(&content)
177    }
178}
179
180impl TableDef {
181    /// Find a column by name
182    pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
183        self.columns
184            .iter()
185            .find(|c| c.name.eq_ignore_ascii_case(name))
186    }
187
188    /// Generate CREATE TABLE IF NOT EXISTS SQL (AST-native DDL).
189    pub fn to_ddl(&self) -> String {
190        let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
191
192        let mut col_defs = Vec::new();
193        for col in &self.columns {
194            let mut line = format!("    {}", col.name);
195
196            // Type with params
197            let mut typ = col.typ.to_uppercase();
198            if let Some(params) = &col.type_params {
199                typ = format!("{}({})", typ, params.join(", "));
200            }
201            if col.is_array {
202                typ.push_str("[]");
203            }
204            line.push_str(&format!(" {}", typ));
205
206            // Constraints
207            if col.primary_key {
208                line.push_str(" PRIMARY KEY");
209            }
210            if !col.nullable && !col.primary_key && !col.is_serial {
211                line.push_str(" NOT NULL");
212            }
213            if col.unique && !col.primary_key {
214                line.push_str(" UNIQUE");
215            }
216            if let Some(ref default) = col.default_value {
217                line.push_str(&format!(" DEFAULT {}", default));
218            }
219            if let Some(ref refs) = col.references {
220                line.push_str(&format!(" REFERENCES {}", refs));
221            }
222            if let Some(ref check) = col.check {
223                line.push_str(&format!(" CHECK({})", check));
224            }
225
226            col_defs.push(line);
227        }
228
229        sql.push_str(&col_defs.join(",\n"));
230        sql.push_str("\n)");
231        sql
232    }
233}
234
235// =============================================================================
236// Parsing Combinators
237// =============================================================================
238
239/// Parse identifier (table/column name)
240fn identifier(input: &str) -> IResult<&str, &str> {
241    let (remaining, ident) =
242        take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_').parse(input)?;
243    if ident
244        .chars()
245        .next()
246        .is_some_and(|c| c.is_ascii_alphabetic() || c == '_')
247    {
248        Ok((remaining, ident))
249    } else {
250        Err(nom::Err::Error(nom::error::Error::new(
251            input,
252            nom::error::ErrorKind::Alpha,
253        )))
254    }
255}
256
257/// Skip whitespace and comments (both `--` and `#` styles)
258fn ws_and_comments(input: &str) -> IResult<&str, ()> {
259    let (input, _) = many0(alt((
260        map(multispace1, |_| ()),
261        map((tag("--"), not_line_ending), |_| ()),
262        map((tag("#"), not_line_ending), |_| ()),
263    )))
264    .parse(input)?;
265    Ok((input, ()))
266}
267
268struct TypeInfo {
269    name: String,
270    params: Option<Vec<String>>,
271    is_array: bool,
272    is_serial: bool,
273}
274
275/// Parse column type with optional params and array suffix
276/// Handles: varchar(255), decimal(10,2), text[], serial, bigserial
277fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
278    let (input, type_name) =
279        take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)?;
280    if !is_identifier_path(type_name) {
281        return Err(nom::Err::Error(nom::error::Error::new(
282            input,
283            nom::error::ErrorKind::Alpha,
284        )));
285    }
286
287    let (input, params) = if let Some(after_open) = input.strip_prefix('(') {
288        let Some(paren_end) = after_open.find(')') else {
289            return Err(nom::Err::Error(nom::error::Error::new(
290                input,
291                nom::error::ErrorKind::Char,
292            )));
293        };
294        let param_str = &after_open[..paren_end];
295        let Ok(params) = split_top_level_csv(param_str) else {
296            return Err(nom::Err::Error(nom::error::Error::new(
297                input,
298                nom::error::ErrorKind::SeparatedList,
299            )));
300        };
301        if params.is_empty() {
302            return Err(nom::Err::Error(nom::error::Error::new(
303                input,
304                nom::error::ErrorKind::SeparatedList,
305            )));
306        }
307        (&after_open[paren_end + 1..], Some(params))
308    } else {
309        (input, None)
310    };
311
312    let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
313        (stripped, true)
314    } else {
315        (input, false)
316    };
317
318    let lower = type_name.to_lowercase();
319    let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
320
321    Ok((
322        input,
323        TypeInfo {
324            name: lower,
325            params,
326            is_array,
327            is_serial,
328        },
329    ))
330}
331
332fn is_identifier_path(path: &str) -> bool {
333    let mut seen = false;
334    for part in path.split('.') {
335        seen = true;
336        let mut chars = part.chars();
337        match chars.next() {
338            Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
339            _ => return false,
340        }
341        if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
342            return false;
343        }
344    }
345    seen
346}
347
348/// Parse constraint text until comma or closing paren (handling nested parens)
349fn constraint_text(input: &str) -> IResult<&str, &str> {
350    let mut paren_depth = 0;
351    let mut in_single = false;
352    let mut in_double = false;
353    let mut end = 0;
354    let mut iter = input.char_indices().peekable();
355
356    while let Some((i, c)) = iter.next() {
357        match c {
358            '\'' if !in_double => {
359                if in_single && matches!(iter.peek(), Some((_, '\''))) {
360                    iter.next();
361                } else {
362                    in_single = !in_single;
363                }
364            }
365            '"' if !in_single => {
366                if in_double && matches!(iter.peek(), Some((_, '"'))) {
367                    iter.next();
368                } else {
369                    in_double = !in_double;
370                }
371            }
372            '(' if !in_single && !in_double => paren_depth += 1,
373            ')' if !in_single && !in_double => {
374                if paren_depth == 0 {
375                    break; // End at column-level closing paren
376                }
377                paren_depth -= 1;
378            }
379            ',' if !in_single && !in_double && paren_depth == 0 => break,
380            '\n' | '\r' if !in_single && !in_double && paren_depth == 0 => break,
381            _ => {}
382        }
383        end = i + c.len_utf8();
384    }
385
386    if end == 0 {
387        Err(nom::Err::Error(nom::error::Error::new(
388            input,
389            nom::error::ErrorKind::TakeWhile1,
390        )))
391    } else {
392        Ok((&input[end..], &input[..end]))
393    }
394}
395
396fn check_expr_end(rest: &str) -> usize {
397    let mut depth = 1usize;
398    let mut in_single = false;
399    let mut in_double = false;
400    let mut iter = rest.char_indices().peekable();
401
402    while let Some((idx, ch)) = iter.next() {
403        match ch {
404            '\'' if !in_double => {
405                if in_single && matches!(iter.peek(), Some((_, '\''))) {
406                    iter.next();
407                } else {
408                    in_single = !in_single;
409                }
410            }
411            '"' if !in_single => {
412                if in_double && matches!(iter.peek(), Some((_, '"'))) {
413                    iter.next();
414                } else {
415                    in_double = !in_double;
416                }
417            }
418            '(' if !in_single && !in_double => depth += 1,
419            ')' if !in_single && !in_double => {
420                depth -= 1;
421                if depth == 0 {
422                    return idx;
423                }
424            }
425            _ => {}
426        }
427    }
428
429    rest.len()
430}
431
432fn checked_check_expr_end(rest: &str) -> Option<usize> {
433    let end = check_expr_end(rest);
434    (end < rest.len()).then_some(end)
435}
436
437fn parenthesized_content(input: &str) -> IResult<&str, &str> {
438    let mut paren_depth = 0usize;
439    let mut in_single = false;
440    let mut in_double = false;
441    let mut iter = input.char_indices().peekable();
442
443    while let Some((idx, ch)) = iter.next() {
444        match ch {
445            '\'' if !in_double => {
446                if in_single && matches!(iter.peek(), Some((_, '\''))) {
447                    iter.next();
448                } else {
449                    in_single = !in_single;
450                }
451            }
452            '"' if !in_single => {
453                if in_double && matches!(iter.peek(), Some((_, '"'))) {
454                    iter.next();
455                } else {
456                    in_double = !in_double;
457                }
458            }
459            '(' if !in_single && !in_double => paren_depth += 1,
460            ')' if !in_single && !in_double => {
461                if paren_depth == 0 {
462                    return Ok((&input[idx + ch.len_utf8()..], &input[..idx]));
463                }
464                paren_depth -= 1;
465            }
466            _ => {}
467        }
468    }
469
470    Err(nom::Err::Error(nom::error::Error::new(
471        input,
472        nom::error::ErrorKind::Char,
473    )))
474}
475
476fn split_top_level_csv(input: &str) -> Result<Vec<String>, ()> {
477    let mut parts = Vec::new();
478    let mut start = 0usize;
479    let mut paren_depth = 0usize;
480    let mut in_single = false;
481    let mut in_double = false;
482    let mut iter = input.char_indices().peekable();
483
484    while let Some((idx, ch)) = iter.next() {
485        match ch {
486            '\'' if !in_double => {
487                if in_single && matches!(iter.peek(), Some((_, '\''))) {
488                    iter.next();
489                } else {
490                    in_single = !in_single;
491                }
492            }
493            '"' if !in_single => {
494                if in_double && matches!(iter.peek(), Some((_, '"'))) {
495                    iter.next();
496                } else {
497                    in_double = !in_double;
498                }
499            }
500            '(' if !in_single && !in_double => paren_depth += 1,
501            ')' if !in_single && !in_double => {
502                if paren_depth == 0 {
503                    return Err(());
504                }
505                paren_depth -= 1;
506            }
507            ',' if !in_single && !in_double && paren_depth == 0 => {
508                let part = input[start..idx].trim();
509                if part.is_empty() {
510                    return Err(());
511                }
512                parts.push(part.to_string());
513                start = idx + ch.len_utf8();
514            }
515            _ => {}
516        }
517    }
518
519    if in_single || in_double || paren_depth != 0 {
520        return Err(());
521    }
522    let part = input[start..].trim();
523    if part.is_empty() {
524        if !input.trim().is_empty() {
525            return Err(());
526        }
527    } else {
528        parts.push(part.to_string());
529    }
530
531    Ok(parts)
532}
533
534fn starts_constraint_keyword(input: &str) -> bool {
535    let lower = input.to_ascii_lowercase();
536    matches!(
537        lower.as_str(),
538        s if s.starts_with("primary_key")
539            || s.starts_with("primary key")
540            || s.starts_with("not_null")
541            || s.starts_with("not null")
542            || s.starts_with("unique")
543            || s.starts_with("references ")
544            || s.starts_with("check(")
545    )
546}
547
548fn default_expr_end(rest: &str) -> usize {
549    let mut in_single = false;
550    let mut in_double = false;
551    let mut paren_depth = 0usize;
552    let mut iter = rest.char_indices().peekable();
553
554    while let Some((idx, ch)) = iter.next() {
555        match ch {
556            '\'' if !in_double => {
557                if in_single && matches!(iter.peek(), Some((_, '\''))) {
558                    iter.next();
559                } else {
560                    in_single = !in_single;
561                }
562            }
563            '"' if !in_single => {
564                if in_double && matches!(iter.peek(), Some((_, '"'))) {
565                    iter.next();
566                } else {
567                    in_double = !in_double;
568                }
569            }
570            '(' if !in_single && !in_double => paren_depth += 1,
571            ')' if !in_single && !in_double && paren_depth > 0 => paren_depth -= 1,
572            c if c.is_whitespace()
573                && !in_single
574                && !in_double
575                && paren_depth == 0
576                && starts_constraint_keyword(rest[idx..].trim_start()) =>
577            {
578                return idx;
579            }
580            _ => {}
581        }
582    }
583
584    rest.len()
585}
586
587/// Parse a single column definition
588fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
589    let (input, _) = ws_and_comments(input)?;
590    let (input, name) = identifier(input)?;
591    let (input, _) = multispace1(input)?;
592    let (input, type_info) = parse_type_info(input)?;
593
594    let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
595
596    let mut col = ColumnDef {
597        name: name.to_string(),
598        typ: type_info.name,
599        is_array: type_info.is_array,
600        type_params: type_info.params,
601        is_serial: type_info.is_serial,
602        nullable: !type_info.is_serial, // Serial types are implicitly not null
603        ..Default::default()
604    };
605
606    if let Some(constraints) = constraint_str
607        && parse_column_constraints(&mut col, constraints).is_err()
608    {
609        return Err(nom::Err::Error(nom::error::Error::new(
610            constraints,
611            nom::error::ErrorKind::Verify,
612        )));
613    }
614
615    Ok((input, col))
616}
617
618fn parse_column_constraints(col: &mut ColumnDef, constraints: &str) -> Result<(), ()> {
619    let mut rest = constraints.trim();
620    while !rest.is_empty() {
621        if let Some(next) = strip_keyword_ci(rest, "primary_key") {
622            if col.primary_key {
623                return Err(());
624            }
625            col.primary_key = true;
626            col.nullable = false;
627            rest = next.trim_start();
628            continue;
629        }
630        if let Some(next) = strip_keyword_pair_ci(rest, "primary", "key") {
631            if col.primary_key {
632                return Err(());
633            }
634            col.primary_key = true;
635            col.nullable = false;
636            rest = next.trim_start();
637            continue;
638        }
639        if let Some(next) = strip_keyword_ci(rest, "not_null") {
640            col.nullable = false;
641            rest = next.trim_start();
642            continue;
643        }
644        if let Some(next) = strip_keyword_pair_ci(rest, "not", "null") {
645            col.nullable = false;
646            rest = next.trim_start();
647            continue;
648        }
649        if let Some(next) = strip_keyword_ci(rest, "unique") {
650            if col.unique {
651                return Err(());
652            }
653            col.unique = true;
654            rest = next.trim_start();
655            continue;
656        }
657        if let Some(next) = strip_keyword_ci(rest, "references") {
658            if col.references.is_some() {
659                return Err(());
660            }
661            let next = next.trim_start();
662            let (target, tail) = parse_reference_constraint_target(next)?;
663            if target.is_empty() {
664                return Err(());
665            }
666            col.references = Some(target.to_string());
667            rest = tail.trim_start();
668            continue;
669        }
670        if let Some(next) = strip_keyword_ci(rest, "default") {
671            if col.default_value.is_some() {
672                return Err(());
673            }
674            let next = next.trim_start();
675            if next.is_empty() {
676                return Err(());
677            }
678            let end = default_expr_end(next);
679            if end == 0 {
680                return Err(());
681            }
682            col.default_value = Some(next[..end].trim_end().to_string());
683            rest = next[end..].trim_start();
684            continue;
685        }
686        if let Some(next) = strip_keyword_ci(rest, "check") {
687            if col.check.is_some() {
688                return Err(());
689            }
690            let next = next.trim_start();
691            let Some(after_open) = next.strip_prefix('(') else {
692                return Err(());
693            };
694            let Some(end) = checked_check_expr_end(after_open) else {
695                return Err(());
696            };
697            let expr = after_open[..end].trim();
698            if expr.is_empty() {
699                return Err(());
700            }
701            col.check = Some(expr.to_string());
702            rest = after_open[end + 1..].trim_start();
703            continue;
704        }
705
706        return Err(());
707    }
708
709    Ok(())
710}
711
712fn strip_keyword_pair_ci<'a>(input: &'a str, first: &str, second: &str) -> Option<&'a str> {
713    let rest = strip_keyword_ci(input, first)?.trim_start();
714    strip_keyword_ci(rest, second)
715}
716
717fn strip_keyword_ci<'a>(input: &'a str, keyword: &str) -> Option<&'a str> {
718    if input.len() < keyword.len() {
719        return None;
720    }
721    let (head, tail) = input.split_at(keyword.len());
722    if !head.eq_ignore_ascii_case(keyword) {
723        return None;
724    }
725    if tail
726        .chars()
727        .next()
728        .is_some_and(|ch| ch.is_ascii_alphanumeric() || ch == '_')
729    {
730        return None;
731    }
732    Some(tail)
733}
734
735fn parse_reference_constraint_target(input: &str) -> Result<(&str, &str), ()> {
736    let mut paren_depth = 0usize;
737    let mut end = input.len();
738    for (idx, ch) in input.char_indices() {
739        match ch {
740            '(' => paren_depth += 1,
741            ')' => {
742                if paren_depth == 0 {
743                    end = idx;
744                    break;
745                }
746                paren_depth -= 1;
747            }
748            c if c.is_whitespace() && paren_depth == 0 => {
749                end = idx;
750                break;
751            }
752            _ => {}
753        }
754    }
755    if paren_depth != 0 {
756        return Err(());
757    }
758    Ok((&input[..end], &input[end..]))
759}
760
761/// Parse column list: (col1 type, col2 type, ...)
762fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
763    let (input, _) = ws_and_comments(input)?;
764    let (input, _) = char('(').parse(input)?;
765    let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
766    let (input, _) = ws_and_comments(input)?;
767    let (input, _) = char(')').parse(input)?;
768
769    Ok((input, columns))
770}
771
772/// Parse a table definition
773fn parse_table(input: &str) -> IResult<&str, TableDef> {
774    let (input, _) = ws_and_comments(input)?;
775    let (input, _) = tag_no_case("table").parse(input)?;
776    let (input, _) = multispace1(input)?;
777    let (input, name) = identifier(input)?;
778    let (input, columns) = parse_column_list(input)?;
779    if columns.is_empty() || has_duplicate_column_names(&columns) {
780        return Err(nom::Err::Error(nom::error::Error::new(
781            input,
782            nom::error::ErrorKind::Verify,
783        )));
784    }
785
786    // Optional enable_rls annotation after closing paren
787    let (input, _) = ws_and_comments(input)?;
788    let enable_rls = if let Ok((rest, _)) =
789        tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
790    {
791        return Ok((
792            rest,
793            TableDef {
794                name: name.to_string(),
795                columns,
796                enable_rls: true,
797            },
798        ));
799    } else {
800        false
801    };
802
803    Ok((
804        input,
805        TableDef {
806            name: name.to_string(),
807            columns,
808            enable_rls,
809        },
810    ))
811}
812
813fn has_duplicate_column_names(columns: &[ColumnDef]) -> bool {
814    let mut seen = HashSet::new();
815    columns
816        .iter()
817        .any(|column| !seen.insert(column.name.to_ascii_lowercase()))
818}
819
820// =============================================================================
821// Policy Parsing
822// =============================================================================
823
824/// A schema item is either a table, policy, or index.
825enum SchemaItem {
826    Table(TableDef),
827    Policy(Box<RlsPolicy>),
828    Index(IndexDef),
829}
830
831/// Parse a policy definition.
832///
833/// Syntax:
834/// ```text
835/// policy <name> on <table>
836///     [for all|select|insert|update|delete]
837///     [restrictive]
838///     [to <role>]
839///     [using (<expr>)]
840///     [with check (<expr>)]
841/// ```
842fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
843    let (input, _) = ws_and_comments(input)?;
844    let (input, _) = tag_no_case("policy").parse(input)?;
845    let (input, _) = multispace1(input)?;
846    let (input, name) = identifier(input)?;
847    let (input, _) = multispace1(input)?;
848    let (input, _) = tag_no_case("on").parse(input)?;
849    let (input, _) = multispace1(input)?;
850    let (input, table) = identifier(input)?;
851
852    let mut policy = RlsPolicy::create(name, table);
853
854    // Parse optional clauses in any order
855    let mut remaining = input;
856    let mut seen_for = false;
857    let mut seen_restrictive = false;
858    let mut seen_role = false;
859    let mut seen_using = false;
860    let mut seen_with_check = false;
861    loop {
862        let (input, _) = ws_and_comments(remaining)?;
863
864        // for all|select|insert|update|delete
865        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
866            if seen_for {
867                return Err(nom::Err::Error(nom::error::Error::new(
868                    input,
869                    nom::error::ErrorKind::Verify,
870                )));
871            }
872            seen_for = true;
873            let (rest, _) = multispace1(rest)?;
874            let (rest, target) = alt((
875                map(tag_no_case("all"), |_| PolicyTarget::All),
876                map(tag_no_case("select"), |_| PolicyTarget::Select),
877                map(tag_no_case("insert"), |_| PolicyTarget::Insert),
878                map(tag_no_case("update"), |_| PolicyTarget::Update),
879                map(tag_no_case("delete"), |_| PolicyTarget::Delete),
880            ))
881            .parse(rest)?;
882            policy.target = target;
883            remaining = rest;
884            continue;
885        }
886
887        // restrictive
888        if let Ok((rest, _)) =
889            tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
890        {
891            if seen_restrictive {
892                return Err(nom::Err::Error(nom::error::Error::new(
893                    input,
894                    nom::error::ErrorKind::Verify,
895                )));
896            }
897            seen_restrictive = true;
898            policy.permissiveness = PolicyPermissiveness::Restrictive;
899            remaining = rest;
900            continue;
901        }
902
903        // to <role>
904        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
905            // Make sure it's not "to_sql" or similar — needs whitespace after
906            if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
907                if seen_role {
908                    return Err(nom::Err::Error(nom::error::Error::new(
909                        input,
910                        nom::error::ErrorKind::Verify,
911                    )));
912                }
913                seen_role = true;
914                let (rest, role) = identifier(rest)?;
915                policy.role = Some(role.to_string());
916                remaining = rest;
917                continue;
918            }
919        }
920
921        // with check (<expr>)
922        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
923            let (rest, _) = multispace1(rest)?;
924            if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
925            {
926                if seen_with_check {
927                    return Err(nom::Err::Error(nom::error::Error::new(
928                        input,
929                        nom::error::ErrorKind::Verify,
930                    )));
931                }
932                seen_with_check = true;
933                let (rest, _) = nom_ws0(rest)?;
934                let (rest, _) = char('(').parse(rest)?;
935                let (rest, _) = nom_ws0(rest)?;
936                let (rest, expr) = parse_policy_expr(rest)?;
937                let (rest, _) = nom_ws0(rest)?;
938                let (rest, _) = char(')').parse(rest)?;
939                policy.with_check = Some(expr);
940                remaining = rest;
941                continue;
942            }
943        }
944
945        // using (<expr>)
946        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
947            if seen_using {
948                return Err(nom::Err::Error(nom::error::Error::new(
949                    input,
950                    nom::error::ErrorKind::Verify,
951                )));
952            }
953            seen_using = true;
954            let (rest, _) = nom_ws0(rest)?;
955            let (rest, _) = char('(').parse(rest)?;
956            let (rest, _) = nom_ws0(rest)?;
957            let (rest, expr) = parse_policy_expr(rest)?;
958            let (rest, _) = nom_ws0(rest)?;
959            let (rest, _) = char(')').parse(rest)?;
960            policy.using = Some(expr);
961            remaining = rest;
962            continue;
963        }
964
965        // No more clauses matched
966        remaining = input;
967        break;
968    }
969
970    Ok((remaining, policy))
971}
972
973/// Parse a policy expression: `left op right [AND/OR left op right ...]`
974///
975/// Produces typed `Expr::Binary` AST nodes — no raw SQL.
976///
977/// Handles:
978/// - `column = value`
979/// - `column = function('arg')::type`   (function call + cast)
980/// - `expr AND expr`, `expr OR expr`
981fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
982    parse_policy_or_expr(input)
983}
984
985fn parse_policy_or_expr(input: &str) -> IResult<&str, Expr> {
986    let (input, first) = parse_policy_and_expr(input)?;
987
988    let mut result = first;
989    let mut remaining = input;
990    loop {
991        let (input, _) = nom_ws0(remaining)?;
992
993        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
994            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
995        {
996            let (rest, right) = parse_policy_and_expr(rest)?;
997            result = Expr::Binary {
998                left: Box::new(result),
999                op: BinaryOp::Or,
1000                right: Box::new(right),
1001                alias: None,
1002            };
1003            remaining = rest;
1004            continue;
1005        }
1006
1007        remaining = input;
1008        break;
1009    }
1010
1011    Ok((remaining, result))
1012}
1013
1014fn parse_policy_and_expr(input: &str) -> IResult<&str, Expr> {
1015    let (input, first) = parse_policy_comparison(input)?;
1016
1017    let mut result = first;
1018    let mut remaining = input;
1019    loop {
1020        let (input, _) = nom_ws0(remaining)?;
1021
1022        if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
1023            && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
1024        {
1025            let (rest, right) = parse_policy_comparison(rest)?;
1026            result = Expr::Binary {
1027                left: Box::new(result),
1028                op: BinaryOp::And,
1029                right: Box::new(right),
1030                alias: None,
1031            };
1032            remaining = rest;
1033            continue;
1034        }
1035
1036        remaining = input;
1037        break;
1038    }
1039
1040    Ok((remaining, result))
1041}
1042
1043/// Parse a single comparison: `atom op atom`
1044fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
1045    let (input, left) = parse_policy_atom(input)?;
1046    let (input, _) = nom_ws0(input)?;
1047
1048    // Try to parse comparison operator
1049    if let Ok((rest, op)) = parse_cmp_op(input) {
1050        let (rest, _) = nom_ws0(rest)?;
1051        let (rest, right) = parse_policy_atom(rest)?;
1052        return Ok((
1053            rest,
1054            Expr::Binary {
1055                left: Box::new(left),
1056                op,
1057                right: Box::new(right),
1058                alias: None,
1059            },
1060        ));
1061    }
1062
1063    // No operator — just an atom
1064    Ok((input, left))
1065}
1066
1067/// Parse comparison operators: =, !=, <>, >=, <=, >, <
1068fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
1069    alt((
1070        map(tag(">="), |_| BinaryOp::Gte),
1071        map(tag("<="), |_| BinaryOp::Lte),
1072        map(tag("<>"), |_| BinaryOp::Ne),
1073        map(tag("!="), |_| BinaryOp::Ne),
1074        map(tag("="), |_| BinaryOp::Eq),
1075        map(tag(">"), |_| BinaryOp::Gt),
1076        map(tag("<"), |_| BinaryOp::Lt),
1077    ))
1078    .parse(input)
1079}
1080
1081/// Parse a policy expression atom:
1082/// - identifier  (column name)
1083/// - function_call(args)  with optional ::cast
1084/// - 'string literal'
1085/// - numeric literal
1086/// - true/false
1087/// - (sub_expr)  grouped
1088fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
1089    alt((
1090        parse_policy_grouped,
1091        parse_policy_bool,
1092        parse_policy_string,
1093        parse_policy_number,
1094        parse_policy_func_or_ident, // function call or plain identifier, with optional ::cast
1095    ))
1096    .parse(input)
1097}
1098
1099/// Parse grouped expression in parens
1100fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
1101    let (input, _) = char('(').parse(input)?;
1102    let (input, _) = nom_ws0(input)?;
1103    let (input, expr) = parse_policy_expr(input)?;
1104    let (input, _) = nom_ws0(input)?;
1105    let (input, _) = char(')').parse(input)?;
1106    Ok((input, expr))
1107}
1108
1109/// Parse true / false
1110fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
1111    alt((
1112        map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
1113        map(tag_no_case("false"), |_| {
1114            Expr::Literal(AstValue::Bool(false))
1115        }),
1116    ))
1117    .parse(input)
1118}
1119
1120/// Parse a 'string literal'
1121fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
1122    let (input, _) = char('\'').parse(input)?;
1123    let mut content = String::new();
1124    let mut iter = input.char_indices().peekable();
1125    while let Some((idx, ch)) = iter.next() {
1126        if ch == '\'' {
1127            if iter.peek().is_some_and(|(_, next)| *next == '\'') {
1128                content.push('\'');
1129                iter.next();
1130                continue;
1131            }
1132            let rest = &input[idx + ch.len_utf8()..];
1133            return Ok((rest, Expr::Literal(AstValue::String(content))));
1134        }
1135        content.push(ch);
1136    }
1137
1138    Err(nom::Err::Error(nom::error::Error::new(
1139        input,
1140        nom::error::ErrorKind::Char,
1141    )))
1142}
1143
1144/// Parse numeric literal
1145fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
1146    let original = input;
1147    let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
1148    // Make sure it starts with digit (not just '.')
1149    if digits.starts_with('.') || digits.is_empty() {
1150        return Err(nom::Err::Error(nom::error::Error::new(
1151            original,
1152            nom::error::ErrorKind::Digit,
1153        )));
1154    }
1155
1156    if !digits.contains('.') {
1157        return digits
1158            .parse::<i64>()
1159            .map(|n| (input, Expr::Literal(AstValue::Int(n))))
1160            .map_err(|_| {
1161                nom::Err::Error(nom::error::Error::new(
1162                    original,
1163                    nom::error::ErrorKind::Digit,
1164                ))
1165            });
1166    }
1167
1168    if digits.matches('.').count() > 1 || policy_number_significant_digits(digits) > 15 {
1169        return Err(nom::Err::Error(nom::error::Error::new(
1170            original,
1171            nom::error::ErrorKind::Float,
1172        )));
1173    }
1174
1175    if let Ok(f) = digits.parse::<f64>() {
1176        if f.is_finite() {
1177            Ok((input, Expr::Literal(AstValue::Float(f))))
1178        } else {
1179            Err(nom::Err::Error(nom::error::Error::new(
1180                original,
1181                nom::error::ErrorKind::Float,
1182            )))
1183        }
1184    } else {
1185        Err(nom::Err::Error(nom::error::Error::new(
1186            original,
1187            nom::error::ErrorKind::Float,
1188        )))
1189    }
1190}
1191
1192fn policy_number_significant_digits(value: &str) -> usize {
1193    let mut count = 0;
1194    let mut seen_non_zero = false;
1195
1196    for byte in value.bytes() {
1197        if !byte.is_ascii_digit() {
1198            continue;
1199        }
1200        if byte != b'0' {
1201            seen_non_zero = true;
1202        }
1203        if seen_non_zero {
1204            count += 1;
1205        }
1206    }
1207
1208    count
1209}
1210
1211/// Parse function call or identifier, with optional ::cast
1212fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
1213    let original = input;
1214    let (input, name) = identifier(input)?;
1215    if name
1216        .bytes()
1217        .next()
1218        .is_some_and(|byte| byte.is_ascii_digit())
1219    {
1220        return Err(nom::Err::Error(nom::error::Error::new(
1221            original,
1222            nom::error::ErrorKind::Alpha,
1223        )));
1224    }
1225
1226    // Check for function call: name(
1227    let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
1228        // Parse args
1229        let (rest, _) = nom_ws0(rest)?;
1230        let (rest, args) =
1231            separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
1232        let (rest, _) = nom_ws0(rest)?;
1233        let (rest, _) = char(')').parse(rest)?;
1234        let input = rest;
1235        (
1236            input,
1237            Expr::FunctionCall {
1238                name: name.to_string(),
1239                args,
1240                alias: None,
1241            },
1242        )
1243    } else {
1244        (input, Expr::Named(name.to_string()))
1245    };
1246
1247    // Check for ::cast
1248    if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
1249        let (rest, cast_type) = identifier(rest)?;
1250        expr = (
1251            rest,
1252            Expr::Cast {
1253                expr: Box::new(expr.1),
1254                target_type: cast_type.to_string(),
1255                alias: None,
1256            },
1257        );
1258    }
1259
1260    Ok(expr)
1261}
1262
1263/// Parse a single schema item: table, policy, or index
1264fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
1265    let (input, _) = ws_and_comments(input)?;
1266
1267    // Try policy first (since "policy" is a distinct keyword)
1268    if let Ok((rest, policy)) = parse_policy(input) {
1269        return Ok((rest, SchemaItem::Policy(Box::new(policy))));
1270    }
1271
1272    // Try index
1273    if let Ok((rest, idx)) = parse_index(input) {
1274        return Ok((rest, SchemaItem::Index(idx)));
1275    }
1276
1277    // Otherwise parse table
1278    let (rest, table) = parse_table(input)?;
1279    Ok((rest, SchemaItem::Table(table)))
1280}
1281
1282/// Parse an index line: `index <name> on <table> (<col1>, <col2>) [unique]`
1283fn parse_index(input: &str) -> IResult<&str, IndexDef> {
1284    let (input, _) = tag_no_case("index")(input)?;
1285    let (input, _) = multispace1(input)?;
1286    let (input, name) = identifier(input)?;
1287    let (input, _) = multispace1(input)?;
1288    let (input, _) = tag_no_case("on")(input)?;
1289    let (input, _) = multispace1(input)?;
1290    let (input, table) = identifier(input)?;
1291    let (input, _) = nom_ws0(input)?;
1292    let (input, _) = char('(')(input)?;
1293    let (input, cols_str) = parenthesized_content(input)?;
1294    let (input, _) = nom_ws0(input)?;
1295    let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
1296
1297    let columns = split_top_level_csv(cols_str).map_err(|_| {
1298        nom::Err::Error(nom::error::Error::new(
1299            cols_str,
1300            nom::error::ErrorKind::SeparatedList,
1301        ))
1302    })?;
1303    if columns.is_empty() {
1304        return Err(nom::Err::Error(nom::error::Error::new(
1305            cols_str,
1306            nom::error::ErrorKind::SeparatedList,
1307        )));
1308    }
1309
1310    let is_unique = unique_tag.is_some();
1311
1312    Ok((
1313        input,
1314        IndexDef {
1315            name: name.to_string(),
1316            table: table.to_string(),
1317            columns,
1318            unique: is_unique,
1319        },
1320    ))
1321}
1322
1323/// Parse complete schema file
1324fn parse_schema(input: &str) -> IResult<&str, Schema> {
1325    // Extract version directive before parsing
1326    let version = extract_version_directive(input);
1327
1328    let (input, items) = many0(parse_schema_item).parse(input)?;
1329    let (input, _) = ws_and_comments(input)?;
1330
1331    let mut tables = Vec::new();
1332    let mut policies = Vec::new();
1333    let mut indexes = Vec::new();
1334    for item in items {
1335        match item {
1336            SchemaItem::Table(t) => tables.push(t),
1337            SchemaItem::Policy(p) => policies.push(*p),
1338            SchemaItem::Index(i) => indexes.push(i),
1339        }
1340    }
1341    if !schema_names_are_unique(&tables, &policies, &indexes) {
1342        return Err(nom::Err::Error(nom::error::Error::new(
1343            input,
1344            nom::error::ErrorKind::Verify,
1345        )));
1346    }
1347
1348    Ok((
1349        input,
1350        Schema {
1351            version,
1352            tables,
1353            policies,
1354            indexes,
1355        },
1356    ))
1357}
1358
1359fn schema_names_are_unique(
1360    tables: &[TableDef],
1361    policies: &[RlsPolicy],
1362    indexes: &[IndexDef],
1363) -> bool {
1364    let mut table_names = HashSet::new();
1365    if tables
1366        .iter()
1367        .any(|table| !table_names.insert(table.name.to_ascii_lowercase()))
1368    {
1369        return false;
1370    }
1371
1372    let mut index_names = HashSet::new();
1373    if indexes
1374        .iter()
1375        .any(|index| !index_names.insert(index.name.to_ascii_lowercase()))
1376    {
1377        return false;
1378    }
1379
1380    let mut policy_names = HashSet::new();
1381    !policies.iter().any(|policy| {
1382        !policy_names.insert((
1383            policy.table.to_ascii_lowercase(),
1384            policy.name.to_ascii_lowercase(),
1385        ))
1386    })
1387}
1388
1389/// Extract version from `-- qail: version=N` directive
1390fn extract_version_directive(input: &str) -> Option<u32> {
1391    for line in input.lines() {
1392        let line = line.trim();
1393        if let Some(rest) = line.strip_prefix("-- qail:") {
1394            let rest = rest.trim();
1395            if let Some(version_str) = rest.strip_prefix("version=") {
1396                return version_str.trim().parse().ok();
1397            }
1398        }
1399    }
1400    None
1401}
1402
1403#[cfg(test)]
1404mod tests {
1405    use super::*;
1406
1407    #[test]
1408    fn test_parse_simple_table() {
1409        let input = r#"
1410            table users (
1411                id uuid primary_key,
1412                email text not null,
1413                name text
1414            )
1415        "#;
1416
1417        let schema = Schema::parse(input).expect("parse failed");
1418        assert_eq!(schema.tables.len(), 1);
1419
1420        let users = &schema.tables[0];
1421        assert_eq!(users.name, "users");
1422        assert_eq!(users.columns.len(), 3);
1423
1424        let id = &users.columns[0];
1425        assert_eq!(id.name, "id");
1426        assert_eq!(id.typ, "uuid");
1427        assert!(id.primary_key);
1428        assert!(!id.nullable);
1429
1430        let email = &users.columns[1];
1431        assert_eq!(email.name, "email");
1432        assert!(!email.nullable);
1433
1434        let name = &users.columns[2];
1435        assert!(name.nullable);
1436    }
1437
1438    #[test]
1439    fn test_parse_multiple_tables() {
1440        let input = r#"
1441            -- Users table
1442            table users (
1443                id uuid primary_key,
1444                email text not null unique
1445            )
1446            
1447            -- Orders table
1448            table orders (
1449                id uuid primary_key,
1450                user_id uuid references users(id),
1451                total i64 not null default 0
1452            )
1453        "#;
1454
1455        let schema = Schema::parse(input).expect("parse failed");
1456        assert_eq!(schema.tables.len(), 2);
1457
1458        let orders = schema.find_table("orders").expect("orders not found");
1459        let user_id = orders.find_column("user_id").expect("user_id not found");
1460        assert_eq!(user_id.references, Some("users(id)".to_string()));
1461
1462        let total = orders.find_column("total").expect("total not found");
1463        assert_eq!(total.default_value, Some("0".to_string()));
1464    }
1465
1466    #[test]
1467    fn test_parse_comments() {
1468        let input = r#"
1469            -- This is a comment
1470            table foo (
1471                bar text
1472            )
1473        "#;
1474
1475        let schema = Schema::parse(input).expect("parse failed");
1476        assert_eq!(schema.tables.len(), 1);
1477    }
1478
1479    #[test]
1480    fn test_array_types() {
1481        let input = r#"
1482            table products (
1483                id uuid primary_key,
1484                tags text[],
1485                prices decimal[]
1486            )
1487        "#;
1488
1489        let schema = Schema::parse(input).expect("parse failed");
1490        let products = &schema.tables[0];
1491
1492        let tags = products.find_column("tags").expect("tags not found");
1493        assert_eq!(tags.typ, "text");
1494        assert!(tags.is_array);
1495
1496        let prices = products.find_column("prices").expect("prices not found");
1497        assert!(prices.is_array);
1498    }
1499
1500    #[test]
1501    fn test_type_params() {
1502        let input = r#"
1503            table items (
1504                id serial primary_key,
1505                name varchar(255) not null,
1506                price decimal(10,2),
1507                code varchar(50) unique
1508            )
1509        "#;
1510
1511        let schema = Schema::parse(input).expect("parse failed");
1512        let items = &schema.tables[0];
1513
1514        let id = items.find_column("id").expect("id not found");
1515        assert!(id.is_serial);
1516        assert!(!id.nullable); // Serial is implicitly not null
1517
1518        let name = items.find_column("name").expect("name not found");
1519        assert_eq!(name.typ, "varchar");
1520        assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1521
1522        let price = items.find_column("price").expect("price not found");
1523        assert_eq!(
1524            price.type_params,
1525            Some(vec!["10".to_string(), "2".to_string()])
1526        );
1527
1528        let code = items.find_column("code").expect("code not found");
1529        assert!(code.unique);
1530    }
1531
1532    #[test]
1533    fn test_rejects_invalid_identifiers_in_schema_shapes() {
1534        for input in [
1535            "table 1users (id uuid)",
1536            "table users (1id uuid)",
1537            "table users (id 1uuid)",
1538            "index 1idx on users (id)",
1539            "index idx on 1users (id)",
1540            "policy 1policy on users using (id = 1)",
1541            "policy users_policy on 1users using (id = 1)",
1542        ] {
1543            Schema::parse(input).expect_err("invalid identifier must fail");
1544        }
1545    }
1546
1547    #[test]
1548    fn test_rejects_empty_tables_and_duplicate_schema_objects() {
1549        for input in [
1550            "table empty ()",
1551            "table users (id uuid, id text)",
1552            "table users (id uuid)\ntable users (id text)",
1553            "index idx_users on users (id)\nindex idx_users on users (email)",
1554            "policy users_filter on users using (id = 1)\npolicy users_filter on users using (id = 2)",
1555        ] {
1556            Schema::parse(input).expect_err("duplicate or empty schema object must fail");
1557        }
1558    }
1559
1560    #[test]
1561    fn test_rejects_empty_type_parameters() {
1562        for input in [
1563            "table invoices (amount decimal())",
1564            "table invoices (amount decimal(10,))",
1565            "table invoices (amount decimal(,2))",
1566            "table invoices (amount decimal(10,,2))",
1567        ] {
1568            Schema::parse(input).expect_err("empty type parameter must fail");
1569        }
1570    }
1571
1572    #[test]
1573    fn test_custom_type_names_with_underscores_and_schema() {
1574        let input = r#"
1575            table bookings (
1576                id uuid primary_key,
1577                status booking_status not null,
1578                gateway_state integrations.payment_state[]
1579            )
1580        "#;
1581
1582        let schema = Schema::parse(input).expect("parse failed");
1583        let bookings = &schema.tables[0];
1584
1585        let status = bookings.find_column("status").expect("status not found");
1586        assert_eq!(status.typ, "booking_status");
1587        assert!(!status.nullable);
1588
1589        let gateway_state = bookings
1590            .find_column("gateway_state")
1591            .expect("gateway_state not found");
1592        assert_eq!(gateway_state.typ, "integrations.payment_state");
1593        assert!(gateway_state.is_array);
1594    }
1595
1596    #[test]
1597    fn test_malformed_type_params_return_parse_error_without_panic() {
1598        let input = "table invoices ( amount decimal(";
1599
1600        let result = std::panic::catch_unwind(|| Schema::parse(input));
1601
1602        assert!(result.is_ok());
1603        assert!(result.unwrap().is_err());
1604    }
1605
1606    #[test]
1607    fn test_check_constraint() {
1608        let input = r#"
1609            table employees (
1610                id uuid primary_key,
1611                age i32 check(age >= 18),
1612                salary decimal check(salary > 0)
1613            )
1614        "#;
1615
1616        let schema = Schema::parse(input).expect("parse failed");
1617        let employees = &schema.tables[0];
1618
1619        let age = employees.find_column("age").expect("age not found");
1620        assert_eq!(age.check, Some("age >= 18".to_string()));
1621
1622        let salary = employees.find_column("salary").expect("salary not found");
1623        assert_eq!(salary.check, Some("salary > 0".to_string()));
1624    }
1625
1626    #[test]
1627    fn test_default_expression_with_spaces() {
1628        let input = r#"
1629            table messages (
1630                id uuid primary_key,
1631                title text default 'new user' not null,
1632                expires_at timestamp default (now() + interval '1 day')
1633            )
1634        "#;
1635
1636        let schema = Schema::parse(input).expect("parse failed");
1637        let messages = &schema.tables[0];
1638
1639        let title = messages.find_column("title").expect("title not found");
1640        assert_eq!(title.default_value, Some("'new user'".to_string()));
1641        assert!(!title.nullable);
1642
1643        let expires_at = messages
1644            .find_column("expires_at")
1645            .expect("expires_at not found");
1646        assert_eq!(
1647            expires_at.default_value,
1648            Some("(now() + interval '1 day')".to_string())
1649        );
1650    }
1651
1652    #[test]
1653    fn test_constraints_handle_quoted_commas_and_parens() {
1654        let input = r#"
1655            table messages (
1656                id uuid primary_key,
1657                title text default 'hello, world' not null,
1658                tag text check(tag in ('a,b', 'c)')),
1659                note text default 'paren ) and comma, still literal'
1660            )
1661        "#;
1662
1663        let schema = Schema::parse(input).expect("parse failed");
1664        let messages = &schema.tables[0];
1665        assert_eq!(messages.columns.len(), 4);
1666
1667        let title = messages.find_column("title").expect("title not found");
1668        assert_eq!(title.default_value, Some("'hello, world'".to_string()));
1669        assert!(!title.nullable);
1670
1671        let tag = messages.find_column("tag").expect("tag not found");
1672        assert_eq!(tag.check, Some("tag in ('a,b', 'c)')".to_string()));
1673
1674        let note = messages.find_column("note").expect("note not found");
1675        assert_eq!(
1676            note.default_value,
1677            Some("'paren ) and comma, still literal'".to_string())
1678        );
1679    }
1680
1681    #[test]
1682    fn test_constraint_keywords_inside_literals_do_not_become_constraints() {
1683        let input = r#"
1684            table messages (
1685                plain text default 'unique not null primary key references users(id) check(x)',
1686                fn_default text default unique_label(),
1687                guarded text check(note = 'unique not null primary key')
1688            )
1689        "#;
1690
1691        let schema = Schema::parse(input).expect("parse failed");
1692        let messages = &schema.tables[0];
1693
1694        let plain = messages.find_column("plain").expect("plain not found");
1695        assert_eq!(
1696            plain.default_value.as_deref(),
1697            Some("'unique not null primary key references users(id) check(x)'")
1698        );
1699        assert!(!plain.unique);
1700        assert!(plain.nullable);
1701        assert!(!plain.primary_key);
1702        assert!(plain.references.is_none());
1703        assert!(plain.check.is_none());
1704
1705        let fn_default = messages
1706            .find_column("fn_default")
1707            .expect("fn_default not found");
1708        assert_eq!(fn_default.default_value.as_deref(), Some("unique_label()"));
1709        assert!(!fn_default.unique);
1710
1711        let guarded = messages.find_column("guarded").expect("guarded not found");
1712        assert_eq!(
1713            guarded.check.as_deref(),
1714            Some("note = 'unique not null primary key'")
1715        );
1716        assert!(!guarded.unique);
1717        assert!(guarded.nullable);
1718        assert!(!guarded.primary_key);
1719    }
1720
1721    #[test]
1722    fn test_rejects_malformed_column_constraints() {
1723        for input in [
1724            "table bad (name text default)",
1725            "table bad (user_id uuid references)",
1726            "table bad (age int check())",
1727            "table bad (age int check(age > 0)",
1728            "table bad (name text unique unique)",
1729            "table bad (id uuid primary_key primary key)",
1730            "table bad (user_id uuid references users(id) references accounts(id))",
1731        ] {
1732            Schema::parse(input).expect_err("malformed column constraint must fail");
1733        }
1734    }
1735
1736    #[test]
1737    fn test_index_columns_handle_nested_expression_commas() {
1738        let input = r#"
1739            table docs (
1740                id uuid primary_key,
1741                title text,
1742                slug text
1743            )
1744
1745            index idx_docs_search on docs (regexp_replace(title, ')', '', 'g'), lower(slug)) unique
1746        "#;
1747
1748        let schema = Schema::parse(input).expect("parse failed");
1749        assert_eq!(schema.indexes.len(), 1);
1750        let index = &schema.indexes[0];
1751        assert_eq!(index.name, "idx_docs_search");
1752        assert_eq!(
1753            index.columns,
1754            vec![
1755                "regexp_replace(title, ')', '', 'g')".to_string(),
1756                "lower(slug)".to_string()
1757            ]
1758        );
1759        assert!(index.unique);
1760        assert_eq!(
1761            index.to_sql(),
1762            "CREATE UNIQUE INDEX IF NOT EXISTS idx_docs_search ON docs (regexp_replace(title, ')', '', 'g'), lower(slug))"
1763        );
1764    }
1765
1766    #[test]
1767    fn test_index_rejects_empty_columns() {
1768        for input in [
1769            "index idx_docs_search on docs ()",
1770            "index idx_docs_search on docs (title,)",
1771            "index idx_docs_search on docs (,title)",
1772            "index idx_docs_search on docs (title,,slug)",
1773        ] {
1774            let err = Schema::parse(input).expect_err("empty index columns should fail");
1775            assert!(
1776                err.contains("Parse error") || err.contains("Unexpected content"),
1777                "{err}"
1778            );
1779        }
1780    }
1781
1782    #[test]
1783    fn test_version_directive() {
1784        let input = r#"
1785            -- qail: version=1
1786            table users (
1787                id uuid primary_key
1788            )
1789        "#;
1790
1791        let schema = Schema::parse(input).expect("parse failed");
1792        assert_eq!(schema.version, Some(1));
1793        assert_eq!(schema.tables.len(), 1);
1794
1795        // Without version directive
1796        let input_no_version = r#"
1797            table items (
1798                id uuid primary_key
1799            )
1800        "#;
1801        let schema2 = Schema::parse(input_no_version).expect("parse failed");
1802        assert_eq!(schema2.version, None);
1803    }
1804
1805    // =========================================================================
1806    // Policy + enable_rls tests
1807    // =========================================================================
1808
1809    #[test]
1810    fn test_enable_rls_table() {
1811        let input = r#"
1812            table orders (
1813                id uuid primary_key,
1814                tenant_id uuid not null
1815            ) enable_rls
1816        "#;
1817
1818        let schema = Schema::parse(input).expect("parse failed");
1819        assert_eq!(schema.tables.len(), 1);
1820        assert!(schema.tables[0].enable_rls);
1821    }
1822
1823    #[test]
1824    fn test_parse_policy_basic() {
1825        let input = r#"
1826            table orders (
1827                id uuid primary_key,
1828                tenant_id uuid not null
1829            ) enable_rls
1830
1831            policy orders_isolation on orders
1832                for all
1833                using (tenant_id = current_setting('app.current_tenant_id')::uuid)
1834        "#;
1835
1836        let schema = Schema::parse(input).expect("parse failed");
1837        assert_eq!(schema.tables.len(), 1);
1838        assert_eq!(schema.policies.len(), 1);
1839
1840        let policy = &schema.policies[0];
1841        assert_eq!(policy.name, "orders_isolation");
1842        assert_eq!(policy.table, "orders");
1843        assert_eq!(policy.target, PolicyTarget::All);
1844        assert!(policy.using.is_some());
1845
1846        // Verify the expression is a typed Binary, not raw SQL
1847        let using = policy.using.as_ref().unwrap();
1848        let Expr::Binary {
1849            left, op, right, ..
1850        } = using
1851        else {
1852            panic!("Expected Binary, got {using:?}");
1853        };
1854        assert_eq!(*op, BinaryOp::Eq);
1855
1856        let Expr::Named(n) = left.as_ref() else {
1857            panic!("Expected Named, got {left:?}");
1858        };
1859        assert_eq!(n, "tenant_id");
1860
1861        let Expr::Cast {
1862            target_type,
1863            expr: cast_expr,
1864            ..
1865        } = right.as_ref()
1866        else {
1867            panic!("Expected Cast, got {right:?}");
1868        };
1869        assert_eq!(target_type, "uuid");
1870
1871        let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1872            panic!("Expected FunctionCall, got {cast_expr:?}");
1873        };
1874        assert_eq!(name, "current_setting");
1875        assert_eq!(args.len(), 1);
1876    }
1877
1878    #[test]
1879    fn test_parse_policy_with_check() {
1880        let input = r#"
1881            table orders (
1882                id uuid primary_key
1883            )
1884
1885            policy orders_write on orders
1886                for insert
1887                with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
1888        "#;
1889
1890        let schema = Schema::parse(input).expect("parse failed");
1891        let policy = &schema.policies[0];
1892        assert_eq!(policy.target, PolicyTarget::Insert);
1893        assert!(policy.with_check.is_some());
1894        assert!(policy.using.is_none());
1895    }
1896
1897    #[test]
1898    fn test_parse_policy_restrictive_with_role() {
1899        let input = r#"
1900            table secrets (
1901                id uuid primary_key
1902            )
1903
1904            policy admin_only on secrets
1905                for select
1906                restrictive
1907                to app_user
1908                using (current_setting('app.is_super_admin')::boolean = true)
1909        "#;
1910
1911        let schema = Schema::parse(input).expect("parse failed");
1912        let policy = &schema.policies[0];
1913        assert_eq!(policy.target, PolicyTarget::Select);
1914        assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1915        assert_eq!(policy.role.as_deref(), Some("app_user"));
1916        assert!(policy.using.is_some());
1917    }
1918
1919    #[test]
1920    fn test_parse_policy_or_expr() {
1921        let input = r#"
1922            table orders (
1923                id uuid primary_key
1924            )
1925
1926            policy tenant_or_admin on orders
1927                for all
1928                using (tenant_id = current_setting('app.current_tenant_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1929        "#;
1930
1931        let schema = Schema::parse(input).expect("parse failed");
1932        let policy = &schema.policies[0];
1933
1934        assert!(
1935            matches!(
1936                policy.using.as_ref().unwrap(),
1937                Expr::Binary {
1938                    op: BinaryOp::Or,
1939                    ..
1940                }
1941            ),
1942            "Expected Binary OR, got {:?}",
1943            policy.using
1944        );
1945    }
1946
1947    #[test]
1948    fn test_parse_policy_string_literals_escape_and_fail_closed() {
1949        let input = r#"
1950            table users (
1951                id uuid primary_key,
1952                name text
1953            )
1954
1955            policy users_name on users
1956                for select
1957                using (name = 'Bob''s account')
1958        "#;
1959        let schema = Schema::parse(input).expect("escaped quote string should parse");
1960        let Expr::Binary { right, .. } = schema.policies[0].using.as_ref().unwrap() else {
1961            panic!("expected binary expression");
1962        };
1963        assert!(matches!(
1964            right.as_ref(),
1965            Expr::Literal(AstValue::String(value)) if value == "Bob's account"
1966        ));
1967
1968        let input = r#"
1969            table users (
1970                id uuid primary_key,
1971                name text
1972            )
1973
1974            policy users_name on users
1975                for select
1976                using (name = 'unterminated)
1977        "#;
1978        Schema::parse(input).expect_err("unterminated policy string must fail");
1979    }
1980
1981    #[test]
1982    fn test_parse_policy_rejects_duplicate_clauses() {
1983        for input in [
1984            r#"
1985            table orders (id uuid primary_key)
1986            policy p on orders for select for update using (id = 1)
1987            "#,
1988            r#"
1989            table orders (id uuid primary_key)
1990            policy p on orders to app_user to app_admin using (id = 1)
1991            "#,
1992            r#"
1993            table orders (id uuid primary_key)
1994            policy p on orders restrictive restrictive using (id = 1)
1995            "#,
1996            r#"
1997            table orders (id uuid primary_key)
1998            policy p on orders using (id = 1) using (id = 2)
1999            "#,
2000            r#"
2001            table orders (id uuid primary_key)
2002            policy p on orders with check (id = 1) with check (id = 2)
2003            "#,
2004        ] {
2005            Schema::parse(input).expect_err("duplicate policy clause must fail");
2006        }
2007    }
2008
2009    #[test]
2010    fn test_parse_policy_and_has_higher_precedence_than_or() {
2011        let input = r#"
2012            table orders (
2013                id uuid primary_key,
2014                tenant_id uuid,
2015                active bool,
2016                public bool
2017            )
2018
2019            policy mixed on orders
2020                for select
2021                using (public = true or tenant_id = 7 and active = true)
2022        "#;
2023
2024        let schema = Schema::parse(input).expect("parse failed");
2025        let Expr::Binary {
2026            op: BinaryOp::Or,
2027            right,
2028            ..
2029        } = schema.policies[0].using.as_ref().unwrap()
2030        else {
2031            panic!("expected top-level OR");
2032        };
2033        assert!(matches!(
2034            right.as_ref(),
2035            Expr::Binary {
2036                op: BinaryOp::And,
2037                ..
2038            }
2039        ));
2040    }
2041
2042    #[test]
2043    fn test_parse_policy_rejects_invalid_numeric_literals() {
2044        let huge = "999999999999999999999999999999999999999999999999999999999999999999";
2045        let input = format!(
2046            r#"
2047            table orders (
2048                id uuid primary_key,
2049                amount numeric
2050            )
2051
2052            policy amount_guard on orders
2053                for select
2054                using (amount = {huge})
2055        "#
2056        );
2057        assert!(Schema::parse(&input).is_err());
2058
2059        let input = r#"
2060            table orders (
2061                id uuid primary_key,
2062                amount numeric
2063            )
2064
2065            policy amount_guard on orders
2066                for select
2067                using (amount = 1.2.3)
2068        "#;
2069        assert!(Schema::parse(input).is_err());
2070
2071        let input = r#"
2072            table orders (
2073                id uuid primary_key,
2074                amount numeric
2075            )
2076
2077            policy amount_guard on orders
2078                for select
2079                using (amount = 9007199254740993.25)
2080        "#;
2081        assert!(Schema::parse(input).is_err());
2082    }
2083
2084    #[test]
2085    fn test_schema_to_sql() {
2086        let input = r#"
2087            table orders (
2088                id uuid primary_key,
2089                tenant_id uuid not null
2090            ) enable_rls
2091
2092            policy orders_isolation on orders
2093                for all
2094                using (tenant_id = current_setting('app.current_tenant_id')::uuid)
2095        "#;
2096
2097        let schema = Schema::parse(input).expect("parse failed");
2098        let sql = schema.to_sql();
2099        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
2100        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
2101        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
2102        assert!(sql.contains("CREATE POLICY"));
2103        assert!(sql.contains("orders_isolation"));
2104        assert!(sql.contains("FOR ALL"));
2105    }
2106
2107    #[test]
2108    fn test_multiple_policies() {
2109        let input = r#"
2110            table orders (
2111                id uuid primary_key,
2112                tenant_id uuid not null
2113            ) enable_rls
2114
2115            policy orders_read on orders
2116                for select
2117                using (tenant_id = current_setting('app.current_tenant_id')::uuid)
2118
2119            policy orders_write on orders
2120                for insert
2121                with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
2122        "#;
2123
2124        let schema = Schema::parse(input).expect("parse failed");
2125        assert_eq!(schema.policies.len(), 2);
2126        assert_eq!(schema.policies[0].name, "orders_read");
2127        assert_eq!(schema.policies[0].target, PolicyTarget::Select);
2128        assert_eq!(schema.policies[1].name, "orders_write");
2129        assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
2130    }
2131}