1use nom::{
18 IResult, Parser,
19 branch::alt,
20 bytes::complete::{tag, tag_no_case, take_while1},
21 character::complete::{char, multispace0 as nom_ws0, multispace1, not_line_ending},
22 combinator::{map, opt},
23 multi::{many0, separated_list0},
24 sequence::preceded,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::ast::{Expr, BinaryOp, Value as AstValue};
29use crate::migrate::policy::{RlsPolicy, PolicyTarget, PolicyPermissiveness};
30use crate::transpiler::policy::{create_policy_sql, alter_table_sql};
31use crate::migrate::alter::AlterTable;
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct Schema {
36 #[serde(default)]
38 pub version: Option<u32>,
39 pub tables: Vec<TableDef>,
40 #[serde(default)]
42 pub policies: Vec<RlsPolicy>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TableDef {
47 pub name: String,
48 pub columns: Vec<ColumnDef>,
49 #[serde(default)]
51 pub enable_rls: bool,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ColumnDef {
56 pub name: String,
57 #[serde(rename = "type", alias = "typ")]
58 pub typ: String,
59 #[serde(default)]
61 pub is_array: bool,
62 #[serde(default)]
64 pub type_params: Option<Vec<String>>,
65 #[serde(default)]
66 pub nullable: bool,
67 #[serde(default)]
68 pub primary_key: bool,
69 #[serde(default)]
70 pub unique: bool,
71 #[serde(default)]
72 pub references: Option<String>,
73 #[serde(default)]
74 pub default_value: Option<String>,
75 #[serde(default)]
77 pub check: Option<String>,
78 #[serde(default)]
80 pub is_serial: bool,
81}
82
83impl Default for ColumnDef {
84 fn default() -> Self {
85 Self {
86 name: String::new(),
87 typ: String::new(),
88 is_array: false,
89 type_params: None,
90 nullable: true,
91 primary_key: false,
92 unique: false,
93 references: None,
94 default_value: None,
95 check: None,
96 is_serial: false,
97 }
98 }
99}
100
101impl Schema {
102 pub fn parse(input: &str) -> Result<Self, String> {
104 match parse_schema(input) {
105 Ok(("", schema)) => Ok(schema),
106 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
107 Err(e) => Err(format!("Parse error: {:?}", e)),
108 }
109 }
110
111 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
113 self.tables
114 .iter()
115 .find(|t| t.name.eq_ignore_ascii_case(name))
116 }
117
118 pub fn to_sql(&self) -> String {
120 let mut parts = Vec::new();
121
122 for table in &self.tables {
123 parts.push(table.to_ddl());
124
125 if table.enable_rls {
126 let alter = AlterTable::new(&table.name).enable_rls().force_rls();
127 for stmt in alter_table_sql(&alter) {
128 parts.push(stmt);
129 }
130 }
131 }
132
133 for policy in &self.policies {
134 parts.push(create_policy_sql(policy));
135 }
136
137 parts.join(";\n\n") + ";"
138 }
139
140 pub fn to_json(&self) -> Result<String, String> {
142 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
143 }
144
145 pub fn from_json(json: &str) -> Result<Self, String> {
147 serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
148 }
149
150 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
152 let content =
153 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
154
155 if content.trim().starts_with('{') {
156 Self::from_json(&content)
157 } else {
158 Self::parse(&content)
159 }
160 }
161}
162
163impl TableDef {
164 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
166 self.columns
167 .iter()
168 .find(|c| c.name.eq_ignore_ascii_case(name))
169 }
170
171 pub fn to_ddl(&self) -> String {
173 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
174
175 let mut col_defs = Vec::new();
176 for col in &self.columns {
177 let mut line = format!(" {}", col.name);
178
179 let mut typ = col.typ.to_uppercase();
181 if let Some(params) = &col.type_params {
182 typ = format!("{}({})", typ, params.join(", "));
183 }
184 if col.is_array {
185 typ.push_str("[]");
186 }
187 line.push_str(&format!(" {}", typ));
188
189 if col.primary_key {
191 line.push_str(" PRIMARY KEY");
192 }
193 if !col.nullable && !col.primary_key && !col.is_serial {
194 line.push_str(" NOT NULL");
195 }
196 if col.unique && !col.primary_key {
197 line.push_str(" UNIQUE");
198 }
199 if let Some(ref default) = col.default_value {
200 line.push_str(&format!(" DEFAULT {}", default));
201 }
202 if let Some(ref refs) = col.references {
203 line.push_str(&format!(" REFERENCES {}", refs));
204 }
205 if let Some(ref check) = col.check {
206 line.push_str(&format!(" CHECK({})", check));
207 }
208
209 col_defs.push(line);
210 }
211
212 sql.push_str(&col_defs.join(",\n"));
213 sql.push_str("\n)");
214 sql
215 }
216}
217
218fn identifier(input: &str) -> IResult<&str, &str> {
224 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
225}
226
227fn ws_and_comments(input: &str) -> IResult<&str, ()> {
229 let (input, _) = many0(alt((
230 map(multispace1, |_| ()),
231 map((tag("--"), not_line_ending), |_| ()),
232 )))
233 .parse(input)?;
234 Ok((input, ()))
235}
236
237struct TypeInfo {
238 name: String,
239 params: Option<Vec<String>>,
240 is_array: bool,
241 is_serial: bool,
242}
243
244fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
247 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
248
249 let (input, params) = if input.starts_with('(') {
250 let paren_start = 1;
251 let mut paren_end = paren_start;
252 for (i, c) in input[paren_start..].char_indices() {
253 if c == ')' {
254 paren_end = paren_start + i;
255 break;
256 }
257 }
258 let param_str = &input[paren_start..paren_end];
259 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
260 (&input[paren_end + 1..], Some(params))
261 } else {
262 (input, None)
263 };
264
265 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
266 (stripped, true)
267 } else {
268 (input, false)
269 };
270
271 let lower = type_name.to_lowercase();
272 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
273
274 Ok((
275 input,
276 TypeInfo {
277 name: lower,
278 params,
279 is_array,
280 is_serial,
281 },
282 ))
283}
284
285fn constraint_text(input: &str) -> IResult<&str, &str> {
287 let mut paren_depth = 0;
288 let mut end = 0;
289
290 for (i, c) in input.char_indices() {
291 match c {
292 '(' => paren_depth += 1,
293 ')' => {
294 if paren_depth == 0 {
295 break; }
297 paren_depth -= 1;
298 }
299 ',' if paren_depth == 0 => break,
300 '\n' | '\r' if paren_depth == 0 => break,
301 _ => {}
302 }
303 end = i + c.len_utf8();
304 }
305
306 if end == 0 {
307 Err(nom::Err::Error(nom::error::Error::new(
308 input,
309 nom::error::ErrorKind::TakeWhile1,
310 )))
311 } else {
312 Ok((&input[end..], &input[..end]))
313 }
314}
315
316fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
318 let (input, _) = ws_and_comments(input)?;
319 let (input, name) = identifier(input)?;
320 let (input, _) = multispace1(input)?;
321 let (input, type_info) = parse_type_info(input)?;
322
323 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
324
325 let mut col = ColumnDef {
326 name: name.to_string(),
327 typ: type_info.name,
328 is_array: type_info.is_array,
329 type_params: type_info.params,
330 is_serial: type_info.is_serial,
331 nullable: !type_info.is_serial, ..Default::default()
333 };
334
335 if let Some(constraints) = constraint_str {
336 let lower = constraints.to_lowercase();
337
338 if lower.contains("primary_key") || lower.contains("primary key") {
339 col.primary_key = true;
340 col.nullable = false;
341 }
342 if lower.contains("not_null") || lower.contains("not null") {
343 col.nullable = false;
344 }
345 if lower.contains("unique") {
346 col.unique = true;
347 }
348
349 if let Some(idx) = lower.find("references ") {
350 let rest = &constraints[idx + 11..];
351 let mut paren_depth = 0;
353 let mut end = rest.len();
354 for (i, c) in rest.char_indices() {
355 match c {
356 '(' => paren_depth += 1,
357 ')' => {
358 if paren_depth == 0 {
359 end = i;
360 break;
361 }
362 paren_depth -= 1;
363 }
364 c if c.is_whitespace() && paren_depth == 0 => {
365 end = i;
366 break;
367 }
368 _ => {}
369 }
370 }
371 col.references = Some(rest[..end].to_string());
372 }
373
374 if let Some(idx) = lower.find("default ") {
375 let rest = &constraints[idx + 8..];
376 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
377 col.default_value = Some(rest[..end].to_string());
378 }
379
380 if let Some(idx) = lower.find("check(") {
381 let rest = &constraints[idx + 6..];
382 let mut depth = 1;
384 let mut end = rest.len();
385 for (i, c) in rest.char_indices() {
386 match c {
387 '(' => depth += 1,
388 ')' => {
389 depth -= 1;
390 if depth == 0 {
391 end = i;
392 break;
393 }
394 }
395 _ => {}
396 }
397 }
398 col.check = Some(rest[..end].to_string());
399 }
400 }
401
402 Ok((input, col))
403}
404
405fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
407 let (input, _) = ws_and_comments(input)?;
408 let (input, _) = char('(').parse(input)?;
409 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
410 let (input, _) = ws_and_comments(input)?;
411 let (input, _) = char(')').parse(input)?;
412
413 Ok((input, columns))
414}
415
416fn parse_table(input: &str) -> IResult<&str, TableDef> {
418 let (input, _) = ws_and_comments(input)?;
419 let (input, _) = tag_no_case("table").parse(input)?;
420 let (input, _) = multispace1(input)?;
421 let (input, name) = identifier(input)?;
422 let (input, columns) = parse_column_list(input)?;
423
424 let (input, _) = ws_and_comments(input)?;
426 let enable_rls = if let Ok((rest, _)) =
427 tag_no_case::<_, _, nom::error::Error<&str>>("enable_rls").parse(input)
428 {
429 return Ok((
430 rest,
431 TableDef {
432 name: name.to_string(),
433 columns,
434 enable_rls: true,
435 },
436 ));
437 } else {
438 false
439 };
440
441 Ok((
442 input,
443 TableDef {
444 name: name.to_string(),
445 columns,
446 enable_rls,
447 },
448 ))
449}
450
451enum SchemaItem {
457 Table(TableDef),
458 Policy(RlsPolicy),
459}
460
461fn parse_policy(input: &str) -> IResult<&str, RlsPolicy> {
473 let (input, _) = ws_and_comments(input)?;
474 let (input, _) = tag_no_case("policy").parse(input)?;
475 let (input, _) = multispace1(input)?;
476 let (input, name) = identifier(input)?;
477 let (input, _) = multispace1(input)?;
478 let (input, _) = tag_no_case("on").parse(input)?;
479 let (input, _) = multispace1(input)?;
480 let (input, table) = identifier(input)?;
481
482 let mut policy = RlsPolicy::create(name, table);
483
484 let mut remaining = input;
486 loop {
487 let (input, _) = ws_and_comments(remaining)?;
488
489 if let Ok((rest, _)) =
491 tag_no_case::<_, _, nom::error::Error<&str>>("for").parse(input)
492 {
493 let (rest, _) = multispace1(rest)?;
494 let (rest, target) = alt((
495 map(tag_no_case("all"), |_| PolicyTarget::All),
496 map(tag_no_case("select"), |_| PolicyTarget::Select),
497 map(tag_no_case("insert"), |_| PolicyTarget::Insert),
498 map(tag_no_case("update"), |_| PolicyTarget::Update),
499 map(tag_no_case("delete"), |_| PolicyTarget::Delete),
500 ))
501 .parse(rest)?;
502 policy.target = target;
503 remaining = rest;
504 continue;
505 }
506
507 if let Ok((rest, _)) =
509 tag_no_case::<_, _, nom::error::Error<&str>>("restrictive").parse(input)
510 {
511 policy.permissiveness = PolicyPermissiveness::Restrictive;
512 remaining = rest;
513 continue;
514 }
515
516 if let Ok((rest, _)) =
518 tag_no_case::<_, _, nom::error::Error<&str>>("to").parse(input)
519 {
520 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
522 let (rest, role) = identifier(rest)?;
523 policy.role = Some(role.to_string());
524 remaining = rest;
525 continue;
526 }
527 }
528
529 if let Ok((rest, _)) =
531 tag_no_case::<_, _, nom::error::Error<&str>>("with").parse(input)
532 {
533 let (rest, _) = multispace1(rest)?;
534 if let Ok((rest, _)) =
535 tag_no_case::<_, _, nom::error::Error<&str>>("check").parse(rest)
536 {
537 let (rest, _) = nom_ws0(rest)?;
538 let (rest, _) = char('(').parse(rest)?;
539 let (rest, _) = nom_ws0(rest)?;
540 let (rest, expr) = parse_policy_expr(rest)?;
541 let (rest, _) = nom_ws0(rest)?;
542 let (rest, _) = char(')').parse(rest)?;
543 policy.with_check = Some(expr);
544 remaining = rest;
545 continue;
546 }
547 }
548
549 if let Ok((rest, _)) =
551 tag_no_case::<_, _, nom::error::Error<&str>>("using").parse(input)
552 {
553 let (rest, _) = nom_ws0(rest)?;
554 let (rest, _) = char('(').parse(rest)?;
555 let (rest, _) = nom_ws0(rest)?;
556 let (rest, expr) = parse_policy_expr(rest)?;
557 let (rest, _) = nom_ws0(rest)?;
558 let (rest, _) = char(')').parse(rest)?;
559 policy.using = Some(expr);
560 remaining = rest;
561 continue;
562 }
563
564 remaining = input;
566 break;
567 }
568
569 Ok((remaining, policy))
570}
571
572fn parse_policy_expr(input: &str) -> IResult<&str, Expr> {
581 let (input, first) = parse_policy_comparison(input)?;
582
583 let mut result = first;
585 let mut remaining = input;
586 loop {
587 let (input, _) = nom_ws0(remaining)?;
588
589 if let Ok((rest, _)) =
590 tag_no_case::<_, _, nom::error::Error<&str>>("or").parse(input)
591 {
592 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
593 let (rest, right) = parse_policy_comparison(rest)?;
594 result = Expr::Binary {
595 left: Box::new(result),
596 op: BinaryOp::Or,
597 right: Box::new(right),
598 alias: None,
599 };
600 remaining = rest;
601 continue;
602 }
603 }
604
605 if let Ok((rest, _)) =
606 tag_no_case::<_, _, nom::error::Error<&str>>("and").parse(input)
607 {
608 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<&str>>(rest) {
609 let (rest, right) = parse_policy_comparison(rest)?;
610 result = Expr::Binary {
611 left: Box::new(result),
612 op: BinaryOp::And,
613 right: Box::new(right),
614 alias: None,
615 };
616 remaining = rest;
617 continue;
618 }
619 }
620
621 remaining = input;
622 break;
623 }
624
625 Ok((remaining, result))
626}
627
628fn parse_policy_comparison(input: &str) -> IResult<&str, Expr> {
630 let (input, left) = parse_policy_atom(input)?;
631 let (input, _) = nom_ws0(input)?;
632
633 if let Ok((rest, op)) = parse_cmp_op(input) {
635 let (rest, _) = nom_ws0(rest)?;
636 let (rest, right) = parse_policy_atom(rest)?;
637 return Ok((
638 rest,
639 Expr::Binary {
640 left: Box::new(left),
641 op,
642 right: Box::new(right),
643 alias: None,
644 },
645 ));
646 }
647
648 Ok((input, left))
650}
651
652fn parse_cmp_op(input: &str) -> IResult<&str, BinaryOp> {
654 alt((
655 map(tag(">="), |_| BinaryOp::Gte),
656 map(tag("<="), |_| BinaryOp::Lte),
657 map(tag("<>"), |_| BinaryOp::Ne),
658 map(tag("!="), |_| BinaryOp::Ne),
659 map(tag("="), |_| BinaryOp::Eq),
660 map(tag(">"), |_| BinaryOp::Gt),
661 map(tag("<"), |_| BinaryOp::Lt),
662 ))
663 .parse(input)
664}
665
666fn parse_policy_atom(input: &str) -> IResult<&str, Expr> {
674 alt((
675 parse_policy_grouped,
676 parse_policy_bool,
677 parse_policy_string,
678 parse_policy_number,
679 parse_policy_func_or_ident, ))
681 .parse(input)
682}
683
684fn parse_policy_grouped(input: &str) -> IResult<&str, Expr> {
686 let (input, _) = char('(').parse(input)?;
687 let (input, _) = nom_ws0(input)?;
688 let (input, expr) = parse_policy_expr(input)?;
689 let (input, _) = nom_ws0(input)?;
690 let (input, _) = char(')').parse(input)?;
691 Ok((input, expr))
692}
693
694fn parse_policy_bool(input: &str) -> IResult<&str, Expr> {
696 alt((
697 map(tag_no_case("true"), |_| {
698 Expr::Literal(AstValue::Bool(true))
699 }),
700 map(tag_no_case("false"), |_| {
701 Expr::Literal(AstValue::Bool(false))
702 }),
703 ))
704 .parse(input)
705}
706
707fn parse_policy_string(input: &str) -> IResult<&str, Expr> {
709 let (input, _) = char('\'').parse(input)?;
710 let mut end = 0;
711 for (i, c) in input.char_indices() {
712 if c == '\'' {
713 end = i;
714 break;
715 }
716 }
717 let content = &input[..end];
718 let rest = &input[end + 1..];
719 Ok((rest, Expr::Literal(AstValue::String(content.to_string()))))
720}
721
722fn parse_policy_number(input: &str) -> IResult<&str, Expr> {
724 let (input, digits) = take_while1(|c: char| c.is_ascii_digit() || c == '.')(input)?;
725 if digits.starts_with('.') || digits.is_empty() {
727 return Err(nom::Err::Error(nom::error::Error::new(
728 input,
729 nom::error::ErrorKind::Digit,
730 )));
731 }
732 if let Ok(n) = digits.parse::<i64>() {
733 Ok((input, Expr::Literal(AstValue::Int(n))))
734 } else if let Ok(f) = digits.parse::<f64>() {
735 Ok((input, Expr::Literal(AstValue::Float(f))))
736 } else {
737 Ok((input, Expr::Named(digits.to_string())))
738 }
739}
740
741fn parse_policy_func_or_ident(input: &str) -> IResult<&str, Expr> {
743 let (input, name) = identifier(input)?;
744
745 let mut expr = if let Ok((rest, _)) = char::<_, nom::error::Error<&str>>('(').parse(input) {
747 let (rest, _) = nom_ws0(rest)?;
749 let (rest, args) = separated_list0(
750 (nom_ws0, char(','), nom_ws0),
751 parse_policy_atom,
752 )
753 .parse(rest)?;
754 let (rest, _) = nom_ws0(rest)?;
755 let (rest, _) = char(')').parse(rest)?;
756 let input = rest;
757 (input, Expr::FunctionCall {
758 name: name.to_string(),
759 args,
760 alias: None,
761 })
762 } else {
763 (input, Expr::Named(name.to_string()))
764 };
765
766 if let Ok((rest, _)) = tag::<_, _, nom::error::Error<&str>>("::").parse(expr.0) {
768 let (rest, cast_type) = identifier(rest)?;
769 expr = (
770 rest,
771 Expr::Cast {
772 expr: Box::new(expr.1),
773 target_type: cast_type.to_string(),
774 alias: None,
775 },
776 );
777 }
778
779 Ok(expr)
780}
781
782fn parse_schema_item(input: &str) -> IResult<&str, SchemaItem> {
784 let (input, _) = ws_and_comments(input)?;
785
786 if let Ok((rest, policy)) = parse_policy(input) {
788 return Ok((rest, SchemaItem::Policy(policy)));
789 }
790
791 let (rest, table) = parse_table(input)?;
793 Ok((rest, SchemaItem::Table(table)))
794}
795
796fn parse_schema(input: &str) -> IResult<&str, Schema> {
798 let version = extract_version_directive(input);
800
801 let (input, items) = many0(parse_schema_item).parse(input)?;
802 let (input, _) = ws_and_comments(input)?;
803
804 let mut tables = Vec::new();
805 let mut policies = Vec::new();
806 for item in items {
807 match item {
808 SchemaItem::Table(t) => tables.push(t),
809 SchemaItem::Policy(p) => policies.push(p),
810 }
811 }
812
813 Ok((input, Schema { version, tables, policies }))
814}
815
816fn extract_version_directive(input: &str) -> Option<u32> {
818 for line in input.lines() {
819 let line = line.trim();
820 if let Some(rest) = line.strip_prefix("-- qail:") {
821 let rest = rest.trim();
822 if let Some(version_str) = rest.strip_prefix("version=") {
823 return version_str.trim().parse().ok();
824 }
825 }
826 }
827 None
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833
834 #[test]
835 fn test_parse_simple_table() {
836 let input = r#"
837 table users (
838 id uuid primary_key,
839 email text not null,
840 name text
841 )
842 "#;
843
844 let schema = Schema::parse(input).expect("parse failed");
845 assert_eq!(schema.tables.len(), 1);
846
847 let users = &schema.tables[0];
848 assert_eq!(users.name, "users");
849 assert_eq!(users.columns.len(), 3);
850
851 let id = &users.columns[0];
852 assert_eq!(id.name, "id");
853 assert_eq!(id.typ, "uuid");
854 assert!(id.primary_key);
855 assert!(!id.nullable);
856
857 let email = &users.columns[1];
858 assert_eq!(email.name, "email");
859 assert!(!email.nullable);
860
861 let name = &users.columns[2];
862 assert!(name.nullable);
863 }
864
865 #[test]
866 fn test_parse_multiple_tables() {
867 let input = r#"
868 -- Users table
869 table users (
870 id uuid primary_key,
871 email text not null unique
872 )
873
874 -- Orders table
875 table orders (
876 id uuid primary_key,
877 user_id uuid references users(id),
878 total i64 not null default 0
879 )
880 "#;
881
882 let schema = Schema::parse(input).expect("parse failed");
883 assert_eq!(schema.tables.len(), 2);
884
885 let orders = schema.find_table("orders").expect("orders not found");
886 let user_id = orders.find_column("user_id").expect("user_id not found");
887 assert_eq!(user_id.references, Some("users(id)".to_string()));
888
889 let total = orders.find_column("total").expect("total not found");
890 assert_eq!(total.default_value, Some("0".to_string()));
891 }
892
893 #[test]
894 fn test_parse_comments() {
895 let input = r#"
896 -- This is a comment
897 table foo (
898 bar text
899 )
900 "#;
901
902 let schema = Schema::parse(input).expect("parse failed");
903 assert_eq!(schema.tables.len(), 1);
904 }
905
906 #[test]
907 fn test_array_types() {
908 let input = r#"
909 table products (
910 id uuid primary_key,
911 tags text[],
912 prices decimal[]
913 )
914 "#;
915
916 let schema = Schema::parse(input).expect("parse failed");
917 let products = &schema.tables[0];
918
919 let tags = products.find_column("tags").expect("tags not found");
920 assert_eq!(tags.typ, "text");
921 assert!(tags.is_array);
922
923 let prices = products.find_column("prices").expect("prices not found");
924 assert!(prices.is_array);
925 }
926
927 #[test]
928 fn test_type_params() {
929 let input = r#"
930 table items (
931 id serial primary_key,
932 name varchar(255) not null,
933 price decimal(10,2),
934 code varchar(50) unique
935 )
936 "#;
937
938 let schema = Schema::parse(input).expect("parse failed");
939 let items = &schema.tables[0];
940
941 let id = items.find_column("id").expect("id not found");
942 assert!(id.is_serial);
943 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
946 assert_eq!(name.typ, "varchar");
947 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
948
949 let price = items.find_column("price").expect("price not found");
950 assert_eq!(
951 price.type_params,
952 Some(vec!["10".to_string(), "2".to_string()])
953 );
954
955 let code = items.find_column("code").expect("code not found");
956 assert!(code.unique);
957 }
958
959 #[test]
960 fn test_check_constraint() {
961 let input = r#"
962 table employees (
963 id uuid primary_key,
964 age i32 check(age >= 18),
965 salary decimal check(salary > 0)
966 )
967 "#;
968
969 let schema = Schema::parse(input).expect("parse failed");
970 let employees = &schema.tables[0];
971
972 let age = employees.find_column("age").expect("age not found");
973 assert_eq!(age.check, Some("age >= 18".to_string()));
974
975 let salary = employees.find_column("salary").expect("salary not found");
976 assert_eq!(salary.check, Some("salary > 0".to_string()));
977 }
978
979 #[test]
980 fn test_version_directive() {
981 let input = r#"
982 -- qail: version=1
983 table users (
984 id uuid primary_key
985 )
986 "#;
987
988 let schema = Schema::parse(input).expect("parse failed");
989 assert_eq!(schema.version, Some(1));
990 assert_eq!(schema.tables.len(), 1);
991
992 let input_no_version = r#"
994 table items (
995 id uuid primary_key
996 )
997 "#;
998 let schema2 = Schema::parse(input_no_version).expect("parse failed");
999 assert_eq!(schema2.version, None);
1000 }
1001
1002 #[test]
1007 fn test_enable_rls_table() {
1008 let input = r#"
1009 table orders (
1010 id uuid primary_key,
1011 operator_id uuid not null
1012 ) enable_rls
1013 "#;
1014
1015 let schema = Schema::parse(input).expect("parse failed");
1016 assert_eq!(schema.tables.len(), 1);
1017 assert!(schema.tables[0].enable_rls);
1018 }
1019
1020 #[test]
1021 fn test_parse_policy_basic() {
1022 let input = r#"
1023 table orders (
1024 id uuid primary_key,
1025 operator_id uuid not null
1026 ) enable_rls
1027
1028 policy orders_isolation on orders
1029 for all
1030 using (operator_id = current_setting('app.current_operator_id')::uuid)
1031 "#;
1032
1033 let schema = Schema::parse(input).expect("parse failed");
1034 assert_eq!(schema.tables.len(), 1);
1035 assert_eq!(schema.policies.len(), 1);
1036
1037 let policy = &schema.policies[0];
1038 assert_eq!(policy.name, "orders_isolation");
1039 assert_eq!(policy.table, "orders");
1040 assert_eq!(policy.target, PolicyTarget::All);
1041 assert!(policy.using.is_some());
1042
1043 match policy.using.as_ref().unwrap() {
1045 Expr::Binary { left, op, right, .. } => {
1046 assert_eq!(*op, BinaryOp::Eq);
1047 match left.as_ref() {
1048 Expr::Named(n) => assert_eq!(n, "operator_id"),
1049 _ => panic!("Expected Named, got {:?}", left),
1050 }
1051 match right.as_ref() {
1052 Expr::Cast { target_type, expr, .. } => {
1053 assert_eq!(target_type, "uuid");
1054 match expr.as_ref() {
1055 Expr::FunctionCall { name, args, .. } => {
1056 assert_eq!(name, "current_setting");
1057 assert_eq!(args.len(), 1);
1058 }
1059 _ => panic!("Expected FunctionCall"),
1060 }
1061 }
1062 _ => panic!("Expected Cast, got {:?}", right),
1063 }
1064 }
1065 _ => panic!("Expected Binary"),
1066 }
1067 }
1068
1069 #[test]
1070 fn test_parse_policy_with_check() {
1071 let input = r#"
1072 table orders (
1073 id uuid primary_key
1074 )
1075
1076 policy orders_write on orders
1077 for insert
1078 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1079 "#;
1080
1081 let schema = Schema::parse(input).expect("parse failed");
1082 let policy = &schema.policies[0];
1083 assert_eq!(policy.target, PolicyTarget::Insert);
1084 assert!(policy.with_check.is_some());
1085 assert!(policy.using.is_none());
1086 }
1087
1088 #[test]
1089 fn test_parse_policy_restrictive_with_role() {
1090 let input = r#"
1091 table secrets (
1092 id uuid primary_key
1093 )
1094
1095 policy admin_only on secrets
1096 for select
1097 restrictive
1098 to app_user
1099 using (current_setting('app.is_super_admin')::boolean = true)
1100 "#;
1101
1102 let schema = Schema::parse(input).expect("parse failed");
1103 let policy = &schema.policies[0];
1104 assert_eq!(policy.target, PolicyTarget::Select);
1105 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
1106 assert_eq!(policy.role.as_deref(), Some("app_user"));
1107 assert!(policy.using.is_some());
1108 }
1109
1110 #[test]
1111 fn test_parse_policy_or_expr() {
1112 let input = r#"
1113 table orders (
1114 id uuid primary_key
1115 )
1116
1117 policy tenant_or_admin on orders
1118 for all
1119 using (operator_id = current_setting('app.current_operator_id')::uuid or current_setting('app.is_super_admin')::boolean = true)
1120 "#;
1121
1122 let schema = Schema::parse(input).expect("parse failed");
1123 let policy = &schema.policies[0];
1124
1125 match policy.using.as_ref().unwrap() {
1126 Expr::Binary { op: BinaryOp::Or, .. } => {}
1127 e => panic!("Expected Binary OR, got {:?}", e),
1128 }
1129 }
1130
1131 #[test]
1132 fn test_schema_to_sql() {
1133 let input = r#"
1134 table orders (
1135 id uuid primary_key,
1136 operator_id uuid not null
1137 ) enable_rls
1138
1139 policy orders_isolation on orders
1140 for all
1141 using (operator_id = current_setting('app.current_operator_id')::uuid)
1142 "#;
1143
1144 let schema = Schema::parse(input).expect("parse failed");
1145 let sql = schema.to_sql();
1146 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
1147 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
1148 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
1149 assert!(sql.contains("CREATE POLICY"));
1150 assert!(sql.contains("orders_isolation"));
1151 assert!(sql.contains("FOR ALL"));
1152 }
1153
1154 #[test]
1155 fn test_multiple_policies() {
1156 let input = r#"
1157 table orders (
1158 id uuid primary_key,
1159 operator_id uuid not null
1160 ) enable_rls
1161
1162 policy orders_read on orders
1163 for select
1164 using (operator_id = current_setting('app.current_operator_id')::uuid)
1165
1166 policy orders_write on orders
1167 for insert
1168 with check (operator_id = current_setting('app.current_operator_id')::uuid)
1169 "#;
1170
1171 let schema = Schema::parse(input).expect("parse failed");
1172 assert_eq!(schema.policies.len(), 2);
1173 assert_eq!(schema.policies[0].name, "orders_read");
1174 assert_eq!(schema.policies[0].target, PolicyTarget::Select);
1175 assert_eq!(schema.policies[1].name, "orders_write");
1176 assert_eq!(schema.policies[1].target, PolicyTarget::Insert);
1177 }
1178}