1mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{
17 MssqlBlockOperation, Policy, PolicyCommand, PolicyType, Server, ServerGroup, ServerProperty,
18 ServerPropertyValue,
19};
20
21pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
23 debug!(input_len = input.len(), "parse_schema() starting");
24 let pairs = PraxParser::parse(Rule::schema, input)
25 .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
26
27 let mut schema = Schema::new();
28 let mut current_doc: Option<Documentation> = None;
29
30 let schema_pair = pairs.into_iter().next().unwrap();
32
33 for pair in schema_pair.into_inner() {
34 match pair.as_rule() {
35 Rule::documentation => {
36 let span = pair.as_span();
37 let text = pair
38 .into_inner()
39 .map(|p| p.as_str().trim_start_matches("///").trim())
40 .collect::<Vec<_>>()
41 .join("\n");
42 current_doc = Some(Documentation::new(
43 text,
44 Span::new(span.start(), span.end()),
45 ));
46 }
47 Rule::model_def => {
48 let mut model = parse_model(pair)?;
49 if let Some(doc) = current_doc.take() {
50 model = model.with_documentation(doc);
51 }
52 schema.add_model(model);
53 }
54 Rule::enum_def => {
55 let mut e = parse_enum(pair)?;
56 if let Some(doc) = current_doc.take() {
57 e = e.with_documentation(doc);
58 }
59 schema.add_enum(e);
60 }
61 Rule::type_def => {
62 let mut t = parse_composite_type(pair)?;
63 if let Some(doc) = current_doc.take() {
64 t = t.with_documentation(doc);
65 }
66 schema.add_type(t);
67 }
68 Rule::view_def => {
69 let mut v = parse_view(pair)?;
70 if let Some(doc) = current_doc.take() {
71 v = v.with_documentation(doc);
72 }
73 schema.add_view(v);
74 }
75 Rule::raw_sql_def => {
76 let sql = parse_raw_sql(pair)?;
77 schema.add_raw_sql(sql);
78 }
79 Rule::server_group_def => {
80 let mut sg = parse_server_group(pair)?;
81 if let Some(doc) = current_doc.take() {
82 sg.set_documentation(doc);
83 }
84 schema.add_server_group(sg);
85 }
86 Rule::policy_def => {
87 let mut policy = parse_policy(pair)?;
88 if let Some(doc) = current_doc.take() {
89 policy = policy.with_documentation(doc);
90 }
91 schema.add_policy(policy);
92 }
93 Rule::datasource_def => {
94 let ds = parse_datasource(pair)?;
95 schema.set_datasource(ds);
96 current_doc = None;
97 }
98 Rule::generator_def => {
99 let generator = parse_generator(pair)?;
100 schema.add_generator(generator);
101 current_doc = None;
102 }
103 Rule::EOI => {}
104 _ => {}
105 }
106 }
107
108 info!(
109 models = schema.models.len(),
110 enums = schema.enums.len(),
111 types = schema.types.len(),
112 views = schema.views.len(),
113 generators = schema.generators.len(),
114 policies = schema.policies.len(),
115 "Schema parsed successfully"
116 );
117 Ok(schema)
118}
119
120pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
122 let path = path.as_ref();
123 info!(path = %path.display(), "Loading schema file");
124 let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
125 path: path.display().to_string(),
126 source: e,
127 })?;
128
129 parse_schema(&content)
130}
131
132fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
134 let span = pair.as_span();
135 let mut inner = pair.into_inner();
136
137 let name_pair = inner.next().unwrap();
138 let name = Ident::new(
139 name_pair.as_str(),
140 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
141 );
142
143 let mut model = Model::new(name, Span::new(span.start(), span.end()));
144
145 for item in inner {
146 match item.as_rule() {
147 Rule::field_def => {
148 let field = parse_field(item)?;
149 model.add_field(field);
150 }
151 Rule::model_attribute => {
152 let attr = parse_attribute(item)?;
153 model.attributes.push(attr);
154 }
155 Rule::model_body_item => {
156 let inner_item = item.into_inner().next().unwrap();
158 match inner_item.as_rule() {
159 Rule::field_def => {
160 let field = parse_field(inner_item)?;
161 model.add_field(field);
162 }
163 Rule::model_attribute => {
164 let attr = parse_attribute(inner_item)?;
165 model.attributes.push(attr);
166 }
167 _ => {}
168 }
169 }
170 _ => {}
171 }
172 }
173
174 Ok(model)
175}
176
177fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
179 let span = pair.as_span();
180 let mut inner = pair.into_inner();
181
182 let name_pair = inner.next().unwrap();
183 let name = Ident::new(
184 name_pair.as_str(),
185 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
186 );
187
188 let mut e = Enum::new(name, Span::new(span.start(), span.end()));
189
190 for item in inner {
191 match item.as_rule() {
192 Rule::enum_variant => {
193 let variant = parse_enum_variant(item)?;
194 e.add_variant(variant);
195 }
196 Rule::model_attribute => {
197 let attr = parse_attribute(item)?;
198 e.attributes.push(attr);
199 }
200 Rule::enum_body_item => {
201 let inner_item = item.into_inner().next().unwrap();
203 match inner_item.as_rule() {
204 Rule::enum_variant => {
205 let variant = parse_enum_variant(inner_item)?;
206 e.add_variant(variant);
207 }
208 Rule::model_attribute => {
209 let attr = parse_attribute(inner_item)?;
210 e.attributes.push(attr);
211 }
212 _ => {}
213 }
214 }
215 _ => {}
216 }
217 }
218
219 Ok(e)
220}
221
222fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
224 let span = pair.as_span();
225 let mut inner = pair.into_inner();
226
227 let name_pair = inner.next().unwrap();
228 let name = Ident::new(
229 name_pair.as_str(),
230 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
231 );
232
233 let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
234
235 for item in inner {
236 if item.as_rule() == Rule::field_attribute {
237 let attr = parse_attribute(item)?;
238 variant.attributes.push(attr);
239 }
240 }
241
242 Ok(variant)
243}
244
245fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
247 let span = pair.as_span();
248 let mut inner = pair.into_inner();
249
250 let name_pair = inner.next().unwrap();
251 let name = Ident::new(
252 name_pair.as_str(),
253 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
254 );
255
256 let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
257
258 for item in inner {
259 if item.as_rule() == Rule::field_def {
260 let field = parse_field(item)?;
261 t.add_field(field);
262 }
263 }
264
265 Ok(t)
266}
267
268fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
270 let span = pair.as_span();
271 let mut inner = pair.into_inner();
272
273 let name_pair = inner.next().unwrap();
274 let name = Ident::new(
275 name_pair.as_str(),
276 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
277 );
278
279 let mut v = View::new(name, Span::new(span.start(), span.end()));
280
281 for item in inner {
282 match item.as_rule() {
283 Rule::field_def => {
284 let field = parse_field(item)?;
285 v.add_field(field);
286 }
287 Rule::model_attribute => {
288 let attr = parse_attribute(item)?;
289 v.attributes.push(attr);
290 }
291 Rule::model_body_item => {
292 let inner_item = item.into_inner().next().unwrap();
294 match inner_item.as_rule() {
295 Rule::field_def => {
296 let field = parse_field(inner_item)?;
297 v.add_field(field);
298 }
299 Rule::model_attribute => {
300 let attr = parse_attribute(inner_item)?;
301 v.attributes.push(attr);
302 }
303 _ => {}
304 }
305 }
306 _ => {}
307 }
308 }
309
310 Ok(v)
311}
312
313fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
315 let span = pair.as_span();
316 let mut inner = pair.into_inner();
317
318 let name_pair = inner.next().unwrap();
319 let name = Ident::new(
320 name_pair.as_str(),
321 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
322 );
323
324 let type_pair = inner.next().unwrap();
325 let (field_type, modifier) = parse_field_type(type_pair)?;
326
327 let mut attributes = vec![];
328 for item in inner {
329 if item.as_rule() == Rule::field_attribute {
330 let attr = parse_attribute(item)?;
331 attributes.push(attr);
332 }
333 }
334
335 Ok(Field::new(
336 name,
337 field_type,
338 modifier,
339 attributes,
340 Span::new(span.start(), span.end()),
341 ))
342}
343
344fn parse_field_type(
346 pair: pest::iterators::Pair<'_, Rule>,
347) -> SchemaResult<(FieldType, TypeModifier)> {
348 let mut type_name = String::new();
349 let mut modifier = TypeModifier::Required;
350
351 for item in pair.into_inner() {
352 match item.as_rule() {
353 Rule::type_name => {
354 type_name = item.as_str().to_string();
355 }
356 Rule::optional_marker => {
357 modifier = if modifier == TypeModifier::List {
358 TypeModifier::OptionalList
359 } else {
360 TypeModifier::Optional
361 };
362 }
363 Rule::list_marker => {
364 modifier = if modifier == TypeModifier::Optional {
365 TypeModifier::OptionalList
366 } else {
367 TypeModifier::List
368 };
369 }
370 _ => {}
371 }
372 }
373
374 let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
375 FieldType::Scalar(scalar)
376 } else {
377 FieldType::Model(SmolStr::new(&type_name))
380 };
381
382 Ok((field_type, modifier))
383}
384
385fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
387 let span = pair.as_span();
388 let mut inner = pair.into_inner();
389
390 let name_pair = inner.next().unwrap();
391 let name = Ident::new(
392 name_pair.as_str(),
393 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
394 );
395
396 let mut args = vec![];
397 for item in inner {
398 if item.as_rule() == Rule::attribute_args {
399 args = parse_attribute_args(item)?;
400 }
401 }
402
403 Ok(Attribute::new(
404 name,
405 args,
406 Span::new(span.start(), span.end()),
407 ))
408}
409
410fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
412 let mut args = vec![];
413
414 for item in pair.into_inner() {
415 if item.as_rule() == Rule::attribute_arg {
416 let arg = parse_attribute_arg(item)?;
417 args.push(arg);
418 }
419 }
420
421 Ok(args)
422}
423
424fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
426 let span = pair.as_span();
427 let mut inner = pair.into_inner();
428
429 let first = inner.next().unwrap();
430
431 if let Some(second) = inner.next() {
433 let name = Ident::new(
435 first.as_str(),
436 Span::new(first.as_span().start(), first.as_span().end()),
437 );
438 let value = parse_attribute_value(second)?;
439 Ok(AttributeArg::named(
440 name,
441 value,
442 Span::new(span.start(), span.end()),
443 ))
444 } else {
445 let value = parse_attribute_value(first)?;
447 Ok(AttributeArg::positional(
448 value,
449 Span::new(span.start(), span.end()),
450 ))
451 }
452}
453
454fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
456 match pair.as_rule() {
457 Rule::string_literal => {
458 let s = pair.as_str();
459 let unquoted = &s[1..s.len() - 1];
461 Ok(AttributeValue::String(unquoted.to_string()))
462 }
463 Rule::number_literal => {
464 let s = pair.as_str();
465 if s.contains('.') {
466 Ok(AttributeValue::Float(s.parse().unwrap()))
467 } else {
468 Ok(AttributeValue::Int(s.parse().unwrap()))
469 }
470 }
471 Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
472 Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
473 Rule::function_call => {
474 let mut inner = pair.into_inner();
475 let name = SmolStr::new(inner.next().unwrap().as_str());
476 let mut args = vec![];
477 for item in inner {
478 args.push(parse_attribute_value(item)?);
479 }
480 Ok(AttributeValue::Function(name, args))
481 }
482 Rule::field_ref_list => {
483 let refs: Vec<SmolStr> = pair
484 .into_inner()
485 .map(|p| SmolStr::new(p.as_str()))
486 .collect();
487 Ok(AttributeValue::FieldRefList(refs))
488 }
489 Rule::array_literal => {
490 let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
491 Ok(AttributeValue::Array(values?))
492 }
493 Rule::attribute_value => {
494 parse_attribute_value(pair.into_inner().next().unwrap())
496 }
497 _ => {
498 Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
500 }
501 }
502}
503
504fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
506 let mut inner = pair.into_inner();
507
508 let name = inner.next().unwrap().as_str();
509 let sql = inner.next().unwrap().as_str();
510
511 let sql_content = sql
513 .trim_start_matches("\"\"\"")
514 .trim_end_matches("\"\"\"")
515 .trim();
516
517 Ok(RawSql::new(name, sql_content))
518}
519
520fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
522 let span = pair.as_span();
523 let mut inner = pair.into_inner();
524
525 let name_pair = inner.next().unwrap();
526 let name = Ident::new(
527 name_pair.as_str(),
528 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
529 );
530
531 let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
532
533 for item in inner {
534 match item.as_rule() {
535 Rule::server_group_item => {
536 let inner_item = item.into_inner().next().unwrap();
538 match inner_item.as_rule() {
539 Rule::server_def => {
540 let server = parse_server(inner_item)?;
541 server_group.add_server(server);
542 }
543 Rule::model_attribute => {
544 let attr = parse_attribute(inner_item)?;
545 server_group.add_attribute(attr);
546 }
547 _ => {}
548 }
549 }
550 Rule::server_def => {
551 let server = parse_server(item)?;
552 server_group.add_server(server);
553 }
554 Rule::model_attribute => {
555 let attr = parse_attribute(item)?;
556 server_group.add_attribute(attr);
557 }
558 _ => {}
559 }
560 }
561
562 Ok(server_group)
563}
564
565fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
567 let span = pair.as_span();
568 let mut inner = pair.into_inner();
569
570 let name_pair = inner.next().unwrap();
571 let name = Ident::new(
572 name_pair.as_str(),
573 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
574 );
575
576 let mut server = Server::new(name, Span::new(span.start(), span.end()));
577
578 for item in inner {
579 if item.as_rule() == Rule::server_property {
580 let prop = parse_server_property(item)?;
581 server.add_property(prop);
582 }
583 }
584
585 Ok(server)
586}
587
588fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
590 let span = pair.as_span();
591 let mut inner = pair.into_inner();
592
593 let key_pair = inner.next().unwrap();
594 let key = key_pair.as_str();
595
596 let value_pair = inner.next().unwrap();
597 let value = parse_server_property_value(value_pair)?;
598
599 Ok(ServerProperty::new(
600 key,
601 value,
602 Span::new(span.start(), span.end()),
603 ))
604}
605
606fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Generator> {
608 let span = pair.as_span();
609 let mut inner = pair.into_inner();
610
611 let name = inner.next().unwrap().as_str();
612 let mut generator = Generator::new(name, Span::new(span.start(), span.end()));
613
614 for prop in inner {
615 if prop.as_rule() == Rule::datasource_property {
616 let mut prop_inner = prop.into_inner();
617 let key = prop_inner.next().unwrap().as_str();
618 let value_pair = prop_inner.next().unwrap();
619
620 match key {
621 "provider" => {
622 let s = extract_datasource_string(&value_pair);
623 generator.provider = Some(SmolStr::new(s));
624 }
625 "output" => {
626 let s = extract_datasource_string(&value_pair);
627 generator.output = Some(SmolStr::new(s));
628 }
629 "generate" => {
630 generator.generate = parse_generator_toggle(&value_pair);
631 }
632 _ => {
633 let val = parse_generator_value(&value_pair);
634 generator.properties.insert(SmolStr::new(key), val);
635 }
636 }
637 }
638 }
639
640 Ok(generator)
641}
642
643fn parse_generator_toggle(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorToggle {
645 match pair.as_rule() {
646 Rule::env_function => {
647 let env_var = pair
648 .clone()
649 .into_inner()
650 .next()
651 .map(|p| {
652 let s = p.as_str();
653 SmolStr::new(&s[1..s.len() - 1])
654 })
655 .unwrap_or_default();
656 GeneratorToggle::Env(env_var)
657 }
658 Rule::datasource_value => {
659 let inner = pair.clone().into_inner().next().unwrap();
660 parse_generator_toggle(&inner)
661 }
662 _ => {
663 let s = pair.as_str().trim().trim_matches('"');
664 match s {
665 "true" => GeneratorToggle::Literal(true),
666 "false" => GeneratorToggle::Literal(false),
667 _ => GeneratorToggle::Literal(false),
668 }
669 }
670 }
671}
672
673fn parse_generator_value(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorValue {
675 match pair.as_rule() {
676 Rule::env_function => {
677 let env_var = pair
678 .clone()
679 .into_inner()
680 .next()
681 .map(|p| {
682 let s = p.as_str();
683 SmolStr::new(&s[1..s.len() - 1])
684 })
685 .unwrap_or_default();
686 GeneratorValue::Env(env_var)
687 }
688 Rule::datasource_value => {
689 let inner = pair.clone().into_inner().next().unwrap();
690 parse_generator_value(&inner)
691 }
692 Rule::string_literal => {
693 let s = pair.as_str();
694 GeneratorValue::String(SmolStr::new(&s[1..s.len() - 1]))
695 }
696 _ => {
697 let s = pair.as_str().trim().trim_matches('"');
698 match s {
699 "true" => GeneratorValue::Bool(true),
700 "false" => GeneratorValue::Bool(false),
701 _ => GeneratorValue::Ident(SmolStr::new(s)),
702 }
703 }
704 }
705}
706
707fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Datasource> {
709 let span = pair.as_span();
710 let mut inner = pair.into_inner();
711
712 let name_pair = inner.next().unwrap();
713 let name = name_pair.as_str();
714
715 let mut datasource = Datasource::new(
716 name,
717 DatabaseProvider::PostgreSQL,
718 Span::new(span.start(), span.end()),
719 );
720
721 for prop in inner {
722 if prop.as_rule() == Rule::datasource_property {
723 let mut prop_inner = prop.into_inner();
724 let key = prop_inner.next().unwrap().as_str();
725 let value_pair = prop_inner.next().unwrap();
726
727 match key {
728 "provider" => {
729 let provider_str = extract_datasource_string(&value_pair);
730 if let Some(provider) = DatabaseProvider::from_str(&provider_str) {
731 datasource.provider = provider;
732 }
733 }
734 "url" => {
735 match value_pair.as_rule() {
736 Rule::env_function => {
737 let env_var = value_pair
739 .into_inner()
740 .next()
741 .map(|p| {
742 let s = p.as_str();
743 s[1..s.len() - 1].to_string()
744 })
745 .unwrap_or_default();
746 datasource.url_env = Some(SmolStr::new(env_var));
747 }
748 Rule::string_literal => {
749 let s = value_pair.as_str();
750 let url = &s[1..s.len() - 1];
751 datasource.url = Some(SmolStr::new(url));
752 }
753 _ => {}
754 }
755 }
756 "extensions" => {
757 if value_pair.as_rule() == Rule::extension_array {
758 for ext_item in value_pair.into_inner() {
759 if ext_item.as_rule() == Rule::extension_item {
760 let ext = parse_extension_item(
761 ext_item,
762 Span::new(span.start(), span.end()),
763 )?;
764 datasource.add_extension(ext);
765 }
766 }
767 }
768 }
769 _ => {
770 let value_str = extract_datasource_string(&value_pair);
772 datasource.add_property(key, value_str);
773 }
774 }
775 }
776 }
777
778 Ok(datasource)
779}
780
781fn parse_extension_item(
783 pair: pest::iterators::Pair<'_, Rule>,
784 span: Span,
785) -> SchemaResult<PostgresExtension> {
786 let mut inner = pair.into_inner();
787 let name = inner.next().unwrap().as_str();
788 let mut ext = PostgresExtension::new(name, span);
789
790 if let Some(args_pair) = inner.next() {
792 if args_pair.as_rule() == Rule::extension_args {
793 for arg in args_pair.into_inner() {
794 if arg.as_rule() == Rule::extension_arg {
795 let mut arg_inner = arg.into_inner();
796 let arg_key = arg_inner.next().unwrap().as_str();
797 let arg_value_pair = arg_inner.next().unwrap();
798 let arg_value = {
799 let s = arg_value_pair.as_str();
800 &s[1..s.len() - 1]
801 };
802
803 match arg_key {
804 "schema" => {
805 ext = ext.with_schema(arg_value);
806 }
807 "version" => {
808 ext = ext.with_version(arg_value);
809 }
810 _ => {}
811 }
812 }
813 }
814 }
815 }
816
817 Ok(ext)
818}
819
820fn extract_datasource_string(pair: &pest::iterators::Pair<'_, Rule>) -> String {
822 match pair.as_rule() {
823 Rule::string_literal => {
824 let s = pair.as_str();
825 s[1..s.len() - 1].to_string()
826 }
827 Rule::identifier => pair.as_str().to_string(),
828 Rule::datasource_value => {
829 if let Some(inner) = pair.clone().into_inner().next() {
830 extract_datasource_string(&inner)
831 } else {
832 pair.as_str().to_string()
833 }
834 }
835 _ => pair.as_str().to_string(),
836 }
837}
838
839fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
841 match pair.as_rule() {
842 Rule::string_literal => {
843 let s = pair.as_str();
844 s[1..s.len() - 1].to_string()
845 }
846 Rule::attribute_value => {
847 if let Some(inner) = pair.into_inner().next() {
849 extract_string_from_arg(inner)
850 } else {
851 String::new()
852 }
853 }
854 _ => pair.as_str().to_string(),
855 }
856}
857
858fn parse_server_property_value(
860 pair: pest::iterators::Pair<'_, Rule>,
861) -> SchemaResult<ServerPropertyValue> {
862 match pair.as_rule() {
863 Rule::string_literal => {
864 let s = pair.as_str();
865 let unquoted = &s[1..s.len() - 1];
867 Ok(ServerPropertyValue::String(unquoted.to_string()))
868 }
869 Rule::number_literal => {
870 let s = pair.as_str();
871 Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
872 }
873 Rule::boolean_literal => Ok(ServerPropertyValue::Boolean(pair.as_str() == "true")),
874 Rule::identifier => Ok(ServerPropertyValue::Identifier(pair.as_str().to_string())),
875 Rule::function_call => {
876 let mut inner = pair.into_inner();
878 let func_name = inner.next().unwrap().as_str();
879 if func_name == "env" {
880 if let Some(arg) = inner.next() {
881 let var_name = extract_string_from_arg(arg);
882 return Ok(ServerPropertyValue::EnvVar(var_name));
883 }
884 }
885 Ok(ServerPropertyValue::Identifier(func_name.to_string()))
887 }
888 Rule::array_literal => {
889 let values: Result<Vec<_>, _> =
890 pair.into_inner().map(parse_server_property_value).collect();
891 Ok(ServerPropertyValue::Array(values?))
892 }
893 Rule::attribute_value => {
894 parse_server_property_value(pair.into_inner().next().unwrap())
896 }
897 _ => {
898 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
900 }
901 }
902}
903
904fn parse_policy(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Policy> {
906 let span = pair.as_span();
907 let mut inner = pair.into_inner();
908
909 let name_pair = inner.next().unwrap();
911 let name = Ident::new(
912 name_pair.as_str(),
913 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
914 );
915
916 let table_pair = inner.next().unwrap();
918 let table = Ident::new(
919 table_pair.as_str(),
920 Span::new(table_pair.as_span().start(), table_pair.as_span().end()),
921 );
922
923 let mut policy = Policy::new(name, table, Span::new(span.start(), span.end()));
924 policy.commands = vec![];
926
927 for item in inner {
928 match item.as_rule() {
929 Rule::policy_item => {
930 let inner_item = item.into_inner().next().unwrap();
931 parse_policy_item(&mut policy, inner_item)?;
932 }
933 Rule::policy_for
934 | Rule::policy_to
935 | Rule::policy_as
936 | Rule::policy_using
937 | Rule::policy_check => {
938 parse_policy_item(&mut policy, item)?;
939 }
940 _ => {}
941 }
942 }
943
944 if policy.commands.is_empty() {
946 policy.commands.push(PolicyCommand::All);
947 }
948
949 Ok(policy)
950}
951
952fn parse_policy_item(
954 policy: &mut Policy,
955 pair: pest::iterators::Pair<'_, Rule>,
956) -> SchemaResult<()> {
957 match pair.as_rule() {
958 Rule::policy_for => {
959 let inner = pair.into_inner().next().unwrap();
960 match inner.as_rule() {
961 Rule::policy_command => {
962 if let Some(cmd) = PolicyCommand::from_str(inner.as_str()) {
963 policy.add_command(cmd);
964 }
965 }
966 Rule::policy_command_list => {
967 for cmd_pair in inner.into_inner() {
968 if cmd_pair.as_rule() == Rule::policy_command {
969 if let Some(cmd) = PolicyCommand::from_str(cmd_pair.as_str()) {
970 policy.add_command(cmd);
971 }
972 }
973 }
974 }
975 _ => {}
976 }
977 }
978 Rule::policy_to => {
979 let inner = pair.into_inner().next().unwrap();
980 match inner.as_rule() {
981 Rule::identifier => {
982 policy.add_role(inner.as_str());
983 }
984 Rule::policy_role_list => {
985 for role_pair in inner.into_inner() {
986 if role_pair.as_rule() == Rule::identifier {
987 policy.add_role(role_pair.as_str());
988 }
989 }
990 }
991 _ => {}
992 }
993 }
994 Rule::policy_as => {
995 let inner = pair.into_inner().next().unwrap();
996 if inner.as_rule() == Rule::policy_type {
997 if let Some(policy_type) = PolicyType::from_str(inner.as_str()) {
998 policy.policy_type = policy_type;
999 }
1000 }
1001 }
1002 Rule::policy_using => {
1003 let inner = pair.into_inner().next().unwrap();
1004 let expr = extract_policy_expression(&inner);
1005 policy.using_expr = Some(expr);
1006 }
1007 Rule::policy_check => {
1008 let inner = pair.into_inner().next().unwrap();
1009 let expr = extract_policy_expression(&inner);
1010 policy.check_expr = Some(expr);
1011 }
1012 Rule::policy_mssql_schema => {
1013 let inner = pair.into_inner().next().unwrap();
1014 if inner.as_rule() == Rule::string_literal {
1015 let s = inner.as_str();
1016 let schema = &s[1..s.len() - 1]; policy.mssql_schema = Some(SmolStr::new(schema));
1018 }
1019 }
1020 Rule::policy_mssql_block => {
1021 let inner = pair.into_inner().next().unwrap();
1022 match inner.as_rule() {
1023 Rule::mssql_block_op => {
1024 if let Some(op) = MssqlBlockOperation::from_str(inner.as_str()) {
1025 policy.add_mssql_block_operation(op);
1026 }
1027 }
1028 Rule::mssql_block_op_list => {
1029 for op_pair in inner.into_inner() {
1030 if op_pair.as_rule() == Rule::mssql_block_op {
1031 if let Some(op) = MssqlBlockOperation::from_str(op_pair.as_str()) {
1032 policy.add_mssql_block_operation(op);
1033 }
1034 }
1035 }
1036 }
1037 _ => {}
1038 }
1039 }
1040 _ => {}
1041 }
1042 Ok(())
1043}
1044
1045fn extract_policy_expression(pair: &pest::iterators::Pair<'_, Rule>) -> String {
1047 let s = pair.as_str();
1048 match pair.as_rule() {
1049 Rule::multiline_string => {
1050 s.trim_start_matches("\"\"\"")
1052 .trim_end_matches("\"\"\"")
1053 .trim()
1054 .to_string()
1055 }
1056 Rule::string_literal => {
1057 s[1..s.len() - 1].to_string()
1059 }
1060 _ => s.to_string(),
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 #[test]
1071 fn test_parse_simple_model() {
1072 let schema = parse_schema(
1073 r#"
1074 model User {
1075 id Int @id @auto
1076 email String @unique
1077 name String?
1078 }
1079 "#,
1080 )
1081 .unwrap();
1082
1083 assert_eq!(schema.models.len(), 1);
1084 let user = schema.get_model("User").unwrap();
1085 assert_eq!(user.fields.len(), 3);
1086 assert!(user.get_field("id").unwrap().is_id());
1087 assert!(user.get_field("email").unwrap().is_unique());
1088 assert!(user.get_field("name").unwrap().is_optional());
1089 }
1090
1091 #[test]
1092 fn test_parse_model_name() {
1093 let schema = parse_schema(
1094 r#"
1095 model BlogPost {
1096 id Int @id
1097 }
1098 "#,
1099 )
1100 .unwrap();
1101
1102 assert!(schema.get_model("BlogPost").is_some());
1103 }
1104
1105 #[test]
1106 fn test_parse_multiple_models() {
1107 let schema = parse_schema(
1108 r#"
1109 model User {
1110 id Int @id
1111 }
1112
1113 model Post {
1114 id Int @id
1115 }
1116
1117 model Comment {
1118 id Int @id
1119 }
1120 "#,
1121 )
1122 .unwrap();
1123
1124 assert_eq!(schema.models.len(), 3);
1125 assert!(schema.get_model("User").is_some());
1126 assert!(schema.get_model("Post").is_some());
1127 assert!(schema.get_model("Comment").is_some());
1128 }
1129
1130 #[test]
1133 fn test_parse_all_scalar_types() {
1134 let schema = parse_schema(
1135 r#"
1136 model AllTypes {
1137 id Int @id
1138 big BigInt
1139 float_f Float
1140 decimal Decimal
1141 str String
1142 bool Boolean
1143 datetime DateTime
1144 date Date
1145 time Time
1146 json Json
1147 bytes Bytes
1148 uuid Uuid
1149 cuid Cuid
1150 cuid2 Cuid2
1151 nanoid NanoId
1152 ulid Ulid
1153 }
1154 "#,
1155 )
1156 .unwrap();
1157
1158 let model = schema.get_model("AllTypes").unwrap();
1159 assert_eq!(model.fields.len(), 16);
1160
1161 assert!(matches!(
1162 model.get_field("id").unwrap().field_type,
1163 FieldType::Scalar(ScalarType::Int)
1164 ));
1165 assert!(matches!(
1166 model.get_field("big").unwrap().field_type,
1167 FieldType::Scalar(ScalarType::BigInt)
1168 ));
1169 assert!(matches!(
1170 model.get_field("str").unwrap().field_type,
1171 FieldType::Scalar(ScalarType::String)
1172 ));
1173 assert!(matches!(
1174 model.get_field("bool").unwrap().field_type,
1175 FieldType::Scalar(ScalarType::Boolean)
1176 ));
1177 assert!(matches!(
1178 model.get_field("datetime").unwrap().field_type,
1179 FieldType::Scalar(ScalarType::DateTime)
1180 ));
1181 assert!(matches!(
1182 model.get_field("uuid").unwrap().field_type,
1183 FieldType::Scalar(ScalarType::Uuid)
1184 ));
1185 assert!(matches!(
1186 model.get_field("cuid").unwrap().field_type,
1187 FieldType::Scalar(ScalarType::Cuid)
1188 ));
1189 assert!(matches!(
1190 model.get_field("cuid2").unwrap().field_type,
1191 FieldType::Scalar(ScalarType::Cuid2)
1192 ));
1193 assert!(matches!(
1194 model.get_field("nanoid").unwrap().field_type,
1195 FieldType::Scalar(ScalarType::NanoId)
1196 ));
1197 assert!(matches!(
1198 model.get_field("ulid").unwrap().field_type,
1199 FieldType::Scalar(ScalarType::Ulid)
1200 ));
1201 }
1202
1203 #[test]
1204 fn test_parse_optional_field() {
1205 let schema = parse_schema(
1206 r#"
1207 model User {
1208 id Int @id
1209 bio String?
1210 age Int?
1211 }
1212 "#,
1213 )
1214 .unwrap();
1215
1216 let user = schema.get_model("User").unwrap();
1217 assert!(!user.get_field("id").unwrap().is_optional());
1218 assert!(user.get_field("bio").unwrap().is_optional());
1219 assert!(user.get_field("age").unwrap().is_optional());
1220 }
1221
1222 #[test]
1223 fn test_parse_list_field() {
1224 let schema = parse_schema(
1225 r#"
1226 model User {
1227 id Int @id
1228 tags String[]
1229 posts Post[]
1230 }
1231 "#,
1232 )
1233 .unwrap();
1234
1235 let user = schema.get_model("User").unwrap();
1236 assert!(user.get_field("tags").unwrap().is_list());
1237 assert!(user.get_field("posts").unwrap().is_list());
1238 }
1239
1240 #[test]
1241 fn test_parse_optional_list_field() {
1242 let schema = parse_schema(
1243 r#"
1244 model User {
1245 id Int @id
1246 metadata String[]?
1247 }
1248 "#,
1249 )
1250 .unwrap();
1251
1252 let user = schema.get_model("User").unwrap();
1253 let metadata = user.get_field("metadata").unwrap();
1254 assert!(metadata.is_list());
1255 assert!(metadata.is_optional());
1256 }
1257
1258 #[test]
1261 fn test_parse_id_attribute() {
1262 let schema = parse_schema(
1263 r#"
1264 model User {
1265 id Int @id
1266 }
1267 "#,
1268 )
1269 .unwrap();
1270
1271 let user = schema.get_model("User").unwrap();
1272 assert!(user.get_field("id").unwrap().is_id());
1273 }
1274
1275 #[test]
1276 fn test_parse_unique_attribute() {
1277 let schema = parse_schema(
1278 r#"
1279 model User {
1280 id Int @id
1281 email String @unique
1282 }
1283 "#,
1284 )
1285 .unwrap();
1286
1287 let user = schema.get_model("User").unwrap();
1288 assert!(user.get_field("email").unwrap().is_unique());
1289 }
1290
1291 #[test]
1292 fn test_parse_default_int() {
1293 let schema = parse_schema(
1294 r#"
1295 model Counter {
1296 id Int @id
1297 count Int @default(0)
1298 }
1299 "#,
1300 )
1301 .unwrap();
1302
1303 let counter = schema.get_model("Counter").unwrap();
1304 let count_field = counter.get_field("count").unwrap();
1305 let attrs = count_field.extract_attributes();
1306 assert!(attrs.default.is_some());
1307 assert_eq!(attrs.default.unwrap().as_int(), Some(0));
1308 }
1309
1310 #[test]
1311 fn test_parse_default_string() {
1312 let schema = parse_schema(
1313 r#"
1314 model User {
1315 id Int @id
1316 status String @default("active")
1317 }
1318 "#,
1319 )
1320 .unwrap();
1321
1322 let user = schema.get_model("User").unwrap();
1323 let status = user.get_field("status").unwrap();
1324 let attrs = status.extract_attributes();
1325 assert!(attrs.default.is_some());
1326 assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
1327 }
1328
1329 #[test]
1330 fn test_parse_default_boolean() {
1331 let schema = parse_schema(
1332 r#"
1333 model Post {
1334 id Int @id
1335 published Boolean @default(false)
1336 }
1337 "#,
1338 )
1339 .unwrap();
1340
1341 let post = schema.get_model("Post").unwrap();
1342 let published = post.get_field("published").unwrap();
1343 let attrs = published.extract_attributes();
1344 assert!(attrs.default.is_some());
1345 assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
1346 }
1347
1348 #[test]
1349 fn test_parse_default_function() {
1350 let schema = parse_schema(
1351 r#"
1352 model User {
1353 id Int @id
1354 createdAt DateTime @default(now())
1355 }
1356 "#,
1357 )
1358 .unwrap();
1359
1360 let user = schema.get_model("User").unwrap();
1361 let created_at = user.get_field("createdAt").unwrap();
1362 let attrs = created_at.extract_attributes();
1363 assert!(attrs.default.is_some());
1364 if let Some(AttributeValue::Function(name, _)) = attrs.default {
1365 assert_eq!(name.as_str(), "now");
1366 } else {
1367 panic!("Expected function default");
1368 }
1369 }
1370
1371 #[test]
1372 fn test_parse_updated_at_attribute() {
1373 let schema = parse_schema(
1374 r#"
1375 model User {
1376 id Int @id
1377 updatedAt DateTime @updated_at
1378 }
1379 "#,
1380 )
1381 .unwrap();
1382
1383 let user = schema.get_model("User").unwrap();
1384 let updated_at = user.get_field("updatedAt").unwrap();
1385 let attrs = updated_at.extract_attributes();
1386 assert!(attrs.is_updated_at);
1387 }
1388
1389 #[test]
1390 fn test_parse_map_attribute() {
1391 let schema = parse_schema(
1392 r#"
1393 model User {
1394 id Int @id
1395 email String @map("email_address")
1396 }
1397 "#,
1398 )
1399 .unwrap();
1400
1401 let user = schema.get_model("User").unwrap();
1402 let email = user.get_field("email").unwrap();
1403 let attrs = email.extract_attributes();
1404 assert_eq!(attrs.map, Some("email_address".to_string()));
1405 }
1406
1407 #[test]
1408 fn test_parse_multiple_attributes() {
1409 let schema = parse_schema(
1410 r#"
1411 model User {
1412 id Int @id @auto
1413 email String @unique @index
1414 }
1415 "#,
1416 )
1417 .unwrap();
1418
1419 let user = schema.get_model("User").unwrap();
1420 let id = user.get_field("id").unwrap();
1421 let email = user.get_field("email").unwrap();
1422
1423 let id_attrs = id.extract_attributes();
1424 assert!(id_attrs.is_id);
1425 assert!(id_attrs.is_auto);
1426
1427 let email_attrs = email.extract_attributes();
1428 assert!(email_attrs.is_unique);
1429 assert!(email_attrs.is_indexed);
1430 }
1431
1432 #[test]
1435 fn test_parse_model_map_attribute() {
1436 let schema = parse_schema(
1437 r#"
1438 model User {
1439 id Int @id
1440
1441 @@map("app_users")
1442 }
1443 "#,
1444 )
1445 .unwrap();
1446
1447 let user = schema.get_model("User").unwrap();
1448 assert_eq!(user.table_name(), "app_users");
1449 }
1450
1451 #[test]
1452 fn test_parse_model_index_attribute() {
1453 let schema = parse_schema(
1454 r#"
1455 model User {
1456 id Int @id
1457 email String
1458 name String
1459
1460 @@index([email, name])
1461 }
1462 "#,
1463 )
1464 .unwrap();
1465
1466 let user = schema.get_model("User").unwrap();
1467 assert!(user.has_attribute("index"));
1468 }
1469
1470 #[test]
1471 fn test_parse_composite_primary_key() {
1472 let schema = parse_schema(
1473 r#"
1474 model PostTag {
1475 postId Int
1476 tagId Int
1477
1478 @@id([postId, tagId])
1479 }
1480 "#,
1481 )
1482 .unwrap();
1483
1484 let post_tag = schema.get_model("PostTag").unwrap();
1485 assert!(post_tag.has_attribute("id"));
1486 }
1487
1488 #[test]
1491 fn test_parse_enum() {
1492 let schema = parse_schema(
1493 r#"
1494 enum Role {
1495 User
1496 Admin
1497 Moderator
1498 }
1499 "#,
1500 )
1501 .unwrap();
1502
1503 assert_eq!(schema.enums.len(), 1);
1504 let role = schema.get_enum("Role").unwrap();
1505 assert_eq!(role.variants.len(), 3);
1506 }
1507
1508 #[test]
1509 fn test_parse_enum_variant_names() {
1510 let schema = parse_schema(
1511 r#"
1512 enum Status {
1513 Pending
1514 Active
1515 Completed
1516 Cancelled
1517 }
1518 "#,
1519 )
1520 .unwrap();
1521
1522 let status = schema.get_enum("Status").unwrap();
1523 assert!(status.get_variant("Pending").is_some());
1524 assert!(status.get_variant("Active").is_some());
1525 assert!(status.get_variant("Completed").is_some());
1526 assert!(status.get_variant("Cancelled").is_some());
1527 }
1528
1529 #[test]
1530 fn test_parse_enum_with_map() {
1531 let schema = parse_schema(
1532 r#"
1533 enum Role {
1534 User @map("USER")
1535 Admin @map("ADMINISTRATOR")
1536 }
1537 "#,
1538 )
1539 .unwrap();
1540
1541 let role = schema.get_enum("Role").unwrap();
1542 let user_variant = role.get_variant("User").unwrap();
1543 assert_eq!(user_variant.db_value(), "USER");
1544
1545 let admin_variant = role.get_variant("Admin").unwrap();
1546 assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1547 }
1548
1549 #[test]
1552 fn test_parse_one_to_many_relation() {
1553 let schema = parse_schema(
1554 r#"
1555 model User {
1556 id Int @id
1557 posts Post[]
1558 }
1559
1560 model Post {
1561 id Int @id
1562 authorId Int
1563 author User @relation(fields: [authorId], references: [id])
1564 }
1565 "#,
1566 )
1567 .unwrap();
1568
1569 let user = schema.get_model("User").unwrap();
1570 let post = schema.get_model("Post").unwrap();
1571
1572 assert!(user.get_field("posts").unwrap().is_list());
1573 assert!(post.get_field("author").unwrap().is_relation());
1574 }
1575
1576 #[test]
1577 fn test_parse_relation_with_actions() {
1578 let schema = parse_schema(
1579 r#"
1580 model Post {
1581 id Int @id
1582 authorId Int
1583 author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1584 }
1585
1586 model User {
1587 id Int @id
1588 posts Post[]
1589 }
1590 "#,
1591 )
1592 .unwrap();
1593
1594 let post = schema.get_model("Post").unwrap();
1595 let author = post.get_field("author").unwrap();
1596 let attrs = author.extract_attributes();
1597
1598 assert!(attrs.relation.is_some());
1599 let rel = attrs.relation.unwrap();
1600 assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1601 assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1602 }
1603
1604 #[test]
1607 fn test_parse_model_documentation() {
1608 let schema = parse_schema(
1609 r#"/// Represents a user in the system
1610model User {
1611 id Int @id
1612}"#,
1613 )
1614 .unwrap();
1615
1616 let user = schema.get_model("User").unwrap();
1617 if let Some(doc) = &user.documentation {
1620 assert!(doc.text.contains("user"));
1621 }
1622 }
1623
1624 #[test]
1627 fn test_parse_complete_schema() {
1628 let schema = parse_schema(
1629 r#"
1630 /// User model
1631 model User {
1632 id Int @id @auto
1633 email String @unique
1634 name String?
1635 role Role @default(User)
1636 posts Post[]
1637 profile Profile?
1638 createdAt DateTime @default(now())
1639 updatedAt DateTime @updated_at
1640
1641 @@map("users")
1642 @@index([email])
1643 }
1644
1645 model Post {
1646 id Int @id @auto
1647 title String
1648 content String?
1649 published Boolean @default(false)
1650 authorId Int
1651 author User @relation(fields: [authorId], references: [id])
1652 tags Tag[]
1653 createdAt DateTime @default(now())
1654
1655 @@index([authorId])
1656 }
1657
1658 model Profile {
1659 id Int @id @auto
1660 bio String?
1661 userId Int @unique
1662 user User @relation(fields: [userId], references: [id])
1663 }
1664
1665 model Tag {
1666 id Int @id @auto
1667 name String @unique
1668 posts Post[]
1669 }
1670
1671 enum Role {
1672 User
1673 Admin
1674 Moderator
1675 }
1676 "#,
1677 )
1678 .unwrap();
1679
1680 assert_eq!(schema.models.len(), 4);
1682 assert!(schema.get_model("User").is_some());
1683 assert!(schema.get_model("Post").is_some());
1684 assert!(schema.get_model("Profile").is_some());
1685 assert!(schema.get_model("Tag").is_some());
1686
1687 assert_eq!(schema.enums.len(), 1);
1689 assert!(schema.get_enum("Role").is_some());
1690
1691 let user = schema.get_model("User").unwrap();
1693 assert_eq!(user.table_name(), "users");
1694 assert_eq!(user.fields.len(), 8);
1695 assert!(user.has_attribute("index"));
1696
1697 let post = schema.get_model("Post").unwrap();
1699 assert!(post.get_field("author").unwrap().is_relation());
1700 }
1701
1702 #[test]
1705 fn test_parse_invalid_syntax() {
1706 let result = parse_schema("model { broken }");
1707 assert!(result.is_err());
1708 }
1709
1710 #[test]
1711 fn test_parse_empty_schema() {
1712 let schema = parse_schema("").unwrap();
1713 assert!(schema.models.is_empty());
1714 assert!(schema.enums.is_empty());
1715 }
1716
1717 #[test]
1718 fn test_parse_whitespace_only() {
1719 let schema = parse_schema(" \n\t \n ").unwrap();
1720 assert!(schema.models.is_empty());
1721 }
1722
1723 #[test]
1724 fn test_parse_comments_only() {
1725 let schema = parse_schema(
1726 r#"
1727 // This is a comment
1728 // Another comment
1729 "#,
1730 )
1731 .unwrap();
1732 assert!(schema.models.is_empty());
1733 }
1734
1735 #[test]
1738 fn test_parse_model_with_no_fields() {
1739 let result = parse_schema(
1741 r#"
1742 model Empty {
1743 }
1744 "#,
1745 );
1746 let _ = result;
1748 }
1749
1750 #[test]
1751 fn test_parse_long_identifier() {
1752 let schema = parse_schema(
1753 r#"
1754 model VeryLongModelNameThatIsStillValid {
1755 someVeryLongFieldNameThatShouldWork Int @id
1756 }
1757 "#,
1758 )
1759 .unwrap();
1760
1761 assert!(
1762 schema
1763 .get_model("VeryLongModelNameThatIsStillValid")
1764 .is_some()
1765 );
1766 }
1767
1768 #[test]
1769 fn test_parse_underscore_identifiers() {
1770 let schema = parse_schema(
1771 r#"
1772 model user_account {
1773 user_id Int @id
1774 created_at DateTime
1775 }
1776 "#,
1777 )
1778 .unwrap();
1779
1780 let model = schema.get_model("user_account").unwrap();
1781 assert!(model.get_field("user_id").is_some());
1782 assert!(model.get_field("created_at").is_some());
1783 }
1784
1785 #[test]
1786 fn test_parse_negative_default() {
1787 let schema = parse_schema(
1788 r#"
1789 model Config {
1790 id Int @id
1791 minValue Int @default(-100)
1792 }
1793 "#,
1794 )
1795 .unwrap();
1796
1797 let config = schema.get_model("Config").unwrap();
1798 let min_value = config.get_field("minValue").unwrap();
1799 let attrs = min_value.extract_attributes();
1800 assert!(attrs.default.is_some());
1801 }
1802
1803 #[test]
1804 fn test_parse_float_default() {
1805 let schema = parse_schema(
1806 r#"
1807 model Product {
1808 id Int @id
1809 price Float @default(9.99)
1810 }
1811 "#,
1812 )
1813 .unwrap();
1814
1815 let product = schema.get_model("Product").unwrap();
1816 let price = product.get_field("price").unwrap();
1817 let attrs = price.extract_attributes();
1818 assert!(attrs.default.is_some());
1819 }
1820
1821 #[test]
1824 fn test_parse_simple_server_group() {
1825 let schema = parse_schema(
1826 r#"
1827 serverGroup MainCluster {
1828 server primary {
1829 url = "postgres://localhost/db"
1830 role = "primary"
1831 }
1832 }
1833 "#,
1834 )
1835 .unwrap();
1836
1837 assert_eq!(schema.server_groups.len(), 1);
1838 let cluster = schema.get_server_group("MainCluster").unwrap();
1839 assert_eq!(cluster.servers.len(), 1);
1840 assert!(cluster.servers.contains_key("primary"));
1841 }
1842
1843 #[test]
1844 fn test_parse_server_group_with_multiple_servers() {
1845 let schema = parse_schema(
1846 r#"
1847 serverGroup ReadReplicas {
1848 server primary {
1849 url = "postgres://primary.db.com/app"
1850 role = "primary"
1851 weight = 1
1852 }
1853
1854 server replica1 {
1855 url = "postgres://replica1.db.com/app"
1856 role = "replica"
1857 weight = 2
1858 }
1859
1860 server replica2 {
1861 url = "postgres://replica2.db.com/app"
1862 role = "replica"
1863 weight = 2
1864 }
1865 }
1866 "#,
1867 )
1868 .unwrap();
1869
1870 let cluster = schema.get_server_group("ReadReplicas").unwrap();
1871 assert_eq!(cluster.servers.len(), 3);
1872
1873 let primary = cluster.servers.get("primary").unwrap();
1874 assert_eq!(primary.role(), Some(ServerRole::Primary));
1875 assert_eq!(primary.weight(), Some(1));
1876
1877 let replica1 = cluster.servers.get("replica1").unwrap();
1878 assert_eq!(replica1.role(), Some(ServerRole::Replica));
1879 assert_eq!(replica1.weight(), Some(2));
1880 }
1881
1882 #[test]
1883 fn test_parse_server_group_with_attributes() {
1884 let schema = parse_schema(
1885 r#"
1886 serverGroup ProductionCluster {
1887 @@strategy(ReadReplica)
1888 @@loadBalance(RoundRobin)
1889
1890 server main {
1891 url = "postgres://main/db"
1892 role = "primary"
1893 }
1894 }
1895 "#,
1896 )
1897 .unwrap();
1898
1899 let cluster = schema.get_server_group("ProductionCluster").unwrap();
1900 assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1901 assert!(
1902 cluster
1903 .attributes
1904 .iter()
1905 .any(|a| a.name.name == "loadBalance")
1906 );
1907 }
1908
1909 #[test]
1910 fn test_parse_server_group_with_env_vars() {
1911 let schema = parse_schema(
1912 r#"
1913 serverGroup EnvCluster {
1914 server db1 {
1915 url = env("PRIMARY_DB_URL")
1916 role = "primary"
1917 }
1918 }
1919 "#,
1920 )
1921 .unwrap();
1922
1923 let cluster = schema.get_server_group("EnvCluster").unwrap();
1924 let server = cluster.servers.get("db1").unwrap();
1925
1926 if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1928 assert_eq!(var, "PRIMARY_DB_URL");
1929 } else {
1930 panic!("Expected env var for url property");
1931 }
1932 }
1933
1934 #[test]
1935 fn test_parse_server_group_with_boolean_property() {
1936 let schema = parse_schema(
1937 r#"
1938 serverGroup TestCluster {
1939 server replica {
1940 url = "postgres://replica/db"
1941 role = "replica"
1942 readOnly = true
1943 }
1944 }
1945 "#,
1946 )
1947 .unwrap();
1948
1949 let cluster = schema.get_server_group("TestCluster").unwrap();
1950 let server = cluster.servers.get("replica").unwrap();
1951 assert!(server.is_read_only());
1952 }
1953
1954 #[test]
1955 fn test_parse_server_group_with_numeric_properties() {
1956 let schema = parse_schema(
1957 r#"
1958 serverGroup NumericCluster {
1959 server db {
1960 url = "postgres://localhost/db"
1961 weight = 5
1962 priority = 1
1963 maxConnections = 100
1964 }
1965 }
1966 "#,
1967 )
1968 .unwrap();
1969
1970 let cluster = schema.get_server_group("NumericCluster").unwrap();
1971 let server = cluster.servers.get("db").unwrap();
1972
1973 assert_eq!(server.weight(), Some(5));
1974 assert_eq!(server.priority(), Some(1));
1975 assert_eq!(server.max_connections(), Some(100));
1976 }
1977
1978 #[test]
1979 fn test_parse_server_group_with_region() {
1980 let schema = parse_schema(
1981 r#"
1982 serverGroup GeoCluster {
1983 server usEast {
1984 url = "postgres://us-east.db.com/app"
1985 role = "replica"
1986 region = "us-east-1"
1987 }
1988
1989 server usWest {
1990 url = "postgres://us-west.db.com/app"
1991 role = "replica"
1992 region = "us-west-2"
1993 }
1994 }
1995 "#,
1996 )
1997 .unwrap();
1998
1999 let cluster = schema.get_server_group("GeoCluster").unwrap();
2000
2001 let us_east = cluster.servers.get("usEast").unwrap();
2002 assert_eq!(us_east.region(), Some("us-east-1"));
2003
2004 let us_west = cluster.servers.get("usWest").unwrap();
2005 assert_eq!(us_west.region(), Some("us-west-2"));
2006
2007 let us_east_servers = cluster.servers_in_region("us-east-1");
2009 assert_eq!(us_east_servers.len(), 1);
2010 }
2011
2012 #[test]
2013 fn test_parse_multiple_server_groups() {
2014 let schema = parse_schema(
2015 r#"
2016 serverGroup Cluster1 {
2017 server db1 {
2018 url = "postgres://db1/app"
2019 }
2020 }
2021
2022 serverGroup Cluster2 {
2023 server db2 {
2024 url = "postgres://db2/app"
2025 }
2026 }
2027
2028 serverGroup Cluster3 {
2029 server db3 {
2030 url = "postgres://db3/app"
2031 }
2032 }
2033 "#,
2034 )
2035 .unwrap();
2036
2037 assert_eq!(schema.server_groups.len(), 3);
2038 assert!(schema.get_server_group("Cluster1").is_some());
2039 assert!(schema.get_server_group("Cluster2").is_some());
2040 assert!(schema.get_server_group("Cluster3").is_some());
2041 }
2042
2043 #[test]
2044 fn test_parse_schema_with_models_and_server_groups() {
2045 let schema = parse_schema(
2046 r#"
2047 model User {
2048 id Int @id @auto
2049 email String @unique
2050 }
2051
2052 serverGroup Database {
2053 @@strategy(ReadReplica)
2054
2055 server primary {
2056 url = env("DATABASE_URL")
2057 role = "primary"
2058 }
2059 }
2060
2061 model Post {
2062 id Int @id @auto
2063 title String
2064 authorId Int
2065 }
2066 "#,
2067 )
2068 .unwrap();
2069
2070 assert_eq!(schema.models.len(), 2);
2071 assert!(schema.get_model("User").is_some());
2072 assert!(schema.get_model("Post").is_some());
2073
2074 assert_eq!(schema.server_groups.len(), 1);
2075 assert!(schema.get_server_group("Database").is_some());
2076 }
2077
2078 #[test]
2079 fn test_parse_server_group_with_health_check() {
2080 let schema = parse_schema(
2081 r#"
2082 serverGroup HealthyCluster {
2083 server monitored {
2084 url = "postgres://localhost/db"
2085 healthCheck = "/health"
2086 }
2087 }
2088 "#,
2089 )
2090 .unwrap();
2091
2092 let cluster = schema.get_server_group("HealthyCluster").unwrap();
2093 let server = cluster.servers.get("monitored").unwrap();
2094 assert_eq!(server.health_check(), Some("/health"));
2095 }
2096
2097 #[test]
2098 fn test_server_group_failover_order() {
2099 let schema = parse_schema(
2100 r#"
2101 serverGroup FailoverCluster {
2102 server db3 {
2103 url = "postgres://db3/app"
2104 priority = 3
2105 }
2106
2107 server db1 {
2108 url = "postgres://db1/app"
2109 priority = 1
2110 }
2111
2112 server db2 {
2113 url = "postgres://db2/app"
2114 priority = 2
2115 }
2116 }
2117 "#,
2118 )
2119 .unwrap();
2120
2121 let cluster = schema.get_server_group("FailoverCluster").unwrap();
2122 let ordered = cluster.failover_order();
2123
2124 assert_eq!(ordered[0].name.name.as_str(), "db1");
2125 assert_eq!(ordered[1].name.name.as_str(), "db2");
2126 assert_eq!(ordered[2].name.name.as_str(), "db3");
2127 }
2128
2129 #[test]
2130 fn test_server_group_names() {
2131 let schema = parse_schema(
2132 r#"
2133 serverGroup Alpha {
2134 server s1 { url = "pg://a" }
2135 }
2136 serverGroup Beta {
2137 server s2 { url = "pg://b" }
2138 }
2139 "#,
2140 )
2141 .unwrap();
2142
2143 let names: Vec<_> = schema.server_group_names().collect();
2144 assert_eq!(names.len(), 2);
2145 assert!(names.contains(&"Alpha"));
2146 assert!(names.contains(&"Beta"));
2147 }
2148
2149 #[test]
2152 fn test_parse_simple_policy() {
2153 let schema = parse_schema(
2154 r#"
2155 policy UserReadOwn on User {
2156 for SELECT
2157 using "id = current_user_id()"
2158 }
2159 "#,
2160 )
2161 .unwrap();
2162
2163 assert_eq!(schema.policies.len(), 1);
2164 let policy = schema.get_policy("UserReadOwn").unwrap();
2165 assert_eq!(policy.name(), "UserReadOwn");
2166 assert_eq!(policy.table(), "User");
2167 assert!(policy.applies_to(PolicyCommand::Select));
2168 assert!(!policy.applies_to(PolicyCommand::Insert));
2169 assert_eq!(policy.using_expr.as_deref(), Some("id = current_user_id()"));
2170 }
2171
2172 #[test]
2173 fn test_parse_policy_with_multiple_commands() {
2174 let schema = parse_schema(
2175 r#"
2176 policy UserModify on User {
2177 for [SELECT, UPDATE, DELETE]
2178 using "id = auth.uid()"
2179 }
2180 "#,
2181 )
2182 .unwrap();
2183
2184 let policy = schema.get_policy("UserModify").unwrap();
2185 assert!(policy.applies_to(PolicyCommand::Select));
2186 assert!(policy.applies_to(PolicyCommand::Update));
2187 assert!(policy.applies_to(PolicyCommand::Delete));
2188 assert!(!policy.applies_to(PolicyCommand::Insert));
2189 }
2190
2191 #[test]
2192 fn test_parse_policy_with_all_command() {
2193 let schema = parse_schema(
2194 r#"
2195 policy UserAll on User {
2196 for ALL
2197 using "true"
2198 }
2199 "#,
2200 )
2201 .unwrap();
2202
2203 let policy = schema.get_policy("UserAll").unwrap();
2204 assert!(policy.applies_to(PolicyCommand::Select));
2205 assert!(policy.applies_to(PolicyCommand::Insert));
2206 assert!(policy.applies_to(PolicyCommand::Update));
2207 assert!(policy.applies_to(PolicyCommand::Delete));
2208 }
2209
2210 #[test]
2211 fn test_parse_policy_with_roles() {
2212 let schema = parse_schema(
2213 r#"
2214 policy AuthenticatedRead on Document {
2215 for SELECT
2216 to authenticated
2217 using "true"
2218 }
2219 "#,
2220 )
2221 .unwrap();
2222
2223 let policy = schema.get_policy("AuthenticatedRead").unwrap();
2224 let roles = policy.effective_roles();
2225 assert!(roles.contains(&"authenticated"));
2226 }
2227
2228 #[test]
2229 fn test_parse_policy_with_multiple_roles() {
2230 let schema = parse_schema(
2231 r#"
2232 policy AdminModerator on Post {
2233 for [UPDATE, DELETE]
2234 to [admin, moderator]
2235 using "true"
2236 }
2237 "#,
2238 )
2239 .unwrap();
2240
2241 let policy = schema.get_policy("AdminModerator").unwrap();
2242 let roles = policy.effective_roles();
2243 assert!(roles.contains(&"admin"));
2244 assert!(roles.contains(&"moderator"));
2245 }
2246
2247 #[test]
2248 fn test_parse_policy_restrictive() {
2249 let schema = parse_schema(
2250 r#"
2251 policy OrgRestriction on Document {
2252 as RESTRICTIVE
2253 for SELECT
2254 using "org_id = current_org_id()"
2255 }
2256 "#,
2257 )
2258 .unwrap();
2259
2260 let policy = schema.get_policy("OrgRestriction").unwrap();
2261 assert!(policy.is_restrictive());
2262 assert!(!policy.is_permissive());
2263 }
2264
2265 #[test]
2266 fn test_parse_policy_permissive_explicit() {
2267 let schema = parse_schema(
2268 r#"
2269 policy Permissive on User {
2270 as PERMISSIVE
2271 for SELECT
2272 using "true"
2273 }
2274 "#,
2275 )
2276 .unwrap();
2277
2278 let policy = schema.get_policy("Permissive").unwrap();
2279 assert!(policy.is_permissive());
2280 }
2281
2282 #[test]
2283 fn test_parse_policy_with_check() {
2284 let schema = parse_schema(
2285 r#"
2286 policy InsertOwn on Post {
2287 for INSERT
2288 to authenticated
2289 check "author_id = current_user_id()"
2290 }
2291 "#,
2292 )
2293 .unwrap();
2294
2295 let policy = schema.get_policy("InsertOwn").unwrap();
2296 assert!(policy.applies_to(PolicyCommand::Insert));
2297 assert_eq!(
2298 policy.check_expr.as_deref(),
2299 Some("author_id = current_user_id()")
2300 );
2301 assert!(policy.using_expr.is_none());
2302 }
2303
2304 #[test]
2305 fn test_parse_policy_with_both_expressions() {
2306 let schema = parse_schema(
2307 r#"
2308 policy UpdateOwn on Post {
2309 for UPDATE
2310 using "author_id = current_user_id()"
2311 check "author_id = current_user_id()"
2312 }
2313 "#,
2314 )
2315 .unwrap();
2316
2317 let policy = schema.get_policy("UpdateOwn").unwrap();
2318 assert!(policy.using_expr.is_some());
2319 assert!(policy.check_expr.is_some());
2320 }
2321
2322 #[test]
2323 fn test_parse_policy_multiline_expression() {
2324 let schema = parse_schema(
2325 r#"
2326 policy ComplexCheck on Document {
2327 for SELECT
2328 using """
2329 (is_public = true)
2330 OR (owner_id = current_user_id())
2331 OR (id IN (SELECT document_id FROM shares WHERE user_id = current_user_id()))
2332 """
2333 }
2334 "#,
2335 )
2336 .unwrap();
2337
2338 let policy = schema.get_policy("ComplexCheck").unwrap();
2339 assert!(policy.using_expr.is_some());
2340 let expr = policy.using_expr.as_ref().unwrap();
2341 assert!(expr.contains("is_public = true"));
2342 assert!(expr.contains("owner_id = current_user_id()"));
2343 assert!(expr.contains("SELECT document_id FROM shares"));
2344 }
2345
2346 #[test]
2347 fn test_parse_multiple_policies() {
2348 let schema = parse_schema(
2349 r#"
2350 policy UserRead on User {
2351 for SELECT
2352 using "true"
2353 }
2354
2355 policy UserInsert on User {
2356 for INSERT
2357 check "id = current_user_id()"
2358 }
2359
2360 policy PostRead on Post {
2361 for SELECT
2362 using "published = true OR author_id = current_user_id()"
2363 }
2364 "#,
2365 )
2366 .unwrap();
2367
2368 assert_eq!(schema.policies.len(), 3);
2369 assert!(schema.get_policy("UserRead").is_some());
2370 assert!(schema.get_policy("UserInsert").is_some());
2371 assert!(schema.get_policy("PostRead").is_some());
2372 }
2373
2374 #[test]
2375 fn test_parse_policy_with_model() {
2376 let schema = parse_schema(
2377 r#"
2378 model User {
2379 id Int @id @auto
2380 email String @unique
2381 }
2382
2383 policy UserReadOwn on User {
2384 for SELECT
2385 to authenticated
2386 using "id = auth.uid()"
2387 }
2388 "#,
2389 )
2390 .unwrap();
2391
2392 assert_eq!(schema.models.len(), 1);
2393 assert_eq!(schema.policies.len(), 1);
2394
2395 let policies = schema.policies_for("User");
2396 assert_eq!(policies.len(), 1);
2397 assert_eq!(policies[0].name(), "UserReadOwn");
2398 }
2399
2400 #[test]
2401 fn test_parse_policies_for_multiple_models() {
2402 let schema = parse_schema(
2403 r#"
2404 policy UserPolicy1 on User {
2405 for SELECT
2406 using "true"
2407 }
2408
2409 policy UserPolicy2 on User {
2410 for INSERT
2411 check "true"
2412 }
2413
2414 policy PostPolicy on Post {
2415 for SELECT
2416 using "true"
2417 }
2418 "#,
2419 )
2420 .unwrap();
2421
2422 assert_eq!(schema.policies_for("User").len(), 2);
2423 assert_eq!(schema.policies_for("Post").len(), 1);
2424 assert!(schema.has_policies("User"));
2425 assert!(schema.has_policies("Post"));
2426 assert!(!schema.has_policies("Comment"));
2427 }
2428
2429 #[test]
2430 fn test_parse_policy_default_all_command() {
2431 let schema = parse_schema(
2432 r#"
2433 policy DefaultAll on User {
2434 using "id = current_user_id()"
2435 }
2436 "#,
2437 )
2438 .unwrap();
2439
2440 let policy = schema.get_policy("DefaultAll").unwrap();
2441 assert!(policy.applies_to(PolicyCommand::All));
2443 }
2444
2445 #[test]
2446 fn test_parse_policy_case_insensitive_keywords() {
2447 let schema = parse_schema(
2448 r#"
2449 policy CaseTest on User {
2450 for select
2451 as permissive
2452 using "true"
2453 }
2454 "#,
2455 )
2456 .unwrap();
2457
2458 let policy = schema.get_policy("CaseTest").unwrap();
2459 assert!(policy.applies_to(PolicyCommand::Select));
2460 assert!(policy.is_permissive());
2461 }
2462
2463 #[test]
2464 fn test_parse_policy_sql_generation() {
2465 let schema = parse_schema(
2466 r#"
2467 model User {
2468 id Int @id
2469
2470 @@map("users")
2471 }
2472
2473 policy ReadOwn on User {
2474 for SELECT
2475 to authenticated
2476 using "id = auth.uid()"
2477 }
2478 "#,
2479 )
2480 .unwrap();
2481
2482 let policy = schema.get_policy("ReadOwn").unwrap();
2483 let sql = policy.to_sql("users");
2484
2485 assert!(sql.contains("CREATE POLICY ReadOwn ON users"));
2486 assert!(sql.contains("FOR SELECT"));
2487 assert!(sql.contains("TO authenticated"));
2488 assert!(sql.contains("USING (id = auth.uid())"));
2489 }
2490
2491 #[test]
2492 fn test_parse_policy_restrictive_sql() {
2493 let schema = parse_schema(
2494 r#"
2495 policy OrgBoundary on Document {
2496 as RESTRICTIVE
2497 for ALL
2498 using "org_id = current_org_id()"
2499 }
2500 "#,
2501 )
2502 .unwrap();
2503
2504 let policy = schema.get_policy("OrgBoundary").unwrap();
2505 let sql = policy.to_sql("documents");
2506
2507 assert!(sql.contains("AS RESTRICTIVE"));
2508 }
2509
2510 #[test]
2511 fn test_parse_policy_with_documentation() {
2512 let schema = parse_schema(
2513 r#"
2514 /// Users can only read their own data
2515 policy UserIsolation on User {
2516 for SELECT
2517 using "id = current_user_id()"
2518 }
2519 "#,
2520 )
2521 .unwrap();
2522
2523 let policy = schema.get_policy("UserIsolation").unwrap();
2524 if let Some(doc) = &policy.documentation {
2525 assert!(doc.text.contains("their own data"));
2526 }
2527 }
2528
2529 #[test]
2530 fn test_parse_complex_rls_schema() {
2531 let schema = parse_schema(
2532 r#"
2533 model Organization {
2534 id Int @id @auto
2535 name String
2536 }
2537
2538 model User {
2539 id Int @id @auto
2540 orgId Int
2541 email String @unique
2542 }
2543
2544 model Document {
2545 id Int @id @auto
2546 title String
2547 ownerId Int
2548 orgId Int
2549 isPublic Boolean @default(false)
2550 }
2551
2552 /// Organization-level isolation
2553 policy OrgIsolation on Document {
2554 as RESTRICTIVE
2555 for ALL
2556 using "org_id = current_setting('app.current_org')::int"
2557 }
2558
2559 /// Users can read public documents
2560 policy PublicRead on Document {
2561 for SELECT
2562 using "is_public = true"
2563 }
2564
2565 /// Users can read their own documents
2566 policy OwnerRead on Document {
2567 for SELECT
2568 to authenticated
2569 using "owner_id = auth.uid()"
2570 }
2571
2572 /// Users can only modify their own documents
2573 policy OwnerModify on Document {
2574 for [UPDATE, DELETE]
2575 to authenticated
2576 using "owner_id = auth.uid()"
2577 check "owner_id = auth.uid()"
2578 }
2579
2580 /// Users can create documents in their org
2581 policy OrgInsert on Document {
2582 for INSERT
2583 to authenticated
2584 check "org_id = current_setting('app.current_org')::int"
2585 }
2586 "#,
2587 )
2588 .unwrap();
2589
2590 assert_eq!(schema.models.len(), 3);
2591 assert_eq!(schema.policies.len(), 5);
2592
2593 let org_iso = schema.get_policy("OrgIsolation").unwrap();
2595 assert!(org_iso.is_restrictive());
2596
2597 let doc_policies = schema.policies_for("Document");
2599 assert_eq!(doc_policies.len(), 5);
2600 }
2601
2602 #[test]
2605 fn test_parse_policy_with_mssql_schema() {
2606 let schema = parse_schema(
2607 r#"
2608 policy UserFilter on User {
2609 for SELECT
2610 using "UserId = @UserId"
2611 mssqlSchema "RLS"
2612 }
2613 "#,
2614 )
2615 .unwrap();
2616
2617 let policy = schema.get_policy("UserFilter").unwrap();
2618 assert_eq!(policy.mssql_schema(), "RLS");
2619 }
2620
2621 #[test]
2622 fn test_parse_policy_with_mssql_block_single() {
2623 let schema = parse_schema(
2624 r#"
2625 policy UserInsert on User {
2626 for INSERT
2627 check "UserId = @UserId"
2628 mssqlBlock AFTER_INSERT
2629 }
2630 "#,
2631 )
2632 .unwrap();
2633
2634 let policy = schema.get_policy("UserInsert").unwrap();
2635 assert_eq!(policy.mssql_block_operations.len(), 1);
2636 assert_eq!(
2637 policy.mssql_block_operations[0],
2638 MssqlBlockOperation::AfterInsert
2639 );
2640 }
2641
2642 #[test]
2643 fn test_parse_policy_with_mssql_block_list() {
2644 let schema = parse_schema(
2645 r#"
2646 policy UserModify on User {
2647 for [INSERT, UPDATE, DELETE]
2648 check "UserId = @UserId"
2649 mssqlBlock [AFTER_INSERT, AFTER_UPDATE, BEFORE_DELETE]
2650 }
2651 "#,
2652 )
2653 .unwrap();
2654
2655 let policy = schema.get_policy("UserModify").unwrap();
2656 assert_eq!(policy.mssql_block_operations.len(), 3);
2657 assert!(
2658 policy
2659 .mssql_block_operations
2660 .contains(&MssqlBlockOperation::AfterInsert)
2661 );
2662 assert!(
2663 policy
2664 .mssql_block_operations
2665 .contains(&MssqlBlockOperation::AfterUpdate)
2666 );
2667 assert!(
2668 policy
2669 .mssql_block_operations
2670 .contains(&MssqlBlockOperation::BeforeDelete)
2671 );
2672 }
2673
2674 #[test]
2675 fn test_parse_policy_full_mssql_config() {
2676 let schema = parse_schema(
2677 r#"
2678 policy TenantIsolation on Order {
2679 for ALL
2680 using "TenantId = @TenantId"
2681 check "TenantId = @TenantId"
2682 mssqlSchema "MultiTenant"
2683 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2684 }
2685 "#,
2686 )
2687 .unwrap();
2688
2689 let policy = schema.get_policy("TenantIsolation").unwrap();
2690
2691 assert!(policy.applies_to(PolicyCommand::All));
2693 assert!(policy.using_expr.is_some());
2694 assert!(policy.check_expr.is_some());
2695
2696 assert_eq!(policy.mssql_schema(), "MultiTenant");
2698 assert_eq!(policy.mssql_block_operations.len(), 4);
2699
2700 let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
2702 assert!(mssql.schema_sql.contains("MultiTenant"));
2703 assert!(mssql.function_sql.contains("fn_TenantIsolation_predicate"));
2704 }
2705
2706 #[test]
2707 fn test_parse_policy_mssql_block_case_variants() {
2708 let schema = parse_schema(
2710 r#"
2711 policy Test1 on User {
2712 for INSERT
2713 check "true"
2714 mssqlBlock after_insert
2715 }
2716 "#,
2717 )
2718 .unwrap();
2719
2720 let policy = schema.get_policy("Test1").unwrap();
2721 assert_eq!(policy.mssql_block_operations.len(), 1);
2722 assert_eq!(
2723 policy.mssql_block_operations[0],
2724 MssqlBlockOperation::AfterInsert
2725 );
2726 }
2727
2728 #[test]
2729 fn test_parse_mixed_postgres_mssql_schema() {
2730 let schema = parse_schema(
2731 r#"
2732 model User {
2733 id Int @id @auto
2734 email String @unique
2735 }
2736
2737 // PostgreSQL-style policy (works on both, MSSQL uses defaults)
2738 policy UserReadOwn on User {
2739 for SELECT
2740 to authenticated
2741 using "id = current_user_id()"
2742 }
2743
2744 // MSSQL-optimized policy with explicit settings
2745 policy UserModifyOwn on User {
2746 for [INSERT, UPDATE, DELETE]
2747 to authenticated
2748 using "id = current_user_id()"
2749 check "id = current_user_id()"
2750 mssqlSchema "Security"
2751 mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2752 }
2753 "#,
2754 )
2755 .unwrap();
2756
2757 assert_eq!(schema.policies.len(), 2);
2758
2759 let read_policy = schema.get_policy("UserReadOwn").unwrap();
2761 assert_eq!(read_policy.mssql_schema(), "Security"); assert!(read_policy.mssql_block_operations.is_empty()); let modify_policy = schema.get_policy("UserModifyOwn").unwrap();
2766 assert_eq!(modify_policy.mssql_schema(), "Security");
2767 assert_eq!(modify_policy.mssql_block_operations.len(), 4);
2768
2769 let pg_sql = read_policy.to_postgres_sql("users");
2771 assert!(pg_sql.contains("CREATE POLICY UserReadOwn ON users"));
2772
2773 let mssql = modify_policy.to_mssql_sql("dbo.Users", "id");
2775 assert!(mssql.policy_sql.contains("Security.UserModifyOwn"));
2776 }
2777}