1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26use std::collections::HashSet;
27
28use crate::ast::{BinaryOp, Expr, Value as AstValue};
29use crate::migrate::alter::AlterTable;
30use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
31use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
32
33#[derive(Debug, Clone, Default)]
35pub struct Schema {
36 pub version: Option<u32>,
38 pub tables: Vec<TableDef>,
40 pub policies: Vec<RlsPolicy>,
42 pub indexes: Vec<IndexDef>,
44}
45
46#[derive(Debug, Clone)]
48pub struct IndexDef {
49 pub name: String,
51 pub table: String,
53 pub columns: Vec<String>,
55 pub unique: bool,
57}
58
59impl IndexDef {
60 pub fn to_sql(&self) -> String {
62 let unique = if self.unique { " UNIQUE" } else { "" };
63 format!(
64 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
65 unique,
66 self.name,
67 self.table,
68 self.columns.join(", ")
69 )
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct TableDef {
76 pub name: String,
78 pub columns: Vec<ColumnDef>,
80 pub enable_rls: bool,
82}
83
84#[derive(Debug, Clone)]
86pub struct ColumnDef {
87 pub name: String,
89 pub typ: String,
91 pub is_array: bool,
93 pub type_params: Option<Vec<String>>,
95 pub nullable: bool,
97 pub primary_key: bool,
99 pub unique: bool,
101 pub references: Option<String>,
103 pub default_value: Option<String>,
105 pub check: Option<String>,
107 pub is_serial: bool,
109}
110
111impl Default for ColumnDef {
112 fn default() -> Self {
113 Self {
114 name: String::new(),
115 typ: String::new(),
116 is_array: false,
117 type_params: None,
118 nullable: true,
119 primary_key: false,
120 unique: false,
121 references: None,
122 default_value: None,
123 check: None,
124 is_serial: false,
125 }
126 }
127}
128
129impl Schema {
130 pub fn parse(input: &str) -> Result<Self, String> {
132 match parse_schema(input) {
133 Ok(("", schema)) => Ok(schema),
134 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
135 Err(e) => Err(format!("Parse error: {:?}", e)),
136 }
137 }
138
139 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
141 self.tables
142 .iter()
143 .find(|t| t.name.eq_ignore_ascii_case(name))
144 }
145
146 pub fn to_sql(&self) -> String {
148 let mut parts = Vec::new();
149
150 for table in &self.tables {
151 parts.push(table.to_ddl());
152
153 if table.enable_rls {
154 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
155 for stmt in alter_table_sql(&alter) {
156 parts.push(stmt);
157 }
158 }
159 }
160
161 for idx in &self.indexes {
162 parts.push(idx.to_sql());
163 }
164
165 for policy in &self.policies {
166 parts.push(create_policy_sql(policy));
167 }
168
169 parts.join(";\n\n") + ";"
170 }
171
172 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
174 let content =
175 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
176 Self::parse(&content)
177 }
178}
179
180impl TableDef {
181 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
183 self.columns
184 .iter()
185 .find(|c| c.name.eq_ignore_ascii_case(name))
186 }
187
188 pub fn to_ddl(&self) -> String {
190 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
191
192 let mut col_defs = Vec::new();
193 for col in &self.columns {
194 let mut line = format!(" {}", col.name);
195
196 let mut typ = col.typ.to_uppercase();
198 if let Some(params) = &col.type_params {
199 typ = format!("{}({})", typ, params.join(", "));
200 }
201 if col.is_array {
202 typ.push_str("[]");
203 }
204 line.push_str(&format!(" {}", typ));
205
206 if col.primary_key {
208 line.push_str(" PRIMARY KEY");
209 }
210 if !col.nullable && !col.primary_key && !col.is_serial {
211 line.push_str(" NOT NULL");
212 }
213 if col.unique && !col.primary_key {
214 line.push_str(" UNIQUE");
215 }
216 if let Some(ref default) = col.default_value {
217 line.push_str(&format!(" DEFAULT {}", default));
218 }
219 if let Some(ref refs) = col.references {
220 line.push_str(&format!(" REFERENCES {}", refs));
221 }
222 if let Some(ref check) = col.check {
223 line.push_str(&format!(" CHECK({})", check));
224 }
225
226 col_defs.push(line);
227 }
228
229 sql.push_str(&col_defs.join(",\n"));
230 sql.push_str("\n)");
231 sql
232 }
233}
234
235fn identifier(input: &str) -> IResult<&str, &str> {
241 let (remaining, ident) =
242 take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_').parse(input)?;
243 if ident
244 .chars()
245 .next()
246 .is_some_and(|c| c.is_ascii_alphabetic() || c == '_')
247 {
248 Ok((remaining, ident))
249 } else {
250 Err(nom::Err::Error(nom::error::Error::new(
251 input,
252 nom::error::ErrorKind::Alpha,
253 )))
254 }
255}
256
257fn ws_and_comments(input: &str) -> IResult<&str, ()> {
259 let (input, _) = many0(alt((
260 map(multispace1, |_| ()),
261 map((tag("--"), not_line_ending), |_| ()),
262 map((tag("#"), not_line_ending), |_| ()),
263 )))
264 .parse(input)?;
265 Ok((input, ()))
266}
267
268struct TypeInfo {
269 name: String,
270 params: Option<Vec<String>>,
271 is_array: bool,
272 is_serial: bool,
273}
274
275fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
278 let (input, type_name) =
279 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)?;
280 if !is_identifier_path(type_name) {
281 return Err(nom::Err::Error(nom::error::Error::new(
282 input,
283 nom::error::ErrorKind::Alpha,
284 )));
285 }
286
287 let (input, params) = if let Some(after_open) = input.strip_prefix('(') {
288 let Some(paren_end) = after_open.find(')') else {
289 return Err(nom::Err::Error(nom::error::Error::new(
290 input,
291 nom::error::ErrorKind::Char,
292 )));
293 };
294 let param_str = &after_open[..paren_end];
295 let Ok(params) = split_top_level_csv(param_str) else {
296 return Err(nom::Err::Error(nom::error::Error::new(
297 input,
298 nom::error::ErrorKind::SeparatedList,
299 )));
300 };
301 if params.is_empty() {
302 return Err(nom::Err::Error(nom::error::Error::new(
303 input,
304 nom::error::ErrorKind::SeparatedList,
305 )));
306 }
307 (&after_open[paren_end + 1..], Some(params))
308 } else {
309 (input, None)
310 };
311
312 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
313 (stripped, true)
314 } else {
315 (input, false)
316 };
317
318 let lower = type_name.to_lowercase();
319 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
320
321 Ok((
322 input,
323 TypeInfo {
324 name: lower,
325 params,
326 is_array,
327 is_serial,
328 },
329 ))
330}
331
332fn is_identifier_path(path: &str) -> bool {
333 let mut seen = false;
334 for part in path.split('.') {
335 seen = true;
336 let mut chars = part.chars();
337 match chars.next() {
338 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
339 _ => return false,
340 }
341 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
342 return false;
343 }
344 }
345 seen
346}
347
348fn constraint_text(input: &str) -> IResult<&str, &str> {
350 let mut paren_depth = 0;
351 let mut in_single = false;
352 let mut in_double = false;
353 let mut end = 0;
354 let mut iter = input.char_indices().peekable();
355
356 while let Some((i, c)) = iter.next() {
357 match c {
358 '\'' if !in_double => {
359 if in_single && matches!(iter.peek(), Some((_, '\''))) {
360 iter.next();
361 } else {
362 in_single = !in_single;
363 }
364 }
365 '"' if !in_single => {
366 if in_double && matches!(iter.peek(), Some((_, '"'))) {
367 iter.next();
368 } else {
369 in_double = !in_double;
370 }
371 }
372 '(' if !in_single && !in_double => paren_depth += 1,
373 ')' if !in_single && !in_double => {
374 if paren_depth == 0 {
375 break; }
377 paren_depth -= 1;
378 }
379 ',' if !in_single && !in_double && paren_depth == 0 => break,
380 '\n' | '\r' if !in_single && !in_double && paren_depth == 0 => break,
381 _ => {}
382 }
383 end = i + c.len_utf8();
384 }
385
386 if end == 0 {
387 Err(nom::Err::Error(nom::error::Error::new(
388 input,
389 nom::error::ErrorKind::TakeWhile1,
390 )))
391 } else {
392 Ok((&input[end..], &input[..end]))
393 }
394}
395
396fn check_expr_end(rest: &str) -> usize {
397 let mut depth = 1usize;
398 let mut in_single = false;
399 let mut in_double = false;
400 let mut iter = rest.char_indices().peekable();
401
402 while let Some((idx, ch)) = iter.next() {
403 match ch {
404 '\'' if !in_double => {
405 if in_single && matches!(iter.peek(), Some((_, '\''))) {
406 iter.next();
407 } else {
408 in_single = !in_single;
409 }
410 }
411 '"' if !in_single => {
412 if in_double && matches!(iter.peek(), Some((_, '"'))) {
413 iter.next();
414 } else {
415 in_double = !in_double;
416 }
417 }
418 '(' if !in_single && !in_double => depth += 1,
419 ')' if !in_single && !in_double => {
420 depth -= 1;
421 if depth == 0 {
422 return idx;
423 }
424 }
425 _ => {}
426 }
427 }
428
429 rest.len()
430}
431
432fn checked_check_expr_end(rest: &str) -> Option<usize> {
433 let end = check_expr_end(rest);
434 (end < rest.len()).then_some(end)
435}
436
437fn parenthesized_content(input: &str) -> IResult<&str, &str> {
438 let mut paren_depth = 0usize;
439 let mut in_single = false;
440 let mut in_double = false;
441 let mut iter = input.char_indices().peekable();
442
443 while let Some((idx, ch)) = iter.next() {
444 match ch {
445 '\'' if !in_double => {
446 if in_single && matches!(iter.peek(), Some((_, '\''))) {
447 iter.next();
448 } else {
449 in_single = !in_single;
450 }
451 }
452 '"' if !in_single => {
453 if in_double && matches!(iter.peek(), Some((_, '"'))) {
454 iter.next();
455 } else {
456 in_double = !in_double;
457 }
458 }
459 '(' if !in_single && !in_double => paren_depth += 1,
460 ')' if !in_single && !in_double => {
461 if paren_depth == 0 {
462 return Ok((&input[idx + ch.len_utf8()..], &input[..idx]));
463 }
464 paren_depth -= 1;
465 }
466 _ => {}
467 }
468 }
469
470 Err(nom::Err::Error(nom::error::Error::new(
471 input,
472 nom::error::ErrorKind::Char,
473 )))
474}
475
476fn split_top_level_csv(input: &str) -> Result<Vec<String>, ()> {
477 let mut parts = Vec::new();
478 let mut start = 0usize;
479 let mut paren_depth = 0usize;
480 let mut in_single = false;
481 let mut in_double = false;
482 let mut iter = input.char_indices().peekable();
483
484 while let Some((idx, ch)) = iter.next() {
485 match ch {
486 '\'' if !in_double => {
487 if in_single && matches!(iter.peek(), Some((_, '\''))) {
488 iter.next();
489 } else {
490 in_single = !in_single;
491 }
492 }
493 '"' if !in_single => {
494 if in_double && matches!(iter.peek(), Some((_, '"'))) {
495 iter.next();
496 } else {
497 in_double = !in_double;
498 }
499 }
500 '(' if !in_single && !in_double => paren_depth += 1,
501 ')' if !in_single && !in_double => {
502 if paren_depth == 0 {
503 return Err(());
504 }
505 paren_depth -= 1;
506 }
507 ',' if !in_single && !in_double && paren_depth == 0 => {
508 let part = input[start..idx].trim();
509 if part.is_empty() {
510 return Err(());
511 }
512 parts.push(part.to_string());
513 start = idx + ch.len_utf8();
514 }
515 _ => {}
516 }
517 }
518
519 if in_single || in_double || paren_depth != 0 {
520 return Err(());
521 }
522 let part = input[start..].trim();
523 if part.is_empty() {
524 if !input.trim().is_empty() {
525 return Err(());
526 }
527 } else {
528 parts.push(part.to_string());
529 }
530
531 Ok(parts)
532}
533
534fn starts_constraint_keyword(input: &str) -> bool {
535 let lower = input.to_ascii_lowercase();
536 matches!(
537 lower.as_str(),
538 s if s.starts_with("primary_key")
539 || s.starts_with("primary key")
540 || s.starts_with("not_null")
541 || s.starts_with("not null")
542 || s.starts_with("unique")
543 || s.starts_with("references ")
544 || s.starts_with("check(")
545 )
546}
547
548fn default_expr_end(rest: &str) -> usize {
549 let mut in_single = false;
550 let mut in_double = false;
551 let mut paren_depth = 0usize;
552 let mut iter = rest.char_indices().peekable();
553
554 while let Some((idx, ch)) = iter.next() {
555 match ch {
556 '\'' if !in_double => {
557 if in_single && matches!(iter.peek(), Some((_, '\''))) {
558 iter.next();
559 } else {
560 in_single = !in_single;
561 }
562 }
563 '"' if !in_single => {
564 if in_double && matches!(iter.peek(), Some((_, '"'))) {
565 iter.next();
566 } else {
567 in_double = !in_double;
568 }
569 }
570 '(' if !in_single && !in_double => paren_depth += 1,
571 ')' if !in_single && !in_double && paren_depth > 0 => paren_depth -= 1,
572 c if c.is_whitespace()
573 && !in_single
574 && !in_double
575 && paren_depth == 0
576 && starts_constraint_keyword(rest[idx..].trim_start()) =>
577 {
578 return idx;
579 }
580 _ => {}
581 }
582 }
583
584 rest.len()
585}
586
587fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
589 let (input, _) = ws_and_comments(input)?;
590 let (input, name) = identifier(input)?;
591 let (input, _) = multispace1(input)?;
592 let (input, type_info) = parse_type_info(input)?;
593
594 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
595
596 let mut col = ColumnDef {
597 name: name.to_string(),
598 typ: type_info.name,
599 is_array: type_info.is_array,
600 type_params: type_info.params,
601 is_serial: type_info.is_serial,
602 nullable: !type_info.is_serial, ..Default::default()
604 };
605
606 if let Some(constraints) = constraint_str
607 && parse_column_constraints(&mut col, constraints).is_err()
608 {
609 return Err(nom::Err::Error(nom::error::Error::new(
610 constraints,
611 nom::error::ErrorKind::Verify,
612 )));
613 }
614
615 Ok((input, col))
616}
617
618fn parse_column_constraints(col: &mut ColumnDef, constraints: &str) -> Result<(), ()> {
619 let mut rest = constraints.trim();
620 while !rest.is_empty() {
621 if let Some(next) = strip_keyword_ci(rest, "primary_key") {
622 if col.primary_key {
623 return Err(());
624 }
625 col.primary_key = true;
626 col.nullable = false;
627 rest = next.trim_start();
628 continue;
629 }
630 if let Some(next) = strip_keyword_pair_ci(rest, "primary", "key") {
631 if col.primary_key {
632 return Err(());
633 }
634 col.primary_key = true;
635 col.nullable = false;
636 rest = next.trim_start();
637 continue;
638 }
639 if let Some(next) = strip_keyword_ci(rest, "not_null") {
640 col.nullable = false;
641 rest = next.trim_start();
642 continue;
643 }
644 if let Some(next) = strip_keyword_pair_ci(rest, "not", "null") {
645 col.nullable = false;
646 rest = next.trim_start();
647 continue;
648 }
649 if let Some(next) = strip_keyword_ci(rest, "unique") {
650 if col.unique {
651 return Err(());
652 }
653 col.unique = true;
654 rest = next.trim_start();
655 continue;
656 }
657 if let Some(next) = strip_keyword_ci(rest, "references") {
658 if col.references.is_some() {
659 return Err(());
660 }
661 let next = next.trim_start();
662 let (target, tail) = parse_reference_constraint_target(next)?;
663 if target.is_empty() {
664 return Err(());
665 }
666 col.references = Some(target.to_string());
667 rest = tail.trim_start();
668 continue;
669 }
670 if let Some(next) = strip_keyword_ci(rest, "default") {
671 if col.default_value.is_some() {
672 return Err(());
673 }
674 let next = next.trim_start();
675 if next.is_empty() {
676 return Err(());
677 }
678 let end = default_expr_end(next);
679 if end == 0 {
680 return Err(());
681 }
682 col.default_value = Some(next[..end].trim_end().to_string());
683 rest = next[end..].trim_start();
684 continue;
685 }
686 if let Some(next) = strip_keyword_ci(rest, "check") {
687 if col.check.is_some() {
688 return Err(());
689 }
690 let next = next.trim_start();
691 let Some(after_open) = next.strip_prefix('(') else {
692 return Err(());
693 };
694 let Some(end) = checked_check_expr_end(after_open) else {
695 return Err(());
696 };
697 let expr = after_open[..end].trim();
698 if expr.is_empty() {
699 return Err(());
700 }
701 col.check = Some(expr.to_string());
702 rest = after_open[end + 1..].trim_start();
703 continue;
704 }
705
706 return Err(());
707 }
708
709 Ok(())
710}
711
712fn strip_keyword_pair_ci<'a>(input: &'a str, first: &str, second: &str) -> Option<&'a str> {
713 let rest = strip_keyword_ci(input, first)?.trim_start();
714 strip_keyword_ci(rest, second)
715}
716
717fn strip_keyword_ci<'a>(input: &'a str, keyword: &str) -> Option<&'a str> {
718 if input.len() < keyword.len() {
719 return None;
720 }
721 let (head, tail) = input.split_at(keyword.len());
722 if !head.eq_ignore_ascii_case(keyword) {
723 return None;
724 }
725 if tail
726 .chars()
727 .next()
728 .is_some_and(|ch| ch.is_ascii_alphanumeric() || ch == '_')
729 {
730 return None;
731 }
732 Some(tail)
733}
734
735fn parse_reference_constraint_target(input: &str) -> Result<(&str, &str), ()> {
736 let mut paren_depth = 0usize;
737 let mut end = input.len();
738 for (idx, ch) in input.char_indices() {
739 match ch {
740 '(' => paren_depth += 1,
741 ')' => {
742 if paren_depth == 0 {
743 end = idx;
744 break;
745 }
746 paren_depth -= 1;
747 }
748 c if c.is_whitespace() && paren_depth == 0 => {
749 end = idx;
750 break;
751 }
752 _ => {}
753 }
754 }
755 if paren_depth != 0 {
756 return Err(());
757 }
758 Ok((&input[..end], &input[end..]))
759}
760
761fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
763 let (input, _) = ws_and_comments(input)?;
764 let (input, _) = char('(').parse(input)?;
765 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
766 let (input, _) = ws_and_comments(input)?;
767 let (input, _) = char(')').parse(input)?;
768
769 Ok((input, columns))
770}
771
772fn parse_table(input: &str) -> IResult<&str, TableDef> {
774 let (input, _) = ws_and_comments(input)?;
775 let (input, _) = tag_no_case("table").parse(input)?;
776 let (input, _) = multispace1(input)?;
777 let (input, name) = identifier(input)?;
778 let (input, columns) = parse_column_list(input)?;
779 if columns.is_empty() || has_duplicate_column_names(&columns) {
780 return Err(nom::Err::Error(nom::error::Error::new(
781 input,
782 nom::error::ErrorKind::Verify,
783 )));
784 }
785
786 let (input, _) = ws_and_comments(input)?;
788 let enable_rls = if let Ok((rest, _)) =
789 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
790 {
791 return Ok((
792 rest,
793 TableDef {
794 name: name.to_string(),
795 columns,
796 enable_rls: true,
797 },
798 ));
799 } else {
800 false
801 };
802
803 Ok((
804 input,
805 TableDef {
806 name: name.to_string(),
807 columns,
808 enable_rls,
809 },
810 ))
811}
812
813fn has_duplicate_column_names(columns: &[ColumnDef]) -> bool {
814 let mut seen = HashSet::new();
815 columns
816 .iter()
817 .any(|column| !seen.insert(column.name.to_ascii_lowercase()))
818}
819
820enum SchemaItem {
826 Table(TableDef),
827 Policy(Box<RlsPolicy>),
828 Index(IndexDef),
829}
830
831fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
843 let (input, _) = ws_and_comments(input)?;
844 let (input, _) = tag_no_case("policy").parse(input)?;
845 let (input, _) = multispace1(input)?;
846 let (input, name) = identifier(input)?;
847 let (input, _) = multispace1(input)?;
848 let (input, _) = tag_no_case("on").parse(input)?;
849 let (input, _) = multispace1(input)?;
850 let (input, table) = identifier(input)?;
851
852 let mut policy = RlsPolicy::create(name, table);
853
854 let mut remaining = input;
856 let mut seen_for = false;
857 let mut seen_restrictive = false;
858 let mut seen_role = false;
859 let mut seen_using = false;
860 let mut seen_with_check = false;
861 loop {
862 let (input, _) = ws_and_comments(remaining)?;
863
864 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
866 if seen_for {
867 return Err(nom::Err::Error(nom::error::Error::new(
868 input,
869 nom::error::ErrorKind::Verify,
870 )));
871 }
872 seen_for = true;
873 let (rest, _) = multispace1(rest)?;
874 let (rest, target) = alt((
875 map(tag_no_case("all"), |_| PolicyTarget::All),
876 map(tag_no_case("select"), |_| PolicyTarget::Select),
877 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
878 map(tag_no_case("update"), |_| PolicyTarget::Update),
879 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
880 ))
881 .parse(rest)?;
882 policy.target = target;
883 remaining = rest;
884 continue;
885 }
886
887 if let Ok((rest, _)) =
889 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
890 {
891 if seen_restrictive {
892 return Err(nom::Err::Error(nom::error::Error::new(
893 input,
894 nom::error::ErrorKind::Verify,
895 )));
896 }
897 seen_restrictive = true;
898 policy.permissiveness = PolicyPermissiveness::Restrictive;
899 remaining = rest;
900 continue;
901 }
902
903 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
905 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
907 if seen_role {
908 return Err(nom::Err::Error(nom::error::Error::new(
909 input,
910 nom::error::ErrorKind::Verify,
911 )));
912 }
913 seen_role = true;
914 let (rest, role) = identifier(rest)?;
915 policy.role = Some(role.to_string());
916 remaining = rest;
917 continue;
918 }
919 }
920
921 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
923 let (rest, _) = multispace1(rest)?;
924 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
925 {
926 if seen_with_check {
927 return Err(nom::Err::Error(nom::error::Error::new(
928 input,
929 nom::error::ErrorKind::Verify,
930 )));
931 }
932 seen_with_check = true;
933 let (rest, _) = nom_ws0(rest)?;
934 let (rest, _) = char('(').parse(rest)?;
935 let (rest, _) = nom_ws0(rest)?;
936 let (rest, expr) = parse_policy_expr(rest)?;
937 let (rest, _) = nom_ws0(rest)?;
938 let (rest, _) = char(')').parse(rest)?;
939 policy.with_check = Some(expr);
940 remaining = rest;
941 continue;
942 }
943 }
944
945 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
947 if seen_using {
948 return Err(nom::Err::Error(nom::error::Error::new(
949 input,
950 nom::error::ErrorKind::Verify,
951 )));
952 }
953 seen_using = true;
954 let (rest, _) = nom_ws0(rest)?;
955 let (rest, _) = char('(').parse(rest)?;
956 let (rest, _) = nom_ws0(rest)?;
957 let (rest, expr) = parse_policy_expr(rest)?;
958 let (rest, _) = nom_ws0(rest)?;
959 let (rest, _) = char(')').parse(rest)?;
960 policy.using = Some(expr);
961 remaining = rest;
962 continue;
963 }
964
965 remaining = input;
967 break;
968 }
969
970 Ok((remaining, policy))
971}
972
973fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
982 parse_policy_or_expr(input)
983}
984
985fn parse_policy_or_expr(input: &str) -> IResult<&str, Expr> {
986 let (input, first) = parse_policy_and_expr(input)?;
987
988 let mut result = first;
989 let mut remaining = input;
990 loop {
991 let (input, _) = nom_ws0(remaining)?;
992
993 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
994 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
995 {
996 let (rest, right) = parse_policy_and_expr(rest)?;
997 result = Expr::Binary {
998 left: Box::new(result),
999 op: BinaryOp::Or,
1000 right: Box::new(right),
1001 alias: None,
1002 };
1003 remaining = rest;
1004 continue;
1005 }
1006
1007 remaining = input;
1008 break;
1009 }
1010
1011 Ok((remaining, result))
1012}
1013
1014fn parse_policy_and_expr(input: &str) -> IResult<&str, Expr> {
1015 let (input, first) = parse_policy_comparison(input)?;
1016
1017 let mut result = first;
1018 let mut remaining = input;
1019 loop {
1020 let (input, _) = nom_ws0(remaining)?;
1021
1022 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
1023 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
1024 {
1025 let (rest, right) = parse_policy_comparison(rest)?;
1026 result = Expr::Binary {
1027 left: Box::new(result),
1028 op: BinaryOp::And,
1029 right: Box::new(right),
1030 alias: None,
1031 };
1032 remaining = rest;
1033 continue;
1034 }
1035
1036 remaining = input;
1037 break;
1038 }
1039
1040 Ok((remaining, result))
1041}
1042
1043fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
1045 let (input, left) = parse_policy_atom(input)?;
1046 let (input, _) = nom_ws0(input)?;
1047
1048 if let Ok((rest, op)) = parse_cmp_op(input) {
1050 let (rest, _) = nom_ws0(rest)?;
1051 let (rest, right) = parse_policy_atom(rest)?;
1052 return Ok((
1053 rest,
1054 Expr::Binary {
1055 left: Box::new(left),
1056 op,
1057 right: Box::new(right),
1058 alias: None,
1059 },
1060 ));
1061 }
1062
1063 Ok((input, left))
1065}
1066
1067fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
1069 alt((
1070 map(tag(">="), |_| BinaryOp::Gte),
1071 map(tag("<="), |_| BinaryOp::Lte),
1072 map(tag("<>"), |_| BinaryOp::Ne),
1073 map(tag("!="), |_| BinaryOp::Ne),
1074 map(tag("="), |_| BinaryOp::Eq),
1075 map(tag(">"), |_| BinaryOp::Gt),
1076 map(tag("<"), |_| BinaryOp::Lt),
1077 ))
1078 .parse(input)
1079}
1080
1081fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
1089 alt((
1090 parse_policy_grouped,
1091 parse_policy_bool,
1092 parse_policy_string,
1093 parse_policy_number,
1094 parse_policy_func_or_ident, ))
1096 .parse(input)
1097}
1098
1099fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
1101 let (input, _) = char('(').parse(input)?;
1102 let (input, _) = nom_ws0(input)?;
1103 let (input, expr) = parse_policy_expr(input)?;
1104 let (input, _) = nom_ws0(input)?;
1105 let (input, _) = char(')').parse(input)?;
1106 Ok((input, expr))
1107}
1108
1109fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
1111 alt((
1112 map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
1113 map(tag_no_case("false"), |_| {
1114 Expr::Literal(AstValue::Bool(false))
1115 }),
1116 ))
1117 .parse(input)
1118}
1119
1120fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
1122 let (input, _) = char('\'').parse(input)?;
1123 let mut content = String::new();
1124 let mut iter = input.char_indices().peekable();
1125 while let Some((idx, ch)) = iter.next() {
1126 if ch == '\'' {
1127 if iter.peek().is_some_and(|(_, next)| *next == '\'') {
1128 content.push('\'');
1129 iter.next();
1130 continue;
1131 }
1132 let rest = &input[idx + ch.len_utf8()..];
1133 return Ok((rest, Expr::Literal(AstValue::String(content))));
1134 }
1135 content.push(ch);
1136 }
1137
1138 Err(nom::Err::Error(nom::error::Error::new(
1139 input,
1140 nom::error::ErrorKind::Char,
1141 )))
1142}
1143
1144fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
1146 let original = input;
1147 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
1148 if digits.starts_with('.') || digits.is_empty() {
1150 return Err(nom::Err::Error(nom::error::Error::new(
1151 original,
1152 nom::error::ErrorKind::Digit,
1153 )));
1154 }
1155
1156 if !digits.contains('.') {
1157 return digits
1158 .parse::<i64>()
1159 .map(|n| (input, Expr::Literal(AstValue::Int(n))))
1160 .map_err(|_| {
1161 nom::Err::Error(nom::error::Error::new(
1162 original,
1163 nom::error::ErrorKind::Digit,
1164 ))
1165 });
1166 }
1167
1168 if digits.matches('.').count() > 1 || policy_number_significant_digits(digits) > 15 {
1169 return Err(nom::Err::Error(nom::error::Error::new(
1170 original,
1171 nom::error::ErrorKind::Float,
1172 )));
1173 }
1174
1175 if let Ok(f) = digits.parse::<f64>() {
1176 if f.is_finite() {
1177 Ok((input, Expr::Literal(AstValue::Float(f))))
1178 } else {
1179 Err(nom::Err::Error(nom::error::Error::new(
1180 original,
1181 nom::error::ErrorKind::Float,
1182 )))
1183 }
1184 } else {
1185 Err(nom::Err::Error(nom::error::Error::new(
1186 original,
1187 nom::error::ErrorKind::Float,
1188 )))
1189 }
1190}
1191
1192fn policy_number_significant_digits(value: &str) -> usize {
1193 let mut count = 0;
1194 let mut seen_non_zero = false;
1195
1196 for byte in value.bytes() {
1197 if !byte.is_ascii_digit() {
1198 continue;
1199 }
1200 if byte != b'0' {
1201 seen_non_zero = true;
1202 }
1203 if seen_non_zero {
1204 count += 1;
1205 }
1206 }
1207
1208 count
1209}
1210
1211fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
1213 let original = input;
1214 let (input, name) = identifier(input)?;
1215 if name
1216 .bytes()
1217 .next()
1218 .is_some_and(|byte| byte.is_ascii_digit())
1219 {
1220 return Err(nom::Err::Error(nom::error::Error::new(
1221 original,
1222 nom::error::ErrorKind::Alpha,
1223 )));
1224 }
1225
1226 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
1228 let (rest, _) = nom_ws0(rest)?;
1230 let (rest, args) =
1231 separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
1232 let (rest, _) = nom_ws0(rest)?;
1233 let (rest, _) = char(')').parse(rest)?;
1234 let input = rest;
1235 (
1236 input,
1237 Expr::FunctionCall {
1238 name: name.to_string(),
1239 args,
1240 alias: None,
1241 },
1242 )
1243 } else {
1244 (input, Expr::Named(name.to_string()))
1245 };
1246
1247 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
1249 let (rest, cast_type) = identifier(rest)?;
1250 expr = (
1251 rest,
1252 Expr::Cast {
1253 expr: Box::new(expr.1),
1254 target_type: cast_type.to_string(),
1255 alias: None,
1256 },
1257 );
1258 }
1259
1260 Ok(expr)
1261}
1262
1263fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
1265 let (input, _) = ws_and_comments(input)?;
1266
1267 if let Ok((rest, policy)) = parse_policy(input) {
1269 return Ok((rest, SchemaItem::Policy(Box::new(policy))));
1270 }
1271
1272 if let Ok((rest, idx)) = parse_index(input) {
1274 return Ok((rest, SchemaItem::Index(idx)));
1275 }
1276
1277 let (rest, table) = parse_table(input)?;
1279 Ok((rest, SchemaItem::Table(table)))
1280}
1281
1282fn parse_index(input: &str) -> IResult<&str, IndexDef> {
1284 let (input, _) = tag_no_case("index")(input)?;
1285 let (input, _) = multispace1(input)?;
1286 let (input, name) = identifier(input)?;
1287 let (input, _) = multispace1(input)?;
1288 let (input, _) = tag_no_case("on")(input)?;
1289 let (input, _) = multispace1(input)?;
1290 let (input, table) = identifier(input)?;
1291 let (input, _) = nom_ws0(input)?;
1292 let (input, _) = char('(')(input)?;
1293 let (input, cols_str) = parenthesized_content(input)?;
1294 let (input, _) = nom_ws0(input)?;
1295 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
1296
1297 let columns = split_top_level_csv(cols_str).map_err(|_| {
1298 nom::Err::Error(nom::error::Error::new(
1299 cols_str,
1300 nom::error::ErrorKind::SeparatedList,
1301 ))
1302 })?;
1303 if columns.is_empty() {
1304 return Err(nom::Err::Error(nom::error::Error::new(
1305 cols_str,
1306 nom::error::ErrorKind::SeparatedList,
1307 )));
1308 }
1309
1310 let is_unique = unique_tag.is_some();
1311
1312 Ok((
1313 input,
1314 IndexDef {
1315 name: name.to_string(),
1316 table: table.to_string(),
1317 columns,
1318 unique: is_unique,
1319 },
1320 ))
1321}
1322
1323fn parse_schema(input: &str) -> IResult<&str, Schema> {
1325 let version = extract_version_directive(input);
1327
1328 let (input, items) = many0(parse_schema_item).parse(input)?;
1329 let (input, _) = ws_and_comments(input)?;
1330
1331 let mut tables = Vec::new();
1332 let mut policies = Vec::new();
1333 let mut indexes = Vec::new();
1334 for item in items {
1335 match item {
1336 SchemaItem::Table(t) => tables.push(t),
1337 SchemaItem::Policy(p) => policies.push(*p),
1338 SchemaItem::Index(i) => indexes.push(i),
1339 }
1340 }
1341 if !schema_names_are_unique(&tables, &policies, &indexes) {
1342 return Err(nom::Err::Error(nom::error::Error::new(
1343 input,
1344 nom::error::ErrorKind::Verify,
1345 )));
1346 }
1347
1348 Ok((
1349 input,
1350 Schema {
1351 version,
1352 tables,
1353 policies,
1354 indexes,
1355 },
1356 ))
1357}
1358
1359fn schema_names_are_unique(
1360 tables: &[TableDef],
1361 policies: &[RlsPolicy],
1362 indexes: &[IndexDef],
1363) -> bool {
1364 let mut table_names = HashSet::new();
1365 if tables
1366 .iter()
1367 .any(|table| !table_names.insert(table.name.to_ascii_lowercase()))
1368 {
1369 return false;
1370 }
1371
1372 let mut index_names = HashSet::new();
1373 if indexes
1374 .iter()
1375 .any(|index| !index_names.insert(index.name.to_ascii_lowercase()))
1376 {
1377 return false;
1378 }
1379
1380 let mut policy_names = HashSet::new();
1381 !policies.iter().any(|policy| {
1382 !policy_names.insert((
1383 policy.table.to_ascii_lowercase(),
1384 policy.name.to_ascii_lowercase(),
1385 ))
1386 })
1387}
1388
1389fn extract_version_directive(input: &str) -> Option<u32> {
1391 for line in input.lines() {
1392 let line = line.trim();
1393 if let Some(rest) = line.strip_prefix("-- qail:") {
1394 let rest = rest.trim();
1395 if let Some(version_str) = rest.strip_prefix("version=") {
1396 return version_str.trim().parse().ok();
1397 }
1398 }
1399 }
1400 None
1401}
1402
1403#[cfg(test)]
1404mod tests {
1405 use super::*;
1406
1407 #[test]
1408 fn test_parse_simple_table() {
1409 let input = r#"
1410 table users (
1411 id uuid primary_key,
1412 email text not null,
1413 name text
1414 )
1415 "#;
1416
1417 let schema = Schema::parse(input).expect("parse failed");
1418 assert_eq!(schema.tables.len(), 1);
1419
1420 let users = &schema.tables[0];
1421 assert_eq!(users.name, "users");
1422 assert_eq!(users.columns.len(), 3);
1423
1424 let id = &users.columns[0];
1425 assert_eq!(id.name, "id");
1426 assert_eq!(id.typ, "uuid");
1427 assert!(id.primary_key);
1428 assert!(!id.nullable);
1429
1430 let email = &users.columns[1];
1431 assert_eq!(email.name, "email");
1432 assert!(!email.nullable);
1433
1434 let name = &users.columns[2];
1435 assert!(name.nullable);
1436 }
1437
1438 #[test]
1439 fn test_parse_multiple_tables() {
1440 let input = r#"
1441 -- Users table
1442 table users (
1443 id uuid primary_key,
1444 email text not null unique
1445 )
1446
1447 -- Orders table
1448 table orders (
1449 id uuid primary_key,
1450 user_id uuid references users(id),
1451 total i64 not null default 0
1452 )
1453 "#;
1454
1455 let schema = Schema::parse(input).expect("parse failed");
1456 assert_eq!(schema.tables.len(), 2);
1457
1458 let orders = schema.find_table("orders").expect("orders not found");
1459 let user_id = orders.find_column("user_id").expect("user_id not found");
1460 assert_eq!(user_id.references, Some("users(id)".to_string()));
1461
1462 let total = orders.find_column("total").expect("total not found");
1463 assert_eq!(total.default_value, Some("0".to_string()));
1464 }
1465
1466 #[test]
1467 fn test_parse_comments() {
1468 let input = r#"
1469 -- This is a comment
1470 table foo (
1471 bar text
1472 )
1473 "#;
1474
1475 let schema = Schema::parse(input).expect("parse failed");
1476 assert_eq!(schema.tables.len(), 1);
1477 }
1478
1479 #[test]
1480 fn test_array_types() {
1481 let input = r#"
1482 table products (
1483 id uuid primary_key,
1484 tags text[],
1485 prices decimal[]
1486 )
1487 "#;
1488
1489 let schema = Schema::parse(input).expect("parse failed");
1490 let products = &schema.tables[0];
1491
1492 let tags = products.find_column("tags").expect("tags not found");
1493 assert_eq!(tags.typ, "text");
1494 assert!(tags.is_array);
1495
1496 let prices = products.find_column("prices").expect("prices not found");
1497 assert!(prices.is_array);
1498 }
1499
1500 #[test]
1501 fn test_type_params() {
1502 let input = r#"
1503 table items (
1504 id serial primary_key,
1505 name varchar(255) not null,
1506 price decimal(10,2),
1507 code varchar(50) unique
1508 )
1509 "#;
1510
1511 let schema = Schema::parse(input).expect("parse failed");
1512 let items = &schema.tables[0];
1513
1514 let id = items.find_column("id").expect("id not found");
1515 assert!(id.is_serial);
1516 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
1519 assert_eq!(name.typ, "varchar");
1520 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1521
1522 let price = items.find_column("price").expect("price not found");
1523 assert_eq!(
1524 price.type_params,
1525 Some(vec!["10".to_string(), "2".to_string()])
1526 );
1527
1528 let code = items.find_column("code").expect("code not found");
1529 assert!(code.unique);
1530 }
1531
1532 #[test]
1533 fn test_rejects_invalid_identifiers_in_schema_shapes() {
1534 for input in [
1535 "table 1users (id uuid)",
1536 "table users (1id uuid)",
1537 "table users (id 1uuid)",
1538 "index 1idx on users (id)",
1539 "index idx on 1users (id)",
1540 "policy 1policy on users using (id = 1)",
1541 "policy users_policy on 1users using (id = 1)",
1542 ] {
1543 Schema::parse(input).expect_err("invalid identifier must fail");
1544 }
1545 }
1546
1547 #[test]
1548 fn test_rejects_empty_tables_and_duplicate_schema_objects() {
1549 for input in [
1550 "table empty ()",
1551 "table users (id uuid, id text)",
1552 "table users (id uuid)\ntable users (id text)",
1553 "index idx_users on users (id)\nindex idx_users on users (email)",
1554 "policy users_filter on users using (id = 1)\npolicy users_filter on users using (id = 2)",
1555 ] {
1556 Schema::parse(input).expect_err("duplicate or empty schema object must fail");
1557 }
1558 }
1559
1560 #[test]
1561 fn test_rejects_empty_type_parameters() {
1562 for input in [
1563 "table invoices (amount decimal())",
1564 "table invoices (amount decimal(10,))",
1565 "table invoices (amount decimal(,2))",
1566 "table invoices (amount decimal(10,,2))",
1567 ] {
1568 Schema::parse(input).expect_err("empty type parameter must fail");
1569 }
1570 }
1571
1572 #[test]
1573 fn test_custom_type_names_with_underscores_and_schema() {
1574 let input = r#"
1575 table bookings (
1576 id uuid primary_key,
1577 status booking_status not null,
1578 gateway_state integrations.payment_state[]
1579 )
1580 "#;
1581
1582 let schema = Schema::parse(input).expect("parse failed");
1583 let bookings = &schema.tables[0];
1584
1585 let status = bookings.find_column("status").expect("status not found");
1586 assert_eq!(status.typ, "booking_status");
1587 assert!(!status.nullable);
1588
1589 let gateway_state = bookings
1590 .find_column("gateway_state")
1591 .expect("gateway_state not found");
1592 assert_eq!(gateway_state.typ, "integrations.payment_state");
1593 assert!(gateway_state.is_array);
1594 }
1595
1596 #[test]
1597 fn test_malformed_type_params_return_parse_error_without_panic() {
1598 let input = "table invoices ( amount decimal(";
1599
1600 let result = std::panic::catch_unwind(|| Schema::parse(input));
1601
1602 assert!(result.is_ok());
1603 assert!(result.unwrap().is_err());
1604 }
1605
1606 #[test]
1607 fn test_check_constraint() {
1608 let input = r#"
1609 table employees (
1610 id uuid primary_key,
1611 age i32 check(age >= 18),
1612 salary decimal check(salary > 0)
1613 )
1614 "#;
1615
1616 let schema = Schema::parse(input).expect("parse failed");
1617 let employees = &schema.tables[0];
1618
1619 let age = employees.find_column("age").expect("age not found");
1620 assert_eq!(age.check, Some("age >= 18".to_string()));
1621
1622 let salary = employees.find_column("salary").expect("salary not found");
1623 assert_eq!(salary.check, Some("salary > 0".to_string()));
1624 }
1625
1626 #[test]
1627 fn test_default_expression_with_spaces() {
1628 let input = r#"
1629 table messages (
1630 id uuid primary_key,
1631 title text default 'new user' not null,
1632 expires_at timestamp default (now() + interval '1 day')
1633 )
1634 "#;
1635
1636 let schema = Schema::parse(input).expect("parse failed");
1637 let messages = &schema.tables[0];
1638
1639 let title = messages.find_column("title").expect("title not found");
1640 assert_eq!(title.default_value, Some("'new user'".to_string()));
1641 assert!(!title.nullable);
1642
1643 let expires_at = messages
1644 .find_column("expires_at")
1645 .expect("expires_at not found");
1646 assert_eq!(
1647 expires_at.default_value,
1648 Some("(now() + interval '1 day')".to_string())
1649 );
1650 }
1651
1652 #[test]
1653 fn test_constraints_handle_quoted_commas_and_parens() {
1654 let input = r#"
1655 table messages (
1656 id uuid primary_key,
1657 title text default 'hello, world' not null,
1658 tag text check(tag in ('a,b', 'c)')),
1659 note text default 'paren ) and comma, still literal'
1660 )
1661 "#;
1662
1663 let schema = Schema::parse(input).expect("parse failed");
1664 let messages = &schema.tables[0];
1665 assert_eq!(messages.columns.len(), 4);
1666
1667 let title = messages.find_column("title").expect("title not found");
1668 assert_eq!(title.default_value, Some("'hello, world'".to_string()));
1669 assert!(!title.nullable);
1670
1671 let tag = messages.find_column("tag").expect("tag not found");
1672 assert_eq!(tag.check, Some("tag in ('a,b', 'c)')".to_string()));
1673
1674 let note = messages.find_column("note").expect("note not found");
1675 assert_eq!(
1676 note.default_value,
1677 Some("'paren ) and comma, still literal'".to_string())
1678 );
1679 }
1680
1681 #[test]
1682 fn test_constraint_keywords_inside_literals_do_not_become_constraints() {
1683 let input = r#"
1684 table messages (
1685 plain text default 'unique not null primary key references users(id) check(x)',
1686 fn_default text default unique_label(),
1687 guarded text check(note = 'unique not null primary key')
1688 )
1689 "#;
1690
1691 let schema = Schema::parse(input).expect("parse failed");
1692 let messages = &schema.tables[0];
1693
1694 let plain = messages.find_column("plain").expect("plain not found");
1695 assert_eq!(
1696 plain.default_value.as_deref(),
1697 Some("'unique not null primary key references users(id) check(x)'")
1698 );
1699 assert!(!plain.unique);
1700 assert!(plain.nullable);
1701 assert!(!plain.primary_key);
1702 assert!(plain.references.is_none());
1703 assert!(plain.check.is_none());
1704
1705 let fn_default = messages
1706 .find_column("fn_default")
1707 .expect("fn_default not found");
1708 assert_eq!(fn_default.default_value.as_deref(), Some("unique_label()"));
1709 assert!(!fn_default.unique);
1710
1711 let guarded = messages.find_column("guarded").expect("guarded not found");
1712 assert_eq!(
1713 guarded.check.as_deref(),
1714 Some("note = 'unique not null primary key'")
1715 );
1716 assert!(!guarded.unique);
1717 assert!(guarded.nullable);
1718 assert!(!guarded.primary_key);
1719 }
1720
1721 #[test]
1722 fn test_rejects_malformed_column_constraints() {
1723 for input in [
1724 "table bad (name text default)",
1725 "table bad (user_id uuid references)",
1726 "table bad (age int check())",
1727 "table bad (age int check(age > 0)",
1728 "table bad (name text unique unique)",
1729 "table bad (id uuid primary_key primary key)",
1730 "table bad (user_id uuid references users(id) references accounts(id))",
1731 ] {
1732 Schema::parse(input).expect_err("malformed column constraint must fail");
1733 }
1734 }
1735
1736 #[test]
1737 fn test_index_columns_handle_nested_expression_commas() {
1738 let input = r#"
1739 table docs (
1740 id uuid primary_key,
1741 title text,
1742 slug text
1743 )
1744
1745 index idx_docs_search on docs (regexp_replace(title, ')', '', 'g'), lower(slug)) unique
1746 "#;
1747
1748 let schema = Schema::parse(input).expect("parse failed");
1749 assert_eq!(schema.indexes.len(), 1);
1750 let index = &schema.indexes[0];
1751 assert_eq!(index.name, "idx_docs_search");
1752 assert_eq!(
1753 index.columns,
1754 vec![
1755 "regexp_replace(title, ')', '', 'g')".to_string(),
1756 "lower(slug)".to_string()
1757 ]
1758 );
1759 assert!(index.unique);
1760 assert_eq!(
1761 index.to_sql(),
1762 "CREATE UNIQUE INDEX IF NOT EXISTS idx_docs_search ON docs (regexp_replace(title, ')', '', 'g'), lower(slug))"
1763 );
1764 }
1765
1766 #[test]
1767 fn test_index_rejects_empty_columns() {
1768 for input in [
1769 "index idx_docs_search on docs ()",
1770 "index idx_docs_search on docs (title,)",
1771 "index idx_docs_search on docs (,title)",
1772 "index idx_docs_search on docs (title,,slug)",
1773 ] {
1774 let err = Schema::parse(input).expect_err("empty index columns should fail");
1775 assert!(
1776 err.contains("Parse error") || err.contains("Unexpected content"),
1777 "{err}"
1778 );
1779 }
1780 }
1781
1782 #[test]
1783 fn test_version_directive() {
1784 let input = r#"
1785 -- qail: version=1
1786 table users (
1787 id uuid primary_key
1788 )
1789 "#;
1790
1791 let schema = Schema::parse(input).expect("parse failed");
1792 assert_eq!(schema.version, Some(1));
1793 assert_eq!(schema.tables.len(), 1);
1794
1795 let input_no_version = r#"
1797 table items (
1798 id uuid primary_key
1799 )
1800 "#;
1801 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1802 assert_eq!(schema2.version, None);
1803 }
1804
1805 #[test]
1810 fn test_enable_rls_table() {
1811 let input = r#"
1812 table orders (
1813 id uuid primary_key,
1814 tenant_id uuid not null
1815 ) enable_rls
1816 "#;
1817
1818 let schema = Schema::parse(input).expect("parse failed");
1819 assert_eq!(schema.tables.len(), 1);
1820 assert!(schema.tables[0].enable_rls);
1821 }
1822
1823 #[test]
1824 fn test_parse_policy_basic() {
1825 let input = r#"
1826 table orders (
1827 id uuid primary_key,
1828 tenant_id uuid not null
1829 ) enable_rls
1830
1831 policy orders_isolation on orders
1832 for all
1833 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
1834 "#;
1835
1836 let schema = Schema::parse(input).expect("parse failed");
1837 assert_eq!(schema.tables.len(), 1);
1838 assert_eq!(schema.policies.len(), 1);
1839
1840 let policy = &schema.policies[0];
1841 assert_eq!(policy.name, "orders_isolation");
1842 assert_eq!(policy.table, "orders");
1843 assert_eq!(policy.target, PolicyTarget::All);
1844 assert!(policy.using.is_some());
1845
1846 let using = policy.using.as_ref().unwrap();
1848 let Expr::Binary {
1849 left, op, right, ..
1850 } = using
1851 else {
1852 panic!("Expected Binary, got {using:?}");
1853 };
1854 assert_eq!(*op, BinaryOp::Eq);
1855
1856 let Expr::Named(n) = left.as_ref() else {
1857 panic!("Expected Named, got {left:?}");
1858 };
1859 assert_eq!(n, "tenant_id");
1860
1861 let Expr::Cast {
1862 target_type,
1863 expr: cast_expr,
1864 ..
1865 } = right.as_ref()
1866 else {
1867 panic!("Expected Cast, got {right:?}");
1868 };
1869 assert_eq!(target_type, "uuid");
1870
1871 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1872 panic!("Expected FunctionCall, got {cast_expr:?}");
1873 };
1874 assert_eq!(name, "current_setting");
1875 assert_eq!(args.len(), 1);
1876 }
1877
1878 #[test]
1879 fn test_parse_policy_with_check() {
1880 let input = r#"
1881 table orders (
1882 id uuid primary_key
1883 )
1884
1885 policy orders_write on orders
1886 for insert
1887 with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
1888 "#;
1889
1890 let schema = Schema::parse(input).expect("parse failed");
1891 let policy = &schema.policies[0];
1892 assert_eq!(policy.target, PolicyTarget::Insert);
1893 assert!(policy.with_check.is_some());
1894 assert!(policy.using.is_none());
1895 }
1896
1897 #[test]
1898 fn test_parse_policy_restrictive_with_role() {
1899 let input = r#"
1900 table secrets (
1901 id uuid primary_key
1902 )
1903
1904 policy admin_only on secrets
1905 for select
1906 restrictive
1907 to app_user
1908 using (current_setting('app.is_super_admin')::boolean = true)
1909 "#;
1910
1911 let schema = Schema::parse(input).expect("parse failed");
1912 let policy = &schema.policies[0];
1913 assert_eq!(policy.target, PolicyTarget::Select);
1914 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1915 assert_eq!(policy.role.as_deref(), Some("app_user"));
1916 assert!(policy.using.is_some());
1917 }
1918
1919 #[test]
1920 fn test_parse_policy_or_expr() {
1921 let input = r#"
1922 table orders (
1923 id uuid primary_key
1924 )
1925
1926 policy tenant_or_admin on orders
1927 for all
1928 using (tenant_id = current_setting('app.current_tenant_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1929 "#;
1930
1931 let schema = Schema::parse(input).expect("parse failed");
1932 let policy = &schema.policies[0];
1933
1934 assert!(
1935 matches!(
1936 policy.using.as_ref().unwrap(),
1937 Expr::Binary {
1938 op: BinaryOp::Or,
1939 ..
1940 }
1941 ),
1942 "Expected Binary OR, got {:?}",
1943 policy.using
1944 );
1945 }
1946
1947 #[test]
1948 fn test_parse_policy_string_literals_escape_and_fail_closed() {
1949 let input = r#"
1950 table users (
1951 id uuid primary_key,
1952 name text
1953 )
1954
1955 policy users_name on users
1956 for select
1957 using (name = 'Bob''s account')
1958 "#;
1959 let schema = Schema::parse(input).expect("escaped quote string should parse");
1960 let Expr::Binary { right, .. } = schema.policies[0].using.as_ref().unwrap() else {
1961 panic!("expected binary expression");
1962 };
1963 assert!(matches!(
1964 right.as_ref(),
1965 Expr::Literal(AstValue::String(value)) if value == "Bob's account"
1966 ));
1967
1968 let input = r#"
1969 table users (
1970 id uuid primary_key,
1971 name text
1972 )
1973
1974 policy users_name on users
1975 for select
1976 using (name = 'unterminated)
1977 "#;
1978 Schema::parse(input).expect_err("unterminated policy string must fail");
1979 }
1980
1981 #[test]
1982 fn test_parse_policy_rejects_duplicate_clauses() {
1983 for input in [
1984 r#"
1985 table orders (id uuid primary_key)
1986 policy p on orders for select for update using (id = 1)
1987 "#,
1988 r#"
1989 table orders (id uuid primary_key)
1990 policy p on orders to app_user to app_admin using (id = 1)
1991 "#,
1992 r#"
1993 table orders (id uuid primary_key)
1994 policy p on orders restrictive restrictive using (id = 1)
1995 "#,
1996 r#"
1997 table orders (id uuid primary_key)
1998 policy p on orders using (id = 1) using (id = 2)
1999 "#,
2000 r#"
2001 table orders (id uuid primary_key)
2002 policy p on orders with check (id = 1) with check (id = 2)
2003 "#,
2004 ] {
2005 Schema::parse(input).expect_err("duplicate policy clause must fail");
2006 }
2007 }
2008
2009 #[test]
2010 fn test_parse_policy_and_has_higher_precedence_than_or() {
2011 let input = r#"
2012 table orders (
2013 id uuid primary_key,
2014 tenant_id uuid,
2015 active bool,
2016 public bool
2017 )
2018
2019 policy mixed on orders
2020 for select
2021 using (public = true or tenant_id = 7 and active = true)
2022 "#;
2023
2024 let schema = Schema::parse(input).expect("parse failed");
2025 let Expr::Binary {
2026 op: BinaryOp::Or,
2027 right,
2028 ..
2029 } = schema.policies[0].using.as_ref().unwrap()
2030 else {
2031 panic!("expected top-level OR");
2032 };
2033 assert!(matches!(
2034 right.as_ref(),
2035 Expr::Binary {
2036 op: BinaryOp::And,
2037 ..
2038 }
2039 ));
2040 }
2041
2042 #[test]
2043 fn test_parse_policy_rejects_invalid_numeric_literals() {
2044 let huge = "999999999999999999999999999999999999999999999999999999999999999999";
2045 let input = format!(
2046 r#"
2047 table orders (
2048 id uuid primary_key,
2049 amount numeric
2050 )
2051
2052 policy amount_guard on orders
2053 for select
2054 using (amount = {huge})
2055 "#
2056 );
2057 assert!(Schema::parse(&input).is_err());
2058
2059 let input = r#"
2060 table orders (
2061 id uuid primary_key,
2062 amount numeric
2063 )
2064
2065 policy amount_guard on orders
2066 for select
2067 using (amount = 1.2.3)
2068 "#;
2069 assert!(Schema::parse(input).is_err());
2070
2071 let input = r#"
2072 table orders (
2073 id uuid primary_key,
2074 amount numeric
2075 )
2076
2077 policy amount_guard on orders
2078 for select
2079 using (amount = 9007199254740993.25)
2080 "#;
2081 assert!(Schema::parse(input).is_err());
2082 }
2083
2084 #[test]
2085 fn test_schema_to_sql() {
2086 let input = r#"
2087 table orders (
2088 id uuid primary_key,
2089 tenant_id uuid not null
2090 ) enable_rls
2091
2092 policy orders_isolation on orders
2093 for all
2094 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
2095 "#;
2096
2097 let schema = Schema::parse(input).expect("parse failed");
2098 let sql = schema.to_sql();
2099 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
2100 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
2101 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
2102 assert!(sql.contains("CREATE POLICY"));
2103 assert!(sql.contains("orders_isolation"));
2104 assert!(sql.contains("FOR ALL"));
2105 }
2106
2107 #[test]
2108 fn test_multiple_policies() {
2109 let input = r#"
2110 table orders (
2111 id uuid primary_key,
2112 tenant_id uuid not null
2113 ) enable_rls
2114
2115 policy orders_read on orders
2116 for select
2117 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
2118
2119 policy orders_write on orders
2120 for insert
2121 with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
2122 "#;
2123
2124 let schema = Schema::parse(input).expect("parse failed");
2125 assert_eq!(schema.policies.len(), 2);
2126 assert_eq!(schema.policies[0].name, "orders_read");
2127 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
2128 assert_eq!(schema.policies[1].name, "orders_write");
2129 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
2130 }
2131}