1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26
27use crate::ast::{BinaryOp, Expr, Value as AstValue};
28use crate::migrate::alter::AlterTable;
29use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
30use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
31
32#[derive(Debug, Clone, Default)]
34pub struct Schema {
35 pub version: Option<u32>,
37 pub tables: Vec<TableDef>,
39 pub policies: Vec<RlsPolicy>,
41 pub indexes: Vec<IndexDef>,
43}
44
45#[derive(Debug, Clone)]
47pub struct IndexDef {
48 pub name: String,
50 pub table: String,
52 pub columns: Vec<String>,
54 pub unique: bool,
56}
57
58impl IndexDef {
59 pub fn to_sql(&self) -> String {
61 let unique = if self.unique { " UNIQUE" } else { "" };
62 format!(
63 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64 unique,
65 self.name,
66 self.table,
67 self.columns.join(", ")
68 )
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct TableDef {
75 pub name: String,
77 pub columns: Vec<ColumnDef>,
79 pub enable_rls: bool,
81}
82
83#[derive(Debug, Clone)]
85pub struct ColumnDef {
86 pub name: String,
88 pub typ: String,
90 pub is_array: bool,
92 pub type_params: Option<Vec<String>>,
94 pub nullable: bool,
96 pub primary_key: bool,
98 pub unique: bool,
100 pub references: Option<String>,
102 pub default_value: Option<String>,
104 pub check: Option<String>,
106 pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111 fn default() -> Self {
112 Self {
113 name: String::new(),
114 typ: String::new(),
115 is_array: false,
116 type_params: None,
117 nullable: true,
118 primary_key: false,
119 unique: false,
120 references: None,
121 default_value: None,
122 check: None,
123 is_serial: false,
124 }
125 }
126}
127
128impl Schema {
129 pub fn parse(input: &str) -> Result<Self, String> {
131 match parse_schema(input) {
132 Ok(("", schema)) => Ok(schema),
133 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134 Err(e) => Err(format!("Parse error: {:?}", e)),
135 }
136 }
137
138 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140 self.tables
141 .iter()
142 .find(|t| t.name.eq_ignore_ascii_case(name))
143 }
144
145 pub fn to_sql(&self) -> String {
147 let mut parts = Vec::new();
148
149 for table in &self.tables {
150 parts.push(table.to_ddl());
151
152 if table.enable_rls {
153 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154 for stmt in alter_table_sql(&alter) {
155 parts.push(stmt);
156 }
157 }
158 }
159
160 for idx in &self.indexes {
161 parts.push(idx.to_sql());
162 }
163
164 for policy in &self.policies {
165 parts.push(create_policy_sql(policy));
166 }
167
168 parts.join(";\n\n") + ";"
169 }
170
171 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
173 let content =
174 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
175 Self::parse(&content)
176 }
177}
178
179impl TableDef {
180 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
182 self.columns
183 .iter()
184 .find(|c| c.name.eq_ignore_ascii_case(name))
185 }
186
187 pub fn to_ddl(&self) -> String {
189 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
190
191 let mut col_defs = Vec::new();
192 for col in &self.columns {
193 let mut line = format!(" {}", col.name);
194
195 let mut typ = col.typ.to_uppercase();
197 if let Some(params) = &col.type_params {
198 typ = format!("{}({})", typ, params.join(", "));
199 }
200 if col.is_array {
201 typ.push_str("[]");
202 }
203 line.push_str(&format!(" {}", typ));
204
205 if col.primary_key {
207 line.push_str(" PRIMARY KEY");
208 }
209 if !col.nullable && !col.primary_key && !col.is_serial {
210 line.push_str(" NOT NULL");
211 }
212 if col.unique && !col.primary_key {
213 line.push_str(" UNIQUE");
214 }
215 if let Some(ref default) = col.default_value {
216 line.push_str(&format!(" DEFAULT {}", default));
217 }
218 if let Some(ref refs) = col.references {
219 line.push_str(&format!(" REFERENCES {}", refs));
220 }
221 if let Some(ref check) = col.check {
222 line.push_str(&format!(" CHECK({})", check));
223 }
224
225 col_defs.push(line);
226 }
227
228 sql.push_str(&col_defs.join(",\n"));
229 sql.push_str("\n)");
230 sql
231 }
232}
233
234fn identifier(input: &str) -> IResult<&str, &str> {
240 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
241}
242
243fn ws_and_comments(input: &str) -> IResult<&str, ()> {
245 let (input, _) = many0(alt((
246 map(multispace1, |_| ()),
247 map((tag("--"), not_line_ending), |_| ()),
248 map((tag("#"), not_line_ending), |_| ()),
249 )))
250 .parse(input)?;
251 Ok((input, ()))
252}
253
254struct TypeInfo {
255 name: String,
256 params: Option<Vec<String>>,
257 is_array: bool,
258 is_serial: bool,
259}
260
261fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
264 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
265
266 let (input, params) = if input.starts_with('(') {
267 let paren_start = 1;
268 let mut paren_end = paren_start;
269 for (i, c) in input[paren_start..].char_indices() {
270 if c == ')' {
271 paren_end = paren_start + i;
272 break;
273 }
274 }
275 let param_str = &input[paren_start..paren_end];
276 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
277 (&input[paren_end + 1..], Some(params))
278 } else {
279 (input, None)
280 };
281
282 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
283 (stripped, true)
284 } else {
285 (input, false)
286 };
287
288 let lower = type_name.to_lowercase();
289 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
290
291 Ok((
292 input,
293 TypeInfo {
294 name: lower,
295 params,
296 is_array,
297 is_serial,
298 },
299 ))
300}
301
302fn constraint_text(input: &str) -> IResult<&str, &str> {
304 let mut paren_depth = 0;
305 let mut end = 0;
306
307 for (i, c) in input.char_indices() {
308 match c {
309 '(' => paren_depth += 1,
310 ')' => {
311 if paren_depth == 0 {
312 break; }
314 paren_depth -= 1;
315 }
316 ',' if paren_depth == 0 => break,
317 '\n' | '\r' if paren_depth == 0 => break,
318 _ => {}
319 }
320 end = i + c.len_utf8();
321 }
322
323 if end == 0 {
324 Err(nom::Err::Error(nom::error::Error::new(
325 input,
326 nom::error::ErrorKind::TakeWhile1,
327 )))
328 } else {
329 Ok((&input[end..], &input[..end]))
330 }
331}
332
333fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
335 let (input, _) = ws_and_comments(input)?;
336 let (input, name) = identifier(input)?;
337 let (input, _) = multispace1(input)?;
338 let (input, type_info) = parse_type_info(input)?;
339
340 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
341
342 let mut col = ColumnDef {
343 name: name.to_string(),
344 typ: type_info.name,
345 is_array: type_info.is_array,
346 type_params: type_info.params,
347 is_serial: type_info.is_serial,
348 nullable: !type_info.is_serial, ..Default::default()
350 };
351
352 if let Some(constraints) = constraint_str {
353 let lower = constraints.to_lowercase();
354
355 if lower.contains("primary_key") || lower.contains("primary key") {
356 col.primary_key = true;
357 col.nullable = false;
358 }
359 if lower.contains("not_null") || lower.contains("not null") {
360 col.nullable = false;
361 }
362 if lower.contains("unique") {
363 col.unique = true;
364 }
365
366 if let Some(idx) = lower.find("references ") {
367 let rest = &constraints[idx + 11..];
368 let mut paren_depth = 0;
370 let mut end = rest.len();
371 for (i, c) in rest.char_indices() {
372 match c {
373 '(' => paren_depth += 1,
374 ')' => {
375 if paren_depth == 0 {
376 end = i;
377 break;
378 }
379 paren_depth -= 1;
380 }
381 c if c.is_whitespace() && paren_depth == 0 => {
382 end = i;
383 break;
384 }
385 _ => {}
386 }
387 }
388 col.references = Some(rest[..end].to_string());
389 }
390
391 if let Some(idx) = lower.find("default ") {
392 let rest = &constraints[idx + 8..];
393 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
394 col.default_value = Some(rest[..end].to_string());
395 }
396
397 if let Some(idx) = lower.find("check(") {
398 let rest = &constraints[idx + 6..];
399 let mut depth = 1;
401 let mut end = rest.len();
402 for (i, c) in rest.char_indices() {
403 match c {
404 '(' => depth += 1,
405 ')' => {
406 depth -= 1;
407 if depth == 0 {
408 end = i;
409 break;
410 }
411 }
412 _ => {}
413 }
414 }
415 col.check = Some(rest[..end].to_string());
416 }
417 }
418
419 Ok((input, col))
420}
421
422fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
424 let (input, _) = ws_and_comments(input)?;
425 let (input, _) = char('(').parse(input)?;
426 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
427 let (input, _) = ws_and_comments(input)?;
428 let (input, _) = char(')').parse(input)?;
429
430 Ok((input, columns))
431}
432
433fn parse_table(input: &str) -> IResult<&str, TableDef> {
435 let (input, _) = ws_and_comments(input)?;
436 let (input, _) = tag_no_case("table").parse(input)?;
437 let (input, _) = multispace1(input)?;
438 let (input, name) = identifier(input)?;
439 let (input, columns) = parse_column_list(input)?;
440
441 let (input, _) = ws_and_comments(input)?;
443 let enable_rls = if let Ok((rest, _)) =
444 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
445 {
446 return Ok((
447 rest,
448 TableDef {
449 name: name.to_string(),
450 columns,
451 enable_rls: true,
452 },
453 ));
454 } else {
455 false
456 };
457
458 Ok((
459 input,
460 TableDef {
461 name: name.to_string(),
462 columns,
463 enable_rls,
464 },
465 ))
466}
467
468enum SchemaItem {
474 Table(TableDef),
475 Policy(Box<RlsPolicy>),
476 Index(IndexDef),
477}
478
479fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
491 let (input, _) = ws_and_comments(input)?;
492 let (input, _) = tag_no_case("policy").parse(input)?;
493 let (input, _) = multispace1(input)?;
494 let (input, name) = identifier(input)?;
495 let (input, _) = multispace1(input)?;
496 let (input, _) = tag_no_case("on").parse(input)?;
497 let (input, _) = multispace1(input)?;
498 let (input, table) = identifier(input)?;
499
500 let mut policy = RlsPolicy::create(name, table);
501
502 let mut remaining = input;
504 loop {
505 let (input, _) = ws_and_comments(remaining)?;
506
507 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
509 let (rest, _) = multispace1(rest)?;
510 let (rest, target) = alt((
511 map(tag_no_case("all"), |_| PolicyTarget::All),
512 map(tag_no_case("select"), |_| PolicyTarget::Select),
513 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
514 map(tag_no_case("update"), |_| PolicyTarget::Update),
515 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
516 ))
517 .parse(rest)?;
518 policy.target = target;
519 remaining = rest;
520 continue;
521 }
522
523 if let Ok((rest, _)) =
525 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
526 {
527 policy.permissiveness = PolicyPermissiveness::Restrictive;
528 remaining = rest;
529 continue;
530 }
531
532 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
534 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
536 let (rest, role) = identifier(rest)?;
537 policy.role = Some(role.to_string());
538 remaining = rest;
539 continue;
540 }
541 }
542
543 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
545 let (rest, _) = multispace1(rest)?;
546 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
547 {
548 let (rest, _) = nom_ws0(rest)?;
549 let (rest, _) = char('(').parse(rest)?;
550 let (rest, _) = nom_ws0(rest)?;
551 let (rest, expr) = parse_policy_expr(rest)?;
552 let (rest, _) = nom_ws0(rest)?;
553 let (rest, _) = char(')').parse(rest)?;
554 policy.with_check = Some(expr);
555 remaining = rest;
556 continue;
557 }
558 }
559
560 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
562 let (rest, _) = nom_ws0(rest)?;
563 let (rest, _) = char('(').parse(rest)?;
564 let (rest, _) = nom_ws0(rest)?;
565 let (rest, expr) = parse_policy_expr(rest)?;
566 let (rest, _) = nom_ws0(rest)?;
567 let (rest, _) = char(')').parse(rest)?;
568 policy.using = Some(expr);
569 remaining = rest;
570 continue;
571 }
572
573 remaining = input;
575 break;
576 }
577
578 Ok((remaining, policy))
579}
580
581fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
590 let (input, first) = parse_policy_comparison(input)?;
591
592 let mut result = first;
594 let mut remaining = input;
595 loop {
596 let (input, _) = nom_ws0(remaining)?;
597
598 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
599 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
600 {
601 let (rest, right) = parse_policy_comparison(rest)?;
602 result = Expr::Binary {
603 left: Box::new(result),
604 op: BinaryOp::Or,
605 right: Box::new(right),
606 alias: None,
607 };
608 remaining = rest;
609 continue;
610 }
611
612 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
613 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
614 {
615 let (rest, right) = parse_policy_comparison(rest)?;
616 result = Expr::Binary {
617 left: Box::new(result),
618 op: BinaryOp::And,
619 right: Box::new(right),
620 alias: None,
621 };
622 remaining = rest;
623 continue;
624 }
625
626 remaining = input;
627 break;
628 }
629
630 Ok((remaining, result))
631}
632
633fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
635 let (input, left) = parse_policy_atom(input)?;
636 let (input, _) = nom_ws0(input)?;
637
638 if let Ok((rest, op)) = parse_cmp_op(input) {
640 let (rest, _) = nom_ws0(rest)?;
641 let (rest, right) = parse_policy_atom(rest)?;
642 return Ok((
643 rest,
644 Expr::Binary {
645 left: Box::new(left),
646 op,
647 right: Box::new(right),
648 alias: None,
649 },
650 ));
651 }
652
653 Ok((input, left))
655}
656
657fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
659 alt((
660 map(tag(">="), |_| BinaryOp::Gte),
661 map(tag("<="), |_| BinaryOp::Lte),
662 map(tag("<>"), |_| BinaryOp::Ne),
663 map(tag("!="), |_| BinaryOp::Ne),
664 map(tag("="), |_| BinaryOp::Eq),
665 map(tag(">"), |_| BinaryOp::Gt),
666 map(tag("<"), |_| BinaryOp::Lt),
667 ))
668 .parse(input)
669}
670
671fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
679 alt((
680 parse_policy_grouped,
681 parse_policy_bool,
682 parse_policy_string,
683 parse_policy_number,
684 parse_policy_func_or_ident, ))
686 .parse(input)
687}
688
689fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
691 let (input, _) = char('(').parse(input)?;
692 let (input, _) = nom_ws0(input)?;
693 let (input, expr) = parse_policy_expr(input)?;
694 let (input, _) = nom_ws0(input)?;
695 let (input, _) = char(')').parse(input)?;
696 Ok((input, expr))
697}
698
699fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
701 alt((
702 map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
703 map(tag_no_case("false"), |_| {
704 Expr::Literal(AstValue::Bool(false))
705 }),
706 ))
707 .parse(input)
708}
709
710fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
712 let (input, _) = char('\'').parse(input)?;
713 let mut end = 0;
714 for (i, c) in input.char_indices() {
715 if c == '\'' {
716 end = i;
717 break;
718 }
719 }
720 let content = &input[..end];
721 let rest = &input[end + 1..];
722 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
723}
724
725fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
727 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
728 if digits.starts_with('.') || digits.is_empty() {
730 return Err(nom::Err::Error(nom::error::Error::new(
731 input,
732 nom::error::ErrorKind::Digit,
733 )));
734 }
735 if let Ok(n) = digits.parse::<i64>() {
736 Ok((input, Expr::Literal(AstValue::Int(n))))
737 } else if let Ok(f) = digits.parse::<f64>() {
738 Ok((input, Expr::Literal(AstValue::Float(f))))
739 } else {
740 Ok((input, Expr::Named(digits.to_string())))
741 }
742}
743
744fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
746 let (input, name) = identifier(input)?;
747
748 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
750 let (rest, _) = nom_ws0(rest)?;
752 let (rest, args) =
753 separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
754 let (rest, _) = nom_ws0(rest)?;
755 let (rest, _) = char(')').parse(rest)?;
756 let input = rest;
757 (
758 input,
759 Expr::FunctionCall {
760 name: name.to_string(),
761 args,
762 alias: None,
763 },
764 )
765 } else {
766 (input, Expr::Named(name.to_string()))
767 };
768
769 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
771 let (rest, cast_type) = identifier(rest)?;
772 expr = (
773 rest,
774 Expr::Cast {
775 expr: Box::new(expr.1),
776 target_type: cast_type.to_string(),
777 alias: None,
778 },
779 );
780 }
781
782 Ok(expr)
783}
784
785fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
787 let (input, _) = ws_and_comments(input)?;
788
789 if let Ok((rest, policy)) = parse_policy(input) {
791 return Ok((rest, SchemaItem::Policy(Box::new(policy))));
792 }
793
794 if let Ok((rest, idx)) = parse_index(input) {
796 return Ok((rest, SchemaItem::Index(idx)));
797 }
798
799 let (rest, table) = parse_table(input)?;
801 Ok((rest, SchemaItem::Table(table)))
802}
803
804fn parse_index(input: &str) -> IResult<&str, IndexDef> {
806 let (input, _) = tag_no_case("index")(input)?;
807 let (input, _) = multispace1(input)?;
808 let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
809 let (input, _) = multispace1(input)?;
810 let (input, _) = tag_no_case("on")(input)?;
811 let (input, _) = multispace1(input)?;
812 let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
813 let (input, _) = nom_ws0(input)?;
814 let (input, _) = char('(')(input)?;
815 let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
816 let (input, _) = char(')')(input)?;
817 let (input, _) = nom_ws0(input)?;
818 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
819
820 let columns: Vec<String> = cols_str
821 .split(',')
822 .map(|s| s.trim().to_string())
823 .filter(|s| !s.is_empty())
824 .collect();
825
826 let is_unique = unique_tag.is_some();
827
828 Ok((
829 input,
830 IndexDef {
831 name: name.to_string(),
832 table: table.to_string(),
833 columns,
834 unique: is_unique,
835 },
836 ))
837}
838
839fn parse_schema(input: &str) -> IResult<&str, Schema> {
841 let version = extract_version_directive(input);
843
844 let (input, items) = many0(parse_schema_item).parse(input)?;
845 let (input, _) = ws_and_comments(input)?;
846
847 let mut tables = Vec::new();
848 let mut policies = Vec::new();
849 let mut indexes = Vec::new();
850 for item in items {
851 match item {
852 SchemaItem::Table(t) => tables.push(t),
853 SchemaItem::Policy(p) => policies.push(*p),
854 SchemaItem::Index(i) => indexes.push(i),
855 }
856 }
857
858 Ok((
859 input,
860 Schema {
861 version,
862 tables,
863 policies,
864 indexes,
865 },
866 ))
867}
868
869fn extract_version_directive(input: &str) -> Option<u32> {
871 for line in input.lines() {
872 let line = line.trim();
873 if let Some(rest) = line.strip_prefix("-- qail:") {
874 let rest = rest.trim();
875 if let Some(version_str) = rest.strip_prefix("version=") {
876 return version_str.trim().parse().ok();
877 }
878 }
879 }
880 None
881}
882
883#[cfg(test)]
884mod tests {
885 use super::*;
886
887 #[test]
888 fn test_parse_simple_table() {
889 let input = r#"
890 table users (
891 id uuid primary_key,
892 email text not null,
893 name text
894 )
895 "#;
896
897 let schema = Schema::parse(input).expect("parse failed");
898 assert_eq!(schema.tables.len(), 1);
899
900 let users = &schema.tables[0];
901 assert_eq!(users.name, "users");
902 assert_eq!(users.columns.len(), 3);
903
904 let id = &users.columns[0];
905 assert_eq!(id.name, "id");
906 assert_eq!(id.typ, "uuid");
907 assert!(id.primary_key);
908 assert!(!id.nullable);
909
910 let email = &users.columns[1];
911 assert_eq!(email.name, "email");
912 assert!(!email.nullable);
913
914 let name = &users.columns[2];
915 assert!(name.nullable);
916 }
917
918 #[test]
919 fn test_parse_multiple_tables() {
920 let input = r#"
921 -- Users table
922 table users (
923 id uuid primary_key,
924 email text not null unique
925 )
926
927 -- Orders table
928 table orders (
929 id uuid primary_key,
930 user_id uuid references users(id),
931 total i64 not null default 0
932 )
933 "#;
934
935 let schema = Schema::parse(input).expect("parse failed");
936 assert_eq!(schema.tables.len(), 2);
937
938 let orders = schema.find_table("orders").expect("orders not found");
939 let user_id = orders.find_column("user_id").expect("user_id not found");
940 assert_eq!(user_id.references, Some("users(id)".to_string()));
941
942 let total = orders.find_column("total").expect("total not found");
943 assert_eq!(total.default_value, Some("0".to_string()));
944 }
945
946 #[test]
947 fn test_parse_comments() {
948 let input = r#"
949 -- This is a comment
950 table foo (
951 bar text
952 )
953 "#;
954
955 let schema = Schema::parse(input).expect("parse failed");
956 assert_eq!(schema.tables.len(), 1);
957 }
958
959 #[test]
960 fn test_array_types() {
961 let input = r#"
962 table products (
963 id uuid primary_key,
964 tags text[],
965 prices decimal[]
966 )
967 "#;
968
969 let schema = Schema::parse(input).expect("parse failed");
970 let products = &schema.tables[0];
971
972 let tags = products.find_column("tags").expect("tags not found");
973 assert_eq!(tags.typ, "text");
974 assert!(tags.is_array);
975
976 let prices = products.find_column("prices").expect("prices not found");
977 assert!(prices.is_array);
978 }
979
980 #[test]
981 fn test_type_params() {
982 let input = r#"
983 table items (
984 id serial primary_key,
985 name varchar(255) not null,
986 price decimal(10,2),
987 code varchar(50) unique
988 )
989 "#;
990
991 let schema = Schema::parse(input).expect("parse failed");
992 let items = &schema.tables[0];
993
994 let id = items.find_column("id").expect("id not found");
995 assert!(id.is_serial);
996 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
999 assert_eq!(name.typ, "varchar");
1000 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1001
1002 let price = items.find_column("price").expect("price not found");
1003 assert_eq!(
1004 price.type_params,
1005 Some(vec!["10".to_string(), "2".to_string()])
1006 );
1007
1008 let code = items.find_column("code").expect("code not found");
1009 assert!(code.unique);
1010 }
1011
1012 #[test]
1013 fn test_check_constraint() {
1014 let input = r#"
1015 table employees (
1016 id uuid primary_key,
1017 age i32 check(age >= 18),
1018 salary decimal check(salary > 0)
1019 )
1020 "#;
1021
1022 let schema = Schema::parse(input).expect("parse failed");
1023 let employees = &schema.tables[0];
1024
1025 let age = employees.find_column("age").expect("age not found");
1026 assert_eq!(age.check, Some("age >= 18".to_string()));
1027
1028 let salary = employees.find_column("salary").expect("salary not found");
1029 assert_eq!(salary.check, Some("salary > 0".to_string()));
1030 }
1031
1032 #[test]
1033 fn test_version_directive() {
1034 let input = r#"
1035 -- qail: version=1
1036 table users (
1037 id uuid primary_key
1038 )
1039 "#;
1040
1041 let schema = Schema::parse(input).expect("parse failed");
1042 assert_eq!(schema.version, Some(1));
1043 assert_eq!(schema.tables.len(), 1);
1044
1045 let input_no_version = r#"
1047 table items (
1048 id uuid primary_key
1049 )
1050 "#;
1051 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1052 assert_eq!(schema2.version, None);
1053 }
1054
1055 #[test]
1060 fn test_enable_rls_table() {
1061 let input = r#"
1062 table orders (
1063 id uuid primary_key,
1064 operator_id uuid not null
1065 ) enable_rls
1066 "#;
1067
1068 let schema = Schema::parse(input).expect("parse failed");
1069 assert_eq!(schema.tables.len(), 1);
1070 assert!(schema.tables[0].enable_rls);
1071 }
1072
1073 #[test]
1074 fn test_parse_policy_basic() {
1075 let input = r#"
1076 table orders (
1077 id uuid primary_key,
1078 operator_id uuid not null
1079 ) enable_rls
1080
1081 policy orders_isolation on orders
1082 for all
1083 using (operator_id = current_setting('app.current_operator_id')::uuid)
1084 "#;
1085
1086 let schema = Schema::parse(input).expect("parse failed");
1087 assert_eq!(schema.tables.len(), 1);
1088 assert_eq!(schema.policies.len(), 1);
1089
1090 let policy = &schema.policies[0];
1091 assert_eq!(policy.name, "orders_isolation");
1092 assert_eq!(policy.table, "orders");
1093 assert_eq!(policy.target, PolicyTarget::All);
1094 assert!(policy.using.is_some());
1095
1096 let using = policy.using.as_ref().unwrap();
1098 let Expr::Binary {
1099 left, op, right, ..
1100 } = using
1101 else {
1102 panic!("Expected Binary, got {using:?}");
1103 };
1104 assert_eq!(*op, BinaryOp::Eq);
1105
1106 let Expr::Named(n) = left.as_ref() else {
1107 panic!("Expected Named, got {left:?}");
1108 };
1109 assert_eq!(n, "operator_id");
1110
1111 let Expr::Cast {
1112 target_type,
1113 expr: cast_expr,
1114 ..
1115 } = right.as_ref()
1116 else {
1117 panic!("Expected Cast, got {right:?}");
1118 };
1119 assert_eq!(target_type, "uuid");
1120
1121 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1122 panic!("Expected FunctionCall, got {cast_expr:?}");
1123 };
1124 assert_eq!(name, "current_setting");
1125 assert_eq!(args.len(), 1);
1126 }
1127
1128 #[test]
1129 fn test_parse_policy_with_check() {
1130 let input = r#"
1131 table orders (
1132 id uuid primary_key
1133 )
1134
1135 policy orders_write on orders
1136 for insert
1137 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1138 "#;
1139
1140 let schema = Schema::parse(input).expect("parse failed");
1141 let policy = &schema.policies[0];
1142 assert_eq!(policy.target, PolicyTarget::Insert);
1143 assert!(policy.with_check.is_some());
1144 assert!(policy.using.is_none());
1145 }
1146
1147 #[test]
1148 fn test_parse_policy_restrictive_with_role() {
1149 let input = r#"
1150 table secrets (
1151 id uuid primary_key
1152 )
1153
1154 policy admin_only on secrets
1155 for select
1156 restrictive
1157 to app_user
1158 using (current_setting('app.is_super_admin')::boolean = true)
1159 "#;
1160
1161 let schema = Schema::parse(input).expect("parse failed");
1162 let policy = &schema.policies[0];
1163 assert_eq!(policy.target, PolicyTarget::Select);
1164 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1165 assert_eq!(policy.role.as_deref(), Some("app_user"));
1166 assert!(policy.using.is_some());
1167 }
1168
1169 #[test]
1170 fn test_parse_policy_or_expr() {
1171 let input = r#"
1172 table orders (
1173 id uuid primary_key
1174 )
1175
1176 policy tenant_or_admin on orders
1177 for all
1178 using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1179 "#;
1180
1181 let schema = Schema::parse(input).expect("parse failed");
1182 let policy = &schema.policies[0];
1183
1184 assert!(
1185 matches!(
1186 policy.using.as_ref().unwrap(),
1187 Expr::Binary {
1188 op: BinaryOp::Or,
1189 ..
1190 }
1191 ),
1192 "Expected Binary OR, got {:?}",
1193 policy.using
1194 );
1195 }
1196
1197 #[test]
1198 fn test_schema_to_sql() {
1199 let input = r#"
1200 table orders (
1201 id uuid primary_key,
1202 operator_id uuid not null
1203 ) enable_rls
1204
1205 policy orders_isolation on orders
1206 for all
1207 using (operator_id = current_setting('app.current_operator_id')::uuid)
1208 "#;
1209
1210 let schema = Schema::parse(input).expect("parse failed");
1211 let sql = schema.to_sql();
1212 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1213 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1214 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1215 assert!(sql.contains("CREATE POLICY"));
1216 assert!(sql.contains("orders_isolation"));
1217 assert!(sql.contains("FOR ALL"));
1218 }
1219
1220 #[test]
1221 fn test_multiple_policies() {
1222 let input = r#"
1223 table orders (
1224 id uuid primary_key,
1225 operator_id uuid not null
1226 ) enable_rls
1227
1228 policy orders_read on orders
1229 for select
1230 using (operator_id = current_setting('app.current_operator_id')::uuid)
1231
1232 policy orders_write on orders
1233 for insert
1234 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1235 "#;
1236
1237 let schema = Schema::parse(input).expect("parse failed");
1238 assert_eq!(schema.policies.len(), 2);
1239 assert_eq!(schema.policies[0].name, "orders_read");
1240 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1241 assert_eq!(schema.policies[1].name, "orders_write");
1242 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1243 }
1244}