1mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{
17 MssqlBlockOperation, Policy, PolicyCommand, PolicyType, Server, ServerGroup, ServerProperty,
18 ServerPropertyValue,
19};
20
21pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
23 debug!(input_len = input.len(), "parse_schema() starting");
24 let pairs = PraxParser::parse(Rule::schema, input)
25 .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
26
27 let mut schema = Schema::new();
28 let mut current_doc: Option<Documentation> = None;
29
30 let schema_pair = pairs.into_iter().next().unwrap();
32
33 for pair in schema_pair.into_inner() {
34 match pair.as_rule() {
35 Rule::documentation => {
36 let span = pair.as_span();
37 let text = pair
38 .into_inner()
39 .map(|p| p.as_str().trim_start_matches("///").trim())
40 .collect::<Vec<_>>()
41 .join("\n");
42 current_doc = Some(Documentation::new(
43 text,
44 Span::new(span.start(), span.end()),
45 ));
46 }
47 Rule::model_def => {
48 let mut model = parse_model(pair)?;
49 if let Some(doc) = current_doc.take() {
50 model = model.with_documentation(doc);
51 }
52 schema.add_model(model);
53 }
54 Rule::enum_def => {
55 let mut e = parse_enum(pair)?;
56 if let Some(doc) = current_doc.take() {
57 e = e.with_documentation(doc);
58 }
59 schema.add_enum(e);
60 }
61 Rule::type_def => {
62 let mut t = parse_composite_type(pair)?;
63 if let Some(doc) = current_doc.take() {
64 t = t.with_documentation(doc);
65 }
66 schema.add_type(t);
67 }
68 Rule::view_def => {
69 let mut v = parse_view(pair)?;
70 if let Some(doc) = current_doc.take() {
71 v = v.with_documentation(doc);
72 }
73 schema.add_view(v);
74 }
75 Rule::raw_sql_def => {
76 let sql = parse_raw_sql(pair)?;
77 schema.add_raw_sql(sql);
78 }
79 Rule::server_group_def => {
80 let mut sg = parse_server_group(pair)?;
81 if let Some(doc) = current_doc.take() {
82 sg.set_documentation(doc);
83 }
84 schema.add_server_group(sg);
85 }
86 Rule::policy_def => {
87 let mut policy = parse_policy(pair)?;
88 if let Some(doc) = current_doc.take() {
89 policy = policy.with_documentation(doc);
90 }
91 schema.add_policy(policy);
92 }
93 Rule::datasource_def => {
94 let ds = parse_datasource(pair)?;
95 schema.set_datasource(ds);
96 current_doc = None;
97 }
98 Rule::generator_def => {
99 current_doc = None;
102 }
103 Rule::EOI => {}
104 _ => {}
105 }
106 }
107
108 info!(
109 models = schema.models.len(),
110 enums = schema.enums.len(),
111 types = schema.types.len(),
112 views = schema.views.len(),
113 policies = schema.policies.len(),
114 "Schema parsed successfully"
115 );
116 Ok(schema)
117}
118
119pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
121 let path = path.as_ref();
122 info!(path = %path.display(), "Loading schema file");
123 let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
124 path: path.display().to_string(),
125 source: e,
126 })?;
127
128 parse_schema(&content)
129}
130
131fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
133 let span = pair.as_span();
134 let mut inner = pair.into_inner();
135
136 let name_pair = inner.next().unwrap();
137 let name = Ident::new(
138 name_pair.as_str(),
139 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
140 );
141
142 let mut model = Model::new(name, Span::new(span.start(), span.end()));
143
144 for item in inner {
145 match item.as_rule() {
146 Rule::field_def => {
147 let field = parse_field(item)?;
148 model.add_field(field);
149 }
150 Rule::model_attribute => {
151 let attr = parse_attribute(item)?;
152 model.attributes.push(attr);
153 }
154 Rule::model_body_item => {
155 let inner_item = item.into_inner().next().unwrap();
157 match inner_item.as_rule() {
158 Rule::field_def => {
159 let field = parse_field(inner_item)?;
160 model.add_field(field);
161 }
162 Rule::model_attribute => {
163 let attr = parse_attribute(inner_item)?;
164 model.attributes.push(attr);
165 }
166 _ => {}
167 }
168 }
169 _ => {}
170 }
171 }
172
173 Ok(model)
174}
175
176fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
178 let span = pair.as_span();
179 let mut inner = pair.into_inner();
180
181 let name_pair = inner.next().unwrap();
182 let name = Ident::new(
183 name_pair.as_str(),
184 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
185 );
186
187 let mut e = Enum::new(name, Span::new(span.start(), span.end()));
188
189 for item in inner {
190 match item.as_rule() {
191 Rule::enum_variant => {
192 let variant = parse_enum_variant(item)?;
193 e.add_variant(variant);
194 }
195 Rule::model_attribute => {
196 let attr = parse_attribute(item)?;
197 e.attributes.push(attr);
198 }
199 Rule::enum_body_item => {
200 let inner_item = item.into_inner().next().unwrap();
202 match inner_item.as_rule() {
203 Rule::enum_variant => {
204 let variant = parse_enum_variant(inner_item)?;
205 e.add_variant(variant);
206 }
207 Rule::model_attribute => {
208 let attr = parse_attribute(inner_item)?;
209 e.attributes.push(attr);
210 }
211 _ => {}
212 }
213 }
214 _ => {}
215 }
216 }
217
218 Ok(e)
219}
220
221fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
223 let span = pair.as_span();
224 let mut inner = pair.into_inner();
225
226 let name_pair = inner.next().unwrap();
227 let name = Ident::new(
228 name_pair.as_str(),
229 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
230 );
231
232 let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
233
234 for item in inner {
235 if item.as_rule() == Rule::field_attribute {
236 let attr = parse_attribute(item)?;
237 variant.attributes.push(attr);
238 }
239 }
240
241 Ok(variant)
242}
243
244fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
246 let span = pair.as_span();
247 let mut inner = pair.into_inner();
248
249 let name_pair = inner.next().unwrap();
250 let name = Ident::new(
251 name_pair.as_str(),
252 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
253 );
254
255 let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
256
257 for item in inner {
258 if item.as_rule() == Rule::field_def {
259 let field = parse_field(item)?;
260 t.add_field(field);
261 }
262 }
263
264 Ok(t)
265}
266
267fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
269 let span = pair.as_span();
270 let mut inner = pair.into_inner();
271
272 let name_pair = inner.next().unwrap();
273 let name = Ident::new(
274 name_pair.as_str(),
275 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
276 );
277
278 let mut v = View::new(name, Span::new(span.start(), span.end()));
279
280 for item in inner {
281 match item.as_rule() {
282 Rule::field_def => {
283 let field = parse_field(item)?;
284 v.add_field(field);
285 }
286 Rule::model_attribute => {
287 let attr = parse_attribute(item)?;
288 v.attributes.push(attr);
289 }
290 Rule::model_body_item => {
291 let inner_item = item.into_inner().next().unwrap();
293 match inner_item.as_rule() {
294 Rule::field_def => {
295 let field = parse_field(inner_item)?;
296 v.add_field(field);
297 }
298 Rule::model_attribute => {
299 let attr = parse_attribute(inner_item)?;
300 v.attributes.push(attr);
301 }
302 _ => {}
303 }
304 }
305 _ => {}
306 }
307 }
308
309 Ok(v)
310}
311
312fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
314 let span = pair.as_span();
315 let mut inner = pair.into_inner();
316
317 let name_pair = inner.next().unwrap();
318 let name = Ident::new(
319 name_pair.as_str(),
320 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
321 );
322
323 let type_pair = inner.next().unwrap();
324 let (field_type, modifier) = parse_field_type(type_pair)?;
325
326 let mut attributes = vec![];
327 for item in inner {
328 if item.as_rule() == Rule::field_attribute {
329 let attr = parse_attribute(item)?;
330 attributes.push(attr);
331 }
332 }
333
334 Ok(Field::new(
335 name,
336 field_type,
337 modifier,
338 attributes,
339 Span::new(span.start(), span.end()),
340 ))
341}
342
343fn parse_field_type(
345 pair: pest::iterators::Pair<'_, Rule>,
346) -> SchemaResult<(FieldType, TypeModifier)> {
347 let mut type_name = String::new();
348 let mut modifier = TypeModifier::Required;
349
350 for item in pair.into_inner() {
351 match item.as_rule() {
352 Rule::type_name => {
353 type_name = item.as_str().to_string();
354 }
355 Rule::optional_marker => {
356 modifier = if modifier == TypeModifier::List {
357 TypeModifier::OptionalList
358 } else {
359 TypeModifier::Optional
360 };
361 }
362 Rule::list_marker => {
363 modifier = if modifier == TypeModifier::Optional {
364 TypeModifier::OptionalList
365 } else {
366 TypeModifier::List
367 };
368 }
369 _ => {}
370 }
371 }
372
373 let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
374 FieldType::Scalar(scalar)
375 } else {
376 FieldType::Model(SmolStr::new(&type_name))
379 };
380
381 Ok((field_type, modifier))
382}
383
384fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
386 let span = pair.as_span();
387 let mut inner = pair.into_inner();
388
389 let name_pair = inner.next().unwrap();
390 let name = Ident::new(
391 name_pair.as_str(),
392 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
393 );
394
395 let mut args = vec![];
396 for item in inner {
397 if item.as_rule() == Rule::attribute_args {
398 args = parse_attribute_args(item)?;
399 }
400 }
401
402 Ok(Attribute::new(
403 name,
404 args,
405 Span::new(span.start(), span.end()),
406 ))
407}
408
409fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
411 let mut args = vec![];
412
413 for item in pair.into_inner() {
414 if item.as_rule() == Rule::attribute_arg {
415 let arg = parse_attribute_arg(item)?;
416 args.push(arg);
417 }
418 }
419
420 Ok(args)
421}
422
423fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
425 let span = pair.as_span();
426 let mut inner = pair.into_inner();
427
428 let first = inner.next().unwrap();
429
430 if let Some(second) = inner.next() {
432 let name = Ident::new(
434 first.as_str(),
435 Span::new(first.as_span().start(), first.as_span().end()),
436 );
437 let value = parse_attribute_value(second)?;
438 Ok(AttributeArg::named(
439 name,
440 value,
441 Span::new(span.start(), span.end()),
442 ))
443 } else {
444 let value = parse_attribute_value(first)?;
446 Ok(AttributeArg::positional(
447 value,
448 Span::new(span.start(), span.end()),
449 ))
450 }
451}
452
453fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
455 match pair.as_rule() {
456 Rule::string_literal => {
457 let s = pair.as_str();
458 let unquoted = &s[1..s.len() - 1];
460 Ok(AttributeValue::String(unquoted.to_string()))
461 }
462 Rule::number_literal => {
463 let s = pair.as_str();
464 if s.contains('.') {
465 Ok(AttributeValue::Float(s.parse().unwrap()))
466 } else {
467 Ok(AttributeValue::Int(s.parse().unwrap()))
468 }
469 }
470 Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
471 Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
472 Rule::function_call => {
473 let mut inner = pair.into_inner();
474 let name = SmolStr::new(inner.next().unwrap().as_str());
475 let mut args = vec![];
476 for item in inner {
477 args.push(parse_attribute_value(item)?);
478 }
479 Ok(AttributeValue::Function(name, args))
480 }
481 Rule::field_ref_list => {
482 let refs: Vec<SmolStr> = pair
483 .into_inner()
484 .map(|p| SmolStr::new(p.as_str()))
485 .collect();
486 Ok(AttributeValue::FieldRefList(refs))
487 }
488 Rule::array_literal => {
489 let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
490 Ok(AttributeValue::Array(values?))
491 }
492 Rule::attribute_value => {
493 parse_attribute_value(pair.into_inner().next().unwrap())
495 }
496 _ => {
497 Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
499 }
500 }
501}
502
503fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
505 let mut inner = pair.into_inner();
506
507 let name = inner.next().unwrap().as_str();
508 let sql = inner.next().unwrap().as_str();
509
510 let sql_content = sql
512 .trim_start_matches("\"\"\"")
513 .trim_end_matches("\"\"\"")
514 .trim();
515
516 Ok(RawSql::new(name, sql_content))
517}
518
519fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
521 let span = pair.as_span();
522 let mut inner = pair.into_inner();
523
524 let name_pair = inner.next().unwrap();
525 let name = Ident::new(
526 name_pair.as_str(),
527 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
528 );
529
530 let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
531
532 for item in inner {
533 match item.as_rule() {
534 Rule::server_group_item => {
535 let inner_item = item.into_inner().next().unwrap();
537 match inner_item.as_rule() {
538 Rule::server_def => {
539 let server = parse_server(inner_item)?;
540 server_group.add_server(server);
541 }
542 Rule::model_attribute => {
543 let attr = parse_attribute(inner_item)?;
544 server_group.add_attribute(attr);
545 }
546 _ => {}
547 }
548 }
549 Rule::server_def => {
550 let server = parse_server(item)?;
551 server_group.add_server(server);
552 }
553 Rule::model_attribute => {
554 let attr = parse_attribute(item)?;
555 server_group.add_attribute(attr);
556 }
557 _ => {}
558 }
559 }
560
561 Ok(server_group)
562}
563
564fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
566 let span = pair.as_span();
567 let mut inner = pair.into_inner();
568
569 let name_pair = inner.next().unwrap();
570 let name = Ident::new(
571 name_pair.as_str(),
572 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
573 );
574
575 let mut server = Server::new(name, Span::new(span.start(), span.end()));
576
577 for item in inner {
578 if item.as_rule() == Rule::server_property {
579 let prop = parse_server_property(item)?;
580 server.add_property(prop);
581 }
582 }
583
584 Ok(server)
585}
586
587fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
589 let span = pair.as_span();
590 let mut inner = pair.into_inner();
591
592 let key_pair = inner.next().unwrap();
593 let key = key_pair.as_str();
594
595 let value_pair = inner.next().unwrap();
596 let value = parse_server_property_value(value_pair)?;
597
598 Ok(ServerProperty::new(
599 key,
600 value,
601 Span::new(span.start(), span.end()),
602 ))
603}
604
605fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Datasource> {
607 let span = pair.as_span();
608 let mut inner = pair.into_inner();
609
610 let name_pair = inner.next().unwrap();
611 let name = name_pair.as_str();
612
613 let mut datasource = Datasource::new(
614 name,
615 DatabaseProvider::PostgreSQL,
616 Span::new(span.start(), span.end()),
617 );
618
619 for prop in inner {
620 if prop.as_rule() == Rule::datasource_property {
621 let mut prop_inner = prop.into_inner();
622 let key = prop_inner.next().unwrap().as_str();
623 let value_pair = prop_inner.next().unwrap();
624
625 match key {
626 "provider" => {
627 let provider_str = extract_datasource_string(&value_pair);
628 if let Some(provider) = DatabaseProvider::from_str(&provider_str) {
629 datasource.provider = provider;
630 }
631 }
632 "url" => {
633 match value_pair.as_rule() {
634 Rule::env_function => {
635 let env_var = value_pair
637 .into_inner()
638 .next()
639 .map(|p| {
640 let s = p.as_str();
641 s[1..s.len() - 1].to_string()
642 })
643 .unwrap_or_default();
644 datasource.url_env = Some(SmolStr::new(env_var));
645 }
646 Rule::string_literal => {
647 let s = value_pair.as_str();
648 let url = &s[1..s.len() - 1];
649 datasource.url = Some(SmolStr::new(url));
650 }
651 _ => {}
652 }
653 }
654 "extensions" => {
655 if value_pair.as_rule() == Rule::extension_array {
656 for ext_item in value_pair.into_inner() {
657 if ext_item.as_rule() == Rule::extension_item {
658 let ext = parse_extension_item(
659 ext_item,
660 Span::new(span.start(), span.end()),
661 )?;
662 datasource.add_extension(ext);
663 }
664 }
665 }
666 }
667 _ => {
668 let value_str = extract_datasource_string(&value_pair);
670 datasource.add_property(key, value_str);
671 }
672 }
673 }
674 }
675
676 Ok(datasource)
677}
678
679fn parse_extension_item(
681 pair: pest::iterators::Pair<'_, Rule>,
682 span: Span,
683) -> SchemaResult<PostgresExtension> {
684 let mut inner = pair.into_inner();
685 let name = inner.next().unwrap().as_str();
686 let mut ext = PostgresExtension::new(name, span);
687
688 if let Some(args_pair) = inner.next() {
690 if args_pair.as_rule() == Rule::extension_args {
691 for arg in args_pair.into_inner() {
692 if arg.as_rule() == Rule::extension_arg {
693 let mut arg_inner = arg.into_inner();
694 let arg_key = arg_inner.next().unwrap().as_str();
695 let arg_value_pair = arg_inner.next().unwrap();
696 let arg_value = {
697 let s = arg_value_pair.as_str();
698 &s[1..s.len() - 1]
699 };
700
701 match arg_key {
702 "schema" => {
703 ext = ext.with_schema(arg_value);
704 }
705 "version" => {
706 ext = ext.with_version(arg_value);
707 }
708 _ => {}
709 }
710 }
711 }
712 }
713 }
714
715 Ok(ext)
716}
717
718fn extract_datasource_string(pair: &pest::iterators::Pair<'_, Rule>) -> String {
720 match pair.as_rule() {
721 Rule::string_literal => {
722 let s = pair.as_str();
723 s[1..s.len() - 1].to_string()
724 }
725 Rule::identifier => pair.as_str().to_string(),
726 Rule::datasource_value => {
727 if let Some(inner) = pair.clone().into_inner().next() {
728 extract_datasource_string(&inner)
729 } else {
730 pair.as_str().to_string()
731 }
732 }
733 _ => pair.as_str().to_string(),
734 }
735}
736
737fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
739 match pair.as_rule() {
740 Rule::string_literal => {
741 let s = pair.as_str();
742 s[1..s.len() - 1].to_string()
743 }
744 Rule::attribute_value => {
745 if let Some(inner) = pair.into_inner().next() {
747 extract_string_from_arg(inner)
748 } else {
749 String::new()
750 }
751 }
752 _ => pair.as_str().to_string(),
753 }
754}
755
756fn parse_server_property_value(
758 pair: pest::iterators::Pair<'_, Rule>,
759) -> SchemaResult<ServerPropertyValue> {
760 match pair.as_rule() {
761 Rule::string_literal => {
762 let s = pair.as_str();
763 let unquoted = &s[1..s.len() - 1];
765 Ok(ServerPropertyValue::String(unquoted.to_string()))
766 }
767 Rule::number_literal => {
768 let s = pair.as_str();
769 Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
770 }
771 Rule::boolean_literal => Ok(ServerPropertyValue::Boolean(pair.as_str() == "true")),
772 Rule::identifier => Ok(ServerPropertyValue::Identifier(pair.as_str().to_string())),
773 Rule::function_call => {
774 let mut inner = pair.into_inner();
776 let func_name = inner.next().unwrap().as_str();
777 if func_name == "env" {
778 if let Some(arg) = inner.next() {
779 let var_name = extract_string_from_arg(arg);
780 return Ok(ServerPropertyValue::EnvVar(var_name));
781 }
782 }
783 Ok(ServerPropertyValue::Identifier(func_name.to_string()))
785 }
786 Rule::array_literal => {
787 let values: Result<Vec<_>, _> =
788 pair.into_inner().map(parse_server_property_value).collect();
789 Ok(ServerPropertyValue::Array(values?))
790 }
791 Rule::attribute_value => {
792 parse_server_property_value(pair.into_inner().next().unwrap())
794 }
795 _ => {
796 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
798 }
799 }
800}
801
802fn parse_policy(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Policy> {
804 let span = pair.as_span();
805 let mut inner = pair.into_inner();
806
807 let name_pair = inner.next().unwrap();
809 let name = Ident::new(
810 name_pair.as_str(),
811 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
812 );
813
814 let table_pair = inner.next().unwrap();
816 let table = Ident::new(
817 table_pair.as_str(),
818 Span::new(table_pair.as_span().start(), table_pair.as_span().end()),
819 );
820
821 let mut policy = Policy::new(name, table, Span::new(span.start(), span.end()));
822 policy.commands = vec![];
824
825 for item in inner {
826 match item.as_rule() {
827 Rule::policy_item => {
828 let inner_item = item.into_inner().next().unwrap();
829 parse_policy_item(&mut policy, inner_item)?;
830 }
831 Rule::policy_for
832 | Rule::policy_to
833 | Rule::policy_as
834 | Rule::policy_using
835 | Rule::policy_check => {
836 parse_policy_item(&mut policy, item)?;
837 }
838 _ => {}
839 }
840 }
841
842 if policy.commands.is_empty() {
844 policy.commands.push(PolicyCommand::All);
845 }
846
847 Ok(policy)
848}
849
850fn parse_policy_item(
852 policy: &mut Policy,
853 pair: pest::iterators::Pair<'_, Rule>,
854) -> SchemaResult<()> {
855 match pair.as_rule() {
856 Rule::policy_for => {
857 let inner = pair.into_inner().next().unwrap();
858 match inner.as_rule() {
859 Rule::policy_command => {
860 if let Some(cmd) = PolicyCommand::from_str(inner.as_str()) {
861 policy.add_command(cmd);
862 }
863 }
864 Rule::policy_command_list => {
865 for cmd_pair in inner.into_inner() {
866 if cmd_pair.as_rule() == Rule::policy_command {
867 if let Some(cmd) = PolicyCommand::from_str(cmd_pair.as_str()) {
868 policy.add_command(cmd);
869 }
870 }
871 }
872 }
873 _ => {}
874 }
875 }
876 Rule::policy_to => {
877 let inner = pair.into_inner().next().unwrap();
878 match inner.as_rule() {
879 Rule::identifier => {
880 policy.add_role(inner.as_str());
881 }
882 Rule::policy_role_list => {
883 for role_pair in inner.into_inner() {
884 if role_pair.as_rule() == Rule::identifier {
885 policy.add_role(role_pair.as_str());
886 }
887 }
888 }
889 _ => {}
890 }
891 }
892 Rule::policy_as => {
893 let inner = pair.into_inner().next().unwrap();
894 if inner.as_rule() == Rule::policy_type {
895 if let Some(policy_type) = PolicyType::from_str(inner.as_str()) {
896 policy.policy_type = policy_type;
897 }
898 }
899 }
900 Rule::policy_using => {
901 let inner = pair.into_inner().next().unwrap();
902 let expr = extract_policy_expression(&inner);
903 policy.using_expr = Some(expr);
904 }
905 Rule::policy_check => {
906 let inner = pair.into_inner().next().unwrap();
907 let expr = extract_policy_expression(&inner);
908 policy.check_expr = Some(expr);
909 }
910 Rule::policy_mssql_schema => {
911 let inner = pair.into_inner().next().unwrap();
912 if inner.as_rule() == Rule::string_literal {
913 let s = inner.as_str();
914 let schema = &s[1..s.len() - 1]; policy.mssql_schema = Some(SmolStr::new(schema));
916 }
917 }
918 Rule::policy_mssql_block => {
919 let inner = pair.into_inner().next().unwrap();
920 match inner.as_rule() {
921 Rule::mssql_block_op => {
922 if let Some(op) = MssqlBlockOperation::from_str(inner.as_str()) {
923 policy.add_mssql_block_operation(op);
924 }
925 }
926 Rule::mssql_block_op_list => {
927 for op_pair in inner.into_inner() {
928 if op_pair.as_rule() == Rule::mssql_block_op {
929 if let Some(op) = MssqlBlockOperation::from_str(op_pair.as_str()) {
930 policy.add_mssql_block_operation(op);
931 }
932 }
933 }
934 }
935 _ => {}
936 }
937 }
938 _ => {}
939 }
940 Ok(())
941}
942
943fn extract_policy_expression(pair: &pest::iterators::Pair<'_, Rule>) -> String {
945 let s = pair.as_str();
946 match pair.as_rule() {
947 Rule::multiline_string => {
948 s.trim_start_matches("\"\"\"")
950 .trim_end_matches("\"\"\"")
951 .trim()
952 .to_string()
953 }
954 Rule::string_literal => {
955 s[1..s.len() - 1].to_string()
957 }
958 _ => s.to_string(),
959 }
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965
966 #[test]
969 fn test_parse_simple_model() {
970 let schema = parse_schema(
971 r#"
972 model User {
973 id Int @id @auto
974 email String @unique
975 name String?
976 }
977 "#,
978 )
979 .unwrap();
980
981 assert_eq!(schema.models.len(), 1);
982 let user = schema.get_model("User").unwrap();
983 assert_eq!(user.fields.len(), 3);
984 assert!(user.get_field("id").unwrap().is_id());
985 assert!(user.get_field("email").unwrap().is_unique());
986 assert!(user.get_field("name").unwrap().is_optional());
987 }
988
989 #[test]
990 fn test_parse_model_name() {
991 let schema = parse_schema(
992 r#"
993 model BlogPost {
994 id Int @id
995 }
996 "#,
997 )
998 .unwrap();
999
1000 assert!(schema.get_model("BlogPost").is_some());
1001 }
1002
1003 #[test]
1004 fn test_parse_multiple_models() {
1005 let schema = parse_schema(
1006 r#"
1007 model User {
1008 id Int @id
1009 }
1010
1011 model Post {
1012 id Int @id
1013 }
1014
1015 model Comment {
1016 id Int @id
1017 }
1018 "#,
1019 )
1020 .unwrap();
1021
1022 assert_eq!(schema.models.len(), 3);
1023 assert!(schema.get_model("User").is_some());
1024 assert!(schema.get_model("Post").is_some());
1025 assert!(schema.get_model("Comment").is_some());
1026 }
1027
1028 #[test]
1031 fn test_parse_all_scalar_types() {
1032 let schema = parse_schema(
1033 r#"
1034 model AllTypes {
1035 id Int @id
1036 big BigInt
1037 float_f Float
1038 decimal Decimal
1039 str String
1040 bool Boolean
1041 datetime DateTime
1042 date Date
1043 time Time
1044 json Json
1045 bytes Bytes
1046 uuid Uuid
1047 cuid Cuid
1048 cuid2 Cuid2
1049 nanoid NanoId
1050 ulid Ulid
1051 }
1052 "#,
1053 )
1054 .unwrap();
1055
1056 let model = schema.get_model("AllTypes").unwrap();
1057 assert_eq!(model.fields.len(), 16);
1058
1059 assert!(matches!(
1060 model.get_field("id").unwrap().field_type,
1061 FieldType::Scalar(ScalarType::Int)
1062 ));
1063 assert!(matches!(
1064 model.get_field("big").unwrap().field_type,
1065 FieldType::Scalar(ScalarType::BigInt)
1066 ));
1067 assert!(matches!(
1068 model.get_field("str").unwrap().field_type,
1069 FieldType::Scalar(ScalarType::String)
1070 ));
1071 assert!(matches!(
1072 model.get_field("bool").unwrap().field_type,
1073 FieldType::Scalar(ScalarType::Boolean)
1074 ));
1075 assert!(matches!(
1076 model.get_field("datetime").unwrap().field_type,
1077 FieldType::Scalar(ScalarType::DateTime)
1078 ));
1079 assert!(matches!(
1080 model.get_field("uuid").unwrap().field_type,
1081 FieldType::Scalar(ScalarType::Uuid)
1082 ));
1083 assert!(matches!(
1084 model.get_field("cuid").unwrap().field_type,
1085 FieldType::Scalar(ScalarType::Cuid)
1086 ));
1087 assert!(matches!(
1088 model.get_field("cuid2").unwrap().field_type,
1089 FieldType::Scalar(ScalarType::Cuid2)
1090 ));
1091 assert!(matches!(
1092 model.get_field("nanoid").unwrap().field_type,
1093 FieldType::Scalar(ScalarType::NanoId)
1094 ));
1095 assert!(matches!(
1096 model.get_field("ulid").unwrap().field_type,
1097 FieldType::Scalar(ScalarType::Ulid)
1098 ));
1099 }
1100
1101 #[test]
1102 fn test_parse_optional_field() {
1103 let schema = parse_schema(
1104 r#"
1105 model User {
1106 id Int @id
1107 bio String?
1108 age Int?
1109 }
1110 "#,
1111 )
1112 .unwrap();
1113
1114 let user = schema.get_model("User").unwrap();
1115 assert!(!user.get_field("id").unwrap().is_optional());
1116 assert!(user.get_field("bio").unwrap().is_optional());
1117 assert!(user.get_field("age").unwrap().is_optional());
1118 }
1119
1120 #[test]
1121 fn test_parse_list_field() {
1122 let schema = parse_schema(
1123 r#"
1124 model User {
1125 id Int @id
1126 tags String[]
1127 posts Post[]
1128 }
1129 "#,
1130 )
1131 .unwrap();
1132
1133 let user = schema.get_model("User").unwrap();
1134 assert!(user.get_field("tags").unwrap().is_list());
1135 assert!(user.get_field("posts").unwrap().is_list());
1136 }
1137
1138 #[test]
1139 fn test_parse_optional_list_field() {
1140 let schema = parse_schema(
1141 r#"
1142 model User {
1143 id Int @id
1144 metadata String[]?
1145 }
1146 "#,
1147 )
1148 .unwrap();
1149
1150 let user = schema.get_model("User").unwrap();
1151 let metadata = user.get_field("metadata").unwrap();
1152 assert!(metadata.is_list());
1153 assert!(metadata.is_optional());
1154 }
1155
1156 #[test]
1159 fn test_parse_id_attribute() {
1160 let schema = parse_schema(
1161 r#"
1162 model User {
1163 id Int @id
1164 }
1165 "#,
1166 )
1167 .unwrap();
1168
1169 let user = schema.get_model("User").unwrap();
1170 assert!(user.get_field("id").unwrap().is_id());
1171 }
1172
1173 #[test]
1174 fn test_parse_unique_attribute() {
1175 let schema = parse_schema(
1176 r#"
1177 model User {
1178 id Int @id
1179 email String @unique
1180 }
1181 "#,
1182 )
1183 .unwrap();
1184
1185 let user = schema.get_model("User").unwrap();
1186 assert!(user.get_field("email").unwrap().is_unique());
1187 }
1188
1189 #[test]
1190 fn test_parse_default_int() {
1191 let schema = parse_schema(
1192 r#"
1193 model Counter {
1194 id Int @id
1195 count Int @default(0)
1196 }
1197 "#,
1198 )
1199 .unwrap();
1200
1201 let counter = schema.get_model("Counter").unwrap();
1202 let count_field = counter.get_field("count").unwrap();
1203 let attrs = count_field.extract_attributes();
1204 assert!(attrs.default.is_some());
1205 assert_eq!(attrs.default.unwrap().as_int(), Some(0));
1206 }
1207
1208 #[test]
1209 fn test_parse_default_string() {
1210 let schema = parse_schema(
1211 r#"
1212 model User {
1213 id Int @id
1214 status String @default("active")
1215 }
1216 "#,
1217 )
1218 .unwrap();
1219
1220 let user = schema.get_model("User").unwrap();
1221 let status = user.get_field("status").unwrap();
1222 let attrs = status.extract_attributes();
1223 assert!(attrs.default.is_some());
1224 assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
1225 }
1226
1227 #[test]
1228 fn test_parse_default_boolean() {
1229 let schema = parse_schema(
1230 r#"
1231 model Post {
1232 id Int @id
1233 published Boolean @default(false)
1234 }
1235 "#,
1236 )
1237 .unwrap();
1238
1239 let post = schema.get_model("Post").unwrap();
1240 let published = post.get_field("published").unwrap();
1241 let attrs = published.extract_attributes();
1242 assert!(attrs.default.is_some());
1243 assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
1244 }
1245
1246 #[test]
1247 fn test_parse_default_function() {
1248 let schema = parse_schema(
1249 r#"
1250 model User {
1251 id Int @id
1252 createdAt DateTime @default(now())
1253 }
1254 "#,
1255 )
1256 .unwrap();
1257
1258 let user = schema.get_model("User").unwrap();
1259 let created_at = user.get_field("createdAt").unwrap();
1260 let attrs = created_at.extract_attributes();
1261 assert!(attrs.default.is_some());
1262 if let Some(AttributeValue::Function(name, _)) = attrs.default {
1263 assert_eq!(name.as_str(), "now");
1264 } else {
1265 panic!("Expected function default");
1266 }
1267 }
1268
1269 #[test]
1270 fn test_parse_updated_at_attribute() {
1271 let schema = parse_schema(
1272 r#"
1273 model User {
1274 id Int @id
1275 updatedAt DateTime @updated_at
1276 }
1277 "#,
1278 )
1279 .unwrap();
1280
1281 let user = schema.get_model("User").unwrap();
1282 let updated_at = user.get_field("updatedAt").unwrap();
1283 let attrs = updated_at.extract_attributes();
1284 assert!(attrs.is_updated_at);
1285 }
1286
1287 #[test]
1288 fn test_parse_map_attribute() {
1289 let schema = parse_schema(
1290 r#"
1291 model User {
1292 id Int @id
1293 email String @map("email_address")
1294 }
1295 "#,
1296 )
1297 .unwrap();
1298
1299 let user = schema.get_model("User").unwrap();
1300 let email = user.get_field("email").unwrap();
1301 let attrs = email.extract_attributes();
1302 assert_eq!(attrs.map, Some("email_address".to_string()));
1303 }
1304
1305 #[test]
1306 fn test_parse_multiple_attributes() {
1307 let schema = parse_schema(
1308 r#"
1309 model User {
1310 id Int @id @auto
1311 email String @unique @index
1312 }
1313 "#,
1314 )
1315 .unwrap();
1316
1317 let user = schema.get_model("User").unwrap();
1318 let id = user.get_field("id").unwrap();
1319 let email = user.get_field("email").unwrap();
1320
1321 let id_attrs = id.extract_attributes();
1322 assert!(id_attrs.is_id);
1323 assert!(id_attrs.is_auto);
1324
1325 let email_attrs = email.extract_attributes();
1326 assert!(email_attrs.is_unique);
1327 assert!(email_attrs.is_indexed);
1328 }
1329
1330 #[test]
1333 fn test_parse_model_map_attribute() {
1334 let schema = parse_schema(
1335 r#"
1336 model User {
1337 id Int @id
1338
1339 @@map("app_users")
1340 }
1341 "#,
1342 )
1343 .unwrap();
1344
1345 let user = schema.get_model("User").unwrap();
1346 assert_eq!(user.table_name(), "app_users");
1347 }
1348
1349 #[test]
1350 fn test_parse_model_index_attribute() {
1351 let schema = parse_schema(
1352 r#"
1353 model User {
1354 id Int @id
1355 email String
1356 name String
1357
1358 @@index([email, name])
1359 }
1360 "#,
1361 )
1362 .unwrap();
1363
1364 let user = schema.get_model("User").unwrap();
1365 assert!(user.has_attribute("index"));
1366 }
1367
1368 #[test]
1369 fn test_parse_composite_primary_key() {
1370 let schema = parse_schema(
1371 r#"
1372 model PostTag {
1373 postId Int
1374 tagId Int
1375
1376 @@id([postId, tagId])
1377 }
1378 "#,
1379 )
1380 .unwrap();
1381
1382 let post_tag = schema.get_model("PostTag").unwrap();
1383 assert!(post_tag.has_attribute("id"));
1384 }
1385
1386 #[test]
1389 fn test_parse_enum() {
1390 let schema = parse_schema(
1391 r#"
1392 enum Role {
1393 User
1394 Admin
1395 Moderator
1396 }
1397 "#,
1398 )
1399 .unwrap();
1400
1401 assert_eq!(schema.enums.len(), 1);
1402 let role = schema.get_enum("Role").unwrap();
1403 assert_eq!(role.variants.len(), 3);
1404 }
1405
1406 #[test]
1407 fn test_parse_enum_variant_names() {
1408 let schema = parse_schema(
1409 r#"
1410 enum Status {
1411 Pending
1412 Active
1413 Completed
1414 Cancelled
1415 }
1416 "#,
1417 )
1418 .unwrap();
1419
1420 let status = schema.get_enum("Status").unwrap();
1421 assert!(status.get_variant("Pending").is_some());
1422 assert!(status.get_variant("Active").is_some());
1423 assert!(status.get_variant("Completed").is_some());
1424 assert!(status.get_variant("Cancelled").is_some());
1425 }
1426
1427 #[test]
1428 fn test_parse_enum_with_map() {
1429 let schema = parse_schema(
1430 r#"
1431 enum Role {
1432 User @map("USER")
1433 Admin @map("ADMINISTRATOR")
1434 }
1435 "#,
1436 )
1437 .unwrap();
1438
1439 let role = schema.get_enum("Role").unwrap();
1440 let user_variant = role.get_variant("User").unwrap();
1441 assert_eq!(user_variant.db_value(), "USER");
1442
1443 let admin_variant = role.get_variant("Admin").unwrap();
1444 assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1445 }
1446
1447 #[test]
1450 fn test_parse_one_to_many_relation() {
1451 let schema = parse_schema(
1452 r#"
1453 model User {
1454 id Int @id
1455 posts Post[]
1456 }
1457
1458 model Post {
1459 id Int @id
1460 authorId Int
1461 author User @relation(fields: [authorId], references: [id])
1462 }
1463 "#,
1464 )
1465 .unwrap();
1466
1467 let user = schema.get_model("User").unwrap();
1468 let post = schema.get_model("Post").unwrap();
1469
1470 assert!(user.get_field("posts").unwrap().is_list());
1471 assert!(post.get_field("author").unwrap().is_relation());
1472 }
1473
1474 #[test]
1475 fn test_parse_relation_with_actions() {
1476 let schema = parse_schema(
1477 r#"
1478 model Post {
1479 id Int @id
1480 authorId Int
1481 author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1482 }
1483
1484 model User {
1485 id Int @id
1486 posts Post[]
1487 }
1488 "#,
1489 )
1490 .unwrap();
1491
1492 let post = schema.get_model("Post").unwrap();
1493 let author = post.get_field("author").unwrap();
1494 let attrs = author.extract_attributes();
1495
1496 assert!(attrs.relation.is_some());
1497 let rel = attrs.relation.unwrap();
1498 assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1499 assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1500 }
1501
1502 #[test]
1505 fn test_parse_model_documentation() {
1506 let schema = parse_schema(
1507 r#"/// Represents a user in the system
1508model User {
1509 id Int @id
1510}"#,
1511 )
1512 .unwrap();
1513
1514 let user = schema.get_model("User").unwrap();
1515 if let Some(doc) = &user.documentation {
1518 assert!(doc.text.contains("user"));
1519 }
1520 }
1521
1522 #[test]
1525 fn test_parse_complete_schema() {
1526 let schema = parse_schema(
1527 r#"
1528 /// User model
1529 model User {
1530 id Int @id @auto
1531 email String @unique
1532 name String?
1533 role Role @default(User)
1534 posts Post[]
1535 profile Profile?
1536 createdAt DateTime @default(now())
1537 updatedAt DateTime @updated_at
1538
1539 @@map("users")
1540 @@index([email])
1541 }
1542
1543 model Post {
1544 id Int @id @auto
1545 title String
1546 content String?
1547 published Boolean @default(false)
1548 authorId Int
1549 author User @relation(fields: [authorId], references: [id])
1550 tags Tag[]
1551 createdAt DateTime @default(now())
1552
1553 @@index([authorId])
1554 }
1555
1556 model Profile {
1557 id Int @id @auto
1558 bio String?
1559 userId Int @unique
1560 user User @relation(fields: [userId], references: [id])
1561 }
1562
1563 model Tag {
1564 id Int @id @auto
1565 name String @unique
1566 posts Post[]
1567 }
1568
1569 enum Role {
1570 User
1571 Admin
1572 Moderator
1573 }
1574 "#,
1575 )
1576 .unwrap();
1577
1578 assert_eq!(schema.models.len(), 4);
1580 assert!(schema.get_model("User").is_some());
1581 assert!(schema.get_model("Post").is_some());
1582 assert!(schema.get_model("Profile").is_some());
1583 assert!(schema.get_model("Tag").is_some());
1584
1585 assert_eq!(schema.enums.len(), 1);
1587 assert!(schema.get_enum("Role").is_some());
1588
1589 let user = schema.get_model("User").unwrap();
1591 assert_eq!(user.table_name(), "users");
1592 assert_eq!(user.fields.len(), 8);
1593 assert!(user.has_attribute("index"));
1594
1595 let post = schema.get_model("Post").unwrap();
1597 assert!(post.get_field("author").unwrap().is_relation());
1598 }
1599
1600 #[test]
1603 fn test_parse_invalid_syntax() {
1604 let result = parse_schema("model { broken }");
1605 assert!(result.is_err());
1606 }
1607
1608 #[test]
1609 fn test_parse_empty_schema() {
1610 let schema = parse_schema("").unwrap();
1611 assert!(schema.models.is_empty());
1612 assert!(schema.enums.is_empty());
1613 }
1614
1615 #[test]
1616 fn test_parse_whitespace_only() {
1617 let schema = parse_schema(" \n\t \n ").unwrap();
1618 assert!(schema.models.is_empty());
1619 }
1620
1621 #[test]
1622 fn test_parse_comments_only() {
1623 let schema = parse_schema(
1624 r#"
1625 // This is a comment
1626 // Another comment
1627 "#,
1628 )
1629 .unwrap();
1630 assert!(schema.models.is_empty());
1631 }
1632
1633 #[test]
1636 fn test_parse_model_with_no_fields() {
1637 let result = parse_schema(
1639 r#"
1640 model Empty {
1641 }
1642 "#,
1643 );
1644 let _ = result;
1646 }
1647
1648 #[test]
1649 fn test_parse_long_identifier() {
1650 let schema = parse_schema(
1651 r#"
1652 model VeryLongModelNameThatIsStillValid {
1653 someVeryLongFieldNameThatShouldWork Int @id
1654 }
1655 "#,
1656 )
1657 .unwrap();
1658
1659 assert!(
1660 schema
1661 .get_model("VeryLongModelNameThatIsStillValid")
1662 .is_some()
1663 );
1664 }
1665
1666 #[test]
1667 fn test_parse_underscore_identifiers() {
1668 let schema = parse_schema(
1669 r#"
1670 model user_account {
1671 user_id Int @id
1672 created_at DateTime
1673 }
1674 "#,
1675 )
1676 .unwrap();
1677
1678 let model = schema.get_model("user_account").unwrap();
1679 assert!(model.get_field("user_id").is_some());
1680 assert!(model.get_field("created_at").is_some());
1681 }
1682
1683 #[test]
1684 fn test_parse_negative_default() {
1685 let schema = parse_schema(
1686 r#"
1687 model Config {
1688 id Int @id
1689 minValue Int @default(-100)
1690 }
1691 "#,
1692 )
1693 .unwrap();
1694
1695 let config = schema.get_model("Config").unwrap();
1696 let min_value = config.get_field("minValue").unwrap();
1697 let attrs = min_value.extract_attributes();
1698 assert!(attrs.default.is_some());
1699 }
1700
1701 #[test]
1702 fn test_parse_float_default() {
1703 let schema = parse_schema(
1704 r#"
1705 model Product {
1706 id Int @id
1707 price Float @default(9.99)
1708 }
1709 "#,
1710 )
1711 .unwrap();
1712
1713 let product = schema.get_model("Product").unwrap();
1714 let price = product.get_field("price").unwrap();
1715 let attrs = price.extract_attributes();
1716 assert!(attrs.default.is_some());
1717 }
1718
1719 #[test]
1722 fn test_parse_simple_server_group() {
1723 let schema = parse_schema(
1724 r#"
1725 serverGroup MainCluster {
1726 server primary {
1727 url = "postgres://localhost/db"
1728 role = "primary"
1729 }
1730 }
1731 "#,
1732 )
1733 .unwrap();
1734
1735 assert_eq!(schema.server_groups.len(), 1);
1736 let cluster = schema.get_server_group("MainCluster").unwrap();
1737 assert_eq!(cluster.servers.len(), 1);
1738 assert!(cluster.servers.contains_key("primary"));
1739 }
1740
1741 #[test]
1742 fn test_parse_server_group_with_multiple_servers() {
1743 let schema = parse_schema(
1744 r#"
1745 serverGroup ReadReplicas {
1746 server primary {
1747 url = "postgres://primary.db.com/app"
1748 role = "primary"
1749 weight = 1
1750 }
1751
1752 server replica1 {
1753 url = "postgres://replica1.db.com/app"
1754 role = "replica"
1755 weight = 2
1756 }
1757
1758 server replica2 {
1759 url = "postgres://replica2.db.com/app"
1760 role = "replica"
1761 weight = 2
1762 }
1763 }
1764 "#,
1765 )
1766 .unwrap();
1767
1768 let cluster = schema.get_server_group("ReadReplicas").unwrap();
1769 assert_eq!(cluster.servers.len(), 3);
1770
1771 let primary = cluster.servers.get("primary").unwrap();
1772 assert_eq!(primary.role(), Some(ServerRole::Primary));
1773 assert_eq!(primary.weight(), Some(1));
1774
1775 let replica1 = cluster.servers.get("replica1").unwrap();
1776 assert_eq!(replica1.role(), Some(ServerRole::Replica));
1777 assert_eq!(replica1.weight(), Some(2));
1778 }
1779
1780 #[test]
1781 fn test_parse_server_group_with_attributes() {
1782 let schema = parse_schema(
1783 r#"
1784 serverGroup ProductionCluster {
1785 @@strategy(ReadReplica)
1786 @@loadBalance(RoundRobin)
1787
1788 server main {
1789 url = "postgres://main/db"
1790 role = "primary"
1791 }
1792 }
1793 "#,
1794 )
1795 .unwrap();
1796
1797 let cluster = schema.get_server_group("ProductionCluster").unwrap();
1798 assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1799 assert!(
1800 cluster
1801 .attributes
1802 .iter()
1803 .any(|a| a.name.name == "loadBalance")
1804 );
1805 }
1806
1807 #[test]
1808 fn test_parse_server_group_with_env_vars() {
1809 let schema = parse_schema(
1810 r#"
1811 serverGroup EnvCluster {
1812 server db1 {
1813 url = env("PRIMARY_DB_URL")
1814 role = "primary"
1815 }
1816 }
1817 "#,
1818 )
1819 .unwrap();
1820
1821 let cluster = schema.get_server_group("EnvCluster").unwrap();
1822 let server = cluster.servers.get("db1").unwrap();
1823
1824 if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1826 assert_eq!(var, "PRIMARY_DB_URL");
1827 } else {
1828 panic!("Expected env var for url property");
1829 }
1830 }
1831
1832 #[test]
1833 fn test_parse_server_group_with_boolean_property() {
1834 let schema = parse_schema(
1835 r#"
1836 serverGroup TestCluster {
1837 server replica {
1838 url = "postgres://replica/db"
1839 role = "replica"
1840 readOnly = true
1841 }
1842 }
1843 "#,
1844 )
1845 .unwrap();
1846
1847 let cluster = schema.get_server_group("TestCluster").unwrap();
1848 let server = cluster.servers.get("replica").unwrap();
1849 assert!(server.is_read_only());
1850 }
1851
1852 #[test]
1853 fn test_parse_server_group_with_numeric_properties() {
1854 let schema = parse_schema(
1855 r#"
1856 serverGroup NumericCluster {
1857 server db {
1858 url = "postgres://localhost/db"
1859 weight = 5
1860 priority = 1
1861 maxConnections = 100
1862 }
1863 }
1864 "#,
1865 )
1866 .unwrap();
1867
1868 let cluster = schema.get_server_group("NumericCluster").unwrap();
1869 let server = cluster.servers.get("db").unwrap();
1870
1871 assert_eq!(server.weight(), Some(5));
1872 assert_eq!(server.priority(), Some(1));
1873 assert_eq!(server.max_connections(), Some(100));
1874 }
1875
1876 #[test]
1877 fn test_parse_server_group_with_region() {
1878 let schema = parse_schema(
1879 r#"
1880 serverGroup GeoCluster {
1881 server usEast {
1882 url = "postgres://us-east.db.com/app"
1883 role = "replica"
1884 region = "us-east-1"
1885 }
1886
1887 server usWest {
1888 url = "postgres://us-west.db.com/app"
1889 role = "replica"
1890 region = "us-west-2"
1891 }
1892 }
1893 "#,
1894 )
1895 .unwrap();
1896
1897 let cluster = schema.get_server_group("GeoCluster").unwrap();
1898
1899 let us_east = cluster.servers.get("usEast").unwrap();
1900 assert_eq!(us_east.region(), Some("us-east-1"));
1901
1902 let us_west = cluster.servers.get("usWest").unwrap();
1903 assert_eq!(us_west.region(), Some("us-west-2"));
1904
1905 let us_east_servers = cluster.servers_in_region("us-east-1");
1907 assert_eq!(us_east_servers.len(), 1);
1908 }
1909
1910 #[test]
1911 fn test_parse_multiple_server_groups() {
1912 let schema = parse_schema(
1913 r#"
1914 serverGroup Cluster1 {
1915 server db1 {
1916 url = "postgres://db1/app"
1917 }
1918 }
1919
1920 serverGroup Cluster2 {
1921 server db2 {
1922 url = "postgres://db2/app"
1923 }
1924 }
1925
1926 serverGroup Cluster3 {
1927 server db3 {
1928 url = "postgres://db3/app"
1929 }
1930 }
1931 "#,
1932 )
1933 .unwrap();
1934
1935 assert_eq!(schema.server_groups.len(), 3);
1936 assert!(schema.get_server_group("Cluster1").is_some());
1937 assert!(schema.get_server_group("Cluster2").is_some());
1938 assert!(schema.get_server_group("Cluster3").is_some());
1939 }
1940
1941 #[test]
1942 fn test_parse_schema_with_models_and_server_groups() {
1943 let schema = parse_schema(
1944 r#"
1945 model User {
1946 id Int @id @auto
1947 email String @unique
1948 }
1949
1950 serverGroup Database {
1951 @@strategy(ReadReplica)
1952
1953 server primary {
1954 url = env("DATABASE_URL")
1955 role = "primary"
1956 }
1957 }
1958
1959 model Post {
1960 id Int @id @auto
1961 title String
1962 authorId Int
1963 }
1964 "#,
1965 )
1966 .unwrap();
1967
1968 assert_eq!(schema.models.len(), 2);
1969 assert!(schema.get_model("User").is_some());
1970 assert!(schema.get_model("Post").is_some());
1971
1972 assert_eq!(schema.server_groups.len(), 1);
1973 assert!(schema.get_server_group("Database").is_some());
1974 }
1975
1976 #[test]
1977 fn test_parse_server_group_with_health_check() {
1978 let schema = parse_schema(
1979 r#"
1980 serverGroup HealthyCluster {
1981 server monitored {
1982 url = "postgres://localhost/db"
1983 healthCheck = "/health"
1984 }
1985 }
1986 "#,
1987 )
1988 .unwrap();
1989
1990 let cluster = schema.get_server_group("HealthyCluster").unwrap();
1991 let server = cluster.servers.get("monitored").unwrap();
1992 assert_eq!(server.health_check(), Some("/health"));
1993 }
1994
1995 #[test]
1996 fn test_server_group_failover_order() {
1997 let schema = parse_schema(
1998 r#"
1999 serverGroup FailoverCluster {
2000 server db3 {
2001 url = "postgres://db3/app"
2002 priority = 3
2003 }
2004
2005 server db1 {
2006 url = "postgres://db1/app"
2007 priority = 1
2008 }
2009
2010 server db2 {
2011 url = "postgres://db2/app"
2012 priority = 2
2013 }
2014 }
2015 "#,
2016 )
2017 .unwrap();
2018
2019 let cluster = schema.get_server_group("FailoverCluster").unwrap();
2020 let ordered = cluster.failover_order();
2021
2022 assert_eq!(ordered[0].name.name.as_str(), "db1");
2023 assert_eq!(ordered[1].name.name.as_str(), "db2");
2024 assert_eq!(ordered[2].name.name.as_str(), "db3");
2025 }
2026
2027 #[test]
2028 fn test_server_group_names() {
2029 let schema = parse_schema(
2030 r#"
2031 serverGroup Alpha {
2032 server s1 { url = "pg://a" }
2033 }
2034 serverGroup Beta {
2035 server s2 { url = "pg://b" }
2036 }
2037 "#,
2038 )
2039 .unwrap();
2040
2041 let names: Vec<_> = schema.server_group_names().collect();
2042 assert_eq!(names.len(), 2);
2043 assert!(names.contains(&"Alpha"));
2044 assert!(names.contains(&"Beta"));
2045 }
2046
2047 #[test]
2050 fn test_parse_simple_policy() {
2051 let schema = parse_schema(
2052 r#"
2053 policy UserReadOwn on User {
2054 for SELECT
2055 using "id = current_user_id()"
2056 }
2057 "#,
2058 )
2059 .unwrap();
2060
2061 assert_eq!(schema.policies.len(), 1);
2062 let policy = schema.get_policy("UserReadOwn").unwrap();
2063 assert_eq!(policy.name(), "UserReadOwn");
2064 assert_eq!(policy.table(), "User");
2065 assert!(policy.applies_to(PolicyCommand::Select));
2066 assert!(!policy.applies_to(PolicyCommand::Insert));
2067 assert_eq!(policy.using_expr.as_deref(), Some("id = current_user_id()"));
2068 }
2069
2070 #[test]
2071 fn test_parse_policy_with_multiple_commands() {
2072 let schema = parse_schema(
2073 r#"
2074 policy UserModify on User {
2075 for [SELECT, UPDATE, DELETE]
2076 using "id = auth.uid()"
2077 }
2078 "#,
2079 )
2080 .unwrap();
2081
2082 let policy = schema.get_policy("UserModify").unwrap();
2083 assert!(policy.applies_to(PolicyCommand::Select));
2084 assert!(policy.applies_to(PolicyCommand::Update));
2085 assert!(policy.applies_to(PolicyCommand::Delete));
2086 assert!(!policy.applies_to(PolicyCommand::Insert));
2087 }
2088
2089 #[test]
2090 fn test_parse_policy_with_all_command() {
2091 let schema = parse_schema(
2092 r#"
2093 policy UserAll on User {
2094 for ALL
2095 using "true"
2096 }
2097 "#,
2098 )
2099 .unwrap();
2100
2101 let policy = schema.get_policy("UserAll").unwrap();
2102 assert!(policy.applies_to(PolicyCommand::Select));
2103 assert!(policy.applies_to(PolicyCommand::Insert));
2104 assert!(policy.applies_to(PolicyCommand::Update));
2105 assert!(policy.applies_to(PolicyCommand::Delete));
2106 }
2107
2108 #[test]
2109 fn test_parse_policy_with_roles() {
2110 let schema = parse_schema(
2111 r#"
2112 policy AuthenticatedRead on Document {
2113 for SELECT
2114 to authenticated
2115 using "true"
2116 }
2117 "#,
2118 )
2119 .unwrap();
2120
2121 let policy = schema.get_policy("AuthenticatedRead").unwrap();
2122 let roles = policy.effective_roles();
2123 assert!(roles.contains(&"authenticated"));
2124 }
2125
2126 #[test]
2127 fn test_parse_policy_with_multiple_roles() {
2128 let schema = parse_schema(
2129 r#"
2130 policy AdminModerator on Post {
2131 for [UPDATE, DELETE]
2132 to [admin, moderator]
2133 using "true"
2134 }
2135 "#,
2136 )
2137 .unwrap();
2138
2139 let policy = schema.get_policy("AdminModerator").unwrap();
2140 let roles = policy.effective_roles();
2141 assert!(roles.contains(&"admin"));
2142 assert!(roles.contains(&"moderator"));
2143 }
2144
2145 #[test]
2146 fn test_parse_policy_restrictive() {
2147 let schema = parse_schema(
2148 r#"
2149 policy OrgRestriction on Document {
2150 as RESTRICTIVE
2151 for SELECT
2152 using "org_id = current_org_id()"
2153 }
2154 "#,
2155 )
2156 .unwrap();
2157
2158 let policy = schema.get_policy("OrgRestriction").unwrap();
2159 assert!(policy.is_restrictive());
2160 assert!(!policy.is_permissive());
2161 }
2162
2163 #[test]
2164 fn test_parse_policy_permissive_explicit() {
2165 let schema = parse_schema(
2166 r#"
2167 policy Permissive on User {
2168 as PERMISSIVE
2169 for SELECT
2170 using "true"
2171 }
2172 "#,
2173 )
2174 .unwrap();
2175
2176 let policy = schema.get_policy("Permissive").unwrap();
2177 assert!(policy.is_permissive());
2178 }
2179
2180 #[test]
2181 fn test_parse_policy_with_check() {
2182 let schema = parse_schema(
2183 r#"
2184 policy InsertOwn on Post {
2185 for INSERT
2186 to authenticated
2187 check "author_id = current_user_id()"
2188 }
2189 "#,
2190 )
2191 .unwrap();
2192
2193 let policy = schema.get_policy("InsertOwn").unwrap();
2194 assert!(policy.applies_to(PolicyCommand::Insert));
2195 assert_eq!(
2196 policy.check_expr.as_deref(),
2197 Some("author_id = current_user_id()")
2198 );
2199 assert!(policy.using_expr.is_none());
2200 }
2201
2202 #[test]
2203 fn test_parse_policy_with_both_expressions() {
2204 let schema = parse_schema(
2205 r#"
2206 policy UpdateOwn on Post {
2207 for UPDATE
2208 using "author_id = current_user_id()"
2209 check "author_id = current_user_id()"
2210 }
2211 "#,
2212 )
2213 .unwrap();
2214
2215 let policy = schema.get_policy("UpdateOwn").unwrap();
2216 assert!(policy.using_expr.is_some());
2217 assert!(policy.check_expr.is_some());
2218 }
2219
2220 #[test]
2221 fn test_parse_policy_multiline_expression() {
2222 let schema = parse_schema(
2223 r#"
2224 policy ComplexCheck on Document {
2225 for SELECT
2226 using """
2227 (is_public = true)
2228 OR (owner_id = current_user_id())
2229 OR (id IN (SELECT document_id FROM shares WHERE user_id = current_user_id()))
2230 """
2231 }
2232 "#,
2233 )
2234 .unwrap();
2235
2236 let policy = schema.get_policy("ComplexCheck").unwrap();
2237 assert!(policy.using_expr.is_some());
2238 let expr = policy.using_expr.as_ref().unwrap();
2239 assert!(expr.contains("is_public = true"));
2240 assert!(expr.contains("owner_id = current_user_id()"));
2241 assert!(expr.contains("SELECT document_id FROM shares"));
2242 }
2243
2244 #[test]
2245 fn test_parse_multiple_policies() {
2246 let schema = parse_schema(
2247 r#"
2248 policy UserRead on User {
2249 for SELECT
2250 using "true"
2251 }
2252
2253 policy UserInsert on User {
2254 for INSERT
2255 check "id = current_user_id()"
2256 }
2257
2258 policy PostRead on Post {
2259 for SELECT
2260 using "published = true OR author_id = current_user_id()"
2261 }
2262 "#,
2263 )
2264 .unwrap();
2265
2266 assert_eq!(schema.policies.len(), 3);
2267 assert!(schema.get_policy("UserRead").is_some());
2268 assert!(schema.get_policy("UserInsert").is_some());
2269 assert!(schema.get_policy("PostRead").is_some());
2270 }
2271
2272 #[test]
2273 fn test_parse_policy_with_model() {
2274 let schema = parse_schema(
2275 r#"
2276 model User {
2277 id Int @id @auto
2278 email String @unique
2279 }
2280
2281 policy UserReadOwn on User {
2282 for SELECT
2283 to authenticated
2284 using "id = auth.uid()"
2285 }
2286 "#,
2287 )
2288 .unwrap();
2289
2290 assert_eq!(schema.models.len(), 1);
2291 assert_eq!(schema.policies.len(), 1);
2292
2293 let policies = schema.policies_for("User");
2294 assert_eq!(policies.len(), 1);
2295 assert_eq!(policies[0].name(), "UserReadOwn");
2296 }
2297
2298 #[test]
2299 fn test_parse_policies_for_multiple_models() {
2300 let schema = parse_schema(
2301 r#"
2302 policy UserPolicy1 on User {
2303 for SELECT
2304 using "true"
2305 }
2306
2307 policy UserPolicy2 on User {
2308 for INSERT
2309 check "true"
2310 }
2311
2312 policy PostPolicy on Post {
2313 for SELECT
2314 using "true"
2315 }
2316 "#,
2317 )
2318 .unwrap();
2319
2320 assert_eq!(schema.policies_for("User").len(), 2);
2321 assert_eq!(schema.policies_for("Post").len(), 1);
2322 assert!(schema.has_policies("User"));
2323 assert!(schema.has_policies("Post"));
2324 assert!(!schema.has_policies("Comment"));
2325 }
2326
2327 #[test]
2328 fn test_parse_policy_default_all_command() {
2329 let schema = parse_schema(
2330 r#"
2331 policy DefaultAll on User {
2332 using "id = current_user_id()"
2333 }
2334 "#,
2335 )
2336 .unwrap();
2337
2338 let policy = schema.get_policy("DefaultAll").unwrap();
2339 assert!(policy.applies_to(PolicyCommand::All));
2341 }
2342
2343 #[test]
2344 fn test_parse_policy_case_insensitive_keywords() {
2345 let schema = parse_schema(
2346 r#"
2347 policy CaseTest on User {
2348 for select
2349 as permissive
2350 using "true"
2351 }
2352 "#,
2353 )
2354 .unwrap();
2355
2356 let policy = schema.get_policy("CaseTest").unwrap();
2357 assert!(policy.applies_to(PolicyCommand::Select));
2358 assert!(policy.is_permissive());
2359 }
2360
2361 #[test]
2362 fn test_parse_policy_sql_generation() {
2363 let schema = parse_schema(
2364 r#"
2365 model User {
2366 id Int @id
2367
2368 @@map("users")
2369 }
2370
2371 policy ReadOwn on User {
2372 for SELECT
2373 to authenticated
2374 using "id = auth.uid()"
2375 }
2376 "#,
2377 )
2378 .unwrap();
2379
2380 let policy = schema.get_policy("ReadOwn").unwrap();
2381 let sql = policy.to_sql("users");
2382
2383 assert!(sql.contains("CREATE POLICY ReadOwn ON users"));
2384 assert!(sql.contains("FOR SELECT"));
2385 assert!(sql.contains("TO authenticated"));
2386 assert!(sql.contains("USING (id = auth.uid())"));
2387 }
2388
2389 #[test]
2390 fn test_parse_policy_restrictive_sql() {
2391 let schema = parse_schema(
2392 r#"
2393 policy OrgBoundary on Document {
2394 as RESTRICTIVE
2395 for ALL
2396 using "org_id = current_org_id()"
2397 }
2398 "#,
2399 )
2400 .unwrap();
2401
2402 let policy = schema.get_policy("OrgBoundary").unwrap();
2403 let sql = policy.to_sql("documents");
2404
2405 assert!(sql.contains("AS RESTRICTIVE"));
2406 }
2407
2408 #[test]
2409 fn test_parse_policy_with_documentation() {
2410 let schema = parse_schema(
2411 r#"
2412 /// Users can only read their own data
2413 policy UserIsolation on User {
2414 for SELECT
2415 using "id = current_user_id()"
2416 }
2417 "#,
2418 )
2419 .unwrap();
2420
2421 let policy = schema.get_policy("UserIsolation").unwrap();
2422 if let Some(doc) = &policy.documentation {
2423 assert!(doc.text.contains("their own data"));
2424 }
2425 }
2426
2427 #[test]
2428 fn test_parse_complex_rls_schema() {
2429 let schema = parse_schema(
2430 r#"
2431 model Organization {
2432 id Int @id @auto
2433 name String
2434 }
2435
2436 model User {
2437 id Int @id @auto
2438 orgId Int
2439 email String @unique
2440 }
2441
2442 model Document {
2443 id Int @id @auto
2444 title String
2445 ownerId Int
2446 orgId Int
2447 isPublic Boolean @default(false)
2448 }
2449
2450 /// Organization-level isolation
2451 policy OrgIsolation on Document {
2452 as RESTRICTIVE
2453 for ALL
2454 using "org_id = current_setting('app.current_org')::int"
2455 }
2456
2457 /// Users can read public documents
2458 policy PublicRead on Document {
2459 for SELECT
2460 using "is_public = true"
2461 }
2462
2463 /// Users can read their own documents
2464 policy OwnerRead on Document {
2465 for SELECT
2466 to authenticated
2467 using "owner_id = auth.uid()"
2468 }
2469
2470 /// Users can only modify their own documents
2471 policy OwnerModify on Document {
2472 for [UPDATE, DELETE]
2473 to authenticated
2474 using "owner_id = auth.uid()"
2475 check "owner_id = auth.uid()"
2476 }
2477
2478 /// Users can create documents in their org
2479 policy OrgInsert on Document {
2480 for INSERT
2481 to authenticated
2482 check "org_id = current_setting('app.current_org')::int"
2483 }
2484 "#,
2485 )
2486 .unwrap();
2487
2488 assert_eq!(schema.models.len(), 3);
2489 assert_eq!(schema.policies.len(), 5);
2490
2491 let org_iso = schema.get_policy("OrgIsolation").unwrap();
2493 assert!(org_iso.is_restrictive());
2494
2495 let doc_policies = schema.policies_for("Document");
2497 assert_eq!(doc_policies.len(), 5);
2498 }
2499
2500 #[test]
2503 fn test_parse_policy_with_mssql_schema() {
2504 let schema = parse_schema(
2505 r#"
2506 policy UserFilter on User {
2507 for SELECT
2508 using "UserId = @UserId"
2509 mssqlSchema "RLS"
2510 }
2511 "#,
2512 )
2513 .unwrap();
2514
2515 let policy = schema.get_policy("UserFilter").unwrap();
2516 assert_eq!(policy.mssql_schema(), "RLS");
2517 }
2518
2519 #[test]
2520 fn test_parse_policy_with_mssql_block_single() {
2521 let schema = parse_schema(
2522 r#"
2523 policy UserInsert on User {
2524 for INSERT
2525 check "UserId = @UserId"
2526 mssqlBlock AFTER_INSERT
2527 }
2528 "#,
2529 )
2530 .unwrap();
2531
2532 let policy = schema.get_policy("UserInsert").unwrap();
2533 assert_eq!(policy.mssql_block_operations.len(), 1);
2534 assert_eq!(
2535 policy.mssql_block_operations[0],
2536 MssqlBlockOperation::AfterInsert
2537 );
2538 }
2539
2540 #[test]
2541 fn test_parse_policy_with_mssql_block_list() {
2542 let schema = parse_schema(
2543 r#"
2544 policy UserModify on User {
2545 for [INSERT, UPDATE, DELETE]
2546 check "UserId = @UserId"
2547 mssqlBlock [AFTER_INSERT, AFTER_UPDATE, BEFORE_DELETE]
2548 }
2549 "#,
2550 )
2551 .unwrap();
2552
2553 let policy = schema.get_policy("UserModify").unwrap();
2554 assert_eq!(policy.mssql_block_operations.len(), 3);
2555 assert!(
2556 policy
2557 .mssql_block_operations
2558 .contains(&MssqlBlockOperation::AfterInsert)
2559 );
2560 assert!(
2561 policy
2562 .mssql_block_operations
2563 .contains(&MssqlBlockOperation::AfterUpdate)
2564 );
2565 assert!(
2566 policy
2567 .mssql_block_operations
2568 .contains(&MssqlBlockOperation::BeforeDelete)
2569 );
2570 }
2571
2572 #[test]
2573 fn test_parse_policy_full_mssql_config() {
2574 let schema = parse_schema(
2575 r#"
2576 policy TenantIsolation on Order {
2577 for ALL
2578 using "TenantId = @TenantId"
2579 check "TenantId = @TenantId"
2580 mssqlSchema "MultiTenant"
2581 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2582 }
2583 "#,
2584 )
2585 .unwrap();
2586
2587 let policy = schema.get_policy("TenantIsolation").unwrap();
2588
2589 assert!(policy.applies_to(PolicyCommand::All));
2591 assert!(policy.using_expr.is_some());
2592 assert!(policy.check_expr.is_some());
2593
2594 assert_eq!(policy.mssql_schema(), "MultiTenant");
2596 assert_eq!(policy.mssql_block_operations.len(), 4);
2597
2598 let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
2600 assert!(mssql.schema_sql.contains("MultiTenant"));
2601 assert!(mssql.function_sql.contains("fn_TenantIsolation_predicate"));
2602 }
2603
2604 #[test]
2605 fn test_parse_policy_mssql_block_case_variants() {
2606 let schema = parse_schema(
2608 r#"
2609 policy Test1 on User {
2610 for INSERT
2611 check "true"
2612 mssqlBlock after_insert
2613 }
2614 "#,
2615 )
2616 .unwrap();
2617
2618 let policy = schema.get_policy("Test1").unwrap();
2619 assert_eq!(policy.mssql_block_operations.len(), 1);
2620 assert_eq!(
2621 policy.mssql_block_operations[0],
2622 MssqlBlockOperation::AfterInsert
2623 );
2624 }
2625
2626 #[test]
2627 fn test_parse_mixed_postgres_mssql_schema() {
2628 let schema = parse_schema(
2629 r#"
2630 model User {
2631 id Int @id @auto
2632 email String @unique
2633 }
2634
2635 // PostgreSQL-style policy (works on both, MSSQL uses defaults)
2636 policy UserReadOwn on User {
2637 for SELECT
2638 to authenticated
2639 using "id = current_user_id()"
2640 }
2641
2642 // MSSQL-optimized policy with explicit settings
2643 policy UserModifyOwn on User {
2644 for [INSERT, UPDATE, DELETE]
2645 to authenticated
2646 using "id = current_user_id()"
2647 check "id = current_user_id()"
2648 mssqlSchema "Security"
2649 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2650 }
2651 "#,
2652 )
2653 .unwrap();
2654
2655 assert_eq!(schema.policies.len(), 2);
2656
2657 let read_policy = schema.get_policy("UserReadOwn").unwrap();
2659 assert_eq!(read_policy.mssql_schema(), "Security"); assert!(read_policy.mssql_block_operations.is_empty()); let modify_policy = schema.get_policy("UserModifyOwn").unwrap();
2664 assert_eq!(modify_policy.mssql_schema(), "Security");
2665 assert_eq!(modify_policy.mssql_block_operations.len(), 4);
2666
2667 let pg_sql = read_policy.to_postgres_sql("users");
2669 assert!(pg_sql.contains("CREATE POLICY UserReadOwn ON users"));
2670
2671 let mssql = modify_policy.to_mssql_sql("dbo.Users", "id");
2673 assert!(mssql.policy_sql.contains("Security.UserModifyOwn"));
2674 }
2675}