1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26
27use crate::ast::{BinaryOp, Expr, Value as AstValue};
28use crate::migrate::alter::AlterTable;
29use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
30use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
31
32#[derive(Debug, Clone, Default)]
34pub struct Schema {
35 pub version: Option<u32>,
37 pub tables: Vec<TableDef>,
39 pub policies: Vec<RlsPolicy>,
41 pub indexes: Vec<IndexDef>,
43}
44
45#[derive(Debug, Clone)]
47pub struct IndexDef {
48 pub name: String,
50 pub table: String,
52 pub columns: Vec<String>,
54 pub unique: bool,
56}
57
58impl IndexDef {
59 pub fn to_sql(&self) -> String {
61 let unique = if self.unique { " UNIQUE" } else { "" };
62 format!(
63 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64 unique,
65 self.name,
66 self.table,
67 self.columns.join(", ")
68 )
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct TableDef {
75 pub name: String,
77 pub columns: Vec<ColumnDef>,
79 pub enable_rls: bool,
81}
82
83#[derive(Debug, Clone)]
85pub struct ColumnDef {
86 pub name: String,
88 pub typ: String,
90 pub is_array: bool,
92 pub type_params: Option<Vec<String>>,
94 pub nullable: bool,
96 pub primary_key: bool,
98 pub unique: bool,
100 pub references: Option<String>,
102 pub default_value: Option<String>,
104 pub check: Option<String>,
106 pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111 fn default() -> Self {
112 Self {
113 name: String::new(),
114 typ: String::new(),
115 is_array: false,
116 type_params: None,
117 nullable: true,
118 primary_key: false,
119 unique: false,
120 references: None,
121 default_value: None,
122 check: None,
123 is_serial: false,
124 }
125 }
126}
127
128impl Schema {
129 pub fn parse(input: &str) -> Result<Self, String> {
131 match parse_schema(input) {
132 Ok(("", schema)) => Ok(schema),
133 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134 Err(e) => Err(format!("Parse error: {:?}", e)),
135 }
136 }
137
138 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140 self.tables
141 .iter()
142 .find(|t| t.name.eq_ignore_ascii_case(name))
143 }
144
145 pub fn to_sql(&self) -> String {
147 let mut parts = Vec::new();
148
149 for table in &self.tables {
150 parts.push(table.to_ddl());
151
152 if table.enable_rls {
153 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154 for stmt in alter_table_sql(&alter) {
155 parts.push(stmt);
156 }
157 }
158 }
159
160 for idx in &self.indexes {
161 parts.push(idx.to_sql());
162 }
163
164 for policy in &self.policies {
165 parts.push(create_policy_sql(policy));
166 }
167
168 parts.join(";\n\n") + ";"
169 }
170
171 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
173 let content =
174 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
175 Self::parse(&content)
176 }
177}
178
179impl TableDef {
180 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
182 self.columns
183 .iter()
184 .find(|c| c.name.eq_ignore_ascii_case(name))
185 }
186
187 pub fn to_ddl(&self) -> String {
189 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
190
191 let mut col_defs = Vec::new();
192 for col in &self.columns {
193 let mut line = format!(" {}", col.name);
194
195 let mut typ = col.typ.to_uppercase();
197 if let Some(params) = &col.type_params {
198 typ = format!("{}({})", typ, params.join(", "));
199 }
200 if col.is_array {
201 typ.push_str("[]");
202 }
203 line.push_str(&format!(" {}", typ));
204
205 if col.primary_key {
207 line.push_str(" PRIMARY KEY");
208 }
209 if !col.nullable && !col.primary_key && !col.is_serial {
210 line.push_str(" NOT NULL");
211 }
212 if col.unique && !col.primary_key {
213 line.push_str(" UNIQUE");
214 }
215 if let Some(ref default) = col.default_value {
216 line.push_str(&format!(" DEFAULT {}", default));
217 }
218 if let Some(ref refs) = col.references {
219 line.push_str(&format!(" REFERENCES {}", refs));
220 }
221 if let Some(ref check) = col.check {
222 line.push_str(&format!(" CHECK({})", check));
223 }
224
225 col_defs.push(line);
226 }
227
228 sql.push_str(&col_defs.join(",\n"));
229 sql.push_str("\n)");
230 sql
231 }
232}
233
234fn identifier(input: &str) -> IResult<&str, &str> {
240 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
241}
242
243fn ws_and_comments(input: &str) -> IResult<&str, ()> {
245 let (input, _) = many0(alt((
246 map(multispace1, |_| ()),
247 map((tag("--"), not_line_ending), |_| ()),
248 map((tag("#"), not_line_ending), |_| ()),
249 )))
250 .parse(input)?;
251 Ok((input, ()))
252}
253
254struct TypeInfo {
255 name: String,
256 params: Option<Vec<String>>,
257 is_array: bool,
258 is_serial: bool,
259}
260
261fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
264 let (input, type_name) =
265 take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.').parse(input)?;
266
267 let (input, params) = if let Some(after_open) = input.strip_prefix('(') {
268 let Some(paren_end) = after_open.find(')') else {
269 return Err(nom::Err::Error(nom::error::Error::new(
270 input,
271 nom::error::ErrorKind::Char,
272 )));
273 };
274 let param_str = &after_open[..paren_end];
275 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
276 (&after_open[paren_end + 1..], Some(params))
277 } else {
278 (input, None)
279 };
280
281 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
282 (stripped, true)
283 } else {
284 (input, false)
285 };
286
287 let lower = type_name.to_lowercase();
288 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
289
290 Ok((
291 input,
292 TypeInfo {
293 name: lower,
294 params,
295 is_array,
296 is_serial,
297 },
298 ))
299}
300
301fn constraint_text(input: &str) -> IResult<&str, &str> {
303 let mut paren_depth = 0;
304 let mut in_single = false;
305 let mut in_double = false;
306 let mut end = 0;
307 let mut iter = input.char_indices().peekable();
308
309 while let Some((i, c)) = iter.next() {
310 match c {
311 '\'' if !in_double => {
312 if in_single && matches!(iter.peek(), Some((_, '\''))) {
313 iter.next();
314 } else {
315 in_single = !in_single;
316 }
317 }
318 '"' if !in_single => {
319 if in_double && matches!(iter.peek(), Some((_, '"'))) {
320 iter.next();
321 } else {
322 in_double = !in_double;
323 }
324 }
325 '(' if !in_single && !in_double => paren_depth += 1,
326 ')' if !in_single && !in_double => {
327 if paren_depth == 0 {
328 break; }
330 paren_depth -= 1;
331 }
332 ',' if !in_single && !in_double && paren_depth == 0 => break,
333 '\n' | '\r' if !in_single && !in_double && paren_depth == 0 => break,
334 _ => {}
335 }
336 end = i + c.len_utf8();
337 }
338
339 if end == 0 {
340 Err(nom::Err::Error(nom::error::Error::new(
341 input,
342 nom::error::ErrorKind::TakeWhile1,
343 )))
344 } else {
345 Ok((&input[end..], &input[..end]))
346 }
347}
348
349fn check_expr_end(rest: &str) -> usize {
350 let mut depth = 1usize;
351 let mut in_single = false;
352 let mut in_double = false;
353 let mut iter = rest.char_indices().peekable();
354
355 while let Some((idx, ch)) = iter.next() {
356 match ch {
357 '\'' if !in_double => {
358 if in_single && matches!(iter.peek(), Some((_, '\''))) {
359 iter.next();
360 } else {
361 in_single = !in_single;
362 }
363 }
364 '"' if !in_single => {
365 if in_double && matches!(iter.peek(), Some((_, '"'))) {
366 iter.next();
367 } else {
368 in_double = !in_double;
369 }
370 }
371 '(' if !in_single && !in_double => depth += 1,
372 ')' if !in_single && !in_double => {
373 depth -= 1;
374 if depth == 0 {
375 return idx;
376 }
377 }
378 _ => {}
379 }
380 }
381
382 rest.len()
383}
384
385fn parenthesized_content(input: &str) -> IResult<&str, &str> {
386 let mut paren_depth = 0usize;
387 let mut in_single = false;
388 let mut in_double = false;
389 let mut iter = input.char_indices().peekable();
390
391 while let Some((idx, ch)) = iter.next() {
392 match ch {
393 '\'' if !in_double => {
394 if in_single && matches!(iter.peek(), Some((_, '\''))) {
395 iter.next();
396 } else {
397 in_single = !in_single;
398 }
399 }
400 '"' if !in_single => {
401 if in_double && matches!(iter.peek(), Some((_, '"'))) {
402 iter.next();
403 } else {
404 in_double = !in_double;
405 }
406 }
407 '(' if !in_single && !in_double => paren_depth += 1,
408 ')' if !in_single && !in_double => {
409 if paren_depth == 0 {
410 return Ok((&input[idx + ch.len_utf8()..], &input[..idx]));
411 }
412 paren_depth -= 1;
413 }
414 _ => {}
415 }
416 }
417
418 Err(nom::Err::Error(nom::error::Error::new(
419 input,
420 nom::error::ErrorKind::Char,
421 )))
422}
423
424fn split_top_level_csv(input: &str) -> Result<Vec<String>, ()> {
425 let mut parts = Vec::new();
426 let mut start = 0usize;
427 let mut paren_depth = 0usize;
428 let mut in_single = false;
429 let mut in_double = false;
430 let mut iter = input.char_indices().peekable();
431
432 while let Some((idx, ch)) = iter.next() {
433 match ch {
434 '\'' if !in_double => {
435 if in_single && matches!(iter.peek(), Some((_, '\''))) {
436 iter.next();
437 } else {
438 in_single = !in_single;
439 }
440 }
441 '"' if !in_single => {
442 if in_double && matches!(iter.peek(), Some((_, '"'))) {
443 iter.next();
444 } else {
445 in_double = !in_double;
446 }
447 }
448 '(' if !in_single && !in_double => paren_depth += 1,
449 ')' if !in_single && !in_double => {
450 if paren_depth == 0 {
451 return Err(());
452 }
453 paren_depth -= 1;
454 }
455 ',' if !in_single && !in_double && paren_depth == 0 => {
456 let part = input[start..idx].trim();
457 if part.is_empty() {
458 return Err(());
459 }
460 parts.push(part.to_string());
461 start = idx + ch.len_utf8();
462 }
463 _ => {}
464 }
465 }
466
467 if in_single || in_double || paren_depth != 0 {
468 return Err(());
469 }
470 let part = input[start..].trim();
471 if part.is_empty() {
472 if !input.trim().is_empty() {
473 return Err(());
474 }
475 } else {
476 parts.push(part.to_string());
477 }
478
479 Ok(parts)
480}
481
482fn starts_constraint_keyword(input: &str) -> bool {
483 let lower = input.to_ascii_lowercase();
484 matches!(
485 lower.as_str(),
486 s if s.starts_with("primary_key")
487 || s.starts_with("primary key")
488 || s.starts_with("not_null")
489 || s.starts_with("not null")
490 || s.starts_with("unique")
491 || s.starts_with("references ")
492 || s.starts_with("check(")
493 )
494}
495
496fn default_expr_end(rest: &str) -> usize {
497 let mut in_single = false;
498 let mut in_double = false;
499 let mut paren_depth = 0usize;
500 let mut iter = rest.char_indices().peekable();
501
502 while let Some((idx, ch)) = iter.next() {
503 match ch {
504 '\'' if !in_double => {
505 if in_single && matches!(iter.peek(), Some((_, '\''))) {
506 iter.next();
507 } else {
508 in_single = !in_single;
509 }
510 }
511 '"' if !in_single => {
512 if in_double && matches!(iter.peek(), Some((_, '"'))) {
513 iter.next();
514 } else {
515 in_double = !in_double;
516 }
517 }
518 '(' if !in_single && !in_double => paren_depth += 1,
519 ')' if !in_single && !in_double && paren_depth > 0 => paren_depth -= 1,
520 c if c.is_whitespace()
521 && !in_single
522 && !in_double
523 && paren_depth == 0
524 && starts_constraint_keyword(rest[idx..].trim_start()) =>
525 {
526 return idx;
527 }
528 _ => {}
529 }
530 }
531
532 rest.len()
533}
534
535fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
537 let (input, _) = ws_and_comments(input)?;
538 let (input, name) = identifier(input)?;
539 let (input, _) = multispace1(input)?;
540 let (input, type_info) = parse_type_info(input)?;
541
542 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
543
544 let mut col = ColumnDef {
545 name: name.to_string(),
546 typ: type_info.name,
547 is_array: type_info.is_array,
548 type_params: type_info.params,
549 is_serial: type_info.is_serial,
550 nullable: !type_info.is_serial, ..Default::default()
552 };
553
554 if let Some(constraints) = constraint_str {
555 let lower = constraints.to_lowercase();
556
557 if lower.contains("primary_key") || lower.contains("primary key") {
558 col.primary_key = true;
559 col.nullable = false;
560 }
561 if lower.contains("not_null") || lower.contains("not null") {
562 col.nullable = false;
563 }
564 if lower.contains("unique") {
565 col.unique = true;
566 }
567
568 if let Some(idx) = lower.find("references ") {
569 let rest = &constraints[idx + 11..];
570 let mut paren_depth = 0;
572 let mut end = rest.len();
573 for (i, c) in rest.char_indices() {
574 match c {
575 '(' => paren_depth += 1,
576 ')' => {
577 if paren_depth == 0 {
578 end = i;
579 break;
580 }
581 paren_depth -= 1;
582 }
583 c if c.is_whitespace() && paren_depth == 0 => {
584 end = i;
585 break;
586 }
587 _ => {}
588 }
589 }
590 col.references = Some(rest[..end].to_string());
591 }
592
593 if let Some(idx) = lower.find("default ") {
594 let rest = &constraints[idx + 8..];
595 let end = default_expr_end(rest);
596 col.default_value = Some(rest[..end].to_string());
597 }
598
599 if let Some(idx) = lower.find("check(") {
600 let rest = &constraints[idx + 6..];
601 let end = check_expr_end(rest);
602 col.check = Some(rest[..end].to_string());
603 }
604 }
605
606 Ok((input, col))
607}
608
609fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
611 let (input, _) = ws_and_comments(input)?;
612 let (input, _) = char('(').parse(input)?;
613 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
614 let (input, _) = ws_and_comments(input)?;
615 let (input, _) = char(')').parse(input)?;
616
617 Ok((input, columns))
618}
619
620fn parse_table(input: &str) -> IResult<&str, TableDef> {
622 let (input, _) = ws_and_comments(input)?;
623 let (input, _) = tag_no_case("table").parse(input)?;
624 let (input, _) = multispace1(input)?;
625 let (input, name) = identifier(input)?;
626 let (input, columns) = parse_column_list(input)?;
627
628 let (input, _) = ws_and_comments(input)?;
630 let enable_rls = if let Ok((rest, _)) =
631 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
632 {
633 return Ok((
634 rest,
635 TableDef {
636 name: name.to_string(),
637 columns,
638 enable_rls: true,
639 },
640 ));
641 } else {
642 false
643 };
644
645 Ok((
646 input,
647 TableDef {
648 name: name.to_string(),
649 columns,
650 enable_rls,
651 },
652 ))
653}
654
655enum SchemaItem {
661 Table(TableDef),
662 Policy(Box<RlsPolicy>),
663 Index(IndexDef),
664}
665
666fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
678 let (input, _) = ws_and_comments(input)?;
679 let (input, _) = tag_no_case("policy").parse(input)?;
680 let (input, _) = multispace1(input)?;
681 let (input, name) = identifier(input)?;
682 let (input, _) = multispace1(input)?;
683 let (input, _) = tag_no_case("on").parse(input)?;
684 let (input, _) = multispace1(input)?;
685 let (input, table) = identifier(input)?;
686
687 let mut policy = RlsPolicy::create(name, table);
688
689 let mut remaining = input;
691 loop {
692 let (input, _) = ws_and_comments(remaining)?;
693
694 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
696 let (rest, _) = multispace1(rest)?;
697 let (rest, target) = alt((
698 map(tag_no_case("all"), |_| PolicyTarget::All),
699 map(tag_no_case("select"), |_| PolicyTarget::Select),
700 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
701 map(tag_no_case("update"), |_| PolicyTarget::Update),
702 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
703 ))
704 .parse(rest)?;
705 policy.target = target;
706 remaining = rest;
707 continue;
708 }
709
710 if let Ok((rest, _)) =
712 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
713 {
714 policy.permissiveness = PolicyPermissiveness::Restrictive;
715 remaining = rest;
716 continue;
717 }
718
719 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
721 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
723 let (rest, role) = identifier(rest)?;
724 policy.role = Some(role.to_string());
725 remaining = rest;
726 continue;
727 }
728 }
729
730 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
732 let (rest, _) = multispace1(rest)?;
733 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
734 {
735 let (rest, _) = nom_ws0(rest)?;
736 let (rest, _) = char('(').parse(rest)?;
737 let (rest, _) = nom_ws0(rest)?;
738 let (rest, expr) = parse_policy_expr(rest)?;
739 let (rest, _) = nom_ws0(rest)?;
740 let (rest, _) = char(')').parse(rest)?;
741 policy.with_check = Some(expr);
742 remaining = rest;
743 continue;
744 }
745 }
746
747 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
749 let (rest, _) = nom_ws0(rest)?;
750 let (rest, _) = char('(').parse(rest)?;
751 let (rest, _) = nom_ws0(rest)?;
752 let (rest, expr) = parse_policy_expr(rest)?;
753 let (rest, _) = nom_ws0(rest)?;
754 let (rest, _) = char(')').parse(rest)?;
755 policy.using = Some(expr);
756 remaining = rest;
757 continue;
758 }
759
760 remaining = input;
762 break;
763 }
764
765 Ok((remaining, policy))
766}
767
768fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
777 let (input, first) = parse_policy_comparison(input)?;
778
779 let mut result = first;
781 let mut remaining = input;
782 loop {
783 let (input, _) = nom_ws0(remaining)?;
784
785 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
786 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
787 {
788 let (rest, right) = parse_policy_comparison(rest)?;
789 result = Expr::Binary {
790 left: Box::new(result),
791 op: BinaryOp::Or,
792 right: Box::new(right),
793 alias: None,
794 };
795 remaining = rest;
796 continue;
797 }
798
799 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
800 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
801 {
802 let (rest, right) = parse_policy_comparison(rest)?;
803 result = Expr::Binary {
804 left: Box::new(result),
805 op: BinaryOp::And,
806 right: Box::new(right),
807 alias: None,
808 };
809 remaining = rest;
810 continue;
811 }
812
813 remaining = input;
814 break;
815 }
816
817 Ok((remaining, result))
818}
819
820fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
822 let (input, left) = parse_policy_atom(input)?;
823 let (input, _) = nom_ws0(input)?;
824
825 if let Ok((rest, op)) = parse_cmp_op(input) {
827 let (rest, _) = nom_ws0(rest)?;
828 let (rest, right) = parse_policy_atom(rest)?;
829 return Ok((
830 rest,
831 Expr::Binary {
832 left: Box::new(left),
833 op,
834 right: Box::new(right),
835 alias: None,
836 },
837 ));
838 }
839
840 Ok((input, left))
842}
843
844fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
846 alt((
847 map(tag(">="), |_| BinaryOp::Gte),
848 map(tag("<="), |_| BinaryOp::Lte),
849 map(tag("<>"), |_| BinaryOp::Ne),
850 map(tag("!="), |_| BinaryOp::Ne),
851 map(tag("="), |_| BinaryOp::Eq),
852 map(tag(">"), |_| BinaryOp::Gt),
853 map(tag("<"), |_| BinaryOp::Lt),
854 ))
855 .parse(input)
856}
857
858fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
866 alt((
867 parse_policy_grouped,
868 parse_policy_bool,
869 parse_policy_string,
870 parse_policy_number,
871 parse_policy_func_or_ident, ))
873 .parse(input)
874}
875
876fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
878 let (input, _) = char('(').parse(input)?;
879 let (input, _) = nom_ws0(input)?;
880 let (input, expr) = parse_policy_expr(input)?;
881 let (input, _) = nom_ws0(input)?;
882 let (input, _) = char(')').parse(input)?;
883 Ok((input, expr))
884}
885
886fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
888 alt((
889 map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
890 map(tag_no_case("false"), |_| {
891 Expr::Literal(AstValue::Bool(false))
892 }),
893 ))
894 .parse(input)
895}
896
897fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
899 let (input, _) = char('\'').parse(input)?;
900 let mut end = 0;
901 for (i, c) in input.char_indices() {
902 if c == '\'' {
903 end = i;
904 break;
905 }
906 }
907 let content = &input[..end];
908 let rest = &input[end + 1..];
909 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
910}
911
912fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
914 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
915 if digits.starts_with('.') || digits.is_empty() {
917 return Err(nom::Err::Error(nom::error::Error::new(
918 input,
919 nom::error::ErrorKind::Digit,
920 )));
921 }
922 if let Ok(n) = digits.parse::<i64>() {
923 Ok((input, Expr::Literal(AstValue::Int(n))))
924 } else if let Ok(f) = digits.parse::<f64>() {
925 Ok((input, Expr::Literal(AstValue::Float(f))))
926 } else {
927 Ok((input, Expr::Named(digits.to_string())))
928 }
929}
930
931fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
933 let (input, name) = identifier(input)?;
934
935 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
937 let (rest, _) = nom_ws0(rest)?;
939 let (rest, args) =
940 separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
941 let (rest, _) = nom_ws0(rest)?;
942 let (rest, _) = char(')').parse(rest)?;
943 let input = rest;
944 (
945 input,
946 Expr::FunctionCall {
947 name: name.to_string(),
948 args,
949 alias: None,
950 },
951 )
952 } else {
953 (input, Expr::Named(name.to_string()))
954 };
955
956 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
958 let (rest, cast_type) = identifier(rest)?;
959 expr = (
960 rest,
961 Expr::Cast {
962 expr: Box::new(expr.1),
963 target_type: cast_type.to_string(),
964 alias: None,
965 },
966 );
967 }
968
969 Ok(expr)
970}
971
972fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
974 let (input, _) = ws_and_comments(input)?;
975
976 if let Ok((rest, policy)) = parse_policy(input) {
978 return Ok((rest, SchemaItem::Policy(Box::new(policy))));
979 }
980
981 if let Ok((rest, idx)) = parse_index(input) {
983 return Ok((rest, SchemaItem::Index(idx)));
984 }
985
986 let (rest, table) = parse_table(input)?;
988 Ok((rest, SchemaItem::Table(table)))
989}
990
991fn parse_index(input: &str) -> IResult<&str, IndexDef> {
993 let (input, _) = tag_no_case("index")(input)?;
994 let (input, _) = multispace1(input)?;
995 let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
996 let (input, _) = multispace1(input)?;
997 let (input, _) = tag_no_case("on")(input)?;
998 let (input, _) = multispace1(input)?;
999 let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
1000 let (input, _) = nom_ws0(input)?;
1001 let (input, _) = char('(')(input)?;
1002 let (input, cols_str) = parenthesized_content(input)?;
1003 let (input, _) = nom_ws0(input)?;
1004 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
1005
1006 let columns = split_top_level_csv(cols_str).map_err(|_| {
1007 nom::Err::Error(nom::error::Error::new(
1008 cols_str,
1009 nom::error::ErrorKind::SeparatedList,
1010 ))
1011 })?;
1012 if columns.is_empty() {
1013 return Err(nom::Err::Error(nom::error::Error::new(
1014 cols_str,
1015 nom::error::ErrorKind::SeparatedList,
1016 )));
1017 }
1018
1019 let is_unique = unique_tag.is_some();
1020
1021 Ok((
1022 input,
1023 IndexDef {
1024 name: name.to_string(),
1025 table: table.to_string(),
1026 columns,
1027 unique: is_unique,
1028 },
1029 ))
1030}
1031
1032fn parse_schema(input: &str) -> IResult<&str, Schema> {
1034 let version = extract_version_directive(input);
1036
1037 let (input, items) = many0(parse_schema_item).parse(input)?;
1038 let (input, _) = ws_and_comments(input)?;
1039
1040 let mut tables = Vec::new();
1041 let mut policies = Vec::new();
1042 let mut indexes = Vec::new();
1043 for item in items {
1044 match item {
1045 SchemaItem::Table(t) => tables.push(t),
1046 SchemaItem::Policy(p) => policies.push(*p),
1047 SchemaItem::Index(i) => indexes.push(i),
1048 }
1049 }
1050
1051 Ok((
1052 input,
1053 Schema {
1054 version,
1055 tables,
1056 policies,
1057 indexes,
1058 },
1059 ))
1060}
1061
1062fn extract_version_directive(input: &str) -> Option<u32> {
1064 for line in input.lines() {
1065 let line = line.trim();
1066 if let Some(rest) = line.strip_prefix("-- qail:") {
1067 let rest = rest.trim();
1068 if let Some(version_str) = rest.strip_prefix("version=") {
1069 return version_str.trim().parse().ok();
1070 }
1071 }
1072 }
1073 None
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_parse_simple_table() {
1082 let input = r#"
1083 table users (
1084 id uuid primary_key,
1085 email text not null,
1086 name text
1087 )
1088 "#;
1089
1090 let schema = Schema::parse(input).expect("parse failed");
1091 assert_eq!(schema.tables.len(), 1);
1092
1093 let users = &schema.tables[0];
1094 assert_eq!(users.name, "users");
1095 assert_eq!(users.columns.len(), 3);
1096
1097 let id = &users.columns[0];
1098 assert_eq!(id.name, "id");
1099 assert_eq!(id.typ, "uuid");
1100 assert!(id.primary_key);
1101 assert!(!id.nullable);
1102
1103 let email = &users.columns[1];
1104 assert_eq!(email.name, "email");
1105 assert!(!email.nullable);
1106
1107 let name = &users.columns[2];
1108 assert!(name.nullable);
1109 }
1110
1111 #[test]
1112 fn test_parse_multiple_tables() {
1113 let input = r#"
1114 -- Users table
1115 table users (
1116 id uuid primary_key,
1117 email text not null unique
1118 )
1119
1120 -- Orders table
1121 table orders (
1122 id uuid primary_key,
1123 user_id uuid references users(id),
1124 total i64 not null default 0
1125 )
1126 "#;
1127
1128 let schema = Schema::parse(input).expect("parse failed");
1129 assert_eq!(schema.tables.len(), 2);
1130
1131 let orders = schema.find_table("orders").expect("orders not found");
1132 let user_id = orders.find_column("user_id").expect("user_id not found");
1133 assert_eq!(user_id.references, Some("users(id)".to_string()));
1134
1135 let total = orders.find_column("total").expect("total not found");
1136 assert_eq!(total.default_value, Some("0".to_string()));
1137 }
1138
1139 #[test]
1140 fn test_parse_comments() {
1141 let input = r#"
1142 -- This is a comment
1143 table foo (
1144 bar text
1145 )
1146 "#;
1147
1148 let schema = Schema::parse(input).expect("parse failed");
1149 assert_eq!(schema.tables.len(), 1);
1150 }
1151
1152 #[test]
1153 fn test_array_types() {
1154 let input = r#"
1155 table products (
1156 id uuid primary_key,
1157 tags text[],
1158 prices decimal[]
1159 )
1160 "#;
1161
1162 let schema = Schema::parse(input).expect("parse failed");
1163 let products = &schema.tables[0];
1164
1165 let tags = products.find_column("tags").expect("tags not found");
1166 assert_eq!(tags.typ, "text");
1167 assert!(tags.is_array);
1168
1169 let prices = products.find_column("prices").expect("prices not found");
1170 assert!(prices.is_array);
1171 }
1172
1173 #[test]
1174 fn test_type_params() {
1175 let input = r#"
1176 table items (
1177 id serial primary_key,
1178 name varchar(255) not null,
1179 price decimal(10,2),
1180 code varchar(50) unique
1181 )
1182 "#;
1183
1184 let schema = Schema::parse(input).expect("parse failed");
1185 let items = &schema.tables[0];
1186
1187 let id = items.find_column("id").expect("id not found");
1188 assert!(id.is_serial);
1189 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
1192 assert_eq!(name.typ, "varchar");
1193 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1194
1195 let price = items.find_column("price").expect("price not found");
1196 assert_eq!(
1197 price.type_params,
1198 Some(vec!["10".to_string(), "2".to_string()])
1199 );
1200
1201 let code = items.find_column("code").expect("code not found");
1202 assert!(code.unique);
1203 }
1204
1205 #[test]
1206 fn test_custom_type_names_with_underscores_and_schema() {
1207 let input = r#"
1208 table bookings (
1209 id uuid primary_key,
1210 status booking_status not null,
1211 gateway_state integrations.payment_state[]
1212 )
1213 "#;
1214
1215 let schema = Schema::parse(input).expect("parse failed");
1216 let bookings = &schema.tables[0];
1217
1218 let status = bookings.find_column("status").expect("status not found");
1219 assert_eq!(status.typ, "booking_status");
1220 assert!(!status.nullable);
1221
1222 let gateway_state = bookings
1223 .find_column("gateway_state")
1224 .expect("gateway_state not found");
1225 assert_eq!(gateway_state.typ, "integrations.payment_state");
1226 assert!(gateway_state.is_array);
1227 }
1228
1229 #[test]
1230 fn test_malformed_type_params_return_parse_error_without_panic() {
1231 let input = "table invoices ( amount decimal(";
1232
1233 let result = std::panic::catch_unwind(|| Schema::parse(input));
1234
1235 assert!(result.is_ok());
1236 assert!(result.unwrap().is_err());
1237 }
1238
1239 #[test]
1240 fn test_check_constraint() {
1241 let input = r#"
1242 table employees (
1243 id uuid primary_key,
1244 age i32 check(age >= 18),
1245 salary decimal check(salary > 0)
1246 )
1247 "#;
1248
1249 let schema = Schema::parse(input).expect("parse failed");
1250 let employees = &schema.tables[0];
1251
1252 let age = employees.find_column("age").expect("age not found");
1253 assert_eq!(age.check, Some("age >= 18".to_string()));
1254
1255 let salary = employees.find_column("salary").expect("salary not found");
1256 assert_eq!(salary.check, Some("salary > 0".to_string()));
1257 }
1258
1259 #[test]
1260 fn test_default_expression_with_spaces() {
1261 let input = r#"
1262 table messages (
1263 id uuid primary_key,
1264 title text default 'new user' not null,
1265 expires_at timestamp default (now() + interval '1 day')
1266 )
1267 "#;
1268
1269 let schema = Schema::parse(input).expect("parse failed");
1270 let messages = &schema.tables[0];
1271
1272 let title = messages.find_column("title").expect("title not found");
1273 assert_eq!(title.default_value, Some("'new user'".to_string()));
1274 assert!(!title.nullable);
1275
1276 let expires_at = messages
1277 .find_column("expires_at")
1278 .expect("expires_at not found");
1279 assert_eq!(
1280 expires_at.default_value,
1281 Some("(now() + interval '1 day')".to_string())
1282 );
1283 }
1284
1285 #[test]
1286 fn test_constraints_handle_quoted_commas_and_parens() {
1287 let input = r#"
1288 table messages (
1289 id uuid primary_key,
1290 title text default 'hello, world' not null,
1291 tag text check(tag in ('a,b', 'c)')),
1292 note text default 'paren ) and comma, still literal'
1293 )
1294 "#;
1295
1296 let schema = Schema::parse(input).expect("parse failed");
1297 let messages = &schema.tables[0];
1298 assert_eq!(messages.columns.len(), 4);
1299
1300 let title = messages.find_column("title").expect("title not found");
1301 assert_eq!(title.default_value, Some("'hello, world'".to_string()));
1302 assert!(!title.nullable);
1303
1304 let tag = messages.find_column("tag").expect("tag not found");
1305 assert_eq!(tag.check, Some("tag in ('a,b', 'c)')".to_string()));
1306
1307 let note = messages.find_column("note").expect("note not found");
1308 assert_eq!(
1309 note.default_value,
1310 Some("'paren ) and comma, still literal'".to_string())
1311 );
1312 }
1313
1314 #[test]
1315 fn test_index_columns_handle_nested_expression_commas() {
1316 let input = r#"
1317 table docs (
1318 id uuid primary_key,
1319 title text,
1320 slug text
1321 )
1322
1323 index idx_docs_search on docs (regexp_replace(title, ')', '', 'g'), lower(slug)) unique
1324 "#;
1325
1326 let schema = Schema::parse(input).expect("parse failed");
1327 assert_eq!(schema.indexes.len(), 1);
1328 let index = &schema.indexes[0];
1329 assert_eq!(index.name, "idx_docs_search");
1330 assert_eq!(
1331 index.columns,
1332 vec![
1333 "regexp_replace(title, ')', '', 'g')".to_string(),
1334 "lower(slug)".to_string()
1335 ]
1336 );
1337 assert!(index.unique);
1338 assert_eq!(
1339 index.to_sql(),
1340 "CREATE UNIQUE INDEX IF NOT EXISTS idx_docs_search ON docs (regexp_replace(title, ')', '', 'g'), lower(slug))"
1341 );
1342 }
1343
1344 #[test]
1345 fn test_index_rejects_empty_columns() {
1346 for input in [
1347 "index idx_docs_search on docs ()",
1348 "index idx_docs_search on docs (title,)",
1349 "index idx_docs_search on docs (,title)",
1350 "index idx_docs_search on docs (title,,slug)",
1351 ] {
1352 let err = Schema::parse(input).expect_err("empty index columns should fail");
1353 assert!(
1354 err.contains("Parse error") || err.contains("Unexpected content"),
1355 "{err}"
1356 );
1357 }
1358 }
1359
1360 #[test]
1361 fn test_version_directive() {
1362 let input = r#"
1363 -- qail: version=1
1364 table users (
1365 id uuid primary_key
1366 )
1367 "#;
1368
1369 let schema = Schema::parse(input).expect("parse failed");
1370 assert_eq!(schema.version, Some(1));
1371 assert_eq!(schema.tables.len(), 1);
1372
1373 let input_no_version = r#"
1375 table items (
1376 id uuid primary_key
1377 )
1378 "#;
1379 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1380 assert_eq!(schema2.version, None);
1381 }
1382
1383 #[test]
1388 fn test_enable_rls_table() {
1389 let input = r#"
1390 table orders (
1391 id uuid primary_key,
1392 operator_id uuid not null
1393 ) enable_rls
1394 "#;
1395
1396 let schema = Schema::parse(input).expect("parse failed");
1397 assert_eq!(schema.tables.len(), 1);
1398 assert!(schema.tables[0].enable_rls);
1399 }
1400
1401 #[test]
1402 fn test_parse_policy_basic() {
1403 let input = r#"
1404 table orders (
1405 id uuid primary_key,
1406 operator_id uuid not null
1407 ) enable_rls
1408
1409 policy orders_isolation on orders
1410 for all
1411 using (operator_id = current_setting('app.current_operator_id')::uuid)
1412 "#;
1413
1414 let schema = Schema::parse(input).expect("parse failed");
1415 assert_eq!(schema.tables.len(), 1);
1416 assert_eq!(schema.policies.len(), 1);
1417
1418 let policy = &schema.policies[0];
1419 assert_eq!(policy.name, "orders_isolation");
1420 assert_eq!(policy.table, "orders");
1421 assert_eq!(policy.target, PolicyTarget::All);
1422 assert!(policy.using.is_some());
1423
1424 let using = policy.using.as_ref().unwrap();
1426 let Expr::Binary {
1427 left, op, right, ..
1428 } = using
1429 else {
1430 panic!("Expected Binary, got {using:?}");
1431 };
1432 assert_eq!(*op, BinaryOp::Eq);
1433
1434 let Expr::Named(n) = left.as_ref() else {
1435 panic!("Expected Named, got {left:?}");
1436 };
1437 assert_eq!(n, "operator_id");
1438
1439 let Expr::Cast {
1440 target_type,
1441 expr: cast_expr,
1442 ..
1443 } = right.as_ref()
1444 else {
1445 panic!("Expected Cast, got {right:?}");
1446 };
1447 assert_eq!(target_type, "uuid");
1448
1449 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1450 panic!("Expected FunctionCall, got {cast_expr:?}");
1451 };
1452 assert_eq!(name, "current_setting");
1453 assert_eq!(args.len(), 1);
1454 }
1455
1456 #[test]
1457 fn test_parse_policy_with_check() {
1458 let input = r#"
1459 table orders (
1460 id uuid primary_key
1461 )
1462
1463 policy orders_write on orders
1464 for insert
1465 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1466 "#;
1467
1468 let schema = Schema::parse(input).expect("parse failed");
1469 let policy = &schema.policies[0];
1470 assert_eq!(policy.target, PolicyTarget::Insert);
1471 assert!(policy.with_check.is_some());
1472 assert!(policy.using.is_none());
1473 }
1474
1475 #[test]
1476 fn test_parse_policy_restrictive_with_role() {
1477 let input = r#"
1478 table secrets (
1479 id uuid primary_key
1480 )
1481
1482 policy admin_only on secrets
1483 for select
1484 restrictive
1485 to app_user
1486 using (current_setting('app.is_super_admin')::boolean = true)
1487 "#;
1488
1489 let schema = Schema::parse(input).expect("parse failed");
1490 let policy = &schema.policies[0];
1491 assert_eq!(policy.target, PolicyTarget::Select);
1492 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1493 assert_eq!(policy.role.as_deref(), Some("app_user"));
1494 assert!(policy.using.is_some());
1495 }
1496
1497 #[test]
1498 fn test_parse_policy_or_expr() {
1499 let input = r#"
1500 table orders (
1501 id uuid primary_key
1502 )
1503
1504 policy tenant_or_admin on orders
1505 for all
1506 using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1507 "#;
1508
1509 let schema = Schema::parse(input).expect("parse failed");
1510 let policy = &schema.policies[0];
1511
1512 assert!(
1513 matches!(
1514 policy.using.as_ref().unwrap(),
1515 Expr::Binary {
1516 op: BinaryOp::Or,
1517 ..
1518 }
1519 ),
1520 "Expected Binary OR, got {:?}",
1521 policy.using
1522 );
1523 }
1524
1525 #[test]
1526 fn test_schema_to_sql() {
1527 let input = r#"
1528 table orders (
1529 id uuid primary_key,
1530 operator_id uuid not null
1531 ) enable_rls
1532
1533 policy orders_isolation on orders
1534 for all
1535 using (operator_id = current_setting('app.current_operator_id')::uuid)
1536 "#;
1537
1538 let schema = Schema::parse(input).expect("parse failed");
1539 let sql = schema.to_sql();
1540 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1541 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1542 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1543 assert!(sql.contains("CREATE POLICY"));
1544 assert!(sql.contains("orders_isolation"));
1545 assert!(sql.contains("FOR ALL"));
1546 }
1547
1548 #[test]
1549 fn test_multiple_policies() {
1550 let input = r#"
1551 table orders (
1552 id uuid primary_key,
1553 operator_id uuid not null
1554 ) enable_rls
1555
1556 policy orders_read on orders
1557 for select
1558 using (operator_id = current_setting('app.current_operator_id')::uuid)
1559
1560 policy orders_write on orders
1561 for insert
1562 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1563 "#;
1564
1565 let schema = Schema::parse(input).expect("parse failed");
1566 assert_eq!(schema.policies.len(), 2);
1567 assert_eq!(schema.policies[0].name, "orders_read");
1568 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1569 assert_eq!(schema.policies[1].name, "orders_write");
1570 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1571 }
1572}