1mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{Server, ServerGroup, ServerProperty, ServerPropertyValue};
17
18pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
20 debug!(input_len = input.len(), "parse_schema() starting");
21 let pairs = PraxParser::parse(Rule::schema, input)
22 .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
23
24 let mut schema = Schema::new();
25 let mut current_doc: Option<Documentation> = None;
26
27 let schema_pair = pairs.into_iter().next().unwrap();
29
30 for pair in schema_pair.into_inner() {
31 match pair.as_rule() {
32 Rule::documentation => {
33 let span = pair.as_span();
34 let text = pair
35 .into_inner()
36 .map(|p| p.as_str().trim_start_matches("///").trim())
37 .collect::<Vec<_>>()
38 .join("\n");
39 current_doc = Some(Documentation::new(
40 text,
41 Span::new(span.start(), span.end()),
42 ));
43 }
44 Rule::model_def => {
45 let mut model = parse_model(pair)?;
46 if let Some(doc) = current_doc.take() {
47 model = model.with_documentation(doc);
48 }
49 schema.add_model(model);
50 }
51 Rule::enum_def => {
52 let mut e = parse_enum(pair)?;
53 if let Some(doc) = current_doc.take() {
54 e = e.with_documentation(doc);
55 }
56 schema.add_enum(e);
57 }
58 Rule::type_def => {
59 let mut t = parse_composite_type(pair)?;
60 if let Some(doc) = current_doc.take() {
61 t = t.with_documentation(doc);
62 }
63 schema.add_type(t);
64 }
65 Rule::view_def => {
66 let mut v = parse_view(pair)?;
67 if let Some(doc) = current_doc.take() {
68 v = v.with_documentation(doc);
69 }
70 schema.add_view(v);
71 }
72 Rule::raw_sql_def => {
73 let sql = parse_raw_sql(pair)?;
74 schema.add_raw_sql(sql);
75 }
76 Rule::server_group_def => {
77 let mut sg = parse_server_group(pair)?;
78 if let Some(doc) = current_doc.take() {
79 sg.set_documentation(doc);
80 }
81 schema.add_server_group(sg);
82 }
83 Rule::EOI => {}
84 _ => {}
85 }
86 }
87
88 info!(
89 models = schema.models.len(),
90 enums = schema.enums.len(),
91 types = schema.types.len(),
92 views = schema.views.len(),
93 "Schema parsed successfully"
94 );
95 Ok(schema)
96}
97
98pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
100 let path = path.as_ref();
101 info!(path = %path.display(), "Loading schema file");
102 let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
103 path: path.display().to_string(),
104 source: e,
105 })?;
106
107 parse_schema(&content)
108}
109
110fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
112 let span = pair.as_span();
113 let mut inner = pair.into_inner();
114
115 let name_pair = inner.next().unwrap();
116 let name = Ident::new(
117 name_pair.as_str(),
118 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
119 );
120
121 let mut model = Model::new(name, Span::new(span.start(), span.end()));
122
123 for item in inner {
124 match item.as_rule() {
125 Rule::field_def => {
126 let field = parse_field(item)?;
127 model.add_field(field);
128 }
129 Rule::model_attribute => {
130 let attr = parse_attribute(item)?;
131 model.attributes.push(attr);
132 }
133 Rule::model_body_item => {
134 let inner_item = item.into_inner().next().unwrap();
136 match inner_item.as_rule() {
137 Rule::field_def => {
138 let field = parse_field(inner_item)?;
139 model.add_field(field);
140 }
141 Rule::model_attribute => {
142 let attr = parse_attribute(inner_item)?;
143 model.attributes.push(attr);
144 }
145 _ => {}
146 }
147 }
148 _ => {}
149 }
150 }
151
152 Ok(model)
153}
154
155fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
157 let span = pair.as_span();
158 let mut inner = pair.into_inner();
159
160 let name_pair = inner.next().unwrap();
161 let name = Ident::new(
162 name_pair.as_str(),
163 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
164 );
165
166 let mut e = Enum::new(name, Span::new(span.start(), span.end()));
167
168 for item in inner {
169 match item.as_rule() {
170 Rule::enum_variant => {
171 let variant = parse_enum_variant(item)?;
172 e.add_variant(variant);
173 }
174 Rule::model_attribute => {
175 let attr = parse_attribute(item)?;
176 e.attributes.push(attr);
177 }
178 Rule::enum_body_item => {
179 let inner_item = item.into_inner().next().unwrap();
181 match inner_item.as_rule() {
182 Rule::enum_variant => {
183 let variant = parse_enum_variant(inner_item)?;
184 e.add_variant(variant);
185 }
186 Rule::model_attribute => {
187 let attr = parse_attribute(inner_item)?;
188 e.attributes.push(attr);
189 }
190 _ => {}
191 }
192 }
193 _ => {}
194 }
195 }
196
197 Ok(e)
198}
199
200fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
202 let span = pair.as_span();
203 let mut inner = pair.into_inner();
204
205 let name_pair = inner.next().unwrap();
206 let name = Ident::new(
207 name_pair.as_str(),
208 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
209 );
210
211 let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
212
213 for item in inner {
214 if item.as_rule() == Rule::field_attribute {
215 let attr = parse_attribute(item)?;
216 variant.attributes.push(attr);
217 }
218 }
219
220 Ok(variant)
221}
222
223fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
225 let span = pair.as_span();
226 let mut inner = pair.into_inner();
227
228 let name_pair = inner.next().unwrap();
229 let name = Ident::new(
230 name_pair.as_str(),
231 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
232 );
233
234 let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
235
236 for item in inner {
237 if item.as_rule() == Rule::field_def {
238 let field = parse_field(item)?;
239 t.add_field(field);
240 }
241 }
242
243 Ok(t)
244}
245
246fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
248 let span = pair.as_span();
249 let mut inner = pair.into_inner();
250
251 let name_pair = inner.next().unwrap();
252 let name = Ident::new(
253 name_pair.as_str(),
254 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
255 );
256
257 let mut v = View::new(name, Span::new(span.start(), span.end()));
258
259 for item in inner {
260 match item.as_rule() {
261 Rule::field_def => {
262 let field = parse_field(item)?;
263 v.add_field(field);
264 }
265 Rule::model_attribute => {
266 let attr = parse_attribute(item)?;
267 v.attributes.push(attr);
268 }
269 Rule::model_body_item => {
270 let inner_item = item.into_inner().next().unwrap();
272 match inner_item.as_rule() {
273 Rule::field_def => {
274 let field = parse_field(inner_item)?;
275 v.add_field(field);
276 }
277 Rule::model_attribute => {
278 let attr = parse_attribute(inner_item)?;
279 v.attributes.push(attr);
280 }
281 _ => {}
282 }
283 }
284 _ => {}
285 }
286 }
287
288 Ok(v)
289}
290
291fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
293 let span = pair.as_span();
294 let mut inner = pair.into_inner();
295
296 let name_pair = inner.next().unwrap();
297 let name = Ident::new(
298 name_pair.as_str(),
299 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
300 );
301
302 let type_pair = inner.next().unwrap();
303 let (field_type, modifier) = parse_field_type(type_pair)?;
304
305 let mut attributes = vec![];
306 for item in inner {
307 if item.as_rule() == Rule::field_attribute {
308 let attr = parse_attribute(item)?;
309 attributes.push(attr);
310 }
311 }
312
313 Ok(Field::new(
314 name,
315 field_type,
316 modifier,
317 attributes,
318 Span::new(span.start(), span.end()),
319 ))
320}
321
322fn parse_field_type(
324 pair: pest::iterators::Pair<'_, Rule>,
325) -> SchemaResult<(FieldType, TypeModifier)> {
326 let mut type_name = String::new();
327 let mut modifier = TypeModifier::Required;
328
329 for item in pair.into_inner() {
330 match item.as_rule() {
331 Rule::type_name => {
332 type_name = item.as_str().to_string();
333 }
334 Rule::optional_marker => {
335 modifier = if modifier == TypeModifier::List {
336 TypeModifier::OptionalList
337 } else {
338 TypeModifier::Optional
339 };
340 }
341 Rule::list_marker => {
342 modifier = if modifier == TypeModifier::Optional {
343 TypeModifier::OptionalList
344 } else {
345 TypeModifier::List
346 };
347 }
348 _ => {}
349 }
350 }
351
352 let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
353 FieldType::Scalar(scalar)
354 } else {
355 FieldType::Model(SmolStr::new(&type_name))
358 };
359
360 Ok((field_type, modifier))
361}
362
363fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
365 let span = pair.as_span();
366 let mut inner = pair.into_inner();
367
368 let name_pair = inner.next().unwrap();
369 let name = Ident::new(
370 name_pair.as_str(),
371 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
372 );
373
374 let mut args = vec![];
375 for item in inner {
376 if item.as_rule() == Rule::attribute_args {
377 args = parse_attribute_args(item)?;
378 }
379 }
380
381 Ok(Attribute::new(
382 name,
383 args,
384 Span::new(span.start(), span.end()),
385 ))
386}
387
388fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
390 let mut args = vec![];
391
392 for item in pair.into_inner() {
393 if item.as_rule() == Rule::attribute_arg {
394 let arg = parse_attribute_arg(item)?;
395 args.push(arg);
396 }
397 }
398
399 Ok(args)
400}
401
402fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
404 let span = pair.as_span();
405 let mut inner = pair.into_inner();
406
407 let first = inner.next().unwrap();
408
409 if let Some(second) = inner.next() {
411 let name = Ident::new(
413 first.as_str(),
414 Span::new(first.as_span().start(), first.as_span().end()),
415 );
416 let value = parse_attribute_value(second)?;
417 Ok(AttributeArg::named(
418 name,
419 value,
420 Span::new(span.start(), span.end()),
421 ))
422 } else {
423 let value = parse_attribute_value(first)?;
425 Ok(AttributeArg::positional(
426 value,
427 Span::new(span.start(), span.end()),
428 ))
429 }
430}
431
432fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
434 match pair.as_rule() {
435 Rule::string_literal => {
436 let s = pair.as_str();
437 let unquoted = &s[1..s.len() - 1];
439 Ok(AttributeValue::String(unquoted.to_string()))
440 }
441 Rule::number_literal => {
442 let s = pair.as_str();
443 if s.contains('.') {
444 Ok(AttributeValue::Float(s.parse().unwrap()))
445 } else {
446 Ok(AttributeValue::Int(s.parse().unwrap()))
447 }
448 }
449 Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
450 Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
451 Rule::function_call => {
452 let mut inner = pair.into_inner();
453 let name = SmolStr::new(inner.next().unwrap().as_str());
454 let mut args = vec![];
455 for item in inner {
456 args.push(parse_attribute_value(item)?);
457 }
458 Ok(AttributeValue::Function(name, args))
459 }
460 Rule::field_ref_list => {
461 let refs: Vec<SmolStr> = pair
462 .into_inner()
463 .map(|p| SmolStr::new(p.as_str()))
464 .collect();
465 Ok(AttributeValue::FieldRefList(refs))
466 }
467 Rule::array_literal => {
468 let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
469 Ok(AttributeValue::Array(values?))
470 }
471 Rule::attribute_value => {
472 parse_attribute_value(pair.into_inner().next().unwrap())
474 }
475 _ => {
476 Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
478 }
479 }
480}
481
482fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
484 let mut inner = pair.into_inner();
485
486 let name = inner.next().unwrap().as_str();
487 let sql = inner.next().unwrap().as_str();
488
489 let sql_content = sql
491 .trim_start_matches("\"\"\"")
492 .trim_end_matches("\"\"\"")
493 .trim();
494
495 Ok(RawSql::new(name, sql_content))
496}
497
498fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
500 let span = pair.as_span();
501 let mut inner = pair.into_inner();
502
503 let name_pair = inner.next().unwrap();
504 let name = Ident::new(
505 name_pair.as_str(),
506 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
507 );
508
509 let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
510
511 for item in inner {
512 match item.as_rule() {
513 Rule::server_group_item => {
514 let inner_item = item.into_inner().next().unwrap();
516 match inner_item.as_rule() {
517 Rule::server_def => {
518 let server = parse_server(inner_item)?;
519 server_group.add_server(server);
520 }
521 Rule::model_attribute => {
522 let attr = parse_attribute(inner_item)?;
523 server_group.add_attribute(attr);
524 }
525 _ => {}
526 }
527 }
528 Rule::server_def => {
529 let server = parse_server(item)?;
530 server_group.add_server(server);
531 }
532 Rule::model_attribute => {
533 let attr = parse_attribute(item)?;
534 server_group.add_attribute(attr);
535 }
536 _ => {}
537 }
538 }
539
540 Ok(server_group)
541}
542
543fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
545 let span = pair.as_span();
546 let mut inner = pair.into_inner();
547
548 let name_pair = inner.next().unwrap();
549 let name = Ident::new(
550 name_pair.as_str(),
551 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
552 );
553
554 let mut server = Server::new(name, Span::new(span.start(), span.end()));
555
556 for item in inner {
557 if item.as_rule() == Rule::server_property {
558 let prop = parse_server_property(item)?;
559 server.add_property(prop);
560 }
561 }
562
563 Ok(server)
564}
565
566fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
568 let span = pair.as_span();
569 let mut inner = pair.into_inner();
570
571 let key_pair = inner.next().unwrap();
572 let key = key_pair.as_str();
573
574 let value_pair = inner.next().unwrap();
575 let value = parse_server_property_value(value_pair)?;
576
577 Ok(ServerProperty::new(
578 key,
579 value,
580 Span::new(span.start(), span.end()),
581 ))
582}
583
584fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
586 match pair.as_rule() {
587 Rule::string_literal => {
588 let s = pair.as_str();
589 s[1..s.len() - 1].to_string()
590 }
591 Rule::attribute_value => {
592 if let Some(inner) = pair.into_inner().next() {
594 extract_string_from_arg(inner)
595 } else {
596 String::new()
597 }
598 }
599 _ => pair.as_str().to_string(),
600 }
601}
602
603fn parse_server_property_value(
605 pair: pest::iterators::Pair<'_, Rule>,
606) -> SchemaResult<ServerPropertyValue> {
607 match pair.as_rule() {
608 Rule::string_literal => {
609 let s = pair.as_str();
610 let unquoted = &s[1..s.len() - 1];
612 Ok(ServerPropertyValue::String(unquoted.to_string()))
613 }
614 Rule::number_literal => {
615 let s = pair.as_str();
616 Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
617 }
618 Rule::boolean_literal => Ok(ServerPropertyValue::Boolean(pair.as_str() == "true")),
619 Rule::identifier => Ok(ServerPropertyValue::Identifier(pair.as_str().to_string())),
620 Rule::function_call => {
621 let mut inner = pair.into_inner();
623 let func_name = inner.next().unwrap().as_str();
624 if func_name == "env" {
625 if let Some(arg) = inner.next() {
626 let var_name = extract_string_from_arg(arg);
627 return Ok(ServerPropertyValue::EnvVar(var_name));
628 }
629 }
630 Ok(ServerPropertyValue::Identifier(func_name.to_string()))
632 }
633 Rule::array_literal => {
634 let values: Result<Vec<_>, _> =
635 pair.into_inner().map(parse_server_property_value).collect();
636 Ok(ServerPropertyValue::Array(values?))
637 }
638 Rule::attribute_value => {
639 parse_server_property_value(pair.into_inner().next().unwrap())
641 }
642 _ => {
643 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
645 }
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
656 fn test_parse_simple_model() {
657 let schema = parse_schema(
658 r#"
659 model User {
660 id Int @id @auto
661 email String @unique
662 name String?
663 }
664 "#,
665 )
666 .unwrap();
667
668 assert_eq!(schema.models.len(), 1);
669 let user = schema.get_model("User").unwrap();
670 assert_eq!(user.fields.len(), 3);
671 assert!(user.get_field("id").unwrap().is_id());
672 assert!(user.get_field("email").unwrap().is_unique());
673 assert!(user.get_field("name").unwrap().is_optional());
674 }
675
676 #[test]
677 fn test_parse_model_name() {
678 let schema = parse_schema(
679 r#"
680 model BlogPost {
681 id Int @id
682 }
683 "#,
684 )
685 .unwrap();
686
687 assert!(schema.get_model("BlogPost").is_some());
688 }
689
690 #[test]
691 fn test_parse_multiple_models() {
692 let schema = parse_schema(
693 r#"
694 model User {
695 id Int @id
696 }
697
698 model Post {
699 id Int @id
700 }
701
702 model Comment {
703 id Int @id
704 }
705 "#,
706 )
707 .unwrap();
708
709 assert_eq!(schema.models.len(), 3);
710 assert!(schema.get_model("User").is_some());
711 assert!(schema.get_model("Post").is_some());
712 assert!(schema.get_model("Comment").is_some());
713 }
714
715 #[test]
718 fn test_parse_all_scalar_types() {
719 let schema = parse_schema(
720 r#"
721 model AllTypes {
722 id Int @id
723 big BigInt
724 float_f Float
725 decimal Decimal
726 str String
727 bool Boolean
728 datetime DateTime
729 date Date
730 time Time
731 json Json
732 bytes Bytes
733 uuid Uuid
734 cuid Cuid
735 cuid2 Cuid2
736 nanoid NanoId
737 ulid Ulid
738 }
739 "#,
740 )
741 .unwrap();
742
743 let model = schema.get_model("AllTypes").unwrap();
744 assert_eq!(model.fields.len(), 16);
745
746 assert!(matches!(
747 model.get_field("id").unwrap().field_type,
748 FieldType::Scalar(ScalarType::Int)
749 ));
750 assert!(matches!(
751 model.get_field("big").unwrap().field_type,
752 FieldType::Scalar(ScalarType::BigInt)
753 ));
754 assert!(matches!(
755 model.get_field("str").unwrap().field_type,
756 FieldType::Scalar(ScalarType::String)
757 ));
758 assert!(matches!(
759 model.get_field("bool").unwrap().field_type,
760 FieldType::Scalar(ScalarType::Boolean)
761 ));
762 assert!(matches!(
763 model.get_field("datetime").unwrap().field_type,
764 FieldType::Scalar(ScalarType::DateTime)
765 ));
766 assert!(matches!(
767 model.get_field("uuid").unwrap().field_type,
768 FieldType::Scalar(ScalarType::Uuid)
769 ));
770 assert!(matches!(
771 model.get_field("cuid").unwrap().field_type,
772 FieldType::Scalar(ScalarType::Cuid)
773 ));
774 assert!(matches!(
775 model.get_field("cuid2").unwrap().field_type,
776 FieldType::Scalar(ScalarType::Cuid2)
777 ));
778 assert!(matches!(
779 model.get_field("nanoid").unwrap().field_type,
780 FieldType::Scalar(ScalarType::NanoId)
781 ));
782 assert!(matches!(
783 model.get_field("ulid").unwrap().field_type,
784 FieldType::Scalar(ScalarType::Ulid)
785 ));
786 }
787
788 #[test]
789 fn test_parse_optional_field() {
790 let schema = parse_schema(
791 r#"
792 model User {
793 id Int @id
794 bio String?
795 age Int?
796 }
797 "#,
798 )
799 .unwrap();
800
801 let user = schema.get_model("User").unwrap();
802 assert!(!user.get_field("id").unwrap().is_optional());
803 assert!(user.get_field("bio").unwrap().is_optional());
804 assert!(user.get_field("age").unwrap().is_optional());
805 }
806
807 #[test]
808 fn test_parse_list_field() {
809 let schema = parse_schema(
810 r#"
811 model User {
812 id Int @id
813 tags String[]
814 posts Post[]
815 }
816 "#,
817 )
818 .unwrap();
819
820 let user = schema.get_model("User").unwrap();
821 assert!(user.get_field("tags").unwrap().is_list());
822 assert!(user.get_field("posts").unwrap().is_list());
823 }
824
825 #[test]
826 fn test_parse_optional_list_field() {
827 let schema = parse_schema(
828 r#"
829 model User {
830 id Int @id
831 metadata String[]?
832 }
833 "#,
834 )
835 .unwrap();
836
837 let user = schema.get_model("User").unwrap();
838 let metadata = user.get_field("metadata").unwrap();
839 assert!(metadata.is_list());
840 assert!(metadata.is_optional());
841 }
842
843 #[test]
846 fn test_parse_id_attribute() {
847 let schema = parse_schema(
848 r#"
849 model User {
850 id Int @id
851 }
852 "#,
853 )
854 .unwrap();
855
856 let user = schema.get_model("User").unwrap();
857 assert!(user.get_field("id").unwrap().is_id());
858 }
859
860 #[test]
861 fn test_parse_unique_attribute() {
862 let schema = parse_schema(
863 r#"
864 model User {
865 id Int @id
866 email String @unique
867 }
868 "#,
869 )
870 .unwrap();
871
872 let user = schema.get_model("User").unwrap();
873 assert!(user.get_field("email").unwrap().is_unique());
874 }
875
876 #[test]
877 fn test_parse_default_int() {
878 let schema = parse_schema(
879 r#"
880 model Counter {
881 id Int @id
882 count Int @default(0)
883 }
884 "#,
885 )
886 .unwrap();
887
888 let counter = schema.get_model("Counter").unwrap();
889 let count_field = counter.get_field("count").unwrap();
890 let attrs = count_field.extract_attributes();
891 assert!(attrs.default.is_some());
892 assert_eq!(attrs.default.unwrap().as_int(), Some(0));
893 }
894
895 #[test]
896 fn test_parse_default_string() {
897 let schema = parse_schema(
898 r#"
899 model User {
900 id Int @id
901 status String @default("active")
902 }
903 "#,
904 )
905 .unwrap();
906
907 let user = schema.get_model("User").unwrap();
908 let status = user.get_field("status").unwrap();
909 let attrs = status.extract_attributes();
910 assert!(attrs.default.is_some());
911 assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
912 }
913
914 #[test]
915 fn test_parse_default_boolean() {
916 let schema = parse_schema(
917 r#"
918 model Post {
919 id Int @id
920 published Boolean @default(false)
921 }
922 "#,
923 )
924 .unwrap();
925
926 let post = schema.get_model("Post").unwrap();
927 let published = post.get_field("published").unwrap();
928 let attrs = published.extract_attributes();
929 assert!(attrs.default.is_some());
930 assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
931 }
932
933 #[test]
934 fn test_parse_default_function() {
935 let schema = parse_schema(
936 r#"
937 model User {
938 id Int @id
939 createdAt DateTime @default(now())
940 }
941 "#,
942 )
943 .unwrap();
944
945 let user = schema.get_model("User").unwrap();
946 let created_at = user.get_field("createdAt").unwrap();
947 let attrs = created_at.extract_attributes();
948 assert!(attrs.default.is_some());
949 if let Some(AttributeValue::Function(name, _)) = attrs.default {
950 assert_eq!(name.as_str(), "now");
951 } else {
952 panic!("Expected function default");
953 }
954 }
955
956 #[test]
957 fn test_parse_updated_at_attribute() {
958 let schema = parse_schema(
959 r#"
960 model User {
961 id Int @id
962 updatedAt DateTime @updated_at
963 }
964 "#,
965 )
966 .unwrap();
967
968 let user = schema.get_model("User").unwrap();
969 let updated_at = user.get_field("updatedAt").unwrap();
970 let attrs = updated_at.extract_attributes();
971 assert!(attrs.is_updated_at);
972 }
973
974 #[test]
975 fn test_parse_map_attribute() {
976 let schema = parse_schema(
977 r#"
978 model User {
979 id Int @id
980 email String @map("email_address")
981 }
982 "#,
983 )
984 .unwrap();
985
986 let user = schema.get_model("User").unwrap();
987 let email = user.get_field("email").unwrap();
988 let attrs = email.extract_attributes();
989 assert_eq!(attrs.map, Some("email_address".to_string()));
990 }
991
992 #[test]
993 fn test_parse_multiple_attributes() {
994 let schema = parse_schema(
995 r#"
996 model User {
997 id Int @id @auto
998 email String @unique @index
999 }
1000 "#,
1001 )
1002 .unwrap();
1003
1004 let user = schema.get_model("User").unwrap();
1005 let id = user.get_field("id").unwrap();
1006 let email = user.get_field("email").unwrap();
1007
1008 let id_attrs = id.extract_attributes();
1009 assert!(id_attrs.is_id);
1010 assert!(id_attrs.is_auto);
1011
1012 let email_attrs = email.extract_attributes();
1013 assert!(email_attrs.is_unique);
1014 assert!(email_attrs.is_indexed);
1015 }
1016
1017 #[test]
1020 fn test_parse_model_map_attribute() {
1021 let schema = parse_schema(
1022 r#"
1023 model User {
1024 id Int @id
1025
1026 @@map("app_users")
1027 }
1028 "#,
1029 )
1030 .unwrap();
1031
1032 let user = schema.get_model("User").unwrap();
1033 assert_eq!(user.table_name(), "app_users");
1034 }
1035
1036 #[test]
1037 fn test_parse_model_index_attribute() {
1038 let schema = parse_schema(
1039 r#"
1040 model User {
1041 id Int @id
1042 email String
1043 name String
1044
1045 @@index([email, name])
1046 }
1047 "#,
1048 )
1049 .unwrap();
1050
1051 let user = schema.get_model("User").unwrap();
1052 assert!(user.has_attribute("index"));
1053 }
1054
1055 #[test]
1056 fn test_parse_composite_primary_key() {
1057 let schema = parse_schema(
1058 r#"
1059 model PostTag {
1060 postId Int
1061 tagId Int
1062
1063 @@id([postId, tagId])
1064 }
1065 "#,
1066 )
1067 .unwrap();
1068
1069 let post_tag = schema.get_model("PostTag").unwrap();
1070 assert!(post_tag.has_attribute("id"));
1071 }
1072
1073 #[test]
1076 fn test_parse_enum() {
1077 let schema = parse_schema(
1078 r#"
1079 enum Role {
1080 User
1081 Admin
1082 Moderator
1083 }
1084 "#,
1085 )
1086 .unwrap();
1087
1088 assert_eq!(schema.enums.len(), 1);
1089 let role = schema.get_enum("Role").unwrap();
1090 assert_eq!(role.variants.len(), 3);
1091 }
1092
1093 #[test]
1094 fn test_parse_enum_variant_names() {
1095 let schema = parse_schema(
1096 r#"
1097 enum Status {
1098 Pending
1099 Active
1100 Completed
1101 Cancelled
1102 }
1103 "#,
1104 )
1105 .unwrap();
1106
1107 let status = schema.get_enum("Status").unwrap();
1108 assert!(status.get_variant("Pending").is_some());
1109 assert!(status.get_variant("Active").is_some());
1110 assert!(status.get_variant("Completed").is_some());
1111 assert!(status.get_variant("Cancelled").is_some());
1112 }
1113
1114 #[test]
1115 fn test_parse_enum_with_map() {
1116 let schema = parse_schema(
1117 r#"
1118 enum Role {
1119 User @map("USER")
1120 Admin @map("ADMINISTRATOR")
1121 }
1122 "#,
1123 )
1124 .unwrap();
1125
1126 let role = schema.get_enum("Role").unwrap();
1127 let user_variant = role.get_variant("User").unwrap();
1128 assert_eq!(user_variant.db_value(), "USER");
1129
1130 let admin_variant = role.get_variant("Admin").unwrap();
1131 assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1132 }
1133
1134 #[test]
1137 fn test_parse_one_to_many_relation() {
1138 let schema = parse_schema(
1139 r#"
1140 model User {
1141 id Int @id
1142 posts Post[]
1143 }
1144
1145 model Post {
1146 id Int @id
1147 authorId Int
1148 author User @relation(fields: [authorId], references: [id])
1149 }
1150 "#,
1151 )
1152 .unwrap();
1153
1154 let user = schema.get_model("User").unwrap();
1155 let post = schema.get_model("Post").unwrap();
1156
1157 assert!(user.get_field("posts").unwrap().is_list());
1158 assert!(post.get_field("author").unwrap().is_relation());
1159 }
1160
1161 #[test]
1162 fn test_parse_relation_with_actions() {
1163 let schema = parse_schema(
1164 r#"
1165 model Post {
1166 id Int @id
1167 authorId Int
1168 author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1169 }
1170
1171 model User {
1172 id Int @id
1173 posts Post[]
1174 }
1175 "#,
1176 )
1177 .unwrap();
1178
1179 let post = schema.get_model("Post").unwrap();
1180 let author = post.get_field("author").unwrap();
1181 let attrs = author.extract_attributes();
1182
1183 assert!(attrs.relation.is_some());
1184 let rel = attrs.relation.unwrap();
1185 assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1186 assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1187 }
1188
1189 #[test]
1192 fn test_parse_model_documentation() {
1193 let schema = parse_schema(
1194 r#"/// Represents a user in the system
1195model User {
1196 id Int @id
1197}"#,
1198 )
1199 .unwrap();
1200
1201 let user = schema.get_model("User").unwrap();
1202 if let Some(doc) = &user.documentation {
1205 assert!(doc.text.contains("user"));
1206 }
1207 }
1208
1209 #[test]
1212 fn test_parse_complete_schema() {
1213 let schema = parse_schema(
1214 r#"
1215 /// User model
1216 model User {
1217 id Int @id @auto
1218 email String @unique
1219 name String?
1220 role Role @default(User)
1221 posts Post[]
1222 profile Profile?
1223 createdAt DateTime @default(now())
1224 updatedAt DateTime @updated_at
1225
1226 @@map("users")
1227 @@index([email])
1228 }
1229
1230 model Post {
1231 id Int @id @auto
1232 title String
1233 content String?
1234 published Boolean @default(false)
1235 authorId Int
1236 author User @relation(fields: [authorId], references: [id])
1237 tags Tag[]
1238 createdAt DateTime @default(now())
1239
1240 @@index([authorId])
1241 }
1242
1243 model Profile {
1244 id Int @id @auto
1245 bio String?
1246 userId Int @unique
1247 user User @relation(fields: [userId], references: [id])
1248 }
1249
1250 model Tag {
1251 id Int @id @auto
1252 name String @unique
1253 posts Post[]
1254 }
1255
1256 enum Role {
1257 User
1258 Admin
1259 Moderator
1260 }
1261 "#,
1262 )
1263 .unwrap();
1264
1265 assert_eq!(schema.models.len(), 4);
1267 assert!(schema.get_model("User").is_some());
1268 assert!(schema.get_model("Post").is_some());
1269 assert!(schema.get_model("Profile").is_some());
1270 assert!(schema.get_model("Tag").is_some());
1271
1272 assert_eq!(schema.enums.len(), 1);
1274 assert!(schema.get_enum("Role").is_some());
1275
1276 let user = schema.get_model("User").unwrap();
1278 assert_eq!(user.table_name(), "users");
1279 assert_eq!(user.fields.len(), 8);
1280 assert!(user.has_attribute("index"));
1281
1282 let post = schema.get_model("Post").unwrap();
1284 assert!(post.get_field("author").unwrap().is_relation());
1285 }
1286
1287 #[test]
1290 fn test_parse_invalid_syntax() {
1291 let result = parse_schema("model { broken }");
1292 assert!(result.is_err());
1293 }
1294
1295 #[test]
1296 fn test_parse_empty_schema() {
1297 let schema = parse_schema("").unwrap();
1298 assert!(schema.models.is_empty());
1299 assert!(schema.enums.is_empty());
1300 }
1301
1302 #[test]
1303 fn test_parse_whitespace_only() {
1304 let schema = parse_schema(" \n\t \n ").unwrap();
1305 assert!(schema.models.is_empty());
1306 }
1307
1308 #[test]
1309 fn test_parse_comments_only() {
1310 let schema = parse_schema(
1311 r#"
1312 // This is a comment
1313 // Another comment
1314 "#,
1315 )
1316 .unwrap();
1317 assert!(schema.models.is_empty());
1318 }
1319
1320 #[test]
1323 fn test_parse_model_with_no_fields() {
1324 let result = parse_schema(
1326 r#"
1327 model Empty {
1328 }
1329 "#,
1330 );
1331 let _ = result;
1333 }
1334
1335 #[test]
1336 fn test_parse_long_identifier() {
1337 let schema = parse_schema(
1338 r#"
1339 model VeryLongModelNameThatIsStillValid {
1340 someVeryLongFieldNameThatShouldWork Int @id
1341 }
1342 "#,
1343 )
1344 .unwrap();
1345
1346 assert!(
1347 schema
1348 .get_model("VeryLongModelNameThatIsStillValid")
1349 .is_some()
1350 );
1351 }
1352
1353 #[test]
1354 fn test_parse_underscore_identifiers() {
1355 let schema = parse_schema(
1356 r#"
1357 model user_account {
1358 user_id Int @id
1359 created_at DateTime
1360 }
1361 "#,
1362 )
1363 .unwrap();
1364
1365 let model = schema.get_model("user_account").unwrap();
1366 assert!(model.get_field("user_id").is_some());
1367 assert!(model.get_field("created_at").is_some());
1368 }
1369
1370 #[test]
1371 fn test_parse_negative_default() {
1372 let schema = parse_schema(
1373 r#"
1374 model Config {
1375 id Int @id
1376 minValue Int @default(-100)
1377 }
1378 "#,
1379 )
1380 .unwrap();
1381
1382 let config = schema.get_model("Config").unwrap();
1383 let min_value = config.get_field("minValue").unwrap();
1384 let attrs = min_value.extract_attributes();
1385 assert!(attrs.default.is_some());
1386 }
1387
1388 #[test]
1389 fn test_parse_float_default() {
1390 let schema = parse_schema(
1391 r#"
1392 model Product {
1393 id Int @id
1394 price Float @default(9.99)
1395 }
1396 "#,
1397 )
1398 .unwrap();
1399
1400 let product = schema.get_model("Product").unwrap();
1401 let price = product.get_field("price").unwrap();
1402 let attrs = price.extract_attributes();
1403 assert!(attrs.default.is_some());
1404 }
1405
1406 #[test]
1409 fn test_parse_simple_server_group() {
1410 let schema = parse_schema(
1411 r#"
1412 serverGroup MainCluster {
1413 server primary {
1414 url = "postgres://localhost/db"
1415 role = "primary"
1416 }
1417 }
1418 "#,
1419 )
1420 .unwrap();
1421
1422 assert_eq!(schema.server_groups.len(), 1);
1423 let cluster = schema.get_server_group("MainCluster").unwrap();
1424 assert_eq!(cluster.servers.len(), 1);
1425 assert!(cluster.servers.contains_key("primary"));
1426 }
1427
1428 #[test]
1429 fn test_parse_server_group_with_multiple_servers() {
1430 let schema = parse_schema(
1431 r#"
1432 serverGroup ReadReplicas {
1433 server primary {
1434 url = "postgres://primary.db.com/app"
1435 role = "primary"
1436 weight = 1
1437 }
1438
1439 server replica1 {
1440 url = "postgres://replica1.db.com/app"
1441 role = "replica"
1442 weight = 2
1443 }
1444
1445 server replica2 {
1446 url = "postgres://replica2.db.com/app"
1447 role = "replica"
1448 weight = 2
1449 }
1450 }
1451 "#,
1452 )
1453 .unwrap();
1454
1455 let cluster = schema.get_server_group("ReadReplicas").unwrap();
1456 assert_eq!(cluster.servers.len(), 3);
1457
1458 let primary = cluster.servers.get("primary").unwrap();
1459 assert_eq!(primary.role(), Some(ServerRole::Primary));
1460 assert_eq!(primary.weight(), Some(1));
1461
1462 let replica1 = cluster.servers.get("replica1").unwrap();
1463 assert_eq!(replica1.role(), Some(ServerRole::Replica));
1464 assert_eq!(replica1.weight(), Some(2));
1465 }
1466
1467 #[test]
1468 fn test_parse_server_group_with_attributes() {
1469 let schema = parse_schema(
1470 r#"
1471 serverGroup ProductionCluster {
1472 @@strategy(ReadReplica)
1473 @@loadBalance(RoundRobin)
1474
1475 server main {
1476 url = "postgres://main/db"
1477 role = "primary"
1478 }
1479 }
1480 "#,
1481 )
1482 .unwrap();
1483
1484 let cluster = schema.get_server_group("ProductionCluster").unwrap();
1485 assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1486 assert!(
1487 cluster
1488 .attributes
1489 .iter()
1490 .any(|a| a.name.name == "loadBalance")
1491 );
1492 }
1493
1494 #[test]
1495 fn test_parse_server_group_with_env_vars() {
1496 let schema = parse_schema(
1497 r#"
1498 serverGroup EnvCluster {
1499 server db1 {
1500 url = env("PRIMARY_DB_URL")
1501 role = "primary"
1502 }
1503 }
1504 "#,
1505 )
1506 .unwrap();
1507
1508 let cluster = schema.get_server_group("EnvCluster").unwrap();
1509 let server = cluster.servers.get("db1").unwrap();
1510
1511 if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1513 assert_eq!(var, "PRIMARY_DB_URL");
1514 } else {
1515 panic!("Expected env var for url property");
1516 }
1517 }
1518
1519 #[test]
1520 fn test_parse_server_group_with_boolean_property() {
1521 let schema = parse_schema(
1522 r#"
1523 serverGroup TestCluster {
1524 server replica {
1525 url = "postgres://replica/db"
1526 role = "replica"
1527 readOnly = true
1528 }
1529 }
1530 "#,
1531 )
1532 .unwrap();
1533
1534 let cluster = schema.get_server_group("TestCluster").unwrap();
1535 let server = cluster.servers.get("replica").unwrap();
1536 assert!(server.is_read_only());
1537 }
1538
1539 #[test]
1540 fn test_parse_server_group_with_numeric_properties() {
1541 let schema = parse_schema(
1542 r#"
1543 serverGroup NumericCluster {
1544 server db {
1545 url = "postgres://localhost/db"
1546 weight = 5
1547 priority = 1
1548 maxConnections = 100
1549 }
1550 }
1551 "#,
1552 )
1553 .unwrap();
1554
1555 let cluster = schema.get_server_group("NumericCluster").unwrap();
1556 let server = cluster.servers.get("db").unwrap();
1557
1558 assert_eq!(server.weight(), Some(5));
1559 assert_eq!(server.priority(), Some(1));
1560 assert_eq!(server.max_connections(), Some(100));
1561 }
1562
1563 #[test]
1564 fn test_parse_server_group_with_region() {
1565 let schema = parse_schema(
1566 r#"
1567 serverGroup GeoCluster {
1568 server usEast {
1569 url = "postgres://us-east.db.com/app"
1570 role = "replica"
1571 region = "us-east-1"
1572 }
1573
1574 server usWest {
1575 url = "postgres://us-west.db.com/app"
1576 role = "replica"
1577 region = "us-west-2"
1578 }
1579 }
1580 "#,
1581 )
1582 .unwrap();
1583
1584 let cluster = schema.get_server_group("GeoCluster").unwrap();
1585
1586 let us_east = cluster.servers.get("usEast").unwrap();
1587 assert_eq!(us_east.region(), Some("us-east-1"));
1588
1589 let us_west = cluster.servers.get("usWest").unwrap();
1590 assert_eq!(us_west.region(), Some("us-west-2"));
1591
1592 let us_east_servers = cluster.servers_in_region("us-east-1");
1594 assert_eq!(us_east_servers.len(), 1);
1595 }
1596
1597 #[test]
1598 fn test_parse_multiple_server_groups() {
1599 let schema = parse_schema(
1600 r#"
1601 serverGroup Cluster1 {
1602 server db1 {
1603 url = "postgres://db1/app"
1604 }
1605 }
1606
1607 serverGroup Cluster2 {
1608 server db2 {
1609 url = "postgres://db2/app"
1610 }
1611 }
1612
1613 serverGroup Cluster3 {
1614 server db3 {
1615 url = "postgres://db3/app"
1616 }
1617 }
1618 "#,
1619 )
1620 .unwrap();
1621
1622 assert_eq!(schema.server_groups.len(), 3);
1623 assert!(schema.get_server_group("Cluster1").is_some());
1624 assert!(schema.get_server_group("Cluster2").is_some());
1625 assert!(schema.get_server_group("Cluster3").is_some());
1626 }
1627
1628 #[test]
1629 fn test_parse_schema_with_models_and_server_groups() {
1630 let schema = parse_schema(
1631 r#"
1632 model User {
1633 id Int @id @auto
1634 email String @unique
1635 }
1636
1637 serverGroup Database {
1638 @@strategy(ReadReplica)
1639
1640 server primary {
1641 url = env("DATABASE_URL")
1642 role = "primary"
1643 }
1644 }
1645
1646 model Post {
1647 id Int @id @auto
1648 title String
1649 authorId Int
1650 }
1651 "#,
1652 )
1653 .unwrap();
1654
1655 assert_eq!(schema.models.len(), 2);
1656 assert!(schema.get_model("User").is_some());
1657 assert!(schema.get_model("Post").is_some());
1658
1659 assert_eq!(schema.server_groups.len(), 1);
1660 assert!(schema.get_server_group("Database").is_some());
1661 }
1662
1663 #[test]
1664 fn test_parse_server_group_with_health_check() {
1665 let schema = parse_schema(
1666 r#"
1667 serverGroup HealthyCluster {
1668 server monitored {
1669 url = "postgres://localhost/db"
1670 healthCheck = "/health"
1671 }
1672 }
1673 "#,
1674 )
1675 .unwrap();
1676
1677 let cluster = schema.get_server_group("HealthyCluster").unwrap();
1678 let server = cluster.servers.get("monitored").unwrap();
1679 assert_eq!(server.health_check(), Some("/health"));
1680 }
1681
1682 #[test]
1683 fn test_server_group_failover_order() {
1684 let schema = parse_schema(
1685 r#"
1686 serverGroup FailoverCluster {
1687 server db3 {
1688 url = "postgres://db3/app"
1689 priority = 3
1690 }
1691
1692 server db1 {
1693 url = "postgres://db1/app"
1694 priority = 1
1695 }
1696
1697 server db2 {
1698 url = "postgres://db2/app"
1699 priority = 2
1700 }
1701 }
1702 "#,
1703 )
1704 .unwrap();
1705
1706 let cluster = schema.get_server_group("FailoverCluster").unwrap();
1707 let ordered = cluster.failover_order();
1708
1709 assert_eq!(ordered[0].name.name.as_str(), "db1");
1710 assert_eq!(ordered[1].name.name.as_str(), "db2");
1711 assert_eq!(ordered[2].name.name.as_str(), "db3");
1712 }
1713
1714 #[test]
1715 fn test_server_group_names() {
1716 let schema = parse_schema(
1717 r#"
1718 serverGroup Alpha {
1719 server s1 { url = "pg://a" }
1720 }
1721 serverGroup Beta {
1722 server s2 { url = "pg://b" }
1723 }
1724 "#,
1725 )
1726 .unwrap();
1727
1728 let names: Vec<_> = schema.server_group_names().collect();
1729 assert_eq!(names.len(), 2);
1730 assert!(names.contains(&"Alpha"));
1731 assert!(names.contains(&"Beta"));
1732 }
1733}