1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{Expr, BinaryOp, Value as AstValue};
29use crate::migrate::policy::{RlsPolicy, PolicyTarget, PolicyPermissiveness};
30use crate::transpiler::policy::{create_policy_sql, alter_table_sql};
31use crate::migrate::alter::AlterTable;
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36 #[serde(default)]
38 pub version: Option<u32>,
39 pub tables: Vec<TableDef>,
40 #[serde(default)]
42 pub policies: Vec<RlsPolicy>,
43 #[serde(default)]
45 pub indexes: Vec<IndexDef>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct IndexDef {
51 pub name: String,
52 pub table: String,
53 pub columns: Vec<String>,
54 #[serde(default)]
55 pub unique: bool,
56}
57
58impl IndexDef {
59 pub fn to_sql(&self) -> String {
61 let unique = if self.unique { " UNIQUE" } else { "" };
62 format!(
63 "CREATE{} INDEX IF NOT EXISTS {} ON {} ({})",
64 unique,
65 self.name,
66 self.table,
67 self.columns.join(", ")
68 )
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TableDef {
74 pub name: String,
75 pub columns: Vec<ColumnDef>,
76 #[serde(default)]
78 pub enable_rls: bool,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ColumnDef {
83 pub name: String,
84 #[serde(rename = "type", alias = "typ")]
85 pub typ: String,
86 #[serde(default)]
88 pub is_array: bool,
89 #[serde(default)]
91 pub type_params: Option<Vec<String>>,
92 #[serde(default)]
93 pub nullable: bool,
94 #[serde(default)]
95 pub primary_key: bool,
96 #[serde(default)]
97 pub unique: bool,
98 #[serde(default)]
99 pub references: Option<String>,
100 #[serde(default)]
101 pub default_value: Option<String>,
102 #[serde(default)]
104 pub check: Option<String>,
105 #[serde(default)]
107 pub is_serial: bool,
108}
109
110impl Default for ColumnDef {
111 fn default() -> Self {
112 Self {
113 name: String::new(),
114 typ: String::new(),
115 is_array: false,
116 type_params: None,
117 nullable: true,
118 primary_key: false,
119 unique: false,
120 references: None,
121 default_value: None,
122 check: None,
123 is_serial: false,
124 }
125 }
126}
127
128impl Schema {
129 pub fn parse(input: &str) -> Result<Self, String> {
131 match parse_schema(input) {
132 Ok(("", schema)) => Ok(schema),
133 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
134 Err(e) => Err(format!("Parse error: {:?}", e)),
135 }
136 }
137
138 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
140 self.tables
141 .iter()
142 .find(|t| t.name.eq_ignore_ascii_case(name))
143 }
144
145 pub fn to_sql(&self) -> String {
147 let mut parts = Vec::new();
148
149 for table in &self.tables {
150 parts.push(table.to_ddl());
151
152 if table.enable_rls {
153 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
154 for stmt in alter_table_sql(&alter) {
155 parts.push(stmt);
156 }
157 }
158 }
159
160 for idx in &self.indexes {
161 parts.push(idx.to_sql());
162 }
163
164 for policy in &self.policies {
165 parts.push(create_policy_sql(policy));
166 }
167
168 parts.join(";\n\n") + ";"
169 }
170
171 pub fn to_json(&self) -> Result<String, String> {
173 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
174 }
175
176 pub fn from_json(json: &str) -> Result<Self, String> {
178 serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
179 }
180
181 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
183 let content =
184 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
185
186 if content.trim().starts_with('{') {
187 Self::from_json(&content)
188 } else {
189 Self::parse(&content)
190 }
191 }
192}
193
194impl TableDef {
195 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
197 self.columns
198 .iter()
199 .find(|c| c.name.eq_ignore_ascii_case(name))
200 }
201
202 pub fn to_ddl(&self) -> String {
204 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
205
206 let mut col_defs = Vec::new();
207 for col in &self.columns {
208 let mut line = format!(" {}", col.name);
209
210 let mut typ = col.typ.to_uppercase();
212 if let Some(params) = &col.type_params {
213 typ = format!("{}({})", typ, params.join(", "));
214 }
215 if col.is_array {
216 typ.push_str("[]");
217 }
218 line.push_str(&format!(" {}", typ));
219
220 if col.primary_key {
222 line.push_str(" PRIMARY KEY");
223 }
224 if !col.nullable && !col.primary_key && !col.is_serial {
225 line.push_str(" NOT NULL");
226 }
227 if col.unique && !col.primary_key {
228 line.push_str(" UNIQUE");
229 }
230 if let Some(ref default) = col.default_value {
231 line.push_str(&format!(" DEFAULT {}", default));
232 }
233 if let Some(ref refs) = col.references {
234 line.push_str(&format!(" REFERENCES {}", refs));
235 }
236 if let Some(ref check) = col.check {
237 line.push_str(&format!(" CHECK({})", check));
238 }
239
240 col_defs.push(line);
241 }
242
243 sql.push_str(&col_defs.join(",\n"));
244 sql.push_str("\n)");
245 sql
246 }
247}
248
249fn identifier(input: &str) -> IResult<&str, &str> {
255 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
256}
257
258fn ws_and_comments(input: &str) -> IResult<&str, ()> {
260 let (input, _) = many0(alt((
261 map(multispace1, |_| ()),
262 map((tag("--"), not_line_ending), |_| ()),
263 map((tag("#"), not_line_ending), |_| ()),
264 )))
265 .parse(input)?;
266 Ok((input, ()))
267}
268
269struct TypeInfo {
270 name: String,
271 params: Option<Vec<String>>,
272 is_array: bool,
273 is_serial: bool,
274}
275
276fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
279 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
280
281 let (input, params) = if input.starts_with('(') {
282 let paren_start = 1;
283 let mut paren_end = paren_start;
284 for (i, c) in input[paren_start..].char_indices() {
285 if c == ')' {
286 paren_end = paren_start + i;
287 break;
288 }
289 }
290 let param_str = &input[paren_start..paren_end];
291 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
292 (&input[paren_end + 1..], Some(params))
293 } else {
294 (input, None)
295 };
296
297 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
298 (stripped, true)
299 } else {
300 (input, false)
301 };
302
303 let lower = type_name.to_lowercase();
304 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
305
306 Ok((
307 input,
308 TypeInfo {
309 name: lower,
310 params,
311 is_array,
312 is_serial,
313 },
314 ))
315}
316
317fn constraint_text(input: &str) -> IResult<&str, &str> {
319 let mut paren_depth = 0;
320 let mut end = 0;
321
322 for (i, c) in input.char_indices() {
323 match c {
324 '(' => paren_depth += 1,
325 ')' => {
326 if paren_depth == 0 {
327 break; }
329 paren_depth -= 1;
330 }
331 ',' if paren_depth == 0 => break,
332 '\n' | '\r' if paren_depth == 0 => break,
333 _ => {}
334 }
335 end = i + c.len_utf8();
336 }
337
338 if end == 0 {
339 Err(nom::Err::Error(nom::error::Error::new(
340 input,
341 nom::error::ErrorKind::TakeWhile1,
342 )))
343 } else {
344 Ok((&input[end..], &input[..end]))
345 }
346}
347
348fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
350 let (input, _) = ws_and_comments(input)?;
351 let (input, name) = identifier(input)?;
352 let (input, _) = multispace1(input)?;
353 let (input, type_info) = parse_type_info(input)?;
354
355 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
356
357 let mut col = ColumnDef {
358 name: name.to_string(),
359 typ: type_info.name,
360 is_array: type_info.is_array,
361 type_params: type_info.params,
362 is_serial: type_info.is_serial,
363 nullable: !type_info.is_serial, ..Default::default()
365 };
366
367 if let Some(constraints) = constraint_str {
368 let lower = constraints.to_lowercase();
369
370 if lower.contains("primary_key") || lower.contains("primary key") {
371 col.primary_key = true;
372 col.nullable = false;
373 }
374 if lower.contains("not_null") || lower.contains("not null") {
375 col.nullable = false;
376 }
377 if lower.contains("unique") {
378 col.unique = true;
379 }
380
381 if let Some(idx) = lower.find("references ") {
382 let rest = &constraints[idx + 11..];
383 let mut paren_depth = 0;
385 let mut end = rest.len();
386 for (i, c) in rest.char_indices() {
387 match c {
388 '(' => paren_depth += 1,
389 ')' => {
390 if paren_depth == 0 {
391 end = i;
392 break;
393 }
394 paren_depth -= 1;
395 }
396 c if c.is_whitespace() && paren_depth == 0 => {
397 end = i;
398 break;
399 }
400 _ => {}
401 }
402 }
403 col.references = Some(rest[..end].to_string());
404 }
405
406 if let Some(idx) = lower.find("default ") {
407 let rest = &constraints[idx + 8..];
408 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
409 col.default_value = Some(rest[..end].to_string());
410 }
411
412 if let Some(idx) = lower.find("check(") {
413 let rest = &constraints[idx + 6..];
414 let mut depth = 1;
416 let mut end = rest.len();
417 for (i, c) in rest.char_indices() {
418 match c {
419 '(' => depth += 1,
420 ')' => {
421 depth -= 1;
422 if depth == 0 {
423 end = i;
424 break;
425 }
426 }
427 _ => {}
428 }
429 }
430 col.check = Some(rest[..end].to_string());
431 }
432 }
433
434 Ok((input, col))
435}
436
437fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
439 let (input, _) = ws_and_comments(input)?;
440 let (input, _) = char('(').parse(input)?;
441 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
442 let (input, _) = ws_and_comments(input)?;
443 let (input, _) = char(')').parse(input)?;
444
445 Ok((input, columns))
446}
447
448fn parse_table(input: &str) -> IResult<&str, TableDef> {
450 let (input, _) = ws_and_comments(input)?;
451 let (input, _) = tag_no_case("table").parse(input)?;
452 let (input, _) = multispace1(input)?;
453 let (input, name) = identifier(input)?;
454 let (input, columns) = parse_column_list(input)?;
455
456 let (input, _) = ws_and_comments(input)?;
458 let enable_rls = if let Ok((rest, _)) =
459 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
460 {
461 return Ok((
462 rest,
463 TableDef {
464 name: name.to_string(),
465 columns,
466 enable_rls: true,
467 },
468 ));
469 } else {
470 false
471 };
472
473 Ok((
474 input,
475 TableDef {
476 name: name.to_string(),
477 columns,
478 enable_rls,
479 },
480 ))
481}
482
483enum SchemaItem {
489 Table(TableDef),
490 Policy(RlsPolicy),
491 Index(IndexDef),
492}
493
494fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
506 let (input, _) = ws_and_comments(input)?;
507 let (input, _) = tag_no_case("policy").parse(input)?;
508 let (input, _) = multispace1(input)?;
509 let (input, name) = identifier(input)?;
510 let (input, _) = multispace1(input)?;
511 let (input, _) = tag_no_case("on").parse(input)?;
512 let (input, _) = multispace1(input)?;
513 let (input, table) = identifier(input)?;
514
515 let mut policy = RlsPolicy::create(name, table);
516
517 let mut remaining = input;
519 loop {
520 let (input, _) = ws_and_comments(remaining)?;
521
522 if let Ok((rest, _)) =
524 tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input)
525 {
526 let (rest, _) = multispace1(rest)?;
527 let (rest, target) = alt((
528 map(tag_no_case("all"), |_| PolicyTarget::All),
529 map(tag_no_case("select"), |_| PolicyTarget::Select),
530 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
531 map(tag_no_case("update"), |_| PolicyTarget::Update),
532 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
533 ))
534 .parse(rest)?;
535 policy.target = target;
536 remaining = rest;
537 continue;
538 }
539
540 if let Ok((rest, _)) =
542 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
543 {
544 policy.permissiveness = PolicyPermissiveness::Restrictive;
545 remaining = rest;
546 continue;
547 }
548
549 if let Ok((rest, _)) =
551 tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input)
552 {
553 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
555 let (rest, role) = identifier(rest)?;
556 policy.role = Some(role.to_string());
557 remaining = rest;
558 continue;
559 }
560 }
561
562 if let Ok((rest, _)) =
564 tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input)
565 {
566 let (rest, _) = multispace1(rest)?;
567 if let Ok((rest, _)) =
568 tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
569 {
570 let (rest, _) = nom_ws0(rest)?;
571 let (rest, _) = char('(').parse(rest)?;
572 let (rest, _) = nom_ws0(rest)?;
573 let (rest, expr) = parse_policy_expr(rest)?;
574 let (rest, _) = nom_ws0(rest)?;
575 let (rest, _) = char(')').parse(rest)?;
576 policy.with_check = Some(expr);
577 remaining = rest;
578 continue;
579 }
580 }
581
582 if let Ok((rest, _)) =
584 tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input)
585 {
586 let (rest, _) = nom_ws0(rest)?;
587 let (rest, _) = char('(').parse(rest)?;
588 let (rest, _) = nom_ws0(rest)?;
589 let (rest, expr) = parse_policy_expr(rest)?;
590 let (rest, _) = nom_ws0(rest)?;
591 let (rest, _) = char(')').parse(rest)?;
592 policy.using = Some(expr);
593 remaining = rest;
594 continue;
595 }
596
597 remaining = input;
599 break;
600 }
601
602 Ok((remaining, policy))
603}
604
605fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
614 let (input, first) = parse_policy_comparison(input)?;
615
616 let mut result = first;
618 let mut remaining = input;
619 loop {
620 let (input, _) = nom_ws0(remaining)?;
621
622 if let Ok((rest, _)) =
623 tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
624 {
625 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
626 let (rest, right) = parse_policy_comparison(rest)?;
627 result = Expr::Binary {
628 left: Box::new(result),
629 op: BinaryOp::Or,
630 right: Box::new(right),
631 alias: None,
632 };
633 remaining = rest;
634 continue;
635 }
636 }
637
638 if let Ok((rest, _)) =
639 tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
640 {
641 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
642 let (rest, right) = parse_policy_comparison(rest)?;
643 result = Expr::Binary {
644 left: Box::new(result),
645 op: BinaryOp::And,
646 right: Box::new(right),
647 alias: None,
648 };
649 remaining = rest;
650 continue;
651 }
652 }
653
654 remaining = input;
655 break;
656 }
657
658 Ok((remaining, result))
659}
660
661fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
663 let (input, left) = parse_policy_atom(input)?;
664 let (input, _) = nom_ws0(input)?;
665
666 if let Ok((rest, op)) = parse_cmp_op(input) {
668 let (rest, _) = nom_ws0(rest)?;
669 let (rest, right) = parse_policy_atom(rest)?;
670 return Ok((
671 rest,
672 Expr::Binary {
673 left: Box::new(left),
674 op,
675 right: Box::new(right),
676 alias: None,
677 },
678 ));
679 }
680
681 Ok((input, left))
683}
684
685fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
687 alt((
688 map(tag(">="), |_| BinaryOp::Gte),
689 map(tag("<="), |_| BinaryOp::Lte),
690 map(tag("<>"), |_| BinaryOp::Ne),
691 map(tag("!="), |_| BinaryOp::Ne),
692 map(tag("="), |_| BinaryOp::Eq),
693 map(tag(">"), |_| BinaryOp::Gt),
694 map(tag("<"), |_| BinaryOp::Lt),
695 ))
696 .parse(input)
697}
698
699fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
707 alt((
708 parse_policy_grouped,
709 parse_policy_bool,
710 parse_policy_string,
711 parse_policy_number,
712 parse_policy_func_or_ident, ))
714 .parse(input)
715}
716
717fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
719 let (input, _) = char('(').parse(input)?;
720 let (input, _) = nom_ws0(input)?;
721 let (input, expr) = parse_policy_expr(input)?;
722 let (input, _) = nom_ws0(input)?;
723 let (input, _) = char(')').parse(input)?;
724 Ok((input, expr))
725}
726
727fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
729 alt((
730 map(tag_no_case("true"), |_| {
731 Expr::Literal(AstValue::Bool(true))
732 }),
733 map(tag_no_case("false"), |_| {
734 Expr::Literal(AstValue::Bool(false))
735 }),
736 ))
737 .parse(input)
738}
739
740fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
742 let (input, _) = char('\'').parse(input)?;
743 let mut end = 0;
744 for (i, c) in input.char_indices() {
745 if c == '\'' {
746 end = i;
747 break;
748 }
749 }
750 let content = &input[..end];
751 let rest = &input[end + 1..];
752 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
753}
754
755fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
757 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
758 if digits.starts_with('.') || digits.is_empty() {
760 return Err(nom::Err::Error(nom::error::Error::new(
761 input,
762 nom::error::ErrorKind::Digit,
763 )));
764 }
765 if let Ok(n) = digits.parse::<i64>() {
766 Ok((input, Expr::Literal(AstValue::Int(n))))
767 } else if let Ok(f) = digits.parse::<f64>() {
768 Ok((input, Expr::Literal(AstValue::Float(f))))
769 } else {
770 Ok((input, Expr::Named(digits.to_string())))
771 }
772}
773
774fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
776 let (input, name) = identifier(input)?;
777
778 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
780 let (rest, _) = nom_ws0(rest)?;
782 let (rest, args) = separated_list0(
783 (nom_ws0, char(','), nom_ws0),
784 parse_policy_atom,
785 )
786 .parse(rest)?;
787 let (rest, _) = nom_ws0(rest)?;
788 let (rest, _) = char(')').parse(rest)?;
789 let input = rest;
790 (input, Expr::FunctionCall {
791 name: name.to_string(),
792 args,
793 alias: None,
794 })
795 } else {
796 (input, Expr::Named(name.to_string()))
797 };
798
799 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
801 let (rest, cast_type) = identifier(rest)?;
802 expr = (
803 rest,
804 Expr::Cast {
805 expr: Box::new(expr.1),
806 target_type: cast_type.to_string(),
807 alias: None,
808 },
809 );
810 }
811
812 Ok(expr)
813}
814
815fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
817 let (input, _) = ws_and_comments(input)?;
818
819 if let Ok((rest, policy)) = parse_policy(input) {
821 return Ok((rest, SchemaItem::Policy(policy)));
822 }
823
824 if let Ok((rest, idx)) = parse_index(input) {
826 return Ok((rest, SchemaItem::Index(idx)));
827 }
828
829 let (rest, table) = parse_table(input)?;
831 Ok((rest, SchemaItem::Table(table)))
832}
833
834fn parse_index(input: &str) -> IResult<&str, IndexDef> {
836 let (input, _) = tag_no_case("index")(input)?;
837 let (input, _) = multispace1(input)?;
838 let (input, name) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
839 let (input, _) = multispace1(input)?;
840 let (input, _) = tag_no_case("on")(input)?;
841 let (input, _) = multispace1(input)?;
842 let (input, table) = take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)?;
843 let (input, _) = nom_ws0(input)?;
844 let (input, _) = char('(')(input)?;
845 let (input, cols_str) = take_while1(|c: char| c != ')')(input)?;
846 let (input, _) = char(')')(input)?;
847 let (input, _) = nom_ws0(input)?;
848 let (input, unique_tag) = opt(tag_no_case("unique")).parse(input)?;
849
850 let columns: Vec<String> = cols_str
851 .split(',')
852 .map(|s| s.trim().to_string())
853 .filter(|s| !s.is_empty())
854 .collect();
855
856 let is_unique = unique_tag.is_some();
857
858 Ok((input, IndexDef {
859 name: name.to_string(),
860 table: table.to_string(),
861 columns,
862 unique: is_unique,
863 }))
864}
865
866fn parse_schema(input: &str) -> IResult<&str, Schema> {
868 let version = extract_version_directive(input);
870
871 let (input, items) = many0(parse_schema_item).parse(input)?;
872 let (input, _) = ws_and_comments(input)?;
873
874 let mut tables = Vec::new();
875 let mut policies = Vec::new();
876 let mut indexes = Vec::new();
877 for item in items {
878 match item {
879 SchemaItem::Table(t) => tables.push(t),
880 SchemaItem::Policy(p) => policies.push(p),
881 SchemaItem::Index(i) => indexes.push(i),
882 }
883 }
884
885 Ok((input, Schema { version, tables, policies, indexes }))
886}
887
888fn extract_version_directive(input: &str) -> Option<u32> {
890 for line in input.lines() {
891 let line = line.trim();
892 if let Some(rest) = line.strip_prefix("-- qail:") {
893 let rest = rest.trim();
894 if let Some(version_str) = rest.strip_prefix("version=") {
895 return version_str.trim().parse().ok();
896 }
897 }
898 }
899 None
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905
906 #[test]
907 fn test_parse_simple_table() {
908 let input = r#"
909 table users (
910 id uuid primary_key,
911 email text not null,
912 name text
913 )
914 "#;
915
916 let schema = Schema::parse(input).expect("parse failed");
917 assert_eq!(schema.tables.len(), 1);
918
919 let users = &schema.tables[0];
920 assert_eq!(users.name, "users");
921 assert_eq!(users.columns.len(), 3);
922
923 let id = &users.columns[0];
924 assert_eq!(id.name, "id");
925 assert_eq!(id.typ, "uuid");
926 assert!(id.primary_key);
927 assert!(!id.nullable);
928
929 let email = &users.columns[1];
930 assert_eq!(email.name, "email");
931 assert!(!email.nullable);
932
933 let name = &users.columns[2];
934 assert!(name.nullable);
935 }
936
937 #[test]
938 fn test_parse_multiple_tables() {
939 let input = r#"
940 -- Users table
941 table users (
942 id uuid primary_key,
943 email text not null unique
944 )
945
946 -- Orders table
947 table orders (
948 id uuid primary_key,
949 user_id uuid references users(id),
950 total i64 not null default 0
951 )
952 "#;
953
954 let schema = Schema::parse(input).expect("parse failed");
955 assert_eq!(schema.tables.len(), 2);
956
957 let orders = schema.find_table("orders").expect("orders not found");
958 let user_id = orders.find_column("user_id").expect("user_id not found");
959 assert_eq!(user_id.references, Some("users(id)".to_string()));
960
961 let total = orders.find_column("total").expect("total not found");
962 assert_eq!(total.default_value, Some("0".to_string()));
963 }
964
965 #[test]
966 fn test_parse_comments() {
967 let input = r#"
968 -- This is a comment
969 table foo (
970 bar text
971 )
972 "#;
973
974 let schema = Schema::parse(input).expect("parse failed");
975 assert_eq!(schema.tables.len(), 1);
976 }
977
978 #[test]
979 fn test_array_types() {
980 let input = r#"
981 table products (
982 id uuid primary_key,
983 tags text[],
984 prices decimal[]
985 )
986 "#;
987
988 let schema = Schema::parse(input).expect("parse failed");
989 let products = &schema.tables[0];
990
991 let tags = products.find_column("tags").expect("tags not found");
992 assert_eq!(tags.typ, "text");
993 assert!(tags.is_array);
994
995 let prices = products.find_column("prices").expect("prices not found");
996 assert!(prices.is_array);
997 }
998
999 #[test]
1000 fn test_type_params() {
1001 let input = r#"
1002 table items (
1003 id serial primary_key,
1004 name varchar(255) not null,
1005 price decimal(10,2),
1006 code varchar(50) unique
1007 )
1008 "#;
1009
1010 let schema = Schema::parse(input).expect("parse failed");
1011 let items = &schema.tables[0];
1012
1013 let id = items.find_column("id").expect("id not found");
1014 assert!(id.is_serial);
1015 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
1018 assert_eq!(name.typ, "varchar");
1019 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
1020
1021 let price = items.find_column("price").expect("price not found");
1022 assert_eq!(
1023 price.type_params,
1024 Some(vec!["10".to_string(), "2".to_string()])
1025 );
1026
1027 let code = items.find_column("code").expect("code not found");
1028 assert!(code.unique);
1029 }
1030
1031 #[test]
1032 fn test_check_constraint() {
1033 let input = r#"
1034 table employees (
1035 id uuid primary_key,
1036 age i32 check(age >= 18),
1037 salary decimal check(salary > 0)
1038 )
1039 "#;
1040
1041 let schema = Schema::parse(input).expect("parse failed");
1042 let employees = &schema.tables[0];
1043
1044 let age = employees.find_column("age").expect("age not found");
1045 assert_eq!(age.check, Some("age >= 18".to_string()));
1046
1047 let salary = employees.find_column("salary").expect("salary not found");
1048 assert_eq!(salary.check, Some("salary > 0".to_string()));
1049 }
1050
1051 #[test]
1052 fn test_version_directive() {
1053 let input = r#"
1054 -- qail: version=1
1055 table users (
1056 id uuid primary_key
1057 )
1058 "#;
1059
1060 let schema = Schema::parse(input).expect("parse failed");
1061 assert_eq!(schema.version, Some(1));
1062 assert_eq!(schema.tables.len(), 1);
1063
1064 let input_no_version = r#"
1066 table items (
1067 id uuid primary_key
1068 )
1069 "#;
1070 let schema2 = Schema::parse(input_no_version).expect("parse failed");
1071 assert_eq!(schema2.version, None);
1072 }
1073
1074 #[test]
1079 fn test_enable_rls_table() {
1080 let input = r#"
1081 table orders (
1082 id uuid primary_key,
1083 operator_id uuid not null
1084 ) enable_rls
1085 "#;
1086
1087 let schema = Schema::parse(input).expect("parse failed");
1088 assert_eq!(schema.tables.len(), 1);
1089 assert!(schema.tables[0].enable_rls);
1090 }
1091
1092 #[test]
1093 fn test_parse_policy_basic() {
1094 let input = r#"
1095 table orders (
1096 id uuid primary_key,
1097 operator_id uuid not null
1098 ) enable_rls
1099
1100 policy orders_isolation on orders
1101 for all
1102 using (operator_id = current_setting('app.current_operator_id')::uuid)
1103 "#;
1104
1105 let schema = Schema::parse(input).expect("parse failed");
1106 assert_eq!(schema.tables.len(), 1);
1107 assert_eq!(schema.policies.len(), 1);
1108
1109 let policy = &schema.policies[0];
1110 assert_eq!(policy.name, "orders_isolation");
1111 assert_eq!(policy.table, "orders");
1112 assert_eq!(policy.target, PolicyTarget::All);
1113 assert!(policy.using.is_some());
1114
1115 match policy.using.as_ref().unwrap() {
1117 Expr::Binary { left, op, right, .. } => {
1118 assert_eq!(*op, BinaryOp::Eq);
1119 match left.as_ref() {
1120 Expr::Named(n) => assert_eq!(n, "operator_id"),
1121 _ => panic!("Expected Named, got {:?}", left),
1122 }
1123 match right.as_ref() {
1124 Expr::Cast { target_type, expr, .. } => {
1125 assert_eq!(target_type, "uuid");
1126 match expr.as_ref() {
1127 Expr::FunctionCall { name, args, .. } => {
1128 assert_eq!(name, "current_setting");
1129 assert_eq!(args.len(), 1);
1130 }
1131 _ => panic!("Expected FunctionCall"),
1132 }
1133 }
1134 _ => panic!("Expected Cast, got {:?}", right),
1135 }
1136 }
1137 _ => panic!("Expected Binary"),
1138 }
1139 }
1140
1141 #[test]
1142 fn test_parse_policy_with_check() {
1143 let input = r#"
1144 table orders (
1145 id uuid primary_key
1146 )
1147
1148 policy orders_write on orders
1149 for insert
1150 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1151 "#;
1152
1153 let schema = Schema::parse(input).expect("parse failed");
1154 let policy = &schema.policies[0];
1155 assert_eq!(policy.target, PolicyTarget::Insert);
1156 assert!(policy.with_check.is_some());
1157 assert!(policy.using.is_none());
1158 }
1159
1160 #[test]
1161 fn test_parse_policy_restrictive_with_role() {
1162 let input = r#"
1163 table secrets (
1164 id uuid primary_key
1165 )
1166
1167 policy admin_only on secrets
1168 for select
1169 restrictive
1170 to app_user
1171 using (current_setting('app.is_super_admin')::boolean = true)
1172 "#;
1173
1174 let schema = Schema::parse(input).expect("parse failed");
1175 let policy = &schema.policies[0];
1176 assert_eq!(policy.target, PolicyTarget::Select);
1177 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1178 assert_eq!(policy.role.as_deref(), Some("app_user"));
1179 assert!(policy.using.is_some());
1180 }
1181
1182 #[test]
1183 fn test_parse_policy_or_expr() {
1184 let input = r#"
1185 table orders (
1186 id uuid primary_key
1187 )
1188
1189 policy tenant_or_admin on orders
1190 for all
1191 using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1192 "#;
1193
1194 let schema = Schema::parse(input).expect("parse failed");
1195 let policy = &schema.policies[0];
1196
1197 match policy.using.as_ref().unwrap() {
1198 Expr::Binary { op: BinaryOp::Or, .. } => {}
1199 e => panic!("Expected Binary OR, got {:?}", e),
1200 }
1201 }
1202
1203 #[test]
1204 fn test_schema_to_sql() {
1205 let input = r#"
1206 table orders (
1207 id uuid primary_key,
1208 operator_id uuid not null
1209 ) enable_rls
1210
1211 policy orders_isolation on orders
1212 for all
1213 using (operator_id = current_setting('app.current_operator_id')::uuid)
1214 "#;
1215
1216 let schema = Schema::parse(input).expect("parse failed");
1217 let sql = schema.to_sql();
1218 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1219 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1220 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1221 assert!(sql.contains("CREATE POLICY"));
1222 assert!(sql.contains("orders_isolation"));
1223 assert!(sql.contains("FOR ALL"));
1224 }
1225
1226 #[test]
1227 fn test_multiple_policies() {
1228 let input = r#"
1229 table orders (
1230 id uuid primary_key,
1231 operator_id uuid not null
1232 ) enable_rls
1233
1234 policy orders_read on orders
1235 for select
1236 using (operator_id = current_setting('app.current_operator_id')::uuid)
1237
1238 policy orders_write on orders
1239 for insert
1240 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1241 "#;
1242
1243 let schema = Schema::parse(input).expect("parse failed");
1244 assert_eq!(schema.policies.len(), 2);
1245 assert_eq!(schema.policies[0].name, "orders_read");
1246 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1247 assert_eq!(schema.policies[1].name, "orders_write");
1248 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1249 }
1250}