1mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{Server, ServerGroup, ServerProperty, ServerPropertyValue};
17
18pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
20 debug!(input_len = input.len(), "parse_schema() starting");
21 let pairs = PraxParser::parse(Rule::schema, input)
22 .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
23
24 let mut schema = Schema::new();
25 let mut current_doc: Option<Documentation> = None;
26
27 let schema_pair = pairs.into_iter().next().unwrap();
29
30 for pair in schema_pair.into_inner() {
31 match pair.as_rule() {
32 Rule::documentation => {
33 let span = pair.as_span();
34 let text = pair
35 .into_inner()
36 .map(|p| p.as_str().trim_start_matches("///").trim())
37 .collect::<Vec<_>>()
38 .join("\n");
39 current_doc = Some(Documentation::new(
40 text,
41 Span::new(span.start(), span.end()),
42 ));
43 }
44 Rule::model_def => {
45 let mut model = parse_model(pair)?;
46 if let Some(doc) = current_doc.take() {
47 model = model.with_documentation(doc);
48 }
49 schema.add_model(model);
50 }
51 Rule::enum_def => {
52 let mut e = parse_enum(pair)?;
53 if let Some(doc) = current_doc.take() {
54 e = e.with_documentation(doc);
55 }
56 schema.add_enum(e);
57 }
58 Rule::type_def => {
59 let mut t = parse_composite_type(pair)?;
60 if let Some(doc) = current_doc.take() {
61 t = t.with_documentation(doc);
62 }
63 schema.add_type(t);
64 }
65 Rule::view_def => {
66 let mut v = parse_view(pair)?;
67 if let Some(doc) = current_doc.take() {
68 v = v.with_documentation(doc);
69 }
70 schema.add_view(v);
71 }
72 Rule::raw_sql_def => {
73 let sql = parse_raw_sql(pair)?;
74 schema.add_raw_sql(sql);
75 }
76 Rule::server_group_def => {
77 let mut sg = parse_server_group(pair)?;
78 if let Some(doc) = current_doc.take() {
79 sg.set_documentation(doc);
80 }
81 schema.add_server_group(sg);
82 }
83 Rule::EOI => {}
84 _ => {}
85 }
86 }
87
88 info!(
89 models = schema.models.len(),
90 enums = schema.enums.len(),
91 types = schema.types.len(),
92 views = schema.views.len(),
93 "Schema parsed successfully"
94 );
95 Ok(schema)
96}
97
98pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
100 let path = path.as_ref();
101 info!(path = %path.display(), "Loading schema file");
102 let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
103 path: path.display().to_string(),
104 source: e,
105 })?;
106
107 parse_schema(&content)
108}
109
110fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
112 let span = pair.as_span();
113 let mut inner = pair.into_inner();
114
115 let name_pair = inner.next().unwrap();
116 let name = Ident::new(
117 name_pair.as_str(),
118 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
119 );
120
121 let mut model = Model::new(name, Span::new(span.start(), span.end()));
122
123 for item in inner {
124 match item.as_rule() {
125 Rule::field_def => {
126 let field = parse_field(item)?;
127 model.add_field(field);
128 }
129 Rule::model_attribute => {
130 let attr = parse_attribute(item)?;
131 model.attributes.push(attr);
132 }
133 Rule::model_body_item => {
134 let inner_item = item.into_inner().next().unwrap();
136 match inner_item.as_rule() {
137 Rule::field_def => {
138 let field = parse_field(inner_item)?;
139 model.add_field(field);
140 }
141 Rule::model_attribute => {
142 let attr = parse_attribute(inner_item)?;
143 model.attributes.push(attr);
144 }
145 _ => {}
146 }
147 }
148 _ => {}
149 }
150 }
151
152 Ok(model)
153}
154
155fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
157 let span = pair.as_span();
158 let mut inner = pair.into_inner();
159
160 let name_pair = inner.next().unwrap();
161 let name = Ident::new(
162 name_pair.as_str(),
163 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
164 );
165
166 let mut e = Enum::new(name, Span::new(span.start(), span.end()));
167
168 for item in inner {
169 match item.as_rule() {
170 Rule::enum_variant => {
171 let variant = parse_enum_variant(item)?;
172 e.add_variant(variant);
173 }
174 Rule::model_attribute => {
175 let attr = parse_attribute(item)?;
176 e.attributes.push(attr);
177 }
178 Rule::enum_body_item => {
179 let inner_item = item.into_inner().next().unwrap();
181 match inner_item.as_rule() {
182 Rule::enum_variant => {
183 let variant = parse_enum_variant(inner_item)?;
184 e.add_variant(variant);
185 }
186 Rule::model_attribute => {
187 let attr = parse_attribute(inner_item)?;
188 e.attributes.push(attr);
189 }
190 _ => {}
191 }
192 }
193 _ => {}
194 }
195 }
196
197 Ok(e)
198}
199
200fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
202 let span = pair.as_span();
203 let mut inner = pair.into_inner();
204
205 let name_pair = inner.next().unwrap();
206 let name = Ident::new(
207 name_pair.as_str(),
208 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
209 );
210
211 let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
212
213 for item in inner {
214 if item.as_rule() == Rule::field_attribute {
215 let attr = parse_attribute(item)?;
216 variant.attributes.push(attr);
217 }
218 }
219
220 Ok(variant)
221}
222
223fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
225 let span = pair.as_span();
226 let mut inner = pair.into_inner();
227
228 let name_pair = inner.next().unwrap();
229 let name = Ident::new(
230 name_pair.as_str(),
231 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
232 );
233
234 let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
235
236 for item in inner {
237 if item.as_rule() == Rule::field_def {
238 let field = parse_field(item)?;
239 t.add_field(field);
240 }
241 }
242
243 Ok(t)
244}
245
246fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
248 let span = pair.as_span();
249 let mut inner = pair.into_inner();
250
251 let name_pair = inner.next().unwrap();
252 let name = Ident::new(
253 name_pair.as_str(),
254 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
255 );
256
257 let mut v = View::new(name, Span::new(span.start(), span.end()));
258
259 for item in inner {
260 match item.as_rule() {
261 Rule::field_def => {
262 let field = parse_field(item)?;
263 v.add_field(field);
264 }
265 Rule::model_attribute => {
266 let attr = parse_attribute(item)?;
267 v.attributes.push(attr);
268 }
269 Rule::model_body_item => {
270 let inner_item = item.into_inner().next().unwrap();
272 match inner_item.as_rule() {
273 Rule::field_def => {
274 let field = parse_field(inner_item)?;
275 v.add_field(field);
276 }
277 Rule::model_attribute => {
278 let attr = parse_attribute(inner_item)?;
279 v.attributes.push(attr);
280 }
281 _ => {}
282 }
283 }
284 _ => {}
285 }
286 }
287
288 Ok(v)
289}
290
291fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
293 let span = pair.as_span();
294 let mut inner = pair.into_inner();
295
296 let name_pair = inner.next().unwrap();
297 let name = Ident::new(
298 name_pair.as_str(),
299 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
300 );
301
302 let type_pair = inner.next().unwrap();
303 let (field_type, modifier) = parse_field_type(type_pair)?;
304
305 let mut attributes = vec![];
306 for item in inner {
307 if item.as_rule() == Rule::field_attribute {
308 let attr = parse_attribute(item)?;
309 attributes.push(attr);
310 }
311 }
312
313 Ok(Field::new(
314 name,
315 field_type,
316 modifier,
317 attributes,
318 Span::new(span.start(), span.end()),
319 ))
320}
321
322fn parse_field_type(
324 pair: pest::iterators::Pair<'_, Rule>,
325) -> SchemaResult<(FieldType, TypeModifier)> {
326 let mut type_name = String::new();
327 let mut modifier = TypeModifier::Required;
328
329 for item in pair.into_inner() {
330 match item.as_rule() {
331 Rule::type_name => {
332 type_name = item.as_str().to_string();
333 }
334 Rule::optional_marker => {
335 modifier = if modifier == TypeModifier::List {
336 TypeModifier::OptionalList
337 } else {
338 TypeModifier::Optional
339 };
340 }
341 Rule::list_marker => {
342 modifier = if modifier == TypeModifier::Optional {
343 TypeModifier::OptionalList
344 } else {
345 TypeModifier::List
346 };
347 }
348 _ => {}
349 }
350 }
351
352 let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
353 FieldType::Scalar(scalar)
354 } else {
355 FieldType::Model(SmolStr::new(&type_name))
358 };
359
360 Ok((field_type, modifier))
361}
362
363fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
365 let span = pair.as_span();
366 let mut inner = pair.into_inner();
367
368 let name_pair = inner.next().unwrap();
369 let name = Ident::new(
370 name_pair.as_str(),
371 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
372 );
373
374 let mut args = vec![];
375 for item in inner {
376 if item.as_rule() == Rule::attribute_args {
377 args = parse_attribute_args(item)?;
378 }
379 }
380
381 Ok(Attribute::new(
382 name,
383 args,
384 Span::new(span.start(), span.end()),
385 ))
386}
387
388fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
390 let mut args = vec![];
391
392 for item in pair.into_inner() {
393 if item.as_rule() == Rule::attribute_arg {
394 let arg = parse_attribute_arg(item)?;
395 args.push(arg);
396 }
397 }
398
399 Ok(args)
400}
401
402fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
404 let span = pair.as_span();
405 let mut inner = pair.into_inner();
406
407 let first = inner.next().unwrap();
408
409 if let Some(second) = inner.next() {
411 let name = Ident::new(
413 first.as_str(),
414 Span::new(first.as_span().start(), first.as_span().end()),
415 );
416 let value = parse_attribute_value(second)?;
417 Ok(AttributeArg::named(
418 name,
419 value,
420 Span::new(span.start(), span.end()),
421 ))
422 } else {
423 let value = parse_attribute_value(first)?;
425 Ok(AttributeArg::positional(
426 value,
427 Span::new(span.start(), span.end()),
428 ))
429 }
430}
431
432fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
434 match pair.as_rule() {
435 Rule::string_literal => {
436 let s = pair.as_str();
437 let unquoted = &s[1..s.len() - 1];
439 Ok(AttributeValue::String(unquoted.to_string()))
440 }
441 Rule::number_literal => {
442 let s = pair.as_str();
443 if s.contains('.') {
444 Ok(AttributeValue::Float(s.parse().unwrap()))
445 } else {
446 Ok(AttributeValue::Int(s.parse().unwrap()))
447 }
448 }
449 Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
450 Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
451 Rule::function_call => {
452 let mut inner = pair.into_inner();
453 let name = SmolStr::new(inner.next().unwrap().as_str());
454 let mut args = vec![];
455 for item in inner {
456 args.push(parse_attribute_value(item)?);
457 }
458 Ok(AttributeValue::Function(name, args))
459 }
460 Rule::field_ref_list => {
461 let refs: Vec<SmolStr> = pair
462 .into_inner()
463 .map(|p| SmolStr::new(p.as_str()))
464 .collect();
465 Ok(AttributeValue::FieldRefList(refs))
466 }
467 Rule::array_literal => {
468 let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
469 Ok(AttributeValue::Array(values?))
470 }
471 Rule::attribute_value => {
472 parse_attribute_value(pair.into_inner().next().unwrap())
474 }
475 _ => {
476 Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
478 }
479 }
480}
481
482fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
484 let mut inner = pair.into_inner();
485
486 let name = inner.next().unwrap().as_str();
487 let sql = inner.next().unwrap().as_str();
488
489 let sql_content = sql
491 .trim_start_matches("\"\"\"")
492 .trim_end_matches("\"\"\"")
493 .trim();
494
495 Ok(RawSql::new(name, sql_content))
496}
497
498fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
500 let span = pair.as_span();
501 let mut inner = pair.into_inner();
502
503 let name_pair = inner.next().unwrap();
504 let name = Ident::new(
505 name_pair.as_str(),
506 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
507 );
508
509 let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
510
511 for item in inner {
512 match item.as_rule() {
513 Rule::server_group_item => {
514 let inner_item = item.into_inner().next().unwrap();
516 match inner_item.as_rule() {
517 Rule::server_def => {
518 let server = parse_server(inner_item)?;
519 server_group.add_server(server);
520 }
521 Rule::model_attribute => {
522 let attr = parse_attribute(inner_item)?;
523 server_group.add_attribute(attr);
524 }
525 _ => {}
526 }
527 }
528 Rule::server_def => {
529 let server = parse_server(item)?;
530 server_group.add_server(server);
531 }
532 Rule::model_attribute => {
533 let attr = parse_attribute(item)?;
534 server_group.add_attribute(attr);
535 }
536 _ => {}
537 }
538 }
539
540 Ok(server_group)
541}
542
543fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
545 let span = pair.as_span();
546 let mut inner = pair.into_inner();
547
548 let name_pair = inner.next().unwrap();
549 let name = Ident::new(
550 name_pair.as_str(),
551 Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
552 );
553
554 let mut server = Server::new(name, Span::new(span.start(), span.end()));
555
556 for item in inner {
557 if item.as_rule() == Rule::server_property {
558 let prop = parse_server_property(item)?;
559 server.add_property(prop);
560 }
561 }
562
563 Ok(server)
564}
565
566fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
568 let span = pair.as_span();
569 let mut inner = pair.into_inner();
570
571 let key_pair = inner.next().unwrap();
572 let key = key_pair.as_str();
573
574 let value_pair = inner.next().unwrap();
575 let value = parse_server_property_value(value_pair)?;
576
577 Ok(ServerProperty::new(key, value, Span::new(span.start(), span.end())))
578}
579
580fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
582 match pair.as_rule() {
583 Rule::string_literal => {
584 let s = pair.as_str();
585 s[1..s.len() - 1].to_string()
586 }
587 Rule::attribute_value => {
588 if let Some(inner) = pair.into_inner().next() {
590 extract_string_from_arg(inner)
591 } else {
592 String::new()
593 }
594 }
595 _ => pair.as_str().to_string(),
596 }
597}
598
599fn parse_server_property_value(
601 pair: pest::iterators::Pair<'_, Rule>,
602) -> SchemaResult<ServerPropertyValue> {
603 match pair.as_rule() {
604 Rule::string_literal => {
605 let s = pair.as_str();
606 let unquoted = &s[1..s.len() - 1];
608 Ok(ServerPropertyValue::String(unquoted.to_string()))
609 }
610 Rule::number_literal => {
611 let s = pair.as_str();
612 Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
613 }
614 Rule::boolean_literal => {
615 Ok(ServerPropertyValue::Boolean(pair.as_str() == "true"))
616 }
617 Rule::identifier => {
618 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
619 }
620 Rule::function_call => {
621 let mut inner = pair.into_inner();
623 let func_name = inner.next().unwrap().as_str();
624 if func_name == "env" {
625 if let Some(arg) = inner.next() {
626 let var_name = extract_string_from_arg(arg);
627 return Ok(ServerPropertyValue::EnvVar(var_name));
628 }
629 }
630 Ok(ServerPropertyValue::Identifier(func_name.to_string()))
632 }
633 Rule::array_literal => {
634 let values: Result<Vec<_>, _> = pair
635 .into_inner()
636 .map(parse_server_property_value)
637 .collect();
638 Ok(ServerPropertyValue::Array(values?))
639 }
640 Rule::attribute_value => {
641 parse_server_property_value(pair.into_inner().next().unwrap())
643 }
644 _ => {
645 Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
647 }
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
658 fn test_parse_simple_model() {
659 let schema = parse_schema(
660 r#"
661 model User {
662 id Int @id @auto
663 email String @unique
664 name String?
665 }
666 "#,
667 )
668 .unwrap();
669
670 assert_eq!(schema.models.len(), 1);
671 let user = schema.get_model("User").unwrap();
672 assert_eq!(user.fields.len(), 3);
673 assert!(user.get_field("id").unwrap().is_id());
674 assert!(user.get_field("email").unwrap().is_unique());
675 assert!(user.get_field("name").unwrap().is_optional());
676 }
677
678 #[test]
679 fn test_parse_model_name() {
680 let schema = parse_schema(
681 r#"
682 model BlogPost {
683 id Int @id
684 }
685 "#,
686 )
687 .unwrap();
688
689 assert!(schema.get_model("BlogPost").is_some());
690 }
691
692 #[test]
693 fn test_parse_multiple_models() {
694 let schema = parse_schema(
695 r#"
696 model User {
697 id Int @id
698 }
699
700 model Post {
701 id Int @id
702 }
703
704 model Comment {
705 id Int @id
706 }
707 "#,
708 )
709 .unwrap();
710
711 assert_eq!(schema.models.len(), 3);
712 assert!(schema.get_model("User").is_some());
713 assert!(schema.get_model("Post").is_some());
714 assert!(schema.get_model("Comment").is_some());
715 }
716
717 #[test]
720 fn test_parse_all_scalar_types() {
721 let schema = parse_schema(
722 r#"
723 model AllTypes {
724 id Int @id
725 big BigInt
726 float_f Float
727 decimal Decimal
728 str String
729 bool Boolean
730 datetime DateTime
731 date Date
732 time Time
733 json Json
734 bytes Bytes
735 uuid Uuid
736 cuid Cuid
737 cuid2 Cuid2
738 nanoid NanoId
739 ulid Ulid
740 }
741 "#,
742 )
743 .unwrap();
744
745 let model = schema.get_model("AllTypes").unwrap();
746 assert_eq!(model.fields.len(), 16);
747
748 assert!(matches!(
749 model.get_field("id").unwrap().field_type,
750 FieldType::Scalar(ScalarType::Int)
751 ));
752 assert!(matches!(
753 model.get_field("big").unwrap().field_type,
754 FieldType::Scalar(ScalarType::BigInt)
755 ));
756 assert!(matches!(
757 model.get_field("str").unwrap().field_type,
758 FieldType::Scalar(ScalarType::String)
759 ));
760 assert!(matches!(
761 model.get_field("bool").unwrap().field_type,
762 FieldType::Scalar(ScalarType::Boolean)
763 ));
764 assert!(matches!(
765 model.get_field("datetime").unwrap().field_type,
766 FieldType::Scalar(ScalarType::DateTime)
767 ));
768 assert!(matches!(
769 model.get_field("uuid").unwrap().field_type,
770 FieldType::Scalar(ScalarType::Uuid)
771 ));
772 assert!(matches!(
773 model.get_field("cuid").unwrap().field_type,
774 FieldType::Scalar(ScalarType::Cuid)
775 ));
776 assert!(matches!(
777 model.get_field("cuid2").unwrap().field_type,
778 FieldType::Scalar(ScalarType::Cuid2)
779 ));
780 assert!(matches!(
781 model.get_field("nanoid").unwrap().field_type,
782 FieldType::Scalar(ScalarType::NanoId)
783 ));
784 assert!(matches!(
785 model.get_field("ulid").unwrap().field_type,
786 FieldType::Scalar(ScalarType::Ulid)
787 ));
788 }
789
790 #[test]
791 fn test_parse_optional_field() {
792 let schema = parse_schema(
793 r#"
794 model User {
795 id Int @id
796 bio String?
797 age Int?
798 }
799 "#,
800 )
801 .unwrap();
802
803 let user = schema.get_model("User").unwrap();
804 assert!(!user.get_field("id").unwrap().is_optional());
805 assert!(user.get_field("bio").unwrap().is_optional());
806 assert!(user.get_field("age").unwrap().is_optional());
807 }
808
809 #[test]
810 fn test_parse_list_field() {
811 let schema = parse_schema(
812 r#"
813 model User {
814 id Int @id
815 tags String[]
816 posts Post[]
817 }
818 "#,
819 )
820 .unwrap();
821
822 let user = schema.get_model("User").unwrap();
823 assert!(user.get_field("tags").unwrap().is_list());
824 assert!(user.get_field("posts").unwrap().is_list());
825 }
826
827 #[test]
828 fn test_parse_optional_list_field() {
829 let schema = parse_schema(
830 r#"
831 model User {
832 id Int @id
833 metadata String[]?
834 }
835 "#,
836 )
837 .unwrap();
838
839 let user = schema.get_model("User").unwrap();
840 let metadata = user.get_field("metadata").unwrap();
841 assert!(metadata.is_list());
842 assert!(metadata.is_optional());
843 }
844
845 #[test]
848 fn test_parse_id_attribute() {
849 let schema = parse_schema(
850 r#"
851 model User {
852 id Int @id
853 }
854 "#,
855 )
856 .unwrap();
857
858 let user = schema.get_model("User").unwrap();
859 assert!(user.get_field("id").unwrap().is_id());
860 }
861
862 #[test]
863 fn test_parse_unique_attribute() {
864 let schema = parse_schema(
865 r#"
866 model User {
867 id Int @id
868 email String @unique
869 }
870 "#,
871 )
872 .unwrap();
873
874 let user = schema.get_model("User").unwrap();
875 assert!(user.get_field("email").unwrap().is_unique());
876 }
877
878 #[test]
879 fn test_parse_default_int() {
880 let schema = parse_schema(
881 r#"
882 model Counter {
883 id Int @id
884 count Int @default(0)
885 }
886 "#,
887 )
888 .unwrap();
889
890 let counter = schema.get_model("Counter").unwrap();
891 let count_field = counter.get_field("count").unwrap();
892 let attrs = count_field.extract_attributes();
893 assert!(attrs.default.is_some());
894 assert_eq!(attrs.default.unwrap().as_int(), Some(0));
895 }
896
897 #[test]
898 fn test_parse_default_string() {
899 let schema = parse_schema(
900 r#"
901 model User {
902 id Int @id
903 status String @default("active")
904 }
905 "#,
906 )
907 .unwrap();
908
909 let user = schema.get_model("User").unwrap();
910 let status = user.get_field("status").unwrap();
911 let attrs = status.extract_attributes();
912 assert!(attrs.default.is_some());
913 assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
914 }
915
916 #[test]
917 fn test_parse_default_boolean() {
918 let schema = parse_schema(
919 r#"
920 model Post {
921 id Int @id
922 published Boolean @default(false)
923 }
924 "#,
925 )
926 .unwrap();
927
928 let post = schema.get_model("Post").unwrap();
929 let published = post.get_field("published").unwrap();
930 let attrs = published.extract_attributes();
931 assert!(attrs.default.is_some());
932 assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
933 }
934
935 #[test]
936 fn test_parse_default_function() {
937 let schema = parse_schema(
938 r#"
939 model User {
940 id Int @id
941 createdAt DateTime @default(now())
942 }
943 "#,
944 )
945 .unwrap();
946
947 let user = schema.get_model("User").unwrap();
948 let created_at = user.get_field("createdAt").unwrap();
949 let attrs = created_at.extract_attributes();
950 assert!(attrs.default.is_some());
951 if let Some(AttributeValue::Function(name, _)) = attrs.default {
952 assert_eq!(name.as_str(), "now");
953 } else {
954 panic!("Expected function default");
955 }
956 }
957
958 #[test]
959 fn test_parse_updated_at_attribute() {
960 let schema = parse_schema(
961 r#"
962 model User {
963 id Int @id
964 updatedAt DateTime @updated_at
965 }
966 "#,
967 )
968 .unwrap();
969
970 let user = schema.get_model("User").unwrap();
971 let updated_at = user.get_field("updatedAt").unwrap();
972 let attrs = updated_at.extract_attributes();
973 assert!(attrs.is_updated_at);
974 }
975
976 #[test]
977 fn test_parse_map_attribute() {
978 let schema = parse_schema(
979 r#"
980 model User {
981 id Int @id
982 email String @map("email_address")
983 }
984 "#,
985 )
986 .unwrap();
987
988 let user = schema.get_model("User").unwrap();
989 let email = user.get_field("email").unwrap();
990 let attrs = email.extract_attributes();
991 assert_eq!(attrs.map, Some("email_address".to_string()));
992 }
993
994 #[test]
995 fn test_parse_multiple_attributes() {
996 let schema = parse_schema(
997 r#"
998 model User {
999 id Int @id @auto
1000 email String @unique @index
1001 }
1002 "#,
1003 )
1004 .unwrap();
1005
1006 let user = schema.get_model("User").unwrap();
1007 let id = user.get_field("id").unwrap();
1008 let email = user.get_field("email").unwrap();
1009
1010 let id_attrs = id.extract_attributes();
1011 assert!(id_attrs.is_id);
1012 assert!(id_attrs.is_auto);
1013
1014 let email_attrs = email.extract_attributes();
1015 assert!(email_attrs.is_unique);
1016 assert!(email_attrs.is_indexed);
1017 }
1018
1019 #[test]
1022 fn test_parse_model_map_attribute() {
1023 let schema = parse_schema(
1024 r#"
1025 model User {
1026 id Int @id
1027
1028 @@map("app_users")
1029 }
1030 "#,
1031 )
1032 .unwrap();
1033
1034 let user = schema.get_model("User").unwrap();
1035 assert_eq!(user.table_name(), "app_users");
1036 }
1037
1038 #[test]
1039 fn test_parse_model_index_attribute() {
1040 let schema = parse_schema(
1041 r#"
1042 model User {
1043 id Int @id
1044 email String
1045 name String
1046
1047 @@index([email, name])
1048 }
1049 "#,
1050 )
1051 .unwrap();
1052
1053 let user = schema.get_model("User").unwrap();
1054 assert!(user.has_attribute("index"));
1055 }
1056
1057 #[test]
1058 fn test_parse_composite_primary_key() {
1059 let schema = parse_schema(
1060 r#"
1061 model PostTag {
1062 postId Int
1063 tagId Int
1064
1065 @@id([postId, tagId])
1066 }
1067 "#,
1068 )
1069 .unwrap();
1070
1071 let post_tag = schema.get_model("PostTag").unwrap();
1072 assert!(post_tag.has_attribute("id"));
1073 }
1074
1075 #[test]
1078 fn test_parse_enum() {
1079 let schema = parse_schema(
1080 r#"
1081 enum Role {
1082 User
1083 Admin
1084 Moderator
1085 }
1086 "#,
1087 )
1088 .unwrap();
1089
1090 assert_eq!(schema.enums.len(), 1);
1091 let role = schema.get_enum("Role").unwrap();
1092 assert_eq!(role.variants.len(), 3);
1093 }
1094
1095 #[test]
1096 fn test_parse_enum_variant_names() {
1097 let schema = parse_schema(
1098 r#"
1099 enum Status {
1100 Pending
1101 Active
1102 Completed
1103 Cancelled
1104 }
1105 "#,
1106 )
1107 .unwrap();
1108
1109 let status = schema.get_enum("Status").unwrap();
1110 assert!(status.get_variant("Pending").is_some());
1111 assert!(status.get_variant("Active").is_some());
1112 assert!(status.get_variant("Completed").is_some());
1113 assert!(status.get_variant("Cancelled").is_some());
1114 }
1115
1116 #[test]
1117 fn test_parse_enum_with_map() {
1118 let schema = parse_schema(
1119 r#"
1120 enum Role {
1121 User @map("USER")
1122 Admin @map("ADMINISTRATOR")
1123 }
1124 "#,
1125 )
1126 .unwrap();
1127
1128 let role = schema.get_enum("Role").unwrap();
1129 let user_variant = role.get_variant("User").unwrap();
1130 assert_eq!(user_variant.db_value(), "USER");
1131
1132 let admin_variant = role.get_variant("Admin").unwrap();
1133 assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1134 }
1135
1136 #[test]
1139 fn test_parse_one_to_many_relation() {
1140 let schema = parse_schema(
1141 r#"
1142 model User {
1143 id Int @id
1144 posts Post[]
1145 }
1146
1147 model Post {
1148 id Int @id
1149 authorId Int
1150 author User @relation(fields: [authorId], references: [id])
1151 }
1152 "#,
1153 )
1154 .unwrap();
1155
1156 let user = schema.get_model("User").unwrap();
1157 let post = schema.get_model("Post").unwrap();
1158
1159 assert!(user.get_field("posts").unwrap().is_list());
1160 assert!(post.get_field("author").unwrap().is_relation());
1161 }
1162
1163 #[test]
1164 fn test_parse_relation_with_actions() {
1165 let schema = parse_schema(
1166 r#"
1167 model Post {
1168 id Int @id
1169 authorId Int
1170 author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1171 }
1172
1173 model User {
1174 id Int @id
1175 posts Post[]
1176 }
1177 "#,
1178 )
1179 .unwrap();
1180
1181 let post = schema.get_model("Post").unwrap();
1182 let author = post.get_field("author").unwrap();
1183 let attrs = author.extract_attributes();
1184
1185 assert!(attrs.relation.is_some());
1186 let rel = attrs.relation.unwrap();
1187 assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1188 assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1189 }
1190
1191 #[test]
1194 fn test_parse_model_documentation() {
1195 let schema = parse_schema(
1196 r#"/// Represents a user in the system
1197model User {
1198 id Int @id
1199}"#,
1200 )
1201 .unwrap();
1202
1203 let user = schema.get_model("User").unwrap();
1204 if let Some(doc) = &user.documentation {
1207 assert!(doc.text.contains("user"));
1208 }
1209 }
1210
1211 #[test]
1214 fn test_parse_complete_schema() {
1215 let schema = parse_schema(
1216 r#"
1217 /// User model
1218 model User {
1219 id Int @id @auto
1220 email String @unique
1221 name String?
1222 role Role @default(User)
1223 posts Post[]
1224 profile Profile?
1225 createdAt DateTime @default(now())
1226 updatedAt DateTime @updated_at
1227
1228 @@map("users")
1229 @@index([email])
1230 }
1231
1232 model Post {
1233 id Int @id @auto
1234 title String
1235 content String?
1236 published Boolean @default(false)
1237 authorId Int
1238 author User @relation(fields: [authorId], references: [id])
1239 tags Tag[]
1240 createdAt DateTime @default(now())
1241
1242 @@index([authorId])
1243 }
1244
1245 model Profile {
1246 id Int @id @auto
1247 bio String?
1248 userId Int @unique
1249 user User @relation(fields: [userId], references: [id])
1250 }
1251
1252 model Tag {
1253 id Int @id @auto
1254 name String @unique
1255 posts Post[]
1256 }
1257
1258 enum Role {
1259 User
1260 Admin
1261 Moderator
1262 }
1263 "#,
1264 )
1265 .unwrap();
1266
1267 assert_eq!(schema.models.len(), 4);
1269 assert!(schema.get_model("User").is_some());
1270 assert!(schema.get_model("Post").is_some());
1271 assert!(schema.get_model("Profile").is_some());
1272 assert!(schema.get_model("Tag").is_some());
1273
1274 assert_eq!(schema.enums.len(), 1);
1276 assert!(schema.get_enum("Role").is_some());
1277
1278 let user = schema.get_model("User").unwrap();
1280 assert_eq!(user.table_name(), "users");
1281 assert_eq!(user.fields.len(), 8);
1282 assert!(user.has_attribute("index"));
1283
1284 let post = schema.get_model("Post").unwrap();
1286 assert!(post.get_field("author").unwrap().is_relation());
1287 }
1288
1289 #[test]
1292 fn test_parse_invalid_syntax() {
1293 let result = parse_schema("model { broken }");
1294 assert!(result.is_err());
1295 }
1296
1297 #[test]
1298 fn test_parse_empty_schema() {
1299 let schema = parse_schema("").unwrap();
1300 assert!(schema.models.is_empty());
1301 assert!(schema.enums.is_empty());
1302 }
1303
1304 #[test]
1305 fn test_parse_whitespace_only() {
1306 let schema = parse_schema(" \n\t \n ").unwrap();
1307 assert!(schema.models.is_empty());
1308 }
1309
1310 #[test]
1311 fn test_parse_comments_only() {
1312 let schema = parse_schema(
1313 r#"
1314 // This is a comment
1315 // Another comment
1316 "#,
1317 )
1318 .unwrap();
1319 assert!(schema.models.is_empty());
1320 }
1321
1322 #[test]
1325 fn test_parse_model_with_no_fields() {
1326 let result = parse_schema(
1328 r#"
1329 model Empty {
1330 }
1331 "#,
1332 );
1333 let _ = result;
1335 }
1336
1337 #[test]
1338 fn test_parse_long_identifier() {
1339 let schema = parse_schema(
1340 r#"
1341 model VeryLongModelNameThatIsStillValid {
1342 someVeryLongFieldNameThatShouldWork Int @id
1343 }
1344 "#,
1345 )
1346 .unwrap();
1347
1348 assert!(
1349 schema
1350 .get_model("VeryLongModelNameThatIsStillValid")
1351 .is_some()
1352 );
1353 }
1354
1355 #[test]
1356 fn test_parse_underscore_identifiers() {
1357 let schema = parse_schema(
1358 r#"
1359 model user_account {
1360 user_id Int @id
1361 created_at DateTime
1362 }
1363 "#,
1364 )
1365 .unwrap();
1366
1367 let model = schema.get_model("user_account").unwrap();
1368 assert!(model.get_field("user_id").is_some());
1369 assert!(model.get_field("created_at").is_some());
1370 }
1371
1372 #[test]
1373 fn test_parse_negative_default() {
1374 let schema = parse_schema(
1375 r#"
1376 model Config {
1377 id Int @id
1378 minValue Int @default(-100)
1379 }
1380 "#,
1381 )
1382 .unwrap();
1383
1384 let config = schema.get_model("Config").unwrap();
1385 let min_value = config.get_field("minValue").unwrap();
1386 let attrs = min_value.extract_attributes();
1387 assert!(attrs.default.is_some());
1388 }
1389
1390 #[test]
1391 fn test_parse_float_default() {
1392 let schema = parse_schema(
1393 r#"
1394 model Product {
1395 id Int @id
1396 price Float @default(9.99)
1397 }
1398 "#,
1399 )
1400 .unwrap();
1401
1402 let product = schema.get_model("Product").unwrap();
1403 let price = product.get_field("price").unwrap();
1404 let attrs = price.extract_attributes();
1405 assert!(attrs.default.is_some());
1406 }
1407
1408 #[test]
1411 fn test_parse_simple_server_group() {
1412 let schema = parse_schema(
1413 r#"
1414 serverGroup MainCluster {
1415 server primary {
1416 url = "postgres://localhost/db"
1417 role = "primary"
1418 }
1419 }
1420 "#,
1421 )
1422 .unwrap();
1423
1424 assert_eq!(schema.server_groups.len(), 1);
1425 let cluster = schema.get_server_group("MainCluster").unwrap();
1426 assert_eq!(cluster.servers.len(), 1);
1427 assert!(cluster.servers.contains_key("primary"));
1428 }
1429
1430 #[test]
1431 fn test_parse_server_group_with_multiple_servers() {
1432 let schema = parse_schema(
1433 r#"
1434 serverGroup ReadReplicas {
1435 server primary {
1436 url = "postgres://primary.db.com/app"
1437 role = "primary"
1438 weight = 1
1439 }
1440
1441 server replica1 {
1442 url = "postgres://replica1.db.com/app"
1443 role = "replica"
1444 weight = 2
1445 }
1446
1447 server replica2 {
1448 url = "postgres://replica2.db.com/app"
1449 role = "replica"
1450 weight = 2
1451 }
1452 }
1453 "#,
1454 )
1455 .unwrap();
1456
1457 let cluster = schema.get_server_group("ReadReplicas").unwrap();
1458 assert_eq!(cluster.servers.len(), 3);
1459
1460 let primary = cluster.servers.get("primary").unwrap();
1461 assert_eq!(primary.role(), Some(ServerRole::Primary));
1462 assert_eq!(primary.weight(), Some(1));
1463
1464 let replica1 = cluster.servers.get("replica1").unwrap();
1465 assert_eq!(replica1.role(), Some(ServerRole::Replica));
1466 assert_eq!(replica1.weight(), Some(2));
1467 }
1468
1469 #[test]
1470 fn test_parse_server_group_with_attributes() {
1471 let schema = parse_schema(
1472 r#"
1473 serverGroup ProductionCluster {
1474 @@strategy(ReadReplica)
1475 @@loadBalance(RoundRobin)
1476
1477 server main {
1478 url = "postgres://main/db"
1479 role = "primary"
1480 }
1481 }
1482 "#,
1483 )
1484 .unwrap();
1485
1486 let cluster = schema.get_server_group("ProductionCluster").unwrap();
1487 assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1488 assert!(cluster.attributes.iter().any(|a| a.name.name == "loadBalance"));
1489 }
1490
1491 #[test]
1492 fn test_parse_server_group_with_env_vars() {
1493 let schema = parse_schema(
1494 r#"
1495 serverGroup EnvCluster {
1496 server db1 {
1497 url = env("PRIMARY_DB_URL")
1498 role = "primary"
1499 }
1500 }
1501 "#,
1502 )
1503 .unwrap();
1504
1505 let cluster = schema.get_server_group("EnvCluster").unwrap();
1506 let server = cluster.servers.get("db1").unwrap();
1507
1508 if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1510 assert_eq!(var, "PRIMARY_DB_URL");
1511 } else {
1512 panic!("Expected env var for url property");
1513 }
1514 }
1515
1516 #[test]
1517 fn test_parse_server_group_with_boolean_property() {
1518 let schema = parse_schema(
1519 r#"
1520 serverGroup TestCluster {
1521 server replica {
1522 url = "postgres://replica/db"
1523 role = "replica"
1524 readOnly = true
1525 }
1526 }
1527 "#,
1528 )
1529 .unwrap();
1530
1531 let cluster = schema.get_server_group("TestCluster").unwrap();
1532 let server = cluster.servers.get("replica").unwrap();
1533 assert!(server.is_read_only());
1534 }
1535
1536 #[test]
1537 fn test_parse_server_group_with_numeric_properties() {
1538 let schema = parse_schema(
1539 r#"
1540 serverGroup NumericCluster {
1541 server db {
1542 url = "postgres://localhost/db"
1543 weight = 5
1544 priority = 1
1545 maxConnections = 100
1546 }
1547 }
1548 "#,
1549 )
1550 .unwrap();
1551
1552 let cluster = schema.get_server_group("NumericCluster").unwrap();
1553 let server = cluster.servers.get("db").unwrap();
1554
1555 assert_eq!(server.weight(), Some(5));
1556 assert_eq!(server.priority(), Some(1));
1557 assert_eq!(server.max_connections(), Some(100));
1558 }
1559
1560 #[test]
1561 fn test_parse_server_group_with_region() {
1562 let schema = parse_schema(
1563 r#"
1564 serverGroup GeoCluster {
1565 server usEast {
1566 url = "postgres://us-east.db.com/app"
1567 role = "replica"
1568 region = "us-east-1"
1569 }
1570
1571 server usWest {
1572 url = "postgres://us-west.db.com/app"
1573 role = "replica"
1574 region = "us-west-2"
1575 }
1576 }
1577 "#,
1578 )
1579 .unwrap();
1580
1581 let cluster = schema.get_server_group("GeoCluster").unwrap();
1582
1583 let us_east = cluster.servers.get("usEast").unwrap();
1584 assert_eq!(us_east.region(), Some("us-east-1"));
1585
1586 let us_west = cluster.servers.get("usWest").unwrap();
1587 assert_eq!(us_west.region(), Some("us-west-2"));
1588
1589 let us_east_servers = cluster.servers_in_region("us-east-1");
1591 assert_eq!(us_east_servers.len(), 1);
1592 }
1593
1594 #[test]
1595 fn test_parse_multiple_server_groups() {
1596 let schema = parse_schema(
1597 r#"
1598 serverGroup Cluster1 {
1599 server db1 {
1600 url = "postgres://db1/app"
1601 }
1602 }
1603
1604 serverGroup Cluster2 {
1605 server db2 {
1606 url = "postgres://db2/app"
1607 }
1608 }
1609
1610 serverGroup Cluster3 {
1611 server db3 {
1612 url = "postgres://db3/app"
1613 }
1614 }
1615 "#,
1616 )
1617 .unwrap();
1618
1619 assert_eq!(schema.server_groups.len(), 3);
1620 assert!(schema.get_server_group("Cluster1").is_some());
1621 assert!(schema.get_server_group("Cluster2").is_some());
1622 assert!(schema.get_server_group("Cluster3").is_some());
1623 }
1624
1625 #[test]
1626 fn test_parse_schema_with_models_and_server_groups() {
1627 let schema = parse_schema(
1628 r#"
1629 model User {
1630 id Int @id @auto
1631 email String @unique
1632 }
1633
1634 serverGroup Database {
1635 @@strategy(ReadReplica)
1636
1637 server primary {
1638 url = env("DATABASE_URL")
1639 role = "primary"
1640 }
1641 }
1642
1643 model Post {
1644 id Int @id @auto
1645 title String
1646 authorId Int
1647 }
1648 "#,
1649 )
1650 .unwrap();
1651
1652 assert_eq!(schema.models.len(), 2);
1653 assert!(schema.get_model("User").is_some());
1654 assert!(schema.get_model("Post").is_some());
1655
1656 assert_eq!(schema.server_groups.len(), 1);
1657 assert!(schema.get_server_group("Database").is_some());
1658 }
1659
1660 #[test]
1661 fn test_parse_server_group_with_health_check() {
1662 let schema = parse_schema(
1663 r#"
1664 serverGroup HealthyCluster {
1665 server monitored {
1666 url = "postgres://localhost/db"
1667 healthCheck = "/health"
1668 }
1669 }
1670 "#,
1671 )
1672 .unwrap();
1673
1674 let cluster = schema.get_server_group("HealthyCluster").unwrap();
1675 let server = cluster.servers.get("monitored").unwrap();
1676 assert_eq!(server.health_check(), Some("/health"));
1677 }
1678
1679 #[test]
1680 fn test_server_group_failover_order() {
1681 let schema = parse_schema(
1682 r#"
1683 serverGroup FailoverCluster {
1684 server db3 {
1685 url = "postgres://db3/app"
1686 priority = 3
1687 }
1688
1689 server db1 {
1690 url = "postgres://db1/app"
1691 priority = 1
1692 }
1693
1694 server db2 {
1695 url = "postgres://db2/app"
1696 priority = 2
1697 }
1698 }
1699 "#,
1700 )
1701 .unwrap();
1702
1703 let cluster = schema.get_server_group("FailoverCluster").unwrap();
1704 let ordered = cluster.failover_order();
1705
1706 assert_eq!(ordered[0].name.name.as_str(), "db1");
1707 assert_eq!(ordered[1].name.name.as_str(), "db2");
1708 assert_eq!(ordered[2].name.name.as_str(), "db3");
1709 }
1710
1711 #[test]
1712 fn test_server_group_names() {
1713 let schema = parse_schema(
1714 r#"
1715 serverGroup Alpha {
1716 server s1 { url = "pg://a" }
1717 }
1718 serverGroup Beta {
1719 server s2 { url = "pg://b" }
1720 }
1721 "#,
1722 )
1723 .unwrap();
1724
1725 let names: Vec<_> = schema.server_group_names().collect();
1726 assert_eq!(names.len(), 2);
1727 assert!(names.contains(&"Alpha"));
1728 assert!(names.contains(&"Beta"));
1729 }
1730}