1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26
27use crate::ast::{BinaryOp, Expr, Value as AstValue};
28use crate::migrate::alter::AlterTable;
29use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
30use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
31
32#[derive(Debug, Clone, Default)]
34pub struct Schema {
35 pub version: Option<u32>,
37 pub tables: Vec<TableDef>,
39 pub policies: Vec<RlsPolicy>,
41 pub indexes: Vec<IndexDef>,
43}
44
45#[derive(Debug, Clone)]
47pub struct IndexDef {
48 pub name: String,
50 pub table: String,
52 pub columns: Vec<String>,
54 pub unique: bool,
56}
57
58impl IndexDef {
59 pub fn to_sql(&self) -> String {
61 let unique = if self.unique { " UNIQUE" } else { "" };
62 format!(
63 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64 unique,
65 self.name,
66 self.table,
67 self.columns.join(", ")
68 )
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct TableDef {
75 pub name: String,
77 pub columns: Vec<ColumnDef>,
79 pub enable_rls: bool,
81}
82
83#[derive(Debug, Clone)]
85pub struct ColumnDef {
86 pub name: String,
88 pub typ: String,
90 pub is_array: bool,
92 pub type_params: Option<Vec<String>>,
94 pub nullable: bool,
96 pub primary_key: bool,
98 pub unique: bool,
100 pub references: Option<String>,
102 pub default_value: Option<String>,
104 pub check: Option<String>,
106 pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111 fn default() -> Self {
112 Self {
113 name: String::new(),
114 typ: String::new(),
115 is_array: false,
116 type_params: None,
117 nullable: true,
118 primary_key: false,
119 unique: false,
120 references: None,
121 default_value: None,
122 check: None,
123 is_serial: false,
124 }
125 }
126}
127
128impl Schema {
129 pub fn parse(input: &str) -> Result<Self, String> {
131 match parse_schema(input) {
132 Ok(("", schema)) => Ok(schema),
133 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134 Err(e) => Err(format!("Parse error: {:?}", e)),
135 }
136 }
137
138 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140 self.tables
141 .iter()
142 .find(|t| t.name.eq_ignore_ascii_case(name))
143 }
144
145 pub fn to_sql(&self) -> String {
147 let mut parts = Vec::new();
148
149 for table in &self.tables {
150 parts.push(table.to_ddl());
151
152 if table.enable_rls {
153 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154 for stmt in alter_table_sql(&alter) {
155 parts.push(stmt);
156 }
157 }
158 }
159
160 for idx in &self.indexes {
161 parts.push(idx.to_sql());
162 }
163
164 for policy in &self.policies {
165 parts.push(create_policy_sql(policy));
166 }
167
168 parts.join(";\n\n") + ";"
169 }
170
171 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
173 let content =
174 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
175 Self::parse(&content)
176 }
177}
178
179impl TableDef {
180 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
182 self.columns
183 .iter()
184 .find(|c| c.name.eq_ignore_ascii_case(name))
185 }
186
187 pub fn to_ddl(&self) -> String {
189 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
190
191 let mut col_defs = Vec::new();
192 for col in &self.columns {
193 let mut line = format!(" {}", col.name);
194
195 let mut typ = col.typ.to_uppercase();
197 if let Some(params) = &col.type_params {
198 typ = format!("{}({})", typ, params.join(", "));
199 }
200 if col.is_array {
201 typ.push_str("[]");
202 }
203 line.push_str(&format!(" {}", typ));
204
205 if col.primary_key {
207 line.push_str(" PRIMARY KEY");
208 }
209 if !col.nullable && !col.primary_key && !col.is_serial {
210 line.push_str(" NOT NULL");
211 }
212 if col.unique && !col.primary_key {
213 line.push_str(" UNIQUE");
214 }
215 if let Some(ref default) = col.default_value {
216 line.push_str(&format!(" DEFAULT {}", default));
217 }
218 if let Some(ref refs) = col.references {
219 line.push_str(&format!(" REFERENCES {}", refs));
220 }
221 if let Some(ref check) = col.check {
222 line.push_str(&format!(" CHECK({})", check));
223 }
224
225 col_defs.push(line);
226 }
227
228 sql.push_str(&col_defs.join(",\n"));
229 sql.push_str("\n)");
230 sql
231 }
232}
233
234fn identifier(input: &str) -> IResult<&str, &str> {
240 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
241}
242
243fn ws_and_comments(input: &str) -> IResult<&str, ()> {
245 let (input, _) = many0(alt((
246 map(multispace1, |_| ()),
247 map((tag("--"), not_line_ending), |_| ()),
248 map((tag("#"), not_line_ending), |_| ()),
249 )))
250 .parse(input)?;
251 Ok((input, ()))
252}
253
254struct TypeInfo {
255 name: String,
256 params: Option<Vec<String>>,
257 is_array: bool,
258 is_serial: bool,
259}
260
261fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
264 let (input, type_name) =
265 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)?;
266
267 let (input, params) = if let Some(after_open) = input.strip_prefix('(') {
268 let Some(paren_end) = after_open.find(')') else {
269 return Err(nom::Err::Error(nom::error::Error::new(
270 input,
271 nom::error::ErrorKind::Char,
272 )));
273 };
274 let param_str = &after_open[..paren_end];
275 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
276 (&after_open[paren_end + 1..], Some(params))
277 } else {
278 (input, None)
279 };
280
281 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
282 (stripped, true)
283 } else {
284 (input, false)
285 };
286
287 let lower = type_name.to_lowercase();
288 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
289
290 Ok((
291 input,
292 TypeInfo {
293 name: lower,
294 params,
295 is_array,
296 is_serial,
297 },
298 ))
299}
300
301fn constraint_text(input: &str) -> IResult<&str, &str> {
303 let mut paren_depth = 0;
304 let mut in_single = false;
305 let mut in_double = false;
306 let mut end = 0;
307 let mut iter = input.char_indices().peekable();
308
309 while let Some((i, c)) = iter.next() {
310 match c {
311 '\'' if !in_double => {
312 if in_single && matches!(iter.peek(), Some((_, '\''))) {
313 iter.next();
314 } else {
315 in_single = !in_single;
316 }
317 }
318 '"' if !in_single => {
319 if in_double && matches!(iter.peek(), Some((_, '"'))) {
320 iter.next();
321 } else {
322 in_double = !in_double;
323 }
324 }
325 '(' if !in_single && !in_double => paren_depth += 1,
326 ')' if !in_single && !in_double => {
327 if paren_depth == 0 {
328 break; }
330 paren_depth -= 1;
331 }
332 ',' if !in_single && !in_double && paren_depth == 0 => break,
333 '\n' | '\r' if !in_single && !in_double && paren_depth == 0 => break,
334 _ => {}
335 }
336 end = i + c.len_utf8();
337 }
338
339 if end == 0 {
340 Err(nom::Err::Error(nom::error::Error::new(
341 input,
342 nom::error::ErrorKind::TakeWhile1,
343 )))
344 } else {
345 Ok((&input[end..], &input[..end]))
346 }
347}
348
349fn check_expr_end(rest: &str) -> usize {
350 let mut depth = 1usize;
351 let mut in_single = false;
352 let mut in_double = false;
353 let mut iter = rest.char_indices().peekable();
354
355 while let Some((idx, ch)) = iter.next() {
356 match ch {
357 '\'' if !in_double => {
358 if in_single && matches!(iter.peek(), Some((_, '\''))) {
359 iter.next();
360 } else {
361 in_single = !in_single;
362 }
363 }
364 '"' if !in_single => {
365 if in_double && matches!(iter.peek(), Some((_, '"'))) {
366 iter.next();
367 } else {
368 in_double = !in_double;
369 }
370 }
371 '(' if !in_single && !in_double => depth += 1,
372 ')' if !in_single && !in_double => {
373 depth -= 1;
374 if depth == 0 {
375 return idx;
376 }
377 }
378 _ => {}
379 }
380 }
381
382 rest.len()
383}
384
385fn parenthesized_content(input: &str) -> IResult<&str, &str> {
386 let mut paren_depth = 0usize;
387 let mut in_single = false;
388 let mut in_double = false;
389 let mut iter = input.char_indices().peekable();
390
391 while let Some((idx, ch)) = iter.next() {
392 match ch {
393 '\'' if !in_double => {
394 if in_single && matches!(iter.peek(), Some((_, '\''))) {
395 iter.next();
396 } else {
397 in_single = !in_single;
398 }
399 }
400 '"' if !in_single => {
401 if in_double && matches!(iter.peek(), Some((_, '"'))) {
402 iter.next();
403 } else {
404 in_double = !in_double;
405 }
406 }
407 '(' if !in_single && !in_double => paren_depth += 1,
408 ')' if !in_single && !in_double => {
409 if paren_depth == 0 {
410 return Ok((&input[idx + ch.len_utf8()..], &input[..idx]));
411 }
412 paren_depth -= 1;
413 }
414 _ => {}
415 }
416 }
417
418 Err(nom::Err::Error(nom::error::Error::new(
419 input,
420 nom::error::ErrorKind::Char,
421 )))
422}
423
424fn split_top_level_csv(input: &str) -> Result<Vec<String>, ()> {
425 let mut parts = Vec::new();
426 let mut start = 0usize;
427 let mut paren_depth = 0usize;
428 let mut in_single = false;
429 let mut in_double = false;
430 let mut iter = input.char_indices().peekable();
431
432 while let Some((idx, ch)) = iter.next() {
433 match ch {
434 '\'' if !in_double => {
435 if in_single && matches!(iter.peek(), Some((_, '\''))) {
436 iter.next();
437 } else {
438 in_single = !in_single;
439 }
440 }
441 '"' if !in_single => {
442 if in_double && matches!(iter.peek(), Some((_, '"'))) {
443 iter.next();
444 } else {
445 in_double = !in_double;
446 }
447 }
448 '(' if !in_single && !in_double => paren_depth += 1,
449 ')' if !in_single && !in_double => {
450 if paren_depth == 0 {
451 return Err(());
452 }
453 paren_depth -= 1;
454 }
455 ',' if !in_single && !in_double && paren_depth == 0 => {
456 let part = input[start..idx].trim();
457 if part.is_empty() {
458 return Err(());
459 }
460 parts.push(part.to_string());
461 start = idx + ch.len_utf8();
462 }
463 _ => {}
464 }
465 }
466
467 if in_single || in_double || paren_depth != 0 {
468 return Err(());
469 }
470 let part = input[start..].trim();
471 if part.is_empty() {
472 if !input.trim().is_empty() {
473 return Err(());
474 }
475 } else {
476 parts.push(part.to_string());
477 }
478
479 Ok(parts)
480}
481
482fn starts_constraint_keyword(input: &str) -> bool {
483 let lower = input.to_ascii_lowercase();
484 matches!(
485 lower.as_str(),
486 s if s.starts_with("primary_key")
487 || s.starts_with("primary key")
488 || s.starts_with("not_null")
489 || s.starts_with("not null")
490 || s.starts_with("unique")
491 || s.starts_with("references ")
492 || s.starts_with("check(")
493 )
494}
495
496fn default_expr_end(rest: &str) -> usize {
497 let mut in_single = false;
498 let mut in_double = false;
499 let mut paren_depth = 0usize;
500 let mut iter = rest.char_indices().peekable();
501
502 while let Some((idx, ch)) = iter.next() {
503 match ch {
504 '\'' if !in_double => {
505 if in_single && matches!(iter.peek(), Some((_, '\''))) {
506 iter.next();
507 } else {
508 in_single = !in_single;
509 }
510 }
511 '"' if !in_single => {
512 if in_double && matches!(iter.peek(), Some((_, '"'))) {
513 iter.next();
514 } else {
515 in_double = !in_double;
516 }
517 }
518 '(' if !in_single && !in_double => paren_depth += 1,
519 ')' if !in_single && !in_double && paren_depth > 0 => paren_depth -= 1,
520 c if c.is_whitespace()
521 && !in_single
522 && !in_double
523 && paren_depth == 0
524 && starts_constraint_keyword(rest[idx..].trim_start()) =>
525 {
526 return idx;
527 }
528 _ => {}
529 }
530 }
531
532 rest.len()
533}
534
535fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
537 let (input, _) = ws_and_comments(input)?;
538 let (input, name) = identifier(input)?;
539 let (input, _) = multispace1(input)?;
540 let (input, type_info) = parse_type_info(input)?;
541
542 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
543
544 let mut col = ColumnDef {
545 name: name.to_string(),
546 typ: type_info.name,
547 is_array: type_info.is_array,
548 type_params: type_info.params,
549 is_serial: type_info.is_serial,
550 nullable: !type_info.is_serial, ..Default::default()
552 };
553
554 if let Some(constraints) = constraint_str {
555 let lower = constraints.to_lowercase();
556
557 if lower.contains("primary_key") || lower.contains("primary key") {
558 col.primary_key = true;
559 col.nullable = false;
560 }
561 if lower.contains("not_null") || lower.contains("not null") {
562 col.nullable = false;
563 }
564 if lower.contains("unique") {
565 col.unique = true;
566 }
567
568 if let Some(idx) = lower.find("references ") {
569 let rest = &constraints[idx + 11..];
570 let mut paren_depth = 0;
572 let mut end = rest.len();
573 for (i, c) in rest.char_indices() {
574 match c {
575 '(' => paren_depth += 1,
576 ')' => {
577 if paren_depth == 0 {
578 end = i;
579 break;
580 }
581 paren_depth -= 1;
582 }
583 c if c.is_whitespace() && paren_depth == 0 => {
584 end = i;
585 break;
586 }
587 _ => {}
588 }
589 }
590 col.references = Some(rest[..end].to_string());
591 }
592
593 if let Some(idx) = lower.find("default ") {
594 let rest = &constraints[idx + 8..];
595 let end = default_expr_end(rest);
596 col.default_value = Some(rest[..end].to_string());
597 }
598
599 if let Some(idx) = lower.find("check(") {
600 let rest = &constraints[idx + 6..];
601 let end = check_expr_end(rest);
602 col.check = Some(rest[..end].to_string());
603 }
604 }
605
606 Ok((input, col))
607}
608
609fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
611 let (input, _) = ws_and_comments(input)?;
612 let (input, _) = char('(').parse(input)?;
613 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
614 let (input, _) = ws_and_comments(input)?;
615 let (input, _) = char(')').parse(input)?;
616
617 Ok((input, columns))
618}
619
620fn parse_table(input: &str) -> IResult<&str, TableDef> {
622 let (input, _) = ws_and_comments(input)?;
623 let (input, _) = tag_no_case("table").parse(input)?;
624 let (input, _) = multispace1(input)?;
625 let (input, name) = identifier(input)?;
626 let (input, columns) = parse_column_list(input)?;
627
628 let (input, _) = ws_and_comments(input)?;
630 let enable_rls = if let Ok((rest, _)) =
631 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
632 {
633 return Ok((
634 rest,
635 TableDef {
636 name: name.to_string(),
637 columns,
638 enable_rls: true,
639 },
640 ));
641 } else {
642 false
643 };
644
645 Ok((
646 input,
647 TableDef {
648 name: name.to_string(),
649 columns,
650 enable_rls,
651 },
652 ))
653}
654
655enum SchemaItem {
661 Table(TableDef),
662 Policy(Box<RlsPolicy>),
663 Index(IndexDef),
664}
665
666fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
678 let (input, _) = ws_and_comments(input)?;
679 let (input, _) = tag_no_case("policy").parse(input)?;
680 let (input, _) = multispace1(input)?;
681 let (input, name) = identifier(input)?;
682 let (input, _) = multispace1(input)?;
683 let (input, _) = tag_no_case("on").parse(input)?;
684 let (input, _) = multispace1(input)?;
685 let (input, table) = identifier(input)?;
686
687 let mut policy = RlsPolicy::create(name, table);
688
689 let mut remaining = input;
691 loop {
692 let (input, _) = ws_and_comments(remaining)?;
693
694 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
696 let (rest, _) = multispace1(rest)?;
697 let (rest, target) = alt((
698 map(tag_no_case("all"), |_| PolicyTarget::All),
699 map(tag_no_case("select"), |_| PolicyTarget::Select),
700 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
701 map(tag_no_case("update"), |_| PolicyTarget::Update),
702 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
703 ))
704 .parse(rest)?;
705 policy.target = target;
706 remaining = rest;
707 continue;
708 }
709
710 if let Ok((rest, _)) =
712 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
713 {
714 policy.permissiveness = PolicyPermissiveness::Restrictive;
715 remaining = rest;
716 continue;
717 }
718
719 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
721 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
723 let (rest, role) = identifier(rest)?;
724 policy.role = Some(role.to_string());
725 remaining = rest;
726 continue;
727 }
728 }
729
730 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
732 let (rest, _) = multispace1(rest)?;
733 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
734 {
735 let (rest, _) = nom_ws0(rest)?;
736 let (rest, _) = char('(').parse(rest)?;
737 let (rest, _) = nom_ws0(rest)?;
738 let (rest, expr) = parse_policy_expr(rest)?;
739 let (rest, _) = nom_ws0(rest)?;
740 let (rest, _) = char(')').parse(rest)?;
741 policy.with_check = Some(expr);
742 remaining = rest;
743 continue;
744 }
745 }
746
747 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
749 let (rest, _) = nom_ws0(rest)?;
750 let (rest, _) = char('(').parse(rest)?;
751 let (rest, _) = nom_ws0(rest)?;
752 let (rest, expr) = parse_policy_expr(rest)?;
753 let (rest, _) = nom_ws0(rest)?;
754 let (rest, _) = char(')').parse(rest)?;
755 policy.using = Some(expr);
756 remaining = rest;
757 continue;
758 }
759
760 remaining = input;
762 break;
763 }
764
765 Ok((remaining, policy))
766}
767
768fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
777 let (input, first) = parse_policy_comparison(input)?;
778
779 let mut result = first;
781 let mut remaining = input;
782 loop {
783 let (input, _) = nom_ws0(remaining)?;
784
785 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
786 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
787 {
788 let (rest, right) = parse_policy_comparison(rest)?;
789 result = Expr::Binary {
790 left: Box::new(result),
791 op: BinaryOp::Or,
792 right: Box::new(right),
793 alias: None,
794 };
795 remaining = rest;
796 continue;
797 }
798
799 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
800 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
801 {
802 let (rest, right) = parse_policy_comparison(rest)?;
803 result = Expr::Binary {
804 left: Box::new(result),
805 op: BinaryOp::And,
806 right: Box::new(right),
807 alias: None,
808 };
809 remaining = rest;
810 continue;
811 }
812
813 remaining = input;
814 break;
815 }
816
817 Ok((remaining, result))
818}
819
820fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
822 let (input, left) = parse_policy_atom(input)?;
823 let (input, _) = nom_ws0(input)?;
824
825 if let Ok((rest, op)) = parse_cmp_op(input) {
827 let (rest, _) = nom_ws0(rest)?;
828 let (rest, right) = parse_policy_atom(rest)?;
829 return Ok((
830 rest,
831 Expr::Binary {
832 left: Box::new(left),
833 op,
834 right: Box::new(right),
835 alias: None,
836 },
837 ));
838 }
839
840 Ok((input, left))
842}
843
844fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
846 alt((
847 map(tag(">="), |_| BinaryOp::Gte),
848 map(tag("<="), |_| BinaryOp::Lte),
849 map(tag("<>"), |_| BinaryOp::Ne),
850 map(tag("!="), |_| BinaryOp::Ne),
851 map(tag("="), |_| BinaryOp::Eq),
852 map(tag(">"), |_| BinaryOp::Gt),
853 map(tag("<"), |_| BinaryOp::Lt),
854 ))
855 .parse(input)
856}
857
858fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
866 alt((
867 parse_policy_grouped,
868 parse_policy_bool,
869 parse_policy_string,
870 parse_policy_number,
871 parse_policy_func_or_ident, ))
873 .parse(input)
874}
875
876fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
878 let (input, _) = char('(').parse(input)?;
879 let (input, _) = nom_ws0(input)?;
880 let (input, expr) = parse_policy_expr(input)?;
881 let (input, _) = nom_ws0(input)?;
882 let (input, _) = char(')').parse(input)?;
883 Ok((input, expr))
884}
885
886fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
888 alt((
889 map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
890 map(tag_no_case("false"), |_| {
891 Expr::Literal(AstValue::Bool(false))
892 }),
893 ))
894 .parse(input)
895}
896
897fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
899 let (input, _) = char('\'').parse(input)?;
900 let mut end = 0;
901 for (i, c) in input.char_indices() {
902 if c == '\'' {
903 end = i;
904 break;
905 }
906 }
907 let content = &input[..end];
908 let rest = &input[end + 1..];
909 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
910}
911
912fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
914 let original = input;
915 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
916 if digits.starts_with('.') || digits.is_empty() {
918 return Err(nom::Err::Error(nom::error::Error::new(
919 original,
920 nom::error::ErrorKind::Digit,
921 )));
922 }
923
924 if !digits.contains('.') {
925 return digits
926 .parse::<i64>()
927 .map(|n| (input, Expr::Literal(AstValue::Int(n))))
928 .map_err(|_| {
929 nom::Err::Error(nom::error::Error::new(
930 original,
931 nom::error::ErrorKind::Digit,
932 ))
933 });
934 }
935
936 if digits.matches('.').count() > 1 || policy_number_significant_digits(digits) > 15 {
937 return Err(nom::Err::Error(nom::error::Error::new(
938 original,
939 nom::error::ErrorKind::Float,
940 )));
941 }
942
943 if let Ok(f) = digits.parse::<f64>() {
944 if f.is_finite() {
945 Ok((input, Expr::Literal(AstValue::Float(f))))
946 } else {
947 Err(nom::Err::Error(nom::error::Error::new(
948 original,
949 nom::error::ErrorKind::Float,
950 )))
951 }
952 } else {
953 Err(nom::Err::Error(nom::error::Error::new(
954 original,
955 nom::error::ErrorKind::Float,
956 )))
957 }
958}
959
960fn policy_number_significant_digits(value: &str) -> usize {
961 let mut count = 0;
962 let mut seen_non_zero = false;
963
964 for byte in value.bytes() {
965 if !byte.is_ascii_digit() {
966 continue;
967 }
968 if byte != b'0' {
969 seen_non_zero = true;
970 }
971 if seen_non_zero {
972 count += 1;
973 }
974 }
975
976 count
977}
978
979fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
981 let original = input;
982 let (input, name) = identifier(input)?;
983 if name
984 .bytes()
985 .next()
986 .is_some_and(|byte| byte.is_ascii_digit())
987 {
988 return Err(nom::Err::Error(nom::error::Error::new(
989 original,
990 nom::error::ErrorKind::Alpha,
991 )));
992 }
993
994 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
996 let (rest, _) = nom_ws0(rest)?;
998 let (rest, args) =
999 separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
1000 let (rest, _) = nom_ws0(rest)?;
1001 let (rest, _) = char(')').parse(rest)?;
1002 let input = rest;
1003 (
1004 input,
1005 Expr::FunctionCall {
1006 name: name.to_string(),
1007 args,
1008 alias: None,
1009 },
1010 )
1011 } else {
1012 (input, Expr::Named(name.to_string()))
1013 };
1014
1015 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
1017 let (rest, cast_type) = identifier(rest)?;
1018 expr = (
1019 rest,
1020 Expr::Cast {
1021 expr: Box::new(expr.1),
1022 target_type: cast_type.to_string(),
1023 alias: None,
1024 },
1025 );
1026 }
1027
1028 Ok(expr)
1029}
1030
1031fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
1033 let (input, _) = ws_and_comments(input)?;
1034
1035 if let Ok((rest, policy)) = parse_policy(input) {
1037 return Ok((rest, SchemaItem::Policy(Box::new(policy))));
1038 }
1039
1040 if let Ok((rest, idx)) = parse_index(input) {
1042 return Ok((rest, SchemaItem::Index(idx)));
1043 }
1044
1045 let (rest, table) = parse_table(input)?;
1047 Ok((rest, SchemaItem::Table(table)))
1048}
1049
1050fn parse_index(input: &str) -> IResult<&str, IndexDef> {
1052 let (input, _) = tag_no_case("index")(input)?;
1053 let (input, _) = multispace1(input)?;
1054 let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
1055 let (input, _) = multispace1(input)?;
1056 let (input, _) = tag_no_case("on")(input)?;
1057 let (input, _) = multispace1(input)?;
1058 let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
1059 let (input, _) = nom_ws0(input)?;
1060 let (input, _) = char('(')(input)?;
1061 let (input, cols_str) = parenthesized_content(input)?;
1062 let (input, _) = nom_ws0(input)?;
1063 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
1064
1065 let columns = split_top_level_csv(cols_str).map_err(|_| {
1066 nom::Err::Error(nom::error::Error::new(
1067 cols_str,
1068 nom::error::ErrorKind::SeparatedList,
1069 ))
1070 })?;
1071 if columns.is_empty() {
1072 return Err(nom::Err::Error(nom::error::Error::new(
1073 cols_str,
1074 nom::error::ErrorKind::SeparatedList,
1075 )));
1076 }
1077
1078 let is_unique = unique_tag.is_some();
1079
1080 Ok((
1081 input,
1082 IndexDef {
1083 name: name.to_string(),
1084 table: table.to_string(),
1085 columns,
1086 unique: is_unique,
1087 },
1088 ))
1089}
1090
1091fn parse_schema(input: &str) -> IResult<&str, Schema> {
1093 let version = extract_version_directive(input);
1095
1096 let (input, items) = many0(parse_schema_item).parse(input)?;
1097 let (input, _) = ws_and_comments(input)?;
1098
1099 let mut tables = Vec::new();
1100 let mut policies = Vec::new();
1101 let mut indexes = Vec::new();
1102 for item in items {
1103 match item {
1104 SchemaItem::Table(t) => tables.push(t),
1105 SchemaItem::Policy(p) => policies.push(*p),
1106 SchemaItem::Index(i) => indexes.push(i),
1107 }
1108 }
1109
1110 Ok((
1111 input,
1112 Schema {
1113 version,
1114 tables,
1115 policies,
1116 indexes,
1117 },
1118 ))
1119}
1120
1121fn extract_version_directive(input: &str) -> Option<u32> {
1123 for line in input.lines() {
1124 let line = line.trim();
1125 if let Some(rest) = line.strip_prefix("-- qail:") {
1126 let rest = rest.trim();
1127 if let Some(version_str) = rest.strip_prefix("version=") {
1128 return version_str.trim().parse().ok();
1129 }
1130 }
1131 }
1132 None
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137 use super::*;
1138
1139 #[test]
1140 fn test_parse_simple_table() {
1141 let input = r#"
1142 table users (
1143 id uuid primary_key,
1144 email text not null,
1145 name text
1146 )
1147 "#;
1148
1149 let schema = Schema::parse(input).expect("parse failed");
1150 assert_eq!(schema.tables.len(), 1);
1151
1152 let users = &schema.tables[0];
1153 assert_eq!(users.name, "users");
1154 assert_eq!(users.columns.len(), 3);
1155
1156 let id = &users.columns[0];
1157 assert_eq!(id.name, "id");
1158 assert_eq!(id.typ, "uuid");
1159 assert!(id.primary_key);
1160 assert!(!id.nullable);
1161
1162 let email = &users.columns[1];
1163 assert_eq!(email.name, "email");
1164 assert!(!email.nullable);
1165
1166 let name = &users.columns[2];
1167 assert!(name.nullable);
1168 }
1169
1170 #[test]
1171 fn test_parse_multiple_tables() {
1172 let input = r#"
1173 -- Users table
1174 table users (
1175 id uuid primary_key,
1176 email text not null unique
1177 )
1178
1179 -- Orders table
1180 table orders (
1181 id uuid primary_key,
1182 user_id uuid references users(id),
1183 total i64 not null default 0
1184 )
1185 "#;
1186
1187 let schema = Schema::parse(input).expect("parse failed");
1188 assert_eq!(schema.tables.len(), 2);
1189
1190 let orders = schema.find_table("orders").expect("orders not found");
1191 let user_id = orders.find_column("user_id").expect("user_id not found");
1192 assert_eq!(user_id.references, Some("users(id)".to_string()));
1193
1194 let total = orders.find_column("total").expect("total not found");
1195 assert_eq!(total.default_value, Some("0".to_string()));
1196 }
1197
1198 #[test]
1199 fn test_parse_comments() {
1200 let input = r#"
1201 -- This is a comment
1202 table foo (
1203 bar text
1204 )
1205 "#;
1206
1207 let schema = Schema::parse(input).expect("parse failed");
1208 assert_eq!(schema.tables.len(), 1);
1209 }
1210
1211 #[test]
1212 fn test_array_types() {
1213 let input = r#"
1214 table products (
1215 id uuid primary_key,
1216 tags text[],
1217 prices decimal[]
1218 )
1219 "#;
1220
1221 let schema = Schema::parse(input).expect("parse failed");
1222 let products = &schema.tables[0];
1223
1224 let tags = products.find_column("tags").expect("tags not found");
1225 assert_eq!(tags.typ, "text");
1226 assert!(tags.is_array);
1227
1228 let prices = products.find_column("prices").expect("prices not found");
1229 assert!(prices.is_array);
1230 }
1231
1232 #[test]
1233 fn test_type_params() {
1234 let input = r#"
1235 table items (
1236 id serial primary_key,
1237 name varchar(255) not null,
1238 price decimal(10,2),
1239 code varchar(50) unique
1240 )
1241 "#;
1242
1243 let schema = Schema::parse(input).expect("parse failed");
1244 let items = &schema.tables[0];
1245
1246 let id = items.find_column("id").expect("id not found");
1247 assert!(id.is_serial);
1248 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
1251 assert_eq!(name.typ, "varchar");
1252 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1253
1254 let price = items.find_column("price").expect("price not found");
1255 assert_eq!(
1256 price.type_params,
1257 Some(vec!["10".to_string(), "2".to_string()])
1258 );
1259
1260 let code = items.find_column("code").expect("code not found");
1261 assert!(code.unique);
1262 }
1263
1264 #[test]
1265 fn test_custom_type_names_with_underscores_and_schema() {
1266 let input = r#"
1267 table bookings (
1268 id uuid primary_key,
1269 status booking_status not null,
1270 gateway_state integrations.payment_state[]
1271 )
1272 "#;
1273
1274 let schema = Schema::parse(input).expect("parse failed");
1275 let bookings = &schema.tables[0];
1276
1277 let status = bookings.find_column("status").expect("status not found");
1278 assert_eq!(status.typ, "booking_status");
1279 assert!(!status.nullable);
1280
1281 let gateway_state = bookings
1282 .find_column("gateway_state")
1283 .expect("gateway_state not found");
1284 assert_eq!(gateway_state.typ, "integrations.payment_state");
1285 assert!(gateway_state.is_array);
1286 }
1287
1288 #[test]
1289 fn test_malformed_type_params_return_parse_error_without_panic() {
1290 let input = "table invoices ( amount decimal(";
1291
1292 let result = std::panic::catch_unwind(|| Schema::parse(input));
1293
1294 assert!(result.is_ok());
1295 assert!(result.unwrap().is_err());
1296 }
1297
1298 #[test]
1299 fn test_check_constraint() {
1300 let input = r#"
1301 table employees (
1302 id uuid primary_key,
1303 age i32 check(age >= 18),
1304 salary decimal check(salary > 0)
1305 )
1306 "#;
1307
1308 let schema = Schema::parse(input).expect("parse failed");
1309 let employees = &schema.tables[0];
1310
1311 let age = employees.find_column("age").expect("age not found");
1312 assert_eq!(age.check, Some("age >= 18".to_string()));
1313
1314 let salary = employees.find_column("salary").expect("salary not found");
1315 assert_eq!(salary.check, Some("salary > 0".to_string()));
1316 }
1317
1318 #[test]
1319 fn test_default_expression_with_spaces() {
1320 let input = r#"
1321 table messages (
1322 id uuid primary_key,
1323 title text default 'new user' not null,
1324 expires_at timestamp default (now() + interval '1 day')
1325 )
1326 "#;
1327
1328 let schema = Schema::parse(input).expect("parse failed");
1329 let messages = &schema.tables[0];
1330
1331 let title = messages.find_column("title").expect("title not found");
1332 assert_eq!(title.default_value, Some("'new user'".to_string()));
1333 assert!(!title.nullable);
1334
1335 let expires_at = messages
1336 .find_column("expires_at")
1337 .expect("expires_at not found");
1338 assert_eq!(
1339 expires_at.default_value,
1340 Some("(now() + interval '1 day')".to_string())
1341 );
1342 }
1343
1344 #[test]
1345 fn test_constraints_handle_quoted_commas_and_parens() {
1346 let input = r#"
1347 table messages (
1348 id uuid primary_key,
1349 title text default 'hello, world' not null,
1350 tag text check(tag in ('a,b', 'c)')),
1351 note text default 'paren ) and comma, still literal'
1352 )
1353 "#;
1354
1355 let schema = Schema::parse(input).expect("parse failed");
1356 let messages = &schema.tables[0];
1357 assert_eq!(messages.columns.len(), 4);
1358
1359 let title = messages.find_column("title").expect("title not found");
1360 assert_eq!(title.default_value, Some("'hello, world'".to_string()));
1361 assert!(!title.nullable);
1362
1363 let tag = messages.find_column("tag").expect("tag not found");
1364 assert_eq!(tag.check, Some("tag in ('a,b', 'c)')".to_string()));
1365
1366 let note = messages.find_column("note").expect("note not found");
1367 assert_eq!(
1368 note.default_value,
1369 Some("'paren ) and comma, still literal'".to_string())
1370 );
1371 }
1372
1373 #[test]
1374 fn test_index_columns_handle_nested_expression_commas() {
1375 let input = r#"
1376 table docs (
1377 id uuid primary_key,
1378 title text,
1379 slug text
1380 )
1381
1382 index idx_docs_search on docs (regexp_replace(title, ')', '', 'g'), lower(slug)) unique
1383 "#;
1384
1385 let schema = Schema::parse(input).expect("parse failed");
1386 assert_eq!(schema.indexes.len(), 1);
1387 let index = &schema.indexes[0];
1388 assert_eq!(index.name, "idx_docs_search");
1389 assert_eq!(
1390 index.columns,
1391 vec![
1392 "regexp_replace(title, ')', '', 'g')".to_string(),
1393 "lower(slug)".to_string()
1394 ]
1395 );
1396 assert!(index.unique);
1397 assert_eq!(
1398 index.to_sql(),
1399 "CREATE UNIQUE INDEX IF NOT EXISTS idx_docs_search ON docs (regexp_replace(title, ')', '', 'g'), lower(slug))"
1400 );
1401 }
1402
1403 #[test]
1404 fn test_index_rejects_empty_columns() {
1405 for input in [
1406 "index idx_docs_search on docs ()",
1407 "index idx_docs_search on docs (title,)",
1408 "index idx_docs_search on docs (,title)",
1409 "index idx_docs_search on docs (title,,slug)",
1410 ] {
1411 let err = Schema::parse(input).expect_err("empty index columns should fail");
1412 assert!(
1413 err.contains("Parse error") || err.contains("Unexpected content"),
1414 "{err}"
1415 );
1416 }
1417 }
1418
1419 #[test]
1420 fn test_version_directive() {
1421 let input = r#"
1422 -- qail: version=1
1423 table users (
1424 id uuid primary_key
1425 )
1426 "#;
1427
1428 let schema = Schema::parse(input).expect("parse failed");
1429 assert_eq!(schema.version, Some(1));
1430 assert_eq!(schema.tables.len(), 1);
1431
1432 let input_no_version = r#"
1434 table items (
1435 id uuid primary_key
1436 )
1437 "#;
1438 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1439 assert_eq!(schema2.version, None);
1440 }
1441
1442 #[test]
1447 fn test_enable_rls_table() {
1448 let input = r#"
1449 table orders (
1450 id uuid primary_key,
1451 tenant_id uuid not null
1452 ) enable_rls
1453 "#;
1454
1455 let schema = Schema::parse(input).expect("parse failed");
1456 assert_eq!(schema.tables.len(), 1);
1457 assert!(schema.tables[0].enable_rls);
1458 }
1459
1460 #[test]
1461 fn test_parse_policy_basic() {
1462 let input = r#"
1463 table orders (
1464 id uuid primary_key,
1465 tenant_id uuid not null
1466 ) enable_rls
1467
1468 policy orders_isolation on orders
1469 for all
1470 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
1471 "#;
1472
1473 let schema = Schema::parse(input).expect("parse failed");
1474 assert_eq!(schema.tables.len(), 1);
1475 assert_eq!(schema.policies.len(), 1);
1476
1477 let policy = &schema.policies[0];
1478 assert_eq!(policy.name, "orders_isolation");
1479 assert_eq!(policy.table, "orders");
1480 assert_eq!(policy.target, PolicyTarget::All);
1481 assert!(policy.using.is_some());
1482
1483 let using = policy.using.as_ref().unwrap();
1485 let Expr::Binary {
1486 left, op, right, ..
1487 } = using
1488 else {
1489 panic!("Expected Binary, got {using:?}");
1490 };
1491 assert_eq!(*op, BinaryOp::Eq);
1492
1493 let Expr::Named(n) = left.as_ref() else {
1494 panic!("Expected Named, got {left:?}");
1495 };
1496 assert_eq!(n, "tenant_id");
1497
1498 let Expr::Cast {
1499 target_type,
1500 expr: cast_expr,
1501 ..
1502 } = right.as_ref()
1503 else {
1504 panic!("Expected Cast, got {right:?}");
1505 };
1506 assert_eq!(target_type, "uuid");
1507
1508 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1509 panic!("Expected FunctionCall, got {cast_expr:?}");
1510 };
1511 assert_eq!(name, "current_setting");
1512 assert_eq!(args.len(), 1);
1513 }
1514
1515 #[test]
1516 fn test_parse_policy_with_check() {
1517 let input = r#"
1518 table orders (
1519 id uuid primary_key
1520 )
1521
1522 policy orders_write on orders
1523 for insert
1524 with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
1525 "#;
1526
1527 let schema = Schema::parse(input).expect("parse failed");
1528 let policy = &schema.policies[0];
1529 assert_eq!(policy.target, PolicyTarget::Insert);
1530 assert!(policy.with_check.is_some());
1531 assert!(policy.using.is_none());
1532 }
1533
1534 #[test]
1535 fn test_parse_policy_restrictive_with_role() {
1536 let input = r#"
1537 table secrets (
1538 id uuid primary_key
1539 )
1540
1541 policy admin_only on secrets
1542 for select
1543 restrictive
1544 to app_user
1545 using (current_setting('app.is_super_admin')::boolean = true)
1546 "#;
1547
1548 let schema = Schema::parse(input).expect("parse failed");
1549 let policy = &schema.policies[0];
1550 assert_eq!(policy.target, PolicyTarget::Select);
1551 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1552 assert_eq!(policy.role.as_deref(), Some("app_user"));
1553 assert!(policy.using.is_some());
1554 }
1555
1556 #[test]
1557 fn test_parse_policy_or_expr() {
1558 let input = r#"
1559 table orders (
1560 id uuid primary_key
1561 )
1562
1563 policy tenant_or_admin on orders
1564 for all
1565 using (tenant_id = current_setting('app.current_tenant_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1566 "#;
1567
1568 let schema = Schema::parse(input).expect("parse failed");
1569 let policy = &schema.policies[0];
1570
1571 assert!(
1572 matches!(
1573 policy.using.as_ref().unwrap(),
1574 Expr::Binary {
1575 op: BinaryOp::Or,
1576 ..
1577 }
1578 ),
1579 "Expected Binary OR, got {:?}",
1580 policy.using
1581 );
1582 }
1583
1584 #[test]
1585 fn test_parse_policy_rejects_invalid_numeric_literals() {
1586 let huge = "999999999999999999999999999999999999999999999999999999999999999999";
1587 let input = format!(
1588 r#"
1589 table orders (
1590 id uuid primary_key,
1591 amount numeric
1592 )
1593
1594 policy amount_guard on orders
1595 for select
1596 using (amount = {huge})
1597 "#
1598 );
1599 assert!(Schema::parse(&input).is_err());
1600
1601 let input = r#"
1602 table orders (
1603 id uuid primary_key,
1604 amount numeric
1605 )
1606
1607 policy amount_guard on orders
1608 for select
1609 using (amount = 1.2.3)
1610 "#;
1611 assert!(Schema::parse(input).is_err());
1612
1613 let input = r#"
1614 table orders (
1615 id uuid primary_key,
1616 amount numeric
1617 )
1618
1619 policy amount_guard on orders
1620 for select
1621 using (amount = 9007199254740993.25)
1622 "#;
1623 assert!(Schema::parse(input).is_err());
1624 }
1625
1626 #[test]
1627 fn test_schema_to_sql() {
1628 let input = r#"
1629 table orders (
1630 id uuid primary_key,
1631 tenant_id uuid not null
1632 ) enable_rls
1633
1634 policy orders_isolation on orders
1635 for all
1636 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
1637 "#;
1638
1639 let schema = Schema::parse(input).expect("parse failed");
1640 let sql = schema.to_sql();
1641 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1642 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1643 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1644 assert!(sql.contains("CREATE POLICY"));
1645 assert!(sql.contains("orders_isolation"));
1646 assert!(sql.contains("FOR ALL"));
1647 }
1648
1649 #[test]
1650 fn test_multiple_policies() {
1651 let input = r#"
1652 table orders (
1653 id uuid primary_key,
1654 tenant_id uuid not null
1655 ) enable_rls
1656
1657 policy orders_read on orders
1658 for select
1659 using (tenant_id = current_setting('app.current_tenant_id')::uuid)
1660
1661 policy orders_write on orders
1662 for insert
1663 with check (tenant_id = current_setting('app.current_tenant_id')::uuid)
1664 "#;
1665
1666 let schema = Schema::parse(input).expect("parse failed");
1667 assert_eq!(schema.policies.len(), 2);
1668 assert_eq!(schema.policies[0].name, "orders_read");
1669 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1670 assert_eq!(schema.policies[1].name, "orders_write");
1671 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1672 }
1673}