1mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{
17 MssqlBlockOperation, Policy, PolicyCommand, PolicyType, Server, ServerGroup, ServerProperty,
18 ServerPropertyValue,
19};
20
21pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
23 debug!(input_len = input.len(), "parse_schema() starting");
24 let pairs = PraxParser::parse(Rule::schema, input)
25 .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
26
27 let mut schema = Schema::new();
28 let mut current_doc: Option<Documentation> = None;
29
30 let schema_pair = pairs.into_iter().next().unwrap();
32
33 for pair in schema_pair.into_inner() {
34 match pair.as_rule() {
35 Rule::documentation => {
36 let span = pair.as_span();
37 let text = pair
38 .into_inner()
39 .map(|p| p.as_str().trim_start_matches("///").trim())
40 .collect::<Vec<_>>()
41 .join("\n");
42 current_doc = Some(Documentation::new(
43 text,
44 Span::new(span.start(), span.end()),
45 ));
46 }
47 Rule::model_def => {
48 let mut model = parse_model(pair)?;
49 if let Some(doc) = current_doc.take() {
50 model = model.with_documentation(doc);
51 }
52 schema.add_model(model);
53 }
54 Rule::enum_def => {
55 let mut e = parse_enum(pair)?;
56 if let Some(doc) = current_doc.take() {
57 e = e.with_documentation(doc);
58 }
59 schema.add_enum(e);
60 }
61 Rule::type_def => {
62 let mut t = parse_composite_type(pair)?;
63 if let Some(doc) = current_doc.take() {
64 t = t.with_documentation(doc);
65 }
66 schema.add_type(t);
67 }
68 Rule::view_def => {
69 let mut v = parse_view(pair)?;
70 if let Some(doc) = current_doc.take() {
71 v = v.with_documentation(doc);
72 }
73 schema.add_view(v);
74 }
75 Rule::raw_sql_def => {
76 let sql = parse_raw_sql(pair)?;
77 schema.add_raw_sql(sql);
78 }
79 Rule::server_group_def => {
80 let mut sg = parse_server_group(pair)?;
81 if let Some(doc) = current_doc.take() {
82 sg.set_documentation(doc);
83 }
84 schema.add_server_group(sg);
85 }
86 Rule::policy_def => {
87 let mut policy = parse_policy(pair)?;
88 if let Some(doc) = current_doc.take() {
89 policy = policy.with_documentation(doc);
90 }
91 schema.add_policy(policy);
92 }
93 Rule::datasource_def => {
94 let ds = parse_datasource(pair)?;
95 schema.set_datasource(ds);
96 current_doc = None;
97 }
98 Rule::generator_def => {
99 let generator = parse_generator(pair)?;
100 schema.add_generator(generator);
101 current_doc = None;
102 }
103 Rule::EOI => {}
104 _ => {}
105 }
106 }
107
108 info!(
109 models = schema.models.len(),
110 enums = schema.enums.len(),
111 types = schema.types.len(),
112 views = schema.views.len(),
113 generators = schema.generators.len(),
114 policies = schema.policies.len(),
115 "Schema parsed successfully"
116 );
117 Ok(schema)
118}
119
120pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
122 let path = path.as_ref();
123 info!(path = %path.display(), "Loading schema file");
124 let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
125 path: path.display().to_string(),
126 source: e,
127 })?;
128
129 parse_schema(&content)
130}
131
132fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
134 let span = pair.as_span();
135 let mut inner = pair.into_inner();
136
137 let name_pair = inner.next().unwrap();
138 let name = Ident::new(
139 name_pair.as_str(),
140 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
141 );
142
143 let mut model = Model::new(name, Span::new(span.start(), span.end()));
144
145 for item in inner {
146 match item.as_rule() {
147 Rule::field_def => {
148 let field = parse_field(item)?;
149 model.add_field(field);
150 }
151 Rule::model_attribute => {
152 let attr = parse_attribute(item)?;
153 model.attributes.push(attr);
154 }
155 Rule::model_body_item => {
156 let inner_item = item.into_inner().next().unwrap();
158 match inner_item.as_rule() {
159 Rule::field_def => {
160 let field = parse_field(inner_item)?;
161 model.add_field(field);
162 }
163 Rule::model_attribute => {
164 let attr = parse_attribute(inner_item)?;
165 model.attributes.push(attr);
166 }
167 _ => {}
168 }
169 }
170 _ => {}
171 }
172 }
173
174 Ok(model)
175}
176
177fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
179 let span = pair.as_span();
180 let mut inner = pair.into_inner();
181
182 let name_pair = inner.next().unwrap();
183 let name = Ident::new(
184 name_pair.as_str(),
185 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
186 );
187
188 let mut e = Enum::new(name, Span::new(span.start(), span.end()));
189
190 for item in inner {
191 match item.as_rule() {
192 Rule::enum_variant => {
193 let variant = parse_enum_variant(item)?;
194 e.add_variant(variant);
195 }
196 Rule::model_attribute => {
197 let attr = parse_attribute(item)?;
198 e.attributes.push(attr);
199 }
200 Rule::enum_body_item => {
201 let inner_item = item.into_inner().next().unwrap();
203 match inner_item.as_rule() {
204 Rule::enum_variant => {
205 let variant = parse_enum_variant(inner_item)?;
206 e.add_variant(variant);
207 }
208 Rule::model_attribute => {
209 let attr = parse_attribute(inner_item)?;
210 e.attributes.push(attr);
211 }
212 _ => {}
213 }
214 }
215 _ => {}
216 }
217 }
218
219 Ok(e)
220}
221
222fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
224 let span = pair.as_span();
225 let mut inner = pair.into_inner();
226
227 let name_pair = inner.next().unwrap();
228 let name = Ident::new(
229 name_pair.as_str(),
230 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
231 );
232
233 let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
234
235 for item in inner {
236 if item.as_rule() == Rule::field_attribute {
237 let attr = parse_attribute(item)?;
238 variant.attributes.push(attr);
239 }
240 }
241
242 Ok(variant)
243}
244
245fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
247 let span = pair.as_span();
248 let mut inner = pair.into_inner();
249
250 let name_pair = inner.next().unwrap();
251 let name = Ident::new(
252 name_pair.as_str(),
253 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
254 );
255
256 let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
257
258 for item in inner {
259 if item.as_rule() == Rule::field_def {
260 let field = parse_field(item)?;
261 t.add_field(field);
262 }
263 }
264
265 Ok(t)
266}
267
268fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
270 let span = pair.as_span();
271 let mut inner = pair.into_inner();
272
273 let name_pair = inner.next().unwrap();
274 let name = Ident::new(
275 name_pair.as_str(),
276 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
277 );
278
279 let mut v = View::new(name, Span::new(span.start(), span.end()));
280
281 for item in inner {
282 match item.as_rule() {
283 Rule::field_def => {
284 let field = parse_field(item)?;
285 v.add_field(field);
286 }
287 Rule::model_attribute => {
288 let attr = parse_attribute(item)?;
289 v.attributes.push(attr);
290 }
291 Rule::model_body_item => {
292 let inner_item = item.into_inner().next().unwrap();
294 match inner_item.as_rule() {
295 Rule::field_def => {
296 let field = parse_field(inner_item)?;
297 v.add_field(field);
298 }
299 Rule::model_attribute => {
300 let attr = parse_attribute(inner_item)?;
301 v.attributes.push(attr);
302 }
303 _ => {}
304 }
305 }
306 _ => {}
307 }
308 }
309
310 Ok(v)
311}
312
313fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
315 let span = pair.as_span();
316 let mut inner = pair.into_inner();
317
318 let name_pair = inner.next().unwrap();
319 let name = Ident::new(
320 name_pair.as_str(),
321 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
322 );
323
324 let type_pair = inner.next().unwrap();
325 let (field_type, modifier) = parse_field_type(type_pair)?;
326
327 let mut attributes = vec![];
328 for item in inner {
329 if item.as_rule() == Rule::field_attribute {
330 let attr = parse_attribute(item)?;
331 attributes.push(attr);
332 }
333 }
334
335 Ok(Field::new(
336 name,
337 field_type,
338 modifier,
339 attributes,
340 Span::new(span.start(), span.end()),
341 ))
342}
343
344fn parse_field_type(
346 pair: pest::iterators::Pair<'_, Rule>,
347) -> SchemaResult<(FieldType, TypeModifier)> {
348 let mut type_name = String::new();
349 let mut modifier = TypeModifier::Required;
350
351 for item in pair.into_inner() {
352 match item.as_rule() {
353 Rule::type_name => {
354 type_name = item.as_str().to_string();
355 }
356 Rule::optional_marker => {
357 modifier = if modifier == TypeModifier::List {
358 TypeModifier::OptionalList
359 } else {
360 TypeModifier::Optional
361 };
362 }
363 Rule::list_marker => {
364 modifier = if modifier == TypeModifier::Optional {
365 TypeModifier::OptionalList
366 } else {
367 TypeModifier::List
368 };
369 }
370 _ => {}
371 }
372 }
373
374 let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
375 FieldType::Scalar(scalar)
376 } else {
377 FieldType::Model(SmolStr::new(&type_name))
380 };
381
382 Ok((field_type, modifier))
383}
384
385fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
387 let span = pair.as_span();
388 let mut inner = pair.into_inner();
389
390 let name_pair = inner.next().unwrap();
391 let name = Ident::new(
392 name_pair.as_str(),
393 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
394 );
395
396 let mut args = vec![];
397 for item in inner {
398 if item.as_rule() == Rule::attribute_args {
399 args = parse_attribute_args(item)?;
400 }
401 }
402
403 Ok(Attribute::new(
404 name,
405 args,
406 Span::new(span.start(), span.end()),
407 ))
408}
409
410fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
412 let mut args = vec![];
413
414 for item in pair.into_inner() {
415 if item.as_rule() == Rule::attribute_arg {
416 let arg = parse_attribute_arg(item)?;
417 args.push(arg);
418 }
419 }
420
421 Ok(args)
422}
423
424fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
426 let span = pair.as_span();
427 let mut inner = pair.into_inner();
428
429 let first = inner.next().unwrap();
430
431 if let Some(second) = inner.next() {
433 let name = Ident::new(
435 first.as_str(),
436 Span::new(first.as_span().start(), first.as_span().end()),
437 );
438 let value = parse_attribute_value(second)?;
439 Ok(AttributeArg::named(
440 name,
441 value,
442 Span::new(span.start(), span.end()),
443 ))
444 } else {
445 let value = parse_attribute_value(first)?;
447 Ok(AttributeArg::positional(
448 value,
449 Span::new(span.start(), span.end()),
450 ))
451 }
452}
453
454fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
456 match pair.as_rule() {
457 Rule::string_literal => {
458 let s = pair.as_str();
459 let unquoted = &s[1..s.len() - 1];
461 Ok(AttributeValue::String(unquoted.to_string()))
462 }
463 Rule::number_literal => {
464 let s = pair.as_str();
465 if s.contains('.') {
466 Ok(AttributeValue::Float(s.parse().unwrap()))
467 } else {
468 Ok(AttributeValue::Int(s.parse().unwrap()))
469 }
470 }
471 Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
472 Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
473 Rule::dotted_identifier => {
474 Ok(AttributeValue::String(pair.as_str().to_string()))
476 }
477 Rule::function_call => {
478 let mut inner = pair.into_inner();
479 let name = SmolStr::new(inner.next().unwrap().as_str());
480 let mut args = vec![];
481 for item in inner {
482 args.push(parse_attribute_value(item)?);
483 }
484 Ok(AttributeValue::Function(name, args))
485 }
486 Rule::field_ref_list => {
487 let refs: Vec<SmolStr> = pair
488 .into_inner()
489 .map(|p| SmolStr::new(p.as_str()))
490 .collect();
491 Ok(AttributeValue::FieldRefList(refs))
492 }
493 Rule::array_literal => {
494 let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
495 Ok(AttributeValue::Array(values?))
496 }
497 Rule::attribute_value => {
498 parse_attribute_value(pair.into_inner().next().unwrap())
500 }
501 _ => {
502 Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
504 }
505 }
506}
507
508fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
510 let mut inner = pair.into_inner();
511
512 let name = inner.next().unwrap().as_str();
513 let sql = inner.next().unwrap().as_str();
514
515 let sql_content = sql
517 .trim_start_matches("\"\"\"")
518 .trim_end_matches("\"\"\"")
519 .trim();
520
521 Ok(RawSql::new(name, sql_content))
522}
523
524fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
526 let span = pair.as_span();
527 let mut inner = pair.into_inner();
528
529 let name_pair = inner.next().unwrap();
530 let name = Ident::new(
531 name_pair.as_str(),
532 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
533 );
534
535 let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
536
537 for item in inner {
538 match item.as_rule() {
539 Rule::server_group_item => {
540 let inner_item = item.into_inner().next().unwrap();
542 match inner_item.as_rule() {
543 Rule::server_def => {
544 let server = parse_server(inner_item)?;
545 server_group.add_server(server);
546 }
547 Rule::model_attribute => {
548 let attr = parse_attribute(inner_item)?;
549 server_group.add_attribute(attr);
550 }
551 _ => {}
552 }
553 }
554 Rule::server_def => {
555 let server = parse_server(item)?;
556 server_group.add_server(server);
557 }
558 Rule::model_attribute => {
559 let attr = parse_attribute(item)?;
560 server_group.add_attribute(attr);
561 }
562 _ => {}
563 }
564 }
565
566 Ok(server_group)
567}
568
569fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
571 let span = pair.as_span();
572 let mut inner = pair.into_inner();
573
574 let name_pair = inner.next().unwrap();
575 let name = Ident::new(
576 name_pair.as_str(),
577 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
578 );
579
580 let mut server = Server::new(name, Span::new(span.start(), span.end()));
581
582 for item in inner {
583 if item.as_rule() == Rule::server_property {
584 let prop = parse_server_property(item)?;
585 server.add_property(prop);
586 }
587 }
588
589 Ok(server)
590}
591
592fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
594 let span = pair.as_span();
595 let mut inner = pair.into_inner();
596
597 let key_pair = inner.next().unwrap();
598 let key = key_pair.as_str();
599
600 let value_pair = inner.next().unwrap();
601 let value = parse_server_property_value(value_pair)?;
602
603 Ok(ServerProperty::new(
604 key,
605 value,
606 Span::new(span.start(), span.end()),
607 ))
608}
609
610fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Generator> {
612 let span = pair.as_span();
613 let mut inner = pair.into_inner();
614
615 let name = inner.next().unwrap().as_str();
616 let mut generator = Generator::new(name, Span::new(span.start(), span.end()));
617
618 for prop in inner {
619 if prop.as_rule() == Rule::datasource_property {
620 let mut prop_inner = prop.into_inner();
621 let key = prop_inner.next().unwrap().as_str();
622 let value_pair = prop_inner.next().unwrap();
623
624 match key {
625 "provider" => {
626 let s = extract_datasource_string(&value_pair);
627 generator.provider = Some(SmolStr::new(s));
628 }
629 "output" => {
630 let s = extract_datasource_string(&value_pair);
631 generator.output = Some(SmolStr::new(s));
632 }
633 "generate" => {
634 generator.generate = parse_generator_toggle(&value_pair);
635 }
636 _ => {
637 let val = parse_generator_value(&value_pair);
638 generator.properties.insert(SmolStr::new(key), val);
639 }
640 }
641 }
642 }
643
644 Ok(generator)
645}
646
647fn parse_generator_toggle(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorToggle {
649 match pair.as_rule() {
650 Rule::env_function => {
651 let env_var = pair
652 .clone()
653 .into_inner()
654 .next()
655 .map(|p| {
656 let s = p.as_str();
657 SmolStr::new(&s[1..s.len() - 1])
658 })
659 .unwrap_or_default();
660 GeneratorToggle::Env(env_var)
661 }
662 Rule::datasource_value => {
663 let inner = pair.clone().into_inner().next().unwrap();
664 parse_generator_toggle(&inner)
665 }
666 _ => {
667 let s = pair.as_str().trim().trim_matches('"');
668 match s {
669 "true" => GeneratorToggle::Literal(true),
670 "false" => GeneratorToggle::Literal(false),
671 _ => GeneratorToggle::Literal(false),
672 }
673 }
674 }
675}
676
677fn parse_generator_value(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorValue {
679 match pair.as_rule() {
680 Rule::env_function => {
681 let env_var = pair
682 .clone()
683 .into_inner()
684 .next()
685 .map(|p| {
686 let s = p.as_str();
687 SmolStr::new(&s[1..s.len() - 1])
688 })
689 .unwrap_or_default();
690 GeneratorValue::Env(env_var)
691 }
692 Rule::datasource_value => {
693 let inner = pair.clone().into_inner().next().unwrap();
694 parse_generator_value(&inner)
695 }
696 Rule::string_literal => {
697 let s = pair.as_str();
698 GeneratorValue::String(SmolStr::new(&s[1..s.len() - 1]))
699 }
700 _ => {
701 let s = pair.as_str().trim().trim_matches('"');
702 match s {
703 "true" => GeneratorValue::Bool(true),
704 "false" => GeneratorValue::Bool(false),
705 _ => GeneratorValue::Ident(SmolStr::new(s)),
706 }
707 }
708 }
709}
710
711fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Datasource> {
713 let span = pair.as_span();
714 let mut inner = pair.into_inner();
715
716 let name_pair = inner.next().unwrap();
717 let name = name_pair.as_str();
718
719 let mut datasource = Datasource::new(
720 name,
721 DatabaseProvider::PostgreSQL,
722 Span::new(span.start(), span.end()),
723 );
724
725 for prop in inner {
726 if prop.as_rule() == Rule::datasource_property {
727 let mut prop_inner = prop.into_inner();
728 let key = prop_inner.next().unwrap().as_str();
729 let value_pair = prop_inner.next().unwrap();
730
731 match key {
732 "provider" => {
733 let provider_str = extract_datasource_string(&value_pair);
734 if let Some(provider) = DatabaseProvider::from_str(&provider_str) {
735 datasource.provider = provider;
736 }
737 }
738 "url" => {
739 match value_pair.as_rule() {
740 Rule::env_function => {
741 let env_var = value_pair
743 .into_inner()
744 .next()
745 .map(|p| {
746 let s = p.as_str();
747 s[1..s.len() - 1].to_string()
748 })
749 .unwrap_or_default();
750 datasource.url_env = Some(SmolStr::new(env_var));
751 }
752 Rule::string_literal => {
753 let s = value_pair.as_str();
754 let url = &s[1..s.len() - 1];
755 datasource.url = Some(SmolStr::new(url));
756 }
757 _ => {}
758 }
759 }
760 "extensions" => {
761 if value_pair.as_rule() == Rule::extension_array {
762 for ext_item in value_pair.into_inner() {
763 if ext_item.as_rule() == Rule::extension_item {
764 let ext = parse_extension_item(
765 ext_item,
766 Span::new(span.start(), span.end()),
767 )?;
768 datasource.add_extension(ext);
769 }
770 }
771 }
772 }
773 _ => {
774 let value_str = extract_datasource_string(&value_pair);
776 datasource.add_property(key, value_str);
777 }
778 }
779 }
780 }
781
782 Ok(datasource)
783}
784
785fn parse_extension_item(
787 pair: pest::iterators::Pair<'_, Rule>,
788 span: Span,
789) -> SchemaResult<PostgresExtension> {
790 let mut inner = pair.into_inner();
791 let name = inner.next().unwrap().as_str();
792 let mut ext = PostgresExtension::new(name, span);
793
794 if let Some(args_pair) = inner.next()
796 && args_pair.as_rule() == Rule::extension_args
797 {
798 for arg in args_pair.into_inner() {
799 if arg.as_rule() == Rule::extension_arg {
800 let mut arg_inner = arg.into_inner();
801 let arg_key = arg_inner.next().unwrap().as_str();
802 let arg_value_pair = arg_inner.next().unwrap();
803 let arg_value = {
804 let s = arg_value_pair.as_str();
805 &s[1..s.len() - 1]
806 };
807
808 match arg_key {
809 "schema" => {
810 ext = ext.with_schema(arg_value);
811 }
812 "version" => {
813 ext = ext.with_version(arg_value);
814 }
815 _ => {}
816 }
817 }
818 }
819 }
820
821 Ok(ext)
822}
823
824fn extract_datasource_string(pair: &pest::iterators::Pair<'_, Rule>) -> String {
826 match pair.as_rule() {
827 Rule::string_literal => {
828 let s = pair.as_str();
829 s[1..s.len() - 1].to_string()
830 }
831 Rule::identifier => pair.as_str().to_string(),
832 Rule::datasource_value => {
833 if let Some(inner) = pair.clone().into_inner().next() {
834 extract_datasource_string(&inner)
835 } else {
836 pair.as_str().to_string()
837 }
838 }
839 _ => pair.as_str().to_string(),
840 }
841}
842
843fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
845 match pair.as_rule() {
846 Rule::string_literal => {
847 let s = pair.as_str();
848 s[1..s.len() - 1].to_string()
849 }
850 Rule::attribute_value => {
851 if let Some(inner) = pair.into_inner().next() {
853 extract_string_from_arg(inner)
854 } else {
855 String::new()
856 }
857 }
858 _ => pair.as_str().to_string(),
859 }
860}
861
862fn parse_server_property_value(
864 pair: pest::iterators::Pair<'_, Rule>,
865) -> SchemaResult<ServerPropertyValue> {
866 match pair.as_rule() {
867 Rule::string_literal => {
868 let s = pair.as_str();
869 let unquoted = &s[1..s.len() - 1];
871 Ok(ServerPropertyValue::String(unquoted.to_string()))
872 }
873 Rule::number_literal => {
874 let s = pair.as_str();
875 Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
876 }
877 Rule::boolean_literal => Ok(ServerPropertyValue::Boolean(pair.as_str() == "true")),
878 Rule::identifier => Ok(ServerPropertyValue::Identifier(pair.as_str().to_string())),
879 Rule::function_call => {
880 let mut inner = pair.into_inner();
882 let func_name = inner.next().unwrap().as_str();
883 if func_name == "env"
884 && let Some(arg) = inner.next()
885 {
886 let var_name = extract_string_from_arg(arg);
887 return Ok(ServerPropertyValue::EnvVar(var_name));
888 }
889 Ok(ServerPropertyValue::Identifier(func_name.to_string()))
891 }
892 Rule::array_literal => {
893 let values: Result<Vec<_>, _> =
894 pair.into_inner().map(parse_server_property_value).collect();
895 Ok(ServerPropertyValue::Array(values?))
896 }
897 Rule::attribute_value => {
898 parse_server_property_value(pair.into_inner().next().unwrap())
900 }
901 _ => {
902 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
904 }
905 }
906}
907
908fn parse_policy(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Policy> {
910 let span = pair.as_span();
911 let mut inner = pair.into_inner();
912
913 let name_pair = inner.next().unwrap();
915 let name = Ident::new(
916 name_pair.as_str(),
917 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
918 );
919
920 let table_pair = inner.next().unwrap();
922 let table = Ident::new(
923 table_pair.as_str(),
924 Span::new(table_pair.as_span().start(), table_pair.as_span().end()),
925 );
926
927 let mut policy = Policy::new(name, table, Span::new(span.start(), span.end()));
928 policy.commands = vec![];
930
931 for item in inner {
932 match item.as_rule() {
933 Rule::policy_item => {
934 let inner_item = item.into_inner().next().unwrap();
935 parse_policy_item(&mut policy, inner_item)?;
936 }
937 Rule::policy_for
938 | Rule::policy_to
939 | Rule::policy_as
940 | Rule::policy_using
941 | Rule::policy_check => {
942 parse_policy_item(&mut policy, item)?;
943 }
944 _ => {}
945 }
946 }
947
948 if policy.commands.is_empty() {
950 policy.commands.push(PolicyCommand::All);
951 }
952
953 Ok(policy)
954}
955
956fn parse_policy_item(
958 policy: &mut Policy,
959 pair: pest::iterators::Pair<'_, Rule>,
960) -> SchemaResult<()> {
961 match pair.as_rule() {
962 Rule::policy_for => {
963 let inner = pair.into_inner().next().unwrap();
964 match inner.as_rule() {
965 Rule::policy_command => {
966 if let Some(cmd) = PolicyCommand::from_str(inner.as_str()) {
967 policy.add_command(cmd);
968 }
969 }
970 Rule::policy_command_list => {
971 for cmd_pair in inner.into_inner() {
972 if cmd_pair.as_rule() == Rule::policy_command
973 && let Some(cmd) = PolicyCommand::from_str(cmd_pair.as_str())
974 {
975 policy.add_command(cmd);
976 }
977 }
978 }
979 _ => {}
980 }
981 }
982 Rule::policy_to => {
983 let inner = pair.into_inner().next().unwrap();
984 match inner.as_rule() {
985 Rule::identifier => {
986 policy.add_role(inner.as_str());
987 }
988 Rule::policy_role_list => {
989 for role_pair in inner.into_inner() {
990 if role_pair.as_rule() == Rule::identifier {
991 policy.add_role(role_pair.as_str());
992 }
993 }
994 }
995 _ => {}
996 }
997 }
998 Rule::policy_as => {
999 let inner = pair.into_inner().next().unwrap();
1000 if inner.as_rule() == Rule::policy_type
1001 && let Some(policy_type) = PolicyType::from_str(inner.as_str())
1002 {
1003 policy.policy_type = policy_type;
1004 }
1005 }
1006 Rule::policy_using => {
1007 let inner = pair.into_inner().next().unwrap();
1008 let expr = extract_policy_expression(&inner);
1009 policy.using_expr = Some(expr);
1010 }
1011 Rule::policy_check => {
1012 let inner = pair.into_inner().next().unwrap();
1013 let expr = extract_policy_expression(&inner);
1014 policy.check_expr = Some(expr);
1015 }
1016 Rule::policy_mssql_schema => {
1017 let inner = pair.into_inner().next().unwrap();
1018 if inner.as_rule() == Rule::string_literal {
1019 let s = inner.as_str();
1020 let schema = &s[1..s.len() - 1]; policy.mssql_schema = Some(SmolStr::new(schema));
1022 }
1023 }
1024 Rule::policy_mssql_block => {
1025 let inner = pair.into_inner().next().unwrap();
1026 match inner.as_rule() {
1027 Rule::mssql_block_op => {
1028 if let Some(op) = MssqlBlockOperation::from_str(inner.as_str()) {
1029 policy.add_mssql_block_operation(op);
1030 }
1031 }
1032 Rule::mssql_block_op_list => {
1033 for op_pair in inner.into_inner() {
1034 if op_pair.as_rule() == Rule::mssql_block_op
1035 && let Some(op) = MssqlBlockOperation::from_str(op_pair.as_str())
1036 {
1037 policy.add_mssql_block_operation(op);
1038 }
1039 }
1040 }
1041 _ => {}
1042 }
1043 }
1044 _ => {}
1045 }
1046 Ok(())
1047}
1048
1049fn extract_policy_expression(pair: &pest::iterators::Pair<'_, Rule>) -> String {
1051 let s = pair.as_str();
1052 match pair.as_rule() {
1053 Rule::multiline_string => {
1054 s.trim_start_matches("\"\"\"")
1056 .trim_end_matches("\"\"\"")
1057 .trim()
1058 .to_string()
1059 }
1060 Rule::string_literal => {
1061 s[1..s.len() - 1].to_string()
1063 }
1064 _ => s.to_string(),
1065 }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use super::*;
1071
1072 #[test]
1075 fn test_parse_simple_model() {
1076 let schema = parse_schema(
1077 r#"
1078 model User {
1079 id Int @id @auto
1080 email String @unique
1081 name String?
1082 }
1083 "#,
1084 )
1085 .unwrap();
1086
1087 assert_eq!(schema.models.len(), 1);
1088 let user = schema.get_model("User").unwrap();
1089 assert_eq!(user.fields.len(), 3);
1090 assert!(user.get_field("id").unwrap().is_id());
1091 assert!(user.get_field("email").unwrap().is_unique());
1092 assert!(user.get_field("name").unwrap().is_optional());
1093 }
1094
1095 #[test]
1096 fn test_parse_model_name() {
1097 let schema = parse_schema(
1098 r#"
1099 model BlogPost {
1100 id Int @id
1101 }
1102 "#,
1103 )
1104 .unwrap();
1105
1106 assert!(schema.get_model("BlogPost").is_some());
1107 }
1108
1109 #[test]
1110 fn test_parse_multiple_models() {
1111 let schema = parse_schema(
1112 r#"
1113 model User {
1114 id Int @id
1115 }
1116
1117 model Post {
1118 id Int @id
1119 }
1120
1121 model Comment {
1122 id Int @id
1123 }
1124 "#,
1125 )
1126 .unwrap();
1127
1128 assert_eq!(schema.models.len(), 3);
1129 assert!(schema.get_model("User").is_some());
1130 assert!(schema.get_model("Post").is_some());
1131 assert!(schema.get_model("Comment").is_some());
1132 }
1133
1134 #[test]
1137 fn test_parse_all_scalar_types() {
1138 let schema = parse_schema(
1139 r#"
1140 model AllTypes {
1141 id Int @id
1142 big BigInt
1143 float_f Float
1144 decimal Decimal
1145 str String
1146 bool Boolean
1147 datetime DateTime
1148 date Date
1149 time Time
1150 json Json
1151 bytes Bytes
1152 uuid Uuid
1153 cuid Cuid
1154 cuid2 Cuid2
1155 nanoid NanoId
1156 ulid Ulid
1157 }
1158 "#,
1159 )
1160 .unwrap();
1161
1162 let model = schema.get_model("AllTypes").unwrap();
1163 assert_eq!(model.fields.len(), 16);
1164
1165 assert!(matches!(
1166 model.get_field("id").unwrap().field_type,
1167 FieldType::Scalar(ScalarType::Int)
1168 ));
1169 assert!(matches!(
1170 model.get_field("big").unwrap().field_type,
1171 FieldType::Scalar(ScalarType::BigInt)
1172 ));
1173 assert!(matches!(
1174 model.get_field("str").unwrap().field_type,
1175 FieldType::Scalar(ScalarType::String)
1176 ));
1177 assert!(matches!(
1178 model.get_field("bool").unwrap().field_type,
1179 FieldType::Scalar(ScalarType::Boolean)
1180 ));
1181 assert!(matches!(
1182 model.get_field("datetime").unwrap().field_type,
1183 FieldType::Scalar(ScalarType::DateTime)
1184 ));
1185 assert!(matches!(
1186 model.get_field("uuid").unwrap().field_type,
1187 FieldType::Scalar(ScalarType::Uuid)
1188 ));
1189 assert!(matches!(
1190 model.get_field("cuid").unwrap().field_type,
1191 FieldType::Scalar(ScalarType::Cuid)
1192 ));
1193 assert!(matches!(
1194 model.get_field("cuid2").unwrap().field_type,
1195 FieldType::Scalar(ScalarType::Cuid2)
1196 ));
1197 assert!(matches!(
1198 model.get_field("nanoid").unwrap().field_type,
1199 FieldType::Scalar(ScalarType::NanoId)
1200 ));
1201 assert!(matches!(
1202 model.get_field("ulid").unwrap().field_type,
1203 FieldType::Scalar(ScalarType::Ulid)
1204 ));
1205 }
1206
1207 #[test]
1208 fn test_parse_optional_field() {
1209 let schema = parse_schema(
1210 r#"
1211 model User {
1212 id Int @id
1213 bio String?
1214 age Int?
1215 }
1216 "#,
1217 )
1218 .unwrap();
1219
1220 let user = schema.get_model("User").unwrap();
1221 assert!(!user.get_field("id").unwrap().is_optional());
1222 assert!(user.get_field("bio").unwrap().is_optional());
1223 assert!(user.get_field("age").unwrap().is_optional());
1224 }
1225
1226 #[test]
1227 fn test_parse_list_field() {
1228 let schema = parse_schema(
1229 r#"
1230 model User {
1231 id Int @id
1232 tags String[]
1233 posts Post[]
1234 }
1235 "#,
1236 )
1237 .unwrap();
1238
1239 let user = schema.get_model("User").unwrap();
1240 assert!(user.get_field("tags").unwrap().is_list());
1241 assert!(user.get_field("posts").unwrap().is_list());
1242 }
1243
1244 #[test]
1245 fn test_parse_optional_list_field() {
1246 let schema = parse_schema(
1247 r#"
1248 model User {
1249 id Int @id
1250 metadata String[]?
1251 }
1252 "#,
1253 )
1254 .unwrap();
1255
1256 let user = schema.get_model("User").unwrap();
1257 let metadata = user.get_field("metadata").unwrap();
1258 assert!(metadata.is_list());
1259 assert!(metadata.is_optional());
1260 }
1261
1262 #[test]
1265 fn test_parse_id_attribute() {
1266 let schema = parse_schema(
1267 r#"
1268 model User {
1269 id Int @id
1270 }
1271 "#,
1272 )
1273 .unwrap();
1274
1275 let user = schema.get_model("User").unwrap();
1276 assert!(user.get_field("id").unwrap().is_id());
1277 }
1278
1279 #[test]
1280 fn test_parse_unique_attribute() {
1281 let schema = parse_schema(
1282 r#"
1283 model User {
1284 id Int @id
1285 email String @unique
1286 }
1287 "#,
1288 )
1289 .unwrap();
1290
1291 let user = schema.get_model("User").unwrap();
1292 assert!(user.get_field("email").unwrap().is_unique());
1293 }
1294
1295 #[test]
1296 fn test_parse_default_int() {
1297 let schema = parse_schema(
1298 r#"
1299 model Counter {
1300 id Int @id
1301 count Int @default(0)
1302 }
1303 "#,
1304 )
1305 .unwrap();
1306
1307 let counter = schema.get_model("Counter").unwrap();
1308 let count_field = counter.get_field("count").unwrap();
1309 let attrs = count_field.extract_attributes();
1310 assert!(attrs.default.is_some());
1311 assert_eq!(attrs.default.unwrap().as_int(), Some(0));
1312 }
1313
1314 #[test]
1315 fn test_parse_default_string() {
1316 let schema = parse_schema(
1317 r#"
1318 model User {
1319 id Int @id
1320 status String @default("active")
1321 }
1322 "#,
1323 )
1324 .unwrap();
1325
1326 let user = schema.get_model("User").unwrap();
1327 let status = user.get_field("status").unwrap();
1328 let attrs = status.extract_attributes();
1329 assert!(attrs.default.is_some());
1330 assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
1331 }
1332
1333 #[test]
1334 fn test_parse_default_boolean() {
1335 let schema = parse_schema(
1336 r#"
1337 model Post {
1338 id Int @id
1339 published Boolean @default(false)
1340 }
1341 "#,
1342 )
1343 .unwrap();
1344
1345 let post = schema.get_model("Post").unwrap();
1346 let published = post.get_field("published").unwrap();
1347 let attrs = published.extract_attributes();
1348 assert!(attrs.default.is_some());
1349 assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
1350 }
1351
1352 #[test]
1353 fn test_parse_default_function() {
1354 let schema = parse_schema(
1355 r#"
1356 model User {
1357 id Int @id
1358 createdAt DateTime @default(now())
1359 }
1360 "#,
1361 )
1362 .unwrap();
1363
1364 let user = schema.get_model("User").unwrap();
1365 let created_at = user.get_field("createdAt").unwrap();
1366 let attrs = created_at.extract_attributes();
1367 assert!(attrs.default.is_some());
1368 if let Some(AttributeValue::Function(name, _)) = attrs.default {
1369 assert_eq!(name.as_str(), "now");
1370 } else {
1371 panic!("Expected function default");
1372 }
1373 }
1374
1375 #[test]
1376 fn test_parse_updated_at_attribute() {
1377 let schema = parse_schema(
1378 r#"
1379 model User {
1380 id Int @id
1381 updatedAt DateTime @updated_at
1382 }
1383 "#,
1384 )
1385 .unwrap();
1386
1387 let user = schema.get_model("User").unwrap();
1388 let updated_at = user.get_field("updatedAt").unwrap();
1389 let attrs = updated_at.extract_attributes();
1390 assert!(attrs.is_updated_at);
1391 }
1392
1393 #[test]
1394 fn test_parse_map_attribute() {
1395 let schema = parse_schema(
1396 r#"
1397 model User {
1398 id Int @id
1399 email String @map("email_address")
1400 }
1401 "#,
1402 )
1403 .unwrap();
1404
1405 let user = schema.get_model("User").unwrap();
1406 let email = user.get_field("email").unwrap();
1407 let attrs = email.extract_attributes();
1408 assert_eq!(attrs.map, Some("email_address".to_string()));
1409 }
1410
1411 #[test]
1412 fn test_parse_multiple_attributes() {
1413 let schema = parse_schema(
1414 r#"
1415 model User {
1416 id Int @id @auto
1417 email String @unique @index
1418 }
1419 "#,
1420 )
1421 .unwrap();
1422
1423 let user = schema.get_model("User").unwrap();
1424 let id = user.get_field("id").unwrap();
1425 let email = user.get_field("email").unwrap();
1426
1427 let id_attrs = id.extract_attributes();
1428 assert!(id_attrs.is_id);
1429 assert!(id_attrs.is_auto);
1430
1431 let email_attrs = email.extract_attributes();
1432 assert!(email_attrs.is_unique);
1433 assert!(email_attrs.is_indexed);
1434 }
1435
1436 #[test]
1439 fn test_parse_model_map_attribute() {
1440 let schema = parse_schema(
1441 r#"
1442 model User {
1443 id Int @id
1444
1445 @@map("app_users")
1446 }
1447 "#,
1448 )
1449 .unwrap();
1450
1451 let user = schema.get_model("User").unwrap();
1452 assert_eq!(user.table_name(), "app_users");
1453 }
1454
1455 #[test]
1456 fn test_parse_model_index_attribute() {
1457 let schema = parse_schema(
1458 r#"
1459 model User {
1460 id Int @id
1461 email String
1462 name String
1463
1464 @@index([email, name])
1465 }
1466 "#,
1467 )
1468 .unwrap();
1469
1470 let user = schema.get_model("User").unwrap();
1471 assert!(user.has_attribute("index"));
1472 }
1473
1474 #[test]
1475 fn test_parse_composite_primary_key() {
1476 let schema = parse_schema(
1477 r#"
1478 model PostTag {
1479 postId Int
1480 tagId Int
1481
1482 @@id([postId, tagId])
1483 }
1484 "#,
1485 )
1486 .unwrap();
1487
1488 let post_tag = schema.get_model("PostTag").unwrap();
1489 assert!(post_tag.has_attribute("id"));
1490 }
1491
1492 #[test]
1495 fn test_parse_enum() {
1496 let schema = parse_schema(
1497 r#"
1498 enum Role {
1499 User
1500 Admin
1501 Moderator
1502 }
1503 "#,
1504 )
1505 .unwrap();
1506
1507 assert_eq!(schema.enums.len(), 1);
1508 let role = schema.get_enum("Role").unwrap();
1509 assert_eq!(role.variants.len(), 3);
1510 }
1511
1512 #[test]
1513 fn test_parse_enum_variant_names() {
1514 let schema = parse_schema(
1515 r#"
1516 enum Status {
1517 Pending
1518 Active
1519 Completed
1520 Cancelled
1521 }
1522 "#,
1523 )
1524 .unwrap();
1525
1526 let status = schema.get_enum("Status").unwrap();
1527 assert!(status.get_variant("Pending").is_some());
1528 assert!(status.get_variant("Active").is_some());
1529 assert!(status.get_variant("Completed").is_some());
1530 assert!(status.get_variant("Cancelled").is_some());
1531 }
1532
1533 #[test]
1534 fn test_parse_enum_with_map() {
1535 let schema = parse_schema(
1536 r#"
1537 enum Role {
1538 User @map("USER")
1539 Admin @map("ADMINISTRATOR")
1540 }
1541 "#,
1542 )
1543 .unwrap();
1544
1545 let role = schema.get_enum("Role").unwrap();
1546 let user_variant = role.get_variant("User").unwrap();
1547 assert_eq!(user_variant.db_value(), "USER");
1548
1549 let admin_variant = role.get_variant("Admin").unwrap();
1550 assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1551 }
1552
1553 #[test]
1556 fn test_parse_one_to_many_relation() {
1557 let schema = parse_schema(
1558 r#"
1559 model User {
1560 id Int @id
1561 posts Post[]
1562 }
1563
1564 model Post {
1565 id Int @id
1566 authorId Int
1567 author User @relation(fields: [authorId], references: [id])
1568 }
1569 "#,
1570 )
1571 .unwrap();
1572
1573 let user = schema.get_model("User").unwrap();
1574 let post = schema.get_model("Post").unwrap();
1575
1576 assert!(user.get_field("posts").unwrap().is_list());
1577 assert!(post.get_field("author").unwrap().is_relation());
1578 }
1579
1580 #[test]
1581 fn test_parse_relation_with_actions() {
1582 let schema = parse_schema(
1583 r#"
1584 model Post {
1585 id Int @id
1586 authorId Int
1587 author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1588 }
1589
1590 model User {
1591 id Int @id
1592 posts Post[]
1593 }
1594 "#,
1595 )
1596 .unwrap();
1597
1598 let post = schema.get_model("Post").unwrap();
1599 let author = post.get_field("author").unwrap();
1600 let attrs = author.extract_attributes();
1601
1602 assert!(attrs.relation.is_some());
1603 let rel = attrs.relation.unwrap();
1604 assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1605 assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1606 }
1607
1608 #[test]
1611 fn test_parse_model_documentation() {
1612 let schema = parse_schema(
1613 r#"/// Represents a user in the system
1614model User {
1615 id Int @id
1616}"#,
1617 )
1618 .unwrap();
1619
1620 let user = schema.get_model("User").unwrap();
1621 if let Some(doc) = &user.documentation {
1624 assert!(doc.text.contains("user"));
1625 }
1626 }
1627
1628 #[test]
1631 fn test_parse_complete_schema() {
1632 let schema = parse_schema(
1633 r#"
1634 /// User model
1635 model User {
1636 id Int @id @auto
1637 email String @unique
1638 name String?
1639 role Role @default(User)
1640 posts Post[]
1641 profile Profile?
1642 createdAt DateTime @default(now())
1643 updatedAt DateTime @updated_at
1644
1645 @@map("users")
1646 @@index([email])
1647 }
1648
1649 model Post {
1650 id Int @id @auto
1651 title String
1652 content String?
1653 published Boolean @default(false)
1654 authorId Int
1655 author User @relation(fields: [authorId], references: [id])
1656 tags Tag[]
1657 createdAt DateTime @default(now())
1658
1659 @@index([authorId])
1660 }
1661
1662 model Profile {
1663 id Int @id @auto
1664 bio String?
1665 userId Int @unique
1666 user User @relation(fields: [userId], references: [id])
1667 }
1668
1669 model Tag {
1670 id Int @id @auto
1671 name String @unique
1672 posts Post[]
1673 }
1674
1675 enum Role {
1676 User
1677 Admin
1678 Moderator
1679 }
1680 "#,
1681 )
1682 .unwrap();
1683
1684 assert_eq!(schema.models.len(), 4);
1686 assert!(schema.get_model("User").is_some());
1687 assert!(schema.get_model("Post").is_some());
1688 assert!(schema.get_model("Profile").is_some());
1689 assert!(schema.get_model("Tag").is_some());
1690
1691 assert_eq!(schema.enums.len(), 1);
1693 assert!(schema.get_enum("Role").is_some());
1694
1695 let user = schema.get_model("User").unwrap();
1697 assert_eq!(user.table_name(), "users");
1698 assert_eq!(user.fields.len(), 8);
1699 assert!(user.has_attribute("index"));
1700
1701 let post = schema.get_model("Post").unwrap();
1703 assert!(post.get_field("author").unwrap().is_relation());
1704 }
1705
1706 #[test]
1709 fn test_parse_invalid_syntax() {
1710 let result = parse_schema("model { broken }");
1711 assert!(result.is_err());
1712 }
1713
1714 #[test]
1715 fn test_parse_empty_schema() {
1716 let schema = parse_schema("").unwrap();
1717 assert!(schema.models.is_empty());
1718 assert!(schema.enums.is_empty());
1719 }
1720
1721 #[test]
1722 fn test_parse_whitespace_only() {
1723 let schema = parse_schema(" \n\t \n ").unwrap();
1724 assert!(schema.models.is_empty());
1725 }
1726
1727 #[test]
1728 fn test_parse_comments_only() {
1729 let schema = parse_schema(
1730 r#"
1731 // This is a comment
1732 // Another comment
1733 "#,
1734 )
1735 .unwrap();
1736 assert!(schema.models.is_empty());
1737 }
1738
1739 #[test]
1742 fn test_parse_model_with_no_fields() {
1743 let result = parse_schema(
1745 r#"
1746 model Empty {
1747 }
1748 "#,
1749 );
1750 let _ = result;
1752 }
1753
1754 #[test]
1755 fn test_parse_long_identifier() {
1756 let schema = parse_schema(
1757 r#"
1758 model VeryLongModelNameThatIsStillValid {
1759 someVeryLongFieldNameThatShouldWork Int @id
1760 }
1761 "#,
1762 )
1763 .unwrap();
1764
1765 assert!(
1766 schema
1767 .get_model("VeryLongModelNameThatIsStillValid")
1768 .is_some()
1769 );
1770 }
1771
1772 #[test]
1773 fn test_parse_underscore_identifiers() {
1774 let schema = parse_schema(
1775 r#"
1776 model user_account {
1777 user_id Int @id
1778 created_at DateTime
1779 }
1780 "#,
1781 )
1782 .unwrap();
1783
1784 let model = schema.get_model("user_account").unwrap();
1785 assert!(model.get_field("user_id").is_some());
1786 assert!(model.get_field("created_at").is_some());
1787 }
1788
1789 #[test]
1790 fn test_parse_negative_default() {
1791 let schema = parse_schema(
1792 r#"
1793 model Config {
1794 id Int @id
1795 minValue Int @default(-100)
1796 }
1797 "#,
1798 )
1799 .unwrap();
1800
1801 let config = schema.get_model("Config").unwrap();
1802 let min_value = config.get_field("minValue").unwrap();
1803 let attrs = min_value.extract_attributes();
1804 assert!(attrs.default.is_some());
1805 }
1806
1807 #[test]
1808 fn test_parse_float_default() {
1809 let schema = parse_schema(
1810 r#"
1811 model Product {
1812 id Int @id
1813 price Float @default(9.99)
1814 }
1815 "#,
1816 )
1817 .unwrap();
1818
1819 let product = schema.get_model("Product").unwrap();
1820 let price = product.get_field("price").unwrap();
1821 let attrs = price.extract_attributes();
1822 assert!(attrs.default.is_some());
1823 }
1824
1825 #[test]
1828 fn test_parse_simple_server_group() {
1829 let schema = parse_schema(
1830 r#"
1831 serverGroup MainCluster {
1832 server primary {
1833 url = "postgres://localhost/db"
1834 role = "primary"
1835 }
1836 }
1837 "#,
1838 )
1839 .unwrap();
1840
1841 assert_eq!(schema.server_groups.len(), 1);
1842 let cluster = schema.get_server_group("MainCluster").unwrap();
1843 assert_eq!(cluster.servers.len(), 1);
1844 assert!(cluster.servers.contains_key("primary"));
1845 }
1846
1847 #[test]
1848 fn test_parse_server_group_with_multiple_servers() {
1849 let schema = parse_schema(
1850 r#"
1851 serverGroup ReadReplicas {
1852 server primary {
1853 url = "postgres://primary.db.com/app"
1854 role = "primary"
1855 weight = 1
1856 }
1857
1858 server replica1 {
1859 url = "postgres://replica1.db.com/app"
1860 role = "replica"
1861 weight = 2
1862 }
1863
1864 server replica2 {
1865 url = "postgres://replica2.db.com/app"
1866 role = "replica"
1867 weight = 2
1868 }
1869 }
1870 "#,
1871 )
1872 .unwrap();
1873
1874 let cluster = schema.get_server_group("ReadReplicas").unwrap();
1875 assert_eq!(cluster.servers.len(), 3);
1876
1877 let primary = cluster.servers.get("primary").unwrap();
1878 assert_eq!(primary.role(), Some(ServerRole::Primary));
1879 assert_eq!(primary.weight(), Some(1));
1880
1881 let replica1 = cluster.servers.get("replica1").unwrap();
1882 assert_eq!(replica1.role(), Some(ServerRole::Replica));
1883 assert_eq!(replica1.weight(), Some(2));
1884 }
1885
1886 #[test]
1887 fn test_parse_server_group_with_attributes() {
1888 let schema = parse_schema(
1889 r#"
1890 serverGroup ProductionCluster {
1891 @@strategy(ReadReplica)
1892 @@loadBalance(RoundRobin)
1893
1894 server main {
1895 url = "postgres://main/db"
1896 role = "primary"
1897 }
1898 }
1899 "#,
1900 )
1901 .unwrap();
1902
1903 let cluster = schema.get_server_group("ProductionCluster").unwrap();
1904 assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1905 assert!(
1906 cluster
1907 .attributes
1908 .iter()
1909 .any(|a| a.name.name == "loadBalance")
1910 );
1911 }
1912
1913 #[test]
1914 fn test_parse_server_group_with_env_vars() {
1915 let schema = parse_schema(
1916 r#"
1917 serverGroup EnvCluster {
1918 server db1 {
1919 url = env("PRIMARY_DB_URL")
1920 role = "primary"
1921 }
1922 }
1923 "#,
1924 )
1925 .unwrap();
1926
1927 let cluster = schema.get_server_group("EnvCluster").unwrap();
1928 let server = cluster.servers.get("db1").unwrap();
1929
1930 if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1932 assert_eq!(var, "PRIMARY_DB_URL");
1933 } else {
1934 panic!("Expected env var for url property");
1935 }
1936 }
1937
1938 #[test]
1939 fn test_parse_server_group_with_boolean_property() {
1940 let schema = parse_schema(
1941 r#"
1942 serverGroup TestCluster {
1943 server replica {
1944 url = "postgres://replica/db"
1945 role = "replica"
1946 readOnly = true
1947 }
1948 }
1949 "#,
1950 )
1951 .unwrap();
1952
1953 let cluster = schema.get_server_group("TestCluster").unwrap();
1954 let server = cluster.servers.get("replica").unwrap();
1955 assert!(server.is_read_only());
1956 }
1957
1958 #[test]
1959 fn test_parse_server_group_with_numeric_properties() {
1960 let schema = parse_schema(
1961 r#"
1962 serverGroup NumericCluster {
1963 server db {
1964 url = "postgres://localhost/db"
1965 weight = 5
1966 priority = 1
1967 maxConnections = 100
1968 }
1969 }
1970 "#,
1971 )
1972 .unwrap();
1973
1974 let cluster = schema.get_server_group("NumericCluster").unwrap();
1975 let server = cluster.servers.get("db").unwrap();
1976
1977 assert_eq!(server.weight(), Some(5));
1978 assert_eq!(server.priority(), Some(1));
1979 assert_eq!(server.max_connections(), Some(100));
1980 }
1981
1982 #[test]
1983 fn test_parse_server_group_with_region() {
1984 let schema = parse_schema(
1985 r#"
1986 serverGroup GeoCluster {
1987 server usEast {
1988 url = "postgres://us-east.db.com/app"
1989 role = "replica"
1990 region = "us-east-1"
1991 }
1992
1993 server usWest {
1994 url = "postgres://us-west.db.com/app"
1995 role = "replica"
1996 region = "us-west-2"
1997 }
1998 }
1999 "#,
2000 )
2001 .unwrap();
2002
2003 let cluster = schema.get_server_group("GeoCluster").unwrap();
2004
2005 let us_east = cluster.servers.get("usEast").unwrap();
2006 assert_eq!(us_east.region(), Some("us-east-1"));
2007
2008 let us_west = cluster.servers.get("usWest").unwrap();
2009 assert_eq!(us_west.region(), Some("us-west-2"));
2010
2011 let us_east_servers = cluster.servers_in_region("us-east-1");
2013 assert_eq!(us_east_servers.len(), 1);
2014 }
2015
2016 #[test]
2017 fn test_parse_multiple_server_groups() {
2018 let schema = parse_schema(
2019 r#"
2020 serverGroup Cluster1 {
2021 server db1 {
2022 url = "postgres://db1/app"
2023 }
2024 }
2025
2026 serverGroup Cluster2 {
2027 server db2 {
2028 url = "postgres://db2/app"
2029 }
2030 }
2031
2032 serverGroup Cluster3 {
2033 server db3 {
2034 url = "postgres://db3/app"
2035 }
2036 }
2037 "#,
2038 )
2039 .unwrap();
2040
2041 assert_eq!(schema.server_groups.len(), 3);
2042 assert!(schema.get_server_group("Cluster1").is_some());
2043 assert!(schema.get_server_group("Cluster2").is_some());
2044 assert!(schema.get_server_group("Cluster3").is_some());
2045 }
2046
2047 #[test]
2048 fn test_parse_schema_with_models_and_server_groups() {
2049 let schema = parse_schema(
2050 r#"
2051 model User {
2052 id Int @id @auto
2053 email String @unique
2054 }
2055
2056 serverGroup Database {
2057 @@strategy(ReadReplica)
2058
2059 server primary {
2060 url = env("DATABASE_URL")
2061 role = "primary"
2062 }
2063 }
2064
2065 model Post {
2066 id Int @id @auto
2067 title String
2068 authorId Int
2069 }
2070 "#,
2071 )
2072 .unwrap();
2073
2074 assert_eq!(schema.models.len(), 2);
2075 assert!(schema.get_model("User").is_some());
2076 assert!(schema.get_model("Post").is_some());
2077
2078 assert_eq!(schema.server_groups.len(), 1);
2079 assert!(schema.get_server_group("Database").is_some());
2080 }
2081
2082 #[test]
2083 fn test_parse_server_group_with_health_check() {
2084 let schema = parse_schema(
2085 r#"
2086 serverGroup HealthyCluster {
2087 server monitored {
2088 url = "postgres://localhost/db"
2089 healthCheck = "/health"
2090 }
2091 }
2092 "#,
2093 )
2094 .unwrap();
2095
2096 let cluster = schema.get_server_group("HealthyCluster").unwrap();
2097 let server = cluster.servers.get("monitored").unwrap();
2098 assert_eq!(server.health_check(), Some("/health"));
2099 }
2100
2101 #[test]
2102 fn test_server_group_failover_order() {
2103 let schema = parse_schema(
2104 r#"
2105 serverGroup FailoverCluster {
2106 server db3 {
2107 url = "postgres://db3/app"
2108 priority = 3
2109 }
2110
2111 server db1 {
2112 url = "postgres://db1/app"
2113 priority = 1
2114 }
2115
2116 server db2 {
2117 url = "postgres://db2/app"
2118 priority = 2
2119 }
2120 }
2121 "#,
2122 )
2123 .unwrap();
2124
2125 let cluster = schema.get_server_group("FailoverCluster").unwrap();
2126 let ordered = cluster.failover_order();
2127
2128 assert_eq!(ordered[0].name.name.as_str(), "db1");
2129 assert_eq!(ordered[1].name.name.as_str(), "db2");
2130 assert_eq!(ordered[2].name.name.as_str(), "db3");
2131 }
2132
2133 #[test]
2134 fn test_server_group_names() {
2135 let schema = parse_schema(
2136 r#"
2137 serverGroup Alpha {
2138 server s1 { url = "pg://a" }
2139 }
2140 serverGroup Beta {
2141 server s2 { url = "pg://b" }
2142 }
2143 "#,
2144 )
2145 .unwrap();
2146
2147 let names: Vec<_> = schema.server_group_names().collect();
2148 assert_eq!(names.len(), 2);
2149 assert!(names.contains(&"Alpha"));
2150 assert!(names.contains(&"Beta"));
2151 }
2152
2153 #[test]
2156 fn test_parse_simple_policy() {
2157 let schema = parse_schema(
2158 r#"
2159 policy UserReadOwn on User {
2160 for SELECT
2161 using "id = current_user_id()"
2162 }
2163 "#,
2164 )
2165 .unwrap();
2166
2167 assert_eq!(schema.policies.len(), 1);
2168 let policy = schema.get_policy("UserReadOwn").unwrap();
2169 assert_eq!(policy.name(), "UserReadOwn");
2170 assert_eq!(policy.table(), "User");
2171 assert!(policy.applies_to(PolicyCommand::Select));
2172 assert!(!policy.applies_to(PolicyCommand::Insert));
2173 assert_eq!(policy.using_expr.as_deref(), Some("id = current_user_id()"));
2174 }
2175
2176 #[test]
2177 fn test_parse_policy_with_multiple_commands() {
2178 let schema = parse_schema(
2179 r#"
2180 policy UserModify on User {
2181 for [SELECT, UPDATE, DELETE]
2182 using "id = auth.uid()"
2183 }
2184 "#,
2185 )
2186 .unwrap();
2187
2188 let policy = schema.get_policy("UserModify").unwrap();
2189 assert!(policy.applies_to(PolicyCommand::Select));
2190 assert!(policy.applies_to(PolicyCommand::Update));
2191 assert!(policy.applies_to(PolicyCommand::Delete));
2192 assert!(!policy.applies_to(PolicyCommand::Insert));
2193 }
2194
2195 #[test]
2196 fn test_parse_policy_with_all_command() {
2197 let schema = parse_schema(
2198 r#"
2199 policy UserAll on User {
2200 for ALL
2201 using "true"
2202 }
2203 "#,
2204 )
2205 .unwrap();
2206
2207 let policy = schema.get_policy("UserAll").unwrap();
2208 assert!(policy.applies_to(PolicyCommand::Select));
2209 assert!(policy.applies_to(PolicyCommand::Insert));
2210 assert!(policy.applies_to(PolicyCommand::Update));
2211 assert!(policy.applies_to(PolicyCommand::Delete));
2212 }
2213
2214 #[test]
2215 fn test_parse_policy_with_roles() {
2216 let schema = parse_schema(
2217 r#"
2218 policy AuthenticatedRead on Document {
2219 for SELECT
2220 to authenticated
2221 using "true"
2222 }
2223 "#,
2224 )
2225 .unwrap();
2226
2227 let policy = schema.get_policy("AuthenticatedRead").unwrap();
2228 let roles = policy.effective_roles();
2229 assert!(roles.contains(&"authenticated"));
2230 }
2231
2232 #[test]
2233 fn test_parse_policy_with_multiple_roles() {
2234 let schema = parse_schema(
2235 r#"
2236 policy AdminModerator on Post {
2237 for [UPDATE, DELETE]
2238 to [admin, moderator]
2239 using "true"
2240 }
2241 "#,
2242 )
2243 .unwrap();
2244
2245 let policy = schema.get_policy("AdminModerator").unwrap();
2246 let roles = policy.effective_roles();
2247 assert!(roles.contains(&"admin"));
2248 assert!(roles.contains(&"moderator"));
2249 }
2250
2251 #[test]
2252 fn test_parse_policy_restrictive() {
2253 let schema = parse_schema(
2254 r#"
2255 policy OrgRestriction on Document {
2256 as RESTRICTIVE
2257 for SELECT
2258 using "org_id = current_org_id()"
2259 }
2260 "#,
2261 )
2262 .unwrap();
2263
2264 let policy = schema.get_policy("OrgRestriction").unwrap();
2265 assert!(policy.is_restrictive());
2266 assert!(!policy.is_permissive());
2267 }
2268
2269 #[test]
2270 fn test_parse_policy_permissive_explicit() {
2271 let schema = parse_schema(
2272 r#"
2273 policy Permissive on User {
2274 as PERMISSIVE
2275 for SELECT
2276 using "true"
2277 }
2278 "#,
2279 )
2280 .unwrap();
2281
2282 let policy = schema.get_policy("Permissive").unwrap();
2283 assert!(policy.is_permissive());
2284 }
2285
2286 #[test]
2287 fn test_parse_policy_with_check() {
2288 let schema = parse_schema(
2289 r#"
2290 policy InsertOwn on Post {
2291 for INSERT
2292 to authenticated
2293 check "author_id = current_user_id()"
2294 }
2295 "#,
2296 )
2297 .unwrap();
2298
2299 let policy = schema.get_policy("InsertOwn").unwrap();
2300 assert!(policy.applies_to(PolicyCommand::Insert));
2301 assert_eq!(
2302 policy.check_expr.as_deref(),
2303 Some("author_id = current_user_id()")
2304 );
2305 assert!(policy.using_expr.is_none());
2306 }
2307
2308 #[test]
2309 fn test_parse_policy_with_both_expressions() {
2310 let schema = parse_schema(
2311 r#"
2312 policy UpdateOwn on Post {
2313 for UPDATE
2314 using "author_id = current_user_id()"
2315 check "author_id = current_user_id()"
2316 }
2317 "#,
2318 )
2319 .unwrap();
2320
2321 let policy = schema.get_policy("UpdateOwn").unwrap();
2322 assert!(policy.using_expr.is_some());
2323 assert!(policy.check_expr.is_some());
2324 }
2325
2326 #[test]
2327 fn test_parse_policy_multiline_expression() {
2328 let schema = parse_schema(
2329 r#"
2330 policy ComplexCheck on Document {
2331 for SELECT
2332 using """
2333 (is_public = true)
2334 OR (owner_id = current_user_id())
2335 OR (id IN (SELECT document_id FROM shares WHERE user_id = current_user_id()))
2336 """
2337 }
2338 "#,
2339 )
2340 .unwrap();
2341
2342 let policy = schema.get_policy("ComplexCheck").unwrap();
2343 assert!(policy.using_expr.is_some());
2344 let expr = policy.using_expr.as_ref().unwrap();
2345 assert!(expr.contains("is_public = true"));
2346 assert!(expr.contains("owner_id = current_user_id()"));
2347 assert!(expr.contains("SELECT document_id FROM shares"));
2348 }
2349
2350 #[test]
2351 fn test_parse_multiple_policies() {
2352 let schema = parse_schema(
2353 r#"
2354 policy UserRead on User {
2355 for SELECT
2356 using "true"
2357 }
2358
2359 policy UserInsert on User {
2360 for INSERT
2361 check "id = current_user_id()"
2362 }
2363
2364 policy PostRead on Post {
2365 for SELECT
2366 using "published = true OR author_id = current_user_id()"
2367 }
2368 "#,
2369 )
2370 .unwrap();
2371
2372 assert_eq!(schema.policies.len(), 3);
2373 assert!(schema.get_policy("UserRead").is_some());
2374 assert!(schema.get_policy("UserInsert").is_some());
2375 assert!(schema.get_policy("PostRead").is_some());
2376 }
2377
2378 #[test]
2379 fn test_parse_policy_with_model() {
2380 let schema = parse_schema(
2381 r#"
2382 model User {
2383 id Int @id @auto
2384 email String @unique
2385 }
2386
2387 policy UserReadOwn on User {
2388 for SELECT
2389 to authenticated
2390 using "id = auth.uid()"
2391 }
2392 "#,
2393 )
2394 .unwrap();
2395
2396 assert_eq!(schema.models.len(), 1);
2397 assert_eq!(schema.policies.len(), 1);
2398
2399 let policies = schema.policies_for("User");
2400 assert_eq!(policies.len(), 1);
2401 assert_eq!(policies[0].name(), "UserReadOwn");
2402 }
2403
2404 #[test]
2405 fn test_parse_policies_for_multiple_models() {
2406 let schema = parse_schema(
2407 r#"
2408 policy UserPolicy1 on User {
2409 for SELECT
2410 using "true"
2411 }
2412
2413 policy UserPolicy2 on User {
2414 for INSERT
2415 check "true"
2416 }
2417
2418 policy PostPolicy on Post {
2419 for SELECT
2420 using "true"
2421 }
2422 "#,
2423 )
2424 .unwrap();
2425
2426 assert_eq!(schema.policies_for("User").len(), 2);
2427 assert_eq!(schema.policies_for("Post").len(), 1);
2428 assert!(schema.has_policies("User"));
2429 assert!(schema.has_policies("Post"));
2430 assert!(!schema.has_policies("Comment"));
2431 }
2432
2433 #[test]
2434 fn test_parse_policy_default_all_command() {
2435 let schema = parse_schema(
2436 r#"
2437 policy DefaultAll on User {
2438 using "id = current_user_id()"
2439 }
2440 "#,
2441 )
2442 .unwrap();
2443
2444 let policy = schema.get_policy("DefaultAll").unwrap();
2445 assert!(policy.applies_to(PolicyCommand::All));
2447 }
2448
2449 #[test]
2450 fn test_parse_policy_case_insensitive_keywords() {
2451 let schema = parse_schema(
2452 r#"
2453 policy CaseTest on User {
2454 for select
2455 as permissive
2456 using "true"
2457 }
2458 "#,
2459 )
2460 .unwrap();
2461
2462 let policy = schema.get_policy("CaseTest").unwrap();
2463 assert!(policy.applies_to(PolicyCommand::Select));
2464 assert!(policy.is_permissive());
2465 }
2466
2467 #[test]
2468 fn test_parse_policy_sql_generation() {
2469 let schema = parse_schema(
2470 r#"
2471 model User {
2472 id Int @id
2473
2474 @@map("users")
2475 }
2476
2477 policy ReadOwn on User {
2478 for SELECT
2479 to authenticated
2480 using "id = auth.uid()"
2481 }
2482 "#,
2483 )
2484 .unwrap();
2485
2486 let policy = schema.get_policy("ReadOwn").unwrap();
2487 let sql = policy.to_sql("users");
2488
2489 assert!(sql.contains("CREATE POLICY ReadOwn ON users"));
2490 assert!(sql.contains("FOR SELECT"));
2491 assert!(sql.contains("TO authenticated"));
2492 assert!(sql.contains("USING (id = auth.uid())"));
2493 }
2494
2495 #[test]
2496 fn test_parse_policy_restrictive_sql() {
2497 let schema = parse_schema(
2498 r#"
2499 policy OrgBoundary on Document {
2500 as RESTRICTIVE
2501 for ALL
2502 using "org_id = current_org_id()"
2503 }
2504 "#,
2505 )
2506 .unwrap();
2507
2508 let policy = schema.get_policy("OrgBoundary").unwrap();
2509 let sql = policy.to_sql("documents");
2510
2511 assert!(sql.contains("AS RESTRICTIVE"));
2512 }
2513
2514 #[test]
2515 fn test_parse_policy_with_documentation() {
2516 let schema = parse_schema(
2517 r#"
2518 /// Users can only read their own data
2519 policy UserIsolation on User {
2520 for SELECT
2521 using "id = current_user_id()"
2522 }
2523 "#,
2524 )
2525 .unwrap();
2526
2527 let policy = schema.get_policy("UserIsolation").unwrap();
2528 if let Some(doc) = &policy.documentation {
2529 assert!(doc.text.contains("their own data"));
2530 }
2531 }
2532
2533 #[test]
2534 fn test_parse_complex_rls_schema() {
2535 let schema = parse_schema(
2536 r#"
2537 model Organization {
2538 id Int @id @auto
2539 name String
2540 }
2541
2542 model User {
2543 id Int @id @auto
2544 orgId Int
2545 email String @unique
2546 }
2547
2548 model Document {
2549 id Int @id @auto
2550 title String
2551 ownerId Int
2552 orgId Int
2553 isPublic Boolean @default(false)
2554 }
2555
2556 /// Organization-level isolation
2557 policy OrgIsolation on Document {
2558 as RESTRICTIVE
2559 for ALL
2560 using "org_id = current_setting('app.current_org')::int"
2561 }
2562
2563 /// Users can read public documents
2564 policy PublicRead on Document {
2565 for SELECT
2566 using "is_public = true"
2567 }
2568
2569 /// Users can read their own documents
2570 policy OwnerRead on Document {
2571 for SELECT
2572 to authenticated
2573 using "owner_id = auth.uid()"
2574 }
2575
2576 /// Users can only modify their own documents
2577 policy OwnerModify on Document {
2578 for [UPDATE, DELETE]
2579 to authenticated
2580 using "owner_id = auth.uid()"
2581 check "owner_id = auth.uid()"
2582 }
2583
2584 /// Users can create documents in their org
2585 policy OrgInsert on Document {
2586 for INSERT
2587 to authenticated
2588 check "org_id = current_setting('app.current_org')::int"
2589 }
2590 "#,
2591 )
2592 .unwrap();
2593
2594 assert_eq!(schema.models.len(), 3);
2595 assert_eq!(schema.policies.len(), 5);
2596
2597 let org_iso = schema.get_policy("OrgIsolation").unwrap();
2599 assert!(org_iso.is_restrictive());
2600
2601 let doc_policies = schema.policies_for("Document");
2603 assert_eq!(doc_policies.len(), 5);
2604 }
2605
2606 #[test]
2609 fn test_parse_policy_with_mssql_schema() {
2610 let schema = parse_schema(
2611 r#"
2612 policy UserFilter on User {
2613 for SELECT
2614 using "UserId = @UserId"
2615 mssqlSchema "RLS"
2616 }
2617 "#,
2618 )
2619 .unwrap();
2620
2621 let policy = schema.get_policy("UserFilter").unwrap();
2622 assert_eq!(policy.mssql_schema(), "RLS");
2623 }
2624
2625 #[test]
2626 fn test_parse_policy_with_mssql_block_single() {
2627 let schema = parse_schema(
2628 r#"
2629 policy UserInsert on User {
2630 for INSERT
2631 check "UserId = @UserId"
2632 mssqlBlock AFTER_INSERT
2633 }
2634 "#,
2635 )
2636 .unwrap();
2637
2638 let policy = schema.get_policy("UserInsert").unwrap();
2639 assert_eq!(policy.mssql_block_operations.len(), 1);
2640 assert_eq!(
2641 policy.mssql_block_operations[0],
2642 MssqlBlockOperation::AfterInsert
2643 );
2644 }
2645
2646 #[test]
2647 fn test_parse_policy_with_mssql_block_list() {
2648 let schema = parse_schema(
2649 r#"
2650 policy UserModify on User {
2651 for [INSERT, UPDATE, DELETE]
2652 check "UserId = @UserId"
2653 mssqlBlock [AFTER_INSERT, AFTER_UPDATE, BEFORE_DELETE]
2654 }
2655 "#,
2656 )
2657 .unwrap();
2658
2659 let policy = schema.get_policy("UserModify").unwrap();
2660 assert_eq!(policy.mssql_block_operations.len(), 3);
2661 assert!(
2662 policy
2663 .mssql_block_operations
2664 .contains(&MssqlBlockOperation::AfterInsert)
2665 );
2666 assert!(
2667 policy
2668 .mssql_block_operations
2669 .contains(&MssqlBlockOperation::AfterUpdate)
2670 );
2671 assert!(
2672 policy
2673 .mssql_block_operations
2674 .contains(&MssqlBlockOperation::BeforeDelete)
2675 );
2676 }
2677
2678 #[test]
2679 fn test_parse_policy_full_mssql_config() {
2680 let schema = parse_schema(
2681 r#"
2682 policy TenantIsolation on Order {
2683 for ALL
2684 using "TenantId = @TenantId"
2685 check "TenantId = @TenantId"
2686 mssqlSchema "MultiTenant"
2687 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2688 }
2689 "#,
2690 )
2691 .unwrap();
2692
2693 let policy = schema.get_policy("TenantIsolation").unwrap();
2694
2695 assert!(policy.applies_to(PolicyCommand::All));
2697 assert!(policy.using_expr.is_some());
2698 assert!(policy.check_expr.is_some());
2699
2700 assert_eq!(policy.mssql_schema(), "MultiTenant");
2702 assert_eq!(policy.mssql_block_operations.len(), 4);
2703
2704 let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
2706 assert!(mssql.schema_sql.contains("MultiTenant"));
2707 assert!(mssql.function_sql.contains("fn_TenantIsolation_predicate"));
2708 }
2709
2710 #[test]
2711 fn test_parse_policy_mssql_block_case_variants() {
2712 let schema = parse_schema(
2714 r#"
2715 policy Test1 on User {
2716 for INSERT
2717 check "true"
2718 mssqlBlock after_insert
2719 }
2720 "#,
2721 )
2722 .unwrap();
2723
2724 let policy = schema.get_policy("Test1").unwrap();
2725 assert_eq!(policy.mssql_block_operations.len(), 1);
2726 assert_eq!(
2727 policy.mssql_block_operations[0],
2728 MssqlBlockOperation::AfterInsert
2729 );
2730 }
2731
2732 #[test]
2733 fn test_parse_mixed_postgres_mssql_schema() {
2734 let schema = parse_schema(
2735 r#"
2736 model User {
2737 id Int @id @auto
2738 email String @unique
2739 }
2740
2741 // PostgreSQL-style policy (works on both, MSSQL uses defaults)
2742 policy UserReadOwn on User {
2743 for SELECT
2744 to authenticated
2745 using "id = current_user_id()"
2746 }
2747
2748 // MSSQL-optimized policy with explicit settings
2749 policy UserModifyOwn on User {
2750 for [INSERT, UPDATE, DELETE]
2751 to authenticated
2752 using "id = current_user_id()"
2753 check "id = current_user_id()"
2754 mssqlSchema "Security"
2755 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2756 }
2757 "#,
2758 )
2759 .unwrap();
2760
2761 assert_eq!(schema.policies.len(), 2);
2762
2763 let read_policy = schema.get_policy("UserReadOwn").unwrap();
2765 assert_eq!(read_policy.mssql_schema(), "Security"); assert!(read_policy.mssql_block_operations.is_empty()); let modify_policy = schema.get_policy("UserModifyOwn").unwrap();
2770 assert_eq!(modify_policy.mssql_schema(), "Security");
2771 assert_eq!(modify_policy.mssql_block_operations.len(), 4);
2772
2773 let pg_sql = read_policy.to_postgres_sql("users");
2775 assert!(pg_sql.contains("CREATE POLICY UserReadOwn ON users"));
2776
2777 let mssql = modify_policy.to_mssql_sql("dbo.Users", "id");
2779 assert!(mssql.policy_sql.contains("Security.UserModifyOwn"));
2780 }
2781}