1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{BinaryOp, Expr, Value as AstValue};
29use crate::migrate::alter::AlterTable;
30use crate::migrate::policy::{PolicyPermissiveness, PolicyTarget, RlsPolicy};
31use crate::transpiler::policy::{alter_table_sql, create_policy_sql};
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36 #[serde(default)]
38 pub version: Option<u32>,
39 pub tables: Vec<TableDef>,
41 #[serde(default)]
43 pub policies: Vec<RlsPolicy>,
44 #[serde(default)]
46 pub indexes: Vec<IndexDef>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct IndexDef {
52 pub name: String,
54 pub table: String,
56 pub columns: Vec<String>,
58 #[serde(default)]
60 pub unique: bool,
61}
62
63impl IndexDef {
64 pub fn to_sql(&self) -> String {
66 let unique = if self.unique { " UNIQUE" } else { "" };
67 format!(
68 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
69 unique,
70 self.name,
71 self.table,
72 self.columns.join(", ")
73 )
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TableDef {
80 pub name: String,
82 pub columns: Vec<ColumnDef>,
84 #[serde(default)]
86 pub enable_rls: bool,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ColumnDef {
92 pub name: String,
94 #[serde(rename = "type", alias = "typ")]
96 pub typ: String,
97 #[serde(default)]
99 pub is_array: bool,
100 #[serde(default)]
102 pub type_params: Option<Vec<String>>,
103 #[serde(default)]
105 pub nullable: bool,
106 #[serde(default)]
108 pub primary_key: bool,
109 #[serde(default)]
111 pub unique: bool,
112 #[serde(default)]
113 pub references: Option<String>,
115 #[serde(default)]
117 pub default_value: Option<String>,
118 #[serde(default)]
120 pub check: Option<String>,
121 #[serde(default)]
123 pub is_serial: bool,
124}
125
126impl Default for ColumnDef {
127 fn default() -> Self {
128 Self {
129 name: String::new(),
130 typ: String::new(),
131 is_array: false,
132 type_params: None,
133 nullable: true,
134 primary_key: false,
135 unique: false,
136 references: None,
137 default_value: None,
138 check: None,
139 is_serial: false,
140 }
141 }
142}
143
144impl Schema {
145 pub fn parse(input: &str) -> Result<Self, String> {
147 match parse_schema(input) {
148 Ok(("", schema)) => Ok(schema),
149 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
150 Err(e) => Err(format!("Parse error: {:?}", e)),
151 }
152 }
153
154 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
156 self.tables
157 .iter()
158 .find(|t| t.name.eq_ignore_ascii_case(name))
159 }
160
161 pub fn to_sql(&self) -> String {
163 let mut parts = Vec::new();
164
165 for table in &self.tables {
166 parts.push(table.to_ddl());
167
168 if table.enable_rls {
169 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
170 for stmt in alter_table_sql(&alter) {
171 parts.push(stmt);
172 }
173 }
174 }
175
176 for idx in &self.indexes {
177 parts.push(idx.to_sql());
178 }
179
180 for policy in &self.policies {
181 parts.push(create_policy_sql(policy));
182 }
183
184 parts.join(";\n\n") + ";"
185 }
186
187 pub fn to_json(&self) -> Result<String, String> {
189 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
190 }
191
192 pub fn from_json(json: &str) -> Result<Self, String> {
194 serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
195 }
196
197 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
199 let content =
200 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
201
202 if content.trim().starts_with('{') {
203 Self::from_json(&content)
204 } else {
205 Self::parse(&content)
206 }
207 }
208}
209
210impl TableDef {
211 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
213 self.columns
214 .iter()
215 .find(|c| c.name.eq_ignore_ascii_case(name))
216 }
217
218 pub fn to_ddl(&self) -> String {
220 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
221
222 let mut col_defs = Vec::new();
223 for col in &self.columns {
224 let mut line = format!(" {}", col.name);
225
226 let mut typ = col.typ.to_uppercase();
228 if let Some(params) = &col.type_params {
229 typ = format!("{}({})", typ, params.join(", "));
230 }
231 if col.is_array {
232 typ.push_str("[]");
233 }
234 line.push_str(&format!(" {}", typ));
235
236 if col.primary_key {
238 line.push_str(" PRIMARY KEY");
239 }
240 if !col.nullable && !col.primary_key && !col.is_serial {
241 line.push_str(" NOT NULL");
242 }
243 if col.unique && !col.primary_key {
244 line.push_str(" UNIQUE");
245 }
246 if let Some(ref default) = col.default_value {
247 line.push_str(&format!(" DEFAULT {}", default));
248 }
249 if let Some(ref refs) = col.references {
250 line.push_str(&format!(" REFERENCES {}", refs));
251 }
252 if let Some(ref check) = col.check {
253 line.push_str(&format!(" CHECK({})", check));
254 }
255
256 col_defs.push(line);
257 }
258
259 sql.push_str(&col_defs.join(",\n"));
260 sql.push_str("\n)");
261 sql
262 }
263}
264
265fn identifier(input: &str) -> IResult<&str, &str> {
271 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
272}
273
274fn ws_and_comments(input: &str) -> IResult<&str, ()> {
276 let (input, _) = many0(alt((
277 map(multispace1, |_| ()),
278 map((tag("--"), not_line_ending), |_| ()),
279 map((tag("#"), not_line_ending), |_| ()),
280 )))
281 .parse(input)?;
282 Ok((input, ()))
283}
284
285struct TypeInfo {
286 name: String,
287 params: Option<Vec<String>>,
288 is_array: bool,
289 is_serial: bool,
290}
291
292fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
295 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
296
297 let (input, params) = if input.starts_with('(') {
298 let paren_start = 1;
299 let mut paren_end = paren_start;
300 for (i, c) in input[paren_start..].char_indices() {
301 if c == ')' {
302 paren_end = paren_start + i;
303 break;
304 }
305 }
306 let param_str = &input[paren_start..paren_end];
307 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
308 (&input[paren_end + 1..], Some(params))
309 } else {
310 (input, None)
311 };
312
313 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
314 (stripped, true)
315 } else {
316 (input, false)
317 };
318
319 let lower = type_name.to_lowercase();
320 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
321
322 Ok((
323 input,
324 TypeInfo {
325 name: lower,
326 params,
327 is_array,
328 is_serial,
329 },
330 ))
331}
332
333fn constraint_text(input: &str) -> IResult<&str, &str> {
335 let mut paren_depth = 0;
336 let mut end = 0;
337
338 for (i, c) in input.char_indices() {
339 match c {
340 '(' => paren_depth += 1,
341 ')' => {
342 if paren_depth == 0 {
343 break; }
345 paren_depth -= 1;
346 }
347 ',' if paren_depth == 0 => break,
348 '\n' | '\r' if paren_depth == 0 => break,
349 _ => {}
350 }
351 end = i + c.len_utf8();
352 }
353
354 if end == 0 {
355 Err(nom::Err::Error(nom::error::Error::new(
356 input,
357 nom::error::ErrorKind::TakeWhile1,
358 )))
359 } else {
360 Ok((&input[end..], &input[..end]))
361 }
362}
363
364fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
366 let (input, _) = ws_and_comments(input)?;
367 let (input, name) = identifier(input)?;
368 let (input, _) = multispace1(input)?;
369 let (input, type_info) = parse_type_info(input)?;
370
371 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
372
373 let mut col = ColumnDef {
374 name: name.to_string(),
375 typ: type_info.name,
376 is_array: type_info.is_array,
377 type_params: type_info.params,
378 is_serial: type_info.is_serial,
379 nullable: !type_info.is_serial, ..Default::default()
381 };
382
383 if let Some(constraints) = constraint_str {
384 let lower = constraints.to_lowercase();
385
386 if lower.contains("primary_key") || lower.contains("primary key") {
387 col.primary_key = true;
388 col.nullable = false;
389 }
390 if lower.contains("not_null") || lower.contains("not null") {
391 col.nullable = false;
392 }
393 if lower.contains("unique") {
394 col.unique = true;
395 }
396
397 if let Some(idx) = lower.find("references ") {
398 let rest = &constraints[idx + 11..];
399 let mut paren_depth = 0;
401 let mut end = rest.len();
402 for (i, c) in rest.char_indices() {
403 match c {
404 '(' => paren_depth += 1,
405 ')' => {
406 if paren_depth == 0 {
407 end = i;
408 break;
409 }
410 paren_depth -= 1;
411 }
412 c if c.is_whitespace() && paren_depth == 0 => {
413 end = i;
414 break;
415 }
416 _ => {}
417 }
418 }
419 col.references = Some(rest[..end].to_string());
420 }
421
422 if let Some(idx) = lower.find("default ") {
423 let rest = &constraints[idx + 8..];
424 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
425 col.default_value = Some(rest[..end].to_string());
426 }
427
428 if let Some(idx) = lower.find("check(") {
429 let rest = &constraints[idx + 6..];
430 let mut depth = 1;
432 let mut end = rest.len();
433 for (i, c) in rest.char_indices() {
434 match c {
435 '(' => depth += 1,
436 ')' => {
437 depth -= 1;
438 if depth == 0 {
439 end = i;
440 break;
441 }
442 }
443 _ => {}
444 }
445 }
446 col.check = Some(rest[..end].to_string());
447 }
448 }
449
450 Ok((input, col))
451}
452
453fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
455 let (input, _) = ws_and_comments(input)?;
456 let (input, _) = char('(').parse(input)?;
457 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
458 let (input, _) = ws_and_comments(input)?;
459 let (input, _) = char(')').parse(input)?;
460
461 Ok((input, columns))
462}
463
464fn parse_table(input: &str) -> IResult<&str, TableDef> {
466 let (input, _) = ws_and_comments(input)?;
467 let (input, _) = tag_no_case("table").parse(input)?;
468 let (input, _) = multispace1(input)?;
469 let (input, name) = identifier(input)?;
470 let (input, columns) = parse_column_list(input)?;
471
472 let (input, _) = ws_and_comments(input)?;
474 let enable_rls = if let Ok((rest, _)) =
475 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
476 {
477 return Ok((
478 rest,
479 TableDef {
480 name: name.to_string(),
481 columns,
482 enable_rls: true,
483 },
484 ));
485 } else {
486 false
487 };
488
489 Ok((
490 input,
491 TableDef {
492 name: name.to_string(),
493 columns,
494 enable_rls,
495 },
496 ))
497}
498
499enum SchemaItem {
505 Table(TableDef),
506 Policy(Box<RlsPolicy>),
507 Index(IndexDef),
508}
509
510fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
522 let (input, _) = ws_and_comments(input)?;
523 let (input, _) = tag_no_case("policy").parse(input)?;
524 let (input, _) = multispace1(input)?;
525 let (input, name) = identifier(input)?;
526 let (input, _) = multispace1(input)?;
527 let (input, _) = tag_no_case("on").parse(input)?;
528 let (input, _) = multispace1(input)?;
529 let (input, table) = identifier(input)?;
530
531 let mut policy = RlsPolicy::create(name, table);
532
533 let mut remaining = input;
535 loop {
536 let (input, _) = ws_and_comments(remaining)?;
537
538 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input) {
540 let (rest, _) = multispace1(rest)?;
541 let (rest, target) = alt((
542 map(tag_no_case("all"), |_| PolicyTarget::All),
543 map(tag_no_case("select"), |_| PolicyTarget::Select),
544 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
545 map(tag_no_case("update"), |_| PolicyTarget::Update),
546 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
547 ))
548 .parse(rest)?;
549 policy.target = target;
550 remaining = rest;
551 continue;
552 }
553
554 if let Ok((rest, _)) =
556 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
557 {
558 policy.permissiveness = PolicyPermissiveness::Restrictive;
559 remaining = rest;
560 continue;
561 }
562
563 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input) {
565 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
567 let (rest, role) = identifier(rest)?;
568 policy.role = Some(role.to_string());
569 remaining = rest;
570 continue;
571 }
572 }
573
574 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input) {
576 let (rest, _) = multispace1(rest)?;
577 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
578 {
579 let (rest, _) = nom_ws0(rest)?;
580 let (rest, _) = char('(').parse(rest)?;
581 let (rest, _) = nom_ws0(rest)?;
582 let (rest, expr) = parse_policy_expr(rest)?;
583 let (rest, _) = nom_ws0(rest)?;
584 let (rest, _) = char(')').parse(rest)?;
585 policy.with_check = Some(expr);
586 remaining = rest;
587 continue;
588 }
589 }
590
591 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input) {
593 let (rest, _) = nom_ws0(rest)?;
594 let (rest, _) = char('(').parse(rest)?;
595 let (rest, _) = nom_ws0(rest)?;
596 let (rest, expr) = parse_policy_expr(rest)?;
597 let (rest, _) = nom_ws0(rest)?;
598 let (rest, _) = char(')').parse(rest)?;
599 policy.using = Some(expr);
600 remaining = rest;
601 continue;
602 }
603
604 remaining = input;
606 break;
607 }
608
609 Ok((remaining, policy))
610}
611
612fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
621 let (input, first) = parse_policy_comparison(input)?;
622
623 let mut result = first;
625 let mut remaining = input;
626 loop {
627 let (input, _) = nom_ws0(remaining)?;
628
629 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
630 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
631 {
632 let (rest, right) = parse_policy_comparison(rest)?;
633 result = Expr::Binary {
634 left: Box::new(result),
635 op: BinaryOp::Or,
636 right: Box::new(right),
637 alias: None,
638 };
639 remaining = rest;
640 continue;
641 }
642
643 if let Ok((rest, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
644 && let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest)
645 {
646 let (rest, right) = parse_policy_comparison(rest)?;
647 result = Expr::Binary {
648 left: Box::new(result),
649 op: BinaryOp::And,
650 right: Box::new(right),
651 alias: None,
652 };
653 remaining = rest;
654 continue;
655 }
656
657 remaining = input;
658 break;
659 }
660
661 Ok((remaining, result))
662}
663
664fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
666 let (input, left) = parse_policy_atom(input)?;
667 let (input, _) = nom_ws0(input)?;
668
669 if let Ok((rest, op)) = parse_cmp_op(input) {
671 let (rest, _) = nom_ws0(rest)?;
672 let (rest, right) = parse_policy_atom(rest)?;
673 return Ok((
674 rest,
675 Expr::Binary {
676 left: Box::new(left),
677 op,
678 right: Box::new(right),
679 alias: None,
680 },
681 ));
682 }
683
684 Ok((input, left))
686}
687
688fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
690 alt((
691 map(tag(">="), |_| BinaryOp::Gte),
692 map(tag("<="), |_| BinaryOp::Lte),
693 map(tag("<>"), |_| BinaryOp::Ne),
694 map(tag("!="), |_| BinaryOp::Ne),
695 map(tag("="), |_| BinaryOp::Eq),
696 map(tag(">"), |_| BinaryOp::Gt),
697 map(tag("<"), |_| BinaryOp::Lt),
698 ))
699 .parse(input)
700}
701
702fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
710 alt((
711 parse_policy_grouped,
712 parse_policy_bool,
713 parse_policy_string,
714 parse_policy_number,
715 parse_policy_func_or_ident, ))
717 .parse(input)
718}
719
720fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
722 let (input, _) = char('(').parse(input)?;
723 let (input, _) = nom_ws0(input)?;
724 let (input, expr) = parse_policy_expr(input)?;
725 let (input, _) = nom_ws0(input)?;
726 let (input, _) = char(')').parse(input)?;
727 Ok((input, expr))
728}
729
730fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
732 alt((
733 map(tag_no_case("true"), |_| Expr::Literal(AstValue::Bool(true))),
734 map(tag_no_case("false"), |_| {
735 Expr::Literal(AstValue::Bool(false))
736 }),
737 ))
738 .parse(input)
739}
740
741fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
743 let (input, _) = char('\'').parse(input)?;
744 let mut end = 0;
745 for (i, c) in input.char_indices() {
746 if c == '\'' {
747 end = i;
748 break;
749 }
750 }
751 let content = &input[..end];
752 let rest = &input[end + 1..];
753 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
754}
755
756fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
758 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
759 if digits.starts_with('.') || digits.is_empty() {
761 return Err(nom::Err::Error(nom::error::Error::new(
762 input,
763 nom::error::ErrorKind::Digit,
764 )));
765 }
766 if let Ok(n) = digits.parse::<i64>() {
767 Ok((input, Expr::Literal(AstValue::Int(n))))
768 } else if let Ok(f) = digits.parse::<f64>() {
769 Ok((input, Expr::Literal(AstValue::Float(f))))
770 } else {
771 Ok((input, Expr::Named(digits.to_string())))
772 }
773}
774
775fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
777 let (input, name) = identifier(input)?;
778
779 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
781 let (rest, _) = nom_ws0(rest)?;
783 let (rest, args) =
784 separated_list0((nom_ws0, char(','), nom_ws0), parse_policy_atom).parse(rest)?;
785 let (rest, _) = nom_ws0(rest)?;
786 let (rest, _) = char(')').parse(rest)?;
787 let input = rest;
788 (
789 input,
790 Expr::FunctionCall {
791 name: name.to_string(),
792 args,
793 alias: None,
794 },
795 )
796 } else {
797 (input, Expr::Named(name.to_string()))
798 };
799
800 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
802 let (rest, cast_type) = identifier(rest)?;
803 expr = (
804 rest,
805 Expr::Cast {
806 expr: Box::new(expr.1),
807 target_type: cast_type.to_string(),
808 alias: None,
809 },
810 );
811 }
812
813 Ok(expr)
814}
815
816fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
818 let (input, _) = ws_and_comments(input)?;
819
820 if let Ok((rest, policy)) = parse_policy(input) {
822 return Ok((rest, SchemaItem::Policy(Box::new(policy))));
823 }
824
825 if let Ok((rest, idx)) = parse_index(input) {
827 return Ok((rest, SchemaItem::Index(idx)));
828 }
829
830 let (rest, table) = parse_table(input)?;
832 Ok((rest, SchemaItem::Table(table)))
833}
834
835fn parse_index(input: &str) -> IResult<&str, IndexDef> {
837 let (input, _) = tag_no_case("index")(input)?;
838 let (input, _) = multispace1(input)?;
839 let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
840 let (input, _) = multispace1(input)?;
841 let (input, _) = tag_no_case("on")(input)?;
842 let (input, _) = multispace1(input)?;
843 let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
844 let (input, _) = nom_ws0(input)?;
845 let (input, _) = char('(')(input)?;
846 let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
847 let (input, _) = char(')')(input)?;
848 let (input, _) = nom_ws0(input)?;
849 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
850
851 let columns: Vec<String> = cols_str
852 .split(',')
853 .map(|s| s.trim().to_string())
854 .filter(|s| !s.is_empty())
855 .collect();
856
857 let is_unique = unique_tag.is_some();
858
859 Ok((
860 input,
861 IndexDef {
862 name: name.to_string(),
863 table: table.to_string(),
864 columns,
865 unique: is_unique,
866 },
867 ))
868}
869
870fn parse_schema(input: &str) -> IResult<&str, Schema> {
872 let version = extract_version_directive(input);
874
875 let (input, items) = many0(parse_schema_item).parse(input)?;
876 let (input, _) = ws_and_comments(input)?;
877
878 let mut tables = Vec::new();
879 let mut policies = Vec::new();
880 let mut indexes = Vec::new();
881 for item in items {
882 match item {
883 SchemaItem::Table(t) => tables.push(t),
884 SchemaItem::Policy(p) => policies.push(*p),
885 SchemaItem::Index(i) => indexes.push(i),
886 }
887 }
888
889 Ok((
890 input,
891 Schema {
892 version,
893 tables,
894 policies,
895 indexes,
896 },
897 ))
898}
899
900fn extract_version_directive(input: &str) -> Option<u32> {
902 for line in input.lines() {
903 let line = line.trim();
904 if let Some(rest) = line.strip_prefix("-- qail:") {
905 let rest = rest.trim();
906 if let Some(version_str) = rest.strip_prefix("version=") {
907 return version_str.trim().parse().ok();
908 }
909 }
910 }
911 None
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917
918 #[test]
919 fn test_parse_simple_table() {
920 let input = r#"
921 table users (
922 id uuid primary_key,
923 email text not null,
924 name text
925 )
926 "#;
927
928 let schema = Schema::parse(input).expect("parse failed");
929 assert_eq!(schema.tables.len(), 1);
930
931 let users = &schema.tables[0];
932 assert_eq!(users.name, "users");
933 assert_eq!(users.columns.len(), 3);
934
935 let id = &users.columns[0];
936 assert_eq!(id.name, "id");
937 assert_eq!(id.typ, "uuid");
938 assert!(id.primary_key);
939 assert!(!id.nullable);
940
941 let email = &users.columns[1];
942 assert_eq!(email.name, "email");
943 assert!(!email.nullable);
944
945 let name = &users.columns[2];
946 assert!(name.nullable);
947 }
948
949 #[test]
950 fn test_parse_multiple_tables() {
951 let input = r#"
952 -- Users table
953 table users (
954 id uuid primary_key,
955 email text not null unique
956 )
957
958 -- Orders table
959 table orders (
960 id uuid primary_key,
961 user_id uuid references users(id),
962 total i64 not null default 0
963 )
964 "#;
965
966 let schema = Schema::parse(input).expect("parse failed");
967 assert_eq!(schema.tables.len(), 2);
968
969 let orders = schema.find_table("orders").expect("orders not found");
970 let user_id = orders.find_column("user_id").expect("user_id not found");
971 assert_eq!(user_id.references, Some("users(id)".to_string()));
972
973 let total = orders.find_column("total").expect("total not found");
974 assert_eq!(total.default_value, Some("0".to_string()));
975 }
976
977 #[test]
978 fn test_parse_comments() {
979 let input = r#"
980 -- This is a comment
981 table foo (
982 bar text
983 )
984 "#;
985
986 let schema = Schema::parse(input).expect("parse failed");
987 assert_eq!(schema.tables.len(), 1);
988 }
989
990 #[test]
991 fn test_array_types() {
992 let input = r#"
993 table products (
994 id uuid primary_key,
995 tags text[],
996 prices decimal[]
997 )
998 "#;
999
1000 let schema = Schema::parse(input).expect("parse failed");
1001 let products = &schema.tables[0];
1002
1003 let tags = products.find_column("tags").expect("tags not found");
1004 assert_eq!(tags.typ, "text");
1005 assert!(tags.is_array);
1006
1007 let prices = products.find_column("prices").expect("prices not found");
1008 assert!(prices.is_array);
1009 }
1010
1011 #[test]
1012 fn test_type_params() {
1013 let input = r#"
1014 table items (
1015 id serial primary_key,
1016 name varchar(255) not null,
1017 price decimal(10,2),
1018 code varchar(50) unique
1019 )
1020 "#;
1021
1022 let schema = Schema::parse(input).expect("parse failed");
1023 let items = &schema.tables[0];
1024
1025 let id = items.find_column("id").expect("id not found");
1026 assert!(id.is_serial);
1027 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
1030 assert_eq!(name.typ, "varchar");
1031 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1032
1033 let price = items.find_column("price").expect("price not found");
1034 assert_eq!(
1035 price.type_params,
1036 Some(vec!["10".to_string(), "2".to_string()])
1037 );
1038
1039 let code = items.find_column("code").expect("code not found");
1040 assert!(code.unique);
1041 }
1042
1043 #[test]
1044 fn test_check_constraint() {
1045 let input = r#"
1046 table employees (
1047 id uuid primary_key,
1048 age i32 check(age >= 18),
1049 salary decimal check(salary > 0)
1050 )
1051 "#;
1052
1053 let schema = Schema::parse(input).expect("parse failed");
1054 let employees = &schema.tables[0];
1055
1056 let age = employees.find_column("age").expect("age not found");
1057 assert_eq!(age.check, Some("age >= 18".to_string()));
1058
1059 let salary = employees.find_column("salary").expect("salary not found");
1060 assert_eq!(salary.check, Some("salary > 0".to_string()));
1061 }
1062
1063 #[test]
1064 fn test_version_directive() {
1065 let input = r#"
1066 -- qail: version=1
1067 table users (
1068 id uuid primary_key
1069 )
1070 "#;
1071
1072 let schema = Schema::parse(input).expect("parse failed");
1073 assert_eq!(schema.version, Some(1));
1074 assert_eq!(schema.tables.len(), 1);
1075
1076 let input_no_version = r#"
1078 table items (
1079 id uuid primary_key
1080 )
1081 "#;
1082 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1083 assert_eq!(schema2.version, None);
1084 }
1085
1086 #[test]
1091 fn test_enable_rls_table() {
1092 let input = r#"
1093 table orders (
1094 id uuid primary_key,
1095 operator_id uuid not null
1096 ) enable_rls
1097 "#;
1098
1099 let schema = Schema::parse(input).expect("parse failed");
1100 assert_eq!(schema.tables.len(), 1);
1101 assert!(schema.tables[0].enable_rls);
1102 }
1103
1104 #[test]
1105 fn test_parse_policy_basic() {
1106 let input = r#"
1107 table orders (
1108 id uuid primary_key,
1109 operator_id uuid not null
1110 ) enable_rls
1111
1112 policy orders_isolation on orders
1113 for all
1114 using (operator_id = current_setting('app.current_operator_id')::uuid)
1115 "#;
1116
1117 let schema = Schema::parse(input).expect("parse failed");
1118 assert_eq!(schema.tables.len(), 1);
1119 assert_eq!(schema.policies.len(), 1);
1120
1121 let policy = &schema.policies[0];
1122 assert_eq!(policy.name, "orders_isolation");
1123 assert_eq!(policy.table, "orders");
1124 assert_eq!(policy.target, PolicyTarget::All);
1125 assert!(policy.using.is_some());
1126
1127 let using = policy.using.as_ref().unwrap();
1129 let Expr::Binary {
1130 left, op, right, ..
1131 } = using
1132 else {
1133 panic!("Expected Binary, got {using:?}");
1134 };
1135 assert_eq!(*op, BinaryOp::Eq);
1136
1137 let Expr::Named(n) = left.as_ref() else {
1138 panic!("Expected Named, got {left:?}");
1139 };
1140 assert_eq!(n, "operator_id");
1141
1142 let Expr::Cast {
1143 target_type,
1144 expr: cast_expr,
1145 ..
1146 } = right.as_ref()
1147 else {
1148 panic!("Expected Cast, got {right:?}");
1149 };
1150 assert_eq!(target_type, "uuid");
1151
1152 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
1153 panic!("Expected FunctionCall, got {cast_expr:?}");
1154 };
1155 assert_eq!(name, "current_setting");
1156 assert_eq!(args.len(), 1);
1157 }
1158
1159 #[test]
1160 fn test_parse_policy_with_check() {
1161 let input = r#"
1162 table orders (
1163 id uuid primary_key
1164 )
1165
1166 policy orders_write on orders
1167 for insert
1168 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1169 "#;
1170
1171 let schema = Schema::parse(input).expect("parse failed");
1172 let policy = &schema.policies[0];
1173 assert_eq!(policy.target, PolicyTarget::Insert);
1174 assert!(policy.with_check.is_some());
1175 assert!(policy.using.is_none());
1176 }
1177
1178 #[test]
1179 fn test_parse_policy_restrictive_with_role() {
1180 let input = r#"
1181 table secrets (
1182 id uuid primary_key
1183 )
1184
1185 policy admin_only on secrets
1186 for select
1187 restrictive
1188 to app_user
1189 using (current_setting('app.is_super_admin')::boolean = true)
1190 "#;
1191
1192 let schema = Schema::parse(input).expect("parse failed");
1193 let policy = &schema.policies[0];
1194 assert_eq!(policy.target, PolicyTarget::Select);
1195 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1196 assert_eq!(policy.role.as_deref(), Some("app_user"));
1197 assert!(policy.using.is_some());
1198 }
1199
1200 #[test]
1201 fn test_parse_policy_or_expr() {
1202 let input = r#"
1203 table orders (
1204 id uuid primary_key
1205 )
1206
1207 policy tenant_or_admin on orders
1208 for all
1209 using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1210 "#;
1211
1212 let schema = Schema::parse(input).expect("parse failed");
1213 let policy = &schema.policies[0];
1214
1215 assert!(
1216 matches!(
1217 policy.using.as_ref().unwrap(),
1218 Expr::Binary {
1219 op: BinaryOp::Or,
1220 ..
1221 }
1222 ),
1223 "Expected Binary OR, got {:?}",
1224 policy.using
1225 );
1226 }
1227
1228 #[test]
1229 fn test_schema_to_sql() {
1230 let input = r#"
1231 table orders (
1232 id uuid primary_key,
1233 operator_id uuid not null
1234 ) enable_rls
1235
1236 policy orders_isolation on orders
1237 for all
1238 using (operator_id = current_setting('app.current_operator_id')::uuid)
1239 "#;
1240
1241 let schema = Schema::parse(input).expect("parse failed");
1242 let sql = schema.to_sql();
1243 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1244 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1245 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1246 assert!(sql.contains("CREATE POLICY"));
1247 assert!(sql.contains("orders_isolation"));
1248 assert!(sql.contains("FOR ALL"));
1249 }
1250
1251 #[test]
1252 fn test_multiple_policies() {
1253 let input = r#"
1254 table orders (
1255 id uuid primary_key,
1256 operator_id uuid not null
1257 ) enable_rls
1258
1259 policy orders_read on orders
1260 for select
1261 using (operator_id = current_setting('app.current_operator_id')::uuid)
1262
1263 policy orders_write on orders
1264 for insert
1265 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1266 "#;
1267
1268 let schema = Schema::parse(input).expect("parse failed");
1269 assert_eq!(schema.policies.len(), 2);
1270 assert_eq!(schema.policies[0].name, "orders_read");
1271 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1272 assert_eq!(schema.policies[1].name, "orders_write");
1273 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1274 }
1275}