1use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12#[derive(Debug)]
14pub struct Validator {
15 errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Validator {
26 pub fn new() -> Self {
28 Self { errors: vec![] }
29 }
30
31 pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33 self.errors.clear();
34
35 self.check_duplicates(&schema);
37
38 self.resolve_field_types(&mut schema);
40
41 for model in schema.models.values() {
43 self.validate_model(model, &schema);
44 }
45
46 for e in schema.enums.values() {
48 self.validate_enum(e);
49 }
50
51 for t in schema.types.values() {
53 self.validate_composite_type(t, &schema);
54 }
55
56 for v in schema.views.values() {
58 self.validate_view(v, &schema);
59 }
60
61 for sg in schema.server_groups.values() {
63 self.validate_server_group(sg);
64 }
65
66 let relations = self.resolve_relations(&schema);
68 schema.relations = relations;
69
70 if self.errors.is_empty() {
71 Ok(schema)
72 } else {
73 Err(SchemaError::ValidationFailed {
74 count: self.errors.len(),
75 errors: std::mem::take(&mut self.errors),
76 })
77 }
78 }
79
80 fn check_duplicates(&mut self, schema: &Schema) {
82 let mut seen = std::collections::HashSet::new();
83
84 for name in schema.models.keys() {
85 if !seen.insert(name.as_str()) {
86 self.errors
87 .push(SchemaError::duplicate("model", name.as_str()));
88 }
89 }
90
91 for name in schema.enums.keys() {
92 if !seen.insert(name.as_str()) {
93 self.errors
94 .push(SchemaError::duplicate("enum", name.as_str()));
95 }
96 }
97
98 for name in schema.types.keys() {
99 if !seen.insert(name.as_str()) {
100 self.errors
101 .push(SchemaError::duplicate("type", name.as_str()));
102 }
103 }
104
105 for name in schema.views.keys() {
106 if !seen.insert(name.as_str()) {
107 self.errors
108 .push(SchemaError::duplicate("view", name.as_str()));
109 }
110 }
111
112 let mut server_group_names = std::collections::HashSet::new();
114 for name in schema.server_groups.keys() {
115 if !server_group_names.insert(name.as_str()) {
116 self.errors
117 .push(SchemaError::duplicate("serverGroup", name.as_str()));
118 }
119 }
120 }
121
122 fn resolve_field_types(&self, schema: &mut Schema) {
127 let enum_names: std::collections::HashSet<String> =
129 schema.enums.keys().map(|s| s.to_string()).collect();
130 let composite_names: std::collections::HashSet<String> =
131 schema.types.keys().map(|s| s.to_string()).collect();
132
133 for model in schema.models.values_mut() {
135 for field in model.fields.values_mut() {
136 if let FieldType::Model(ref type_name) = field.field_type {
137 let name = type_name.as_str();
138 if enum_names.contains(name) {
139 field.field_type = FieldType::Enum(type_name.clone());
140 } else if composite_names.contains(name) {
141 field.field_type = FieldType::Composite(type_name.clone());
142 }
143 }
144 }
145 }
146
147 for composite in schema.types.values_mut() {
149 for field in composite.fields.values_mut() {
150 if let FieldType::Model(ref type_name) = field.field_type {
151 let name = type_name.as_str();
152 if enum_names.contains(name) {
153 field.field_type = FieldType::Enum(type_name.clone());
154 } else if composite_names.contains(name) {
155 field.field_type = FieldType::Composite(type_name.clone());
156 }
157 }
158 }
159 }
160
161 for view in schema.views.values_mut() {
163 for field in view.fields.values_mut() {
164 if let FieldType::Model(ref type_name) = field.field_type {
165 let name = type_name.as_str();
166 if enum_names.contains(name) {
167 field.field_type = FieldType::Enum(type_name.clone());
168 } else if composite_names.contains(name) {
169 field.field_type = FieldType::Composite(type_name.clone());
170 }
171 }
172 }
173 }
174 }
175
176 fn validate_model(&mut self, model: &Model, schema: &Schema) {
178 let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180 if id_fields.is_empty() && !self.has_composite_id(model) {
181 self.errors.push(SchemaError::MissingId {
182 model: model.name().to_string(),
183 });
184 }
185
186 for field in model.fields.values() {
188 self.validate_field(field, model.name(), schema);
189 }
190
191 for attr in &model.attributes {
193 self.validate_model_attribute(attr, model);
194 }
195 }
196
197 fn has_composite_id(&self, model: &Model) -> bool {
199 model.attributes.iter().any(|a| a.is("id"))
200 }
201
202 fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204 match &field.field_type {
206 FieldType::Model(name) => {
207 if schema.models.contains_key(name.as_str()) {
209 } else if schema.enums.contains_key(name.as_str()) {
211 } else if schema.types.contains_key(name.as_str()) {
214 } else {
216 self.errors.push(SchemaError::unknown_type(
217 model_name,
218 field.name(),
219 name.as_str(),
220 ));
221 }
222 }
223 FieldType::Enum(name) => {
224 if !schema.enums.contains_key(name.as_str()) {
225 self.errors.push(SchemaError::unknown_type(
226 model_name,
227 field.name(),
228 name.as_str(),
229 ));
230 }
231 }
232 FieldType::Composite(name) => {
233 if !schema.types.contains_key(name.as_str()) {
234 self.errors.push(SchemaError::unknown_type(
235 model_name,
236 field.name(),
237 name.as_str(),
238 ));
239 }
240 }
241 _ => {}
242 }
243
244 for attr in &field.attributes {
246 self.validate_field_attribute(attr, field, model_name, schema);
247 }
248
249 if let FieldType::Model(ref target_name) = field.field_type {
252 let is_actual_relation = schema.models.contains_key(target_name.as_str())
254 && !schema.enums.contains_key(target_name.as_str())
255 && !schema.types.contains_key(target_name.as_str());
256
257 if is_actual_relation && !field.is_list() {
258 let attrs = field.extract_attributes();
260 if attrs.relation.is_some() {
261 let rel = attrs.relation.as_ref().unwrap();
262 for fk_field in &rel.fields {
264 if !schema
265 .models
266 .get(model_name)
267 .map(|m| m.fields.contains_key(fk_field.as_str()))
268 .unwrap_or(false)
269 {
270 self.errors.push(SchemaError::invalid_relation(
271 model_name,
272 field.name(),
273 format!("foreign key field '{}' does not exist", fk_field),
274 ));
275 }
276 }
277 }
278 }
279 }
280 }
281
282 fn validate_field_attribute(
284 &mut self,
285 attr: &Attribute,
286 field: &Field,
287 model_name: &str,
288 schema: &Schema,
289 ) {
290 match attr.name() {
291 "id" => {
292 if field.field_type.is_relation() {
294 self.errors.push(SchemaError::InvalidAttribute {
295 attribute: "id".to_string(),
296 message: format!(
297 "@id cannot be applied to relation field '{}.{}'",
298 model_name,
299 field.name()
300 ),
301 });
302 }
303 }
304 "auto" => {
305 if !matches!(
307 field.field_type,
308 FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
309 ) {
310 self.errors.push(SchemaError::InvalidAttribute {
311 attribute: "auto".to_string(),
312 message: format!(
313 "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
314 model_name,
315 field.name()
316 ),
317 });
318 }
319 }
320 "default" => {
321 if let Some(value) = attr.first_arg() {
323 self.validate_default_value(value, field, model_name, schema);
324 }
325 }
326 "relation" => {
327 let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
329 if schema.models.contains_key(name.as_str()));
330 if !is_model_ref {
331 self.errors.push(SchemaError::InvalidAttribute {
332 attribute: "relation".to_string(),
333 message: format!(
334 "@relation can only be applied to model reference fields, not '{}.{}'",
335 model_name,
336 field.name()
337 ),
338 });
339 }
340 }
341 "updated_at" => {
342 if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
344 self.errors.push(SchemaError::InvalidAttribute {
345 attribute: "updated_at".to_string(),
346 message: format!(
347 "@updated_at can only be applied to DateTime fields, not '{}.{}'",
348 model_name,
349 field.name()
350 ),
351 });
352 }
353 }
354 _ => {}
355 }
356 }
357
358 fn validate_default_value(
360 &mut self,
361 value: &AttributeValue,
362 field: &Field,
363 model_name: &str,
364 schema: &Schema,
365 ) {
366 match (&field.field_type, value) {
367 (_, AttributeValue::Function(_, _)) => {}
369
370 (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
372 (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
373
374 (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
376 (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
377 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
378 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
379
380 (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
382
383 (FieldType::Scalar(ScalarType::Json), AttributeValue::String(_))
388 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Array(_))
389 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Boolean(_))
390 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Int(_))
391 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Float(_)) => {}
392
393 (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
395
396 (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
398 if let Some(e) = schema.enums.get(enum_name.as_str()) {
399 if e.get_variant(variant).is_none() {
400 self.errors.push(SchemaError::invalid_field(
401 model_name,
402 field.name(),
403 format!(
404 "default value '{}' is not a valid variant of enum '{}'",
405 variant, enum_name
406 ),
407 ));
408 }
409 }
410 }
411
412 (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
414 if let Some(e) = schema.enums.get(type_name.as_str()) {
416 if e.get_variant(variant).is_none() {
417 self.errors.push(SchemaError::invalid_field(
418 model_name,
419 field.name(),
420 format!(
421 "default value '{}' is not a valid variant of enum '{}'",
422 variant, type_name
423 ),
424 ));
425 }
426 }
427 }
430
431 _ => {
433 self.errors.push(SchemaError::invalid_field(
434 model_name,
435 field.name(),
436 format!(
437 "default value type does not match field type '{}'",
438 field.field_type
439 ),
440 ));
441 }
442 }
443 }
444
445 fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
447 match attr.name() {
448 "index" | "unique" => {
449 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
451 for field_name in fields {
452 if !model.fields.contains_key(field_name.as_str()) {
453 self.errors.push(SchemaError::invalid_model(
454 model.name(),
455 format!(
456 "@@{} references non-existent field '{}'",
457 attr.name(),
458 field_name
459 ),
460 ));
461 }
462 }
463 }
464 }
465 "id" => {
466 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
468 for field_name in fields {
469 if !model.fields.contains_key(field_name.as_str()) {
470 self.errors.push(SchemaError::invalid_model(
471 model.name(),
472 format!("@@id references non-existent field '{}'", field_name),
473 ));
474 }
475 }
476 }
477 }
478 "search" => {
479 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
481 for field_name in fields {
482 if let Some(field) = model.fields.get(field_name.as_str()) {
483 if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
485 self.errors.push(SchemaError::invalid_model(
486 model.name(),
487 format!(
488 "@@search field '{}' must be of type String",
489 field_name
490 ),
491 ));
492 }
493 } else {
494 self.errors.push(SchemaError::invalid_model(
495 model.name(),
496 format!("@@search references non-existent field '{}'", field_name),
497 ));
498 }
499 }
500 }
501 }
502 _ => {}
503 }
504 }
505
506 fn validate_enum(&mut self, e: &Enum) {
508 if e.variants.is_empty() {
509 self.errors.push(SchemaError::invalid_model(
510 e.name(),
511 "enum must have at least one variant".to_string(),
512 ));
513 }
514
515 let mut seen = std::collections::HashSet::new();
517 for variant in &e.variants {
518 if !seen.insert(variant.name()) {
519 self.errors.push(SchemaError::duplicate(
520 format!("enum variant in {}", e.name()),
521 variant.name(),
522 ));
523 }
524 }
525 }
526
527 fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
529 if t.fields.is_empty() {
530 self.errors.push(SchemaError::invalid_model(
531 t.name(),
532 "composite type must have at least one field".to_string(),
533 ));
534 }
535
536 for field in t.fields.values() {
538 match &field.field_type {
539 FieldType::Model(_) => {
540 self.errors.push(SchemaError::invalid_field(
541 t.name(),
542 field.name(),
543 "composite types cannot have model relations".to_string(),
544 ));
545 }
546 FieldType::Enum(name) => {
547 if !schema.enums.contains_key(name.as_str()) {
548 self.errors.push(SchemaError::unknown_type(
549 t.name(),
550 field.name(),
551 name.as_str(),
552 ));
553 }
554 }
555 FieldType::Composite(name) => {
556 if !schema.types.contains_key(name.as_str()) {
557 self.errors.push(SchemaError::unknown_type(
558 t.name(),
559 field.name(),
560 name.as_str(),
561 ));
562 }
563 }
564 _ => {}
565 }
566 }
567 }
568
569 fn validate_view(&mut self, v: &View, schema: &Schema) {
571 if v.fields.is_empty() {
573 self.errors.push(SchemaError::invalid_model(
574 v.name(),
575 "view must have at least one field".to_string(),
576 ));
577 }
578
579 for field in v.fields.values() {
581 self.validate_field(field, v.name(), schema);
582 }
583 }
584
585 fn validate_server_group(&mut self, sg: &ServerGroup) {
587 if sg.servers.is_empty() {
589 self.errors.push(SchemaError::invalid_model(
590 sg.name.name.as_str(),
591 "serverGroup must have at least one server".to_string(),
592 ));
593 }
594
595 let mut seen_servers = std::collections::HashSet::new();
597 for server_name in sg.servers.keys() {
598 if !seen_servers.insert(server_name.as_str()) {
599 self.errors.push(SchemaError::duplicate(
600 format!("server in serverGroup {}", sg.name.name),
601 server_name.as_str(),
602 ));
603 }
604 }
605
606 for server in sg.servers.values() {
608 self.validate_server(server, sg.name.name.as_str());
609 }
610
611 for attr in &sg.attributes {
613 self.validate_server_group_attribute(attr, sg);
614 }
615
616 if let Some(strategy) = sg.strategy() {
618 if strategy == ServerGroupStrategy::ReadReplica {
619 let has_primary = sg
620 .servers
621 .values()
622 .any(|s| s.role() == Some(ServerRole::Primary));
623 if !has_primary {
624 self.errors.push(SchemaError::invalid_model(
625 sg.name.name.as_str(),
626 "ReadReplica strategy requires at least one server with role = \"primary\""
627 .to_string(),
628 ));
629 }
630 }
631 }
632 }
633
634 fn validate_server(&mut self, server: &Server, group_name: &str) {
636 if server.url().is_none() {
638 self.errors.push(SchemaError::invalid_model(
639 group_name,
640 format!("server '{}' must have a 'url' property", server.name.name),
641 ));
642 }
643
644 if let Some(weight) = server.weight() {
646 if weight == 0 {
647 self.errors.push(SchemaError::invalid_model(
648 group_name,
649 format!(
650 "server '{}' weight must be greater than 0",
651 server.name.name
652 ),
653 ));
654 }
655 }
656
657 if let Some(priority) = server.priority() {
659 if priority == 0 {
660 self.errors.push(SchemaError::invalid_model(
661 group_name,
662 format!(
663 "server '{}' priority must be greater than 0",
664 server.name.name
665 ),
666 ));
667 }
668 }
669 }
670
671 fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
673 match attr.name() {
674 "strategy" => {
675 if let Some(arg) = attr.first_arg() {
677 let value_str = arg
678 .as_string()
679 .map(|s| s.to_string())
680 .or_else(|| arg.as_ident().map(|s| s.to_string()));
681 if let Some(val) = value_str {
682 if ServerGroupStrategy::parse(&val).is_none() {
683 self.errors.push(SchemaError::InvalidAttribute {
684 attribute: "strategy".to_string(),
685 message: format!(
686 "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
687 val,
688 sg.name.name
689 ),
690 });
691 }
692 }
693 }
694 }
695 "loadBalance" => {
696 if let Some(arg) = attr.first_arg() {
698 let value_str = arg
699 .as_string()
700 .map(|s| s.to_string())
701 .or_else(|| arg.as_ident().map(|s| s.to_string()));
702 if let Some(val) = value_str {
703 if LoadBalanceStrategy::parse(&val).is_none() {
704 self.errors.push(SchemaError::InvalidAttribute {
705 attribute: "loadBalance".to_string(),
706 message: format!(
707 "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
708 val,
709 sg.name.name
710 ),
711 });
712 }
713 }
714 }
715 }
716 _ => {} }
718 }
719
720 fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
722 let mut relations = Vec::new();
723
724 for model in schema.models.values() {
725 for field in model.fields.values() {
726 if let FieldType::Model(ref target_model) = field.field_type {
727 if schema.enums.contains_key(target_model.as_str()) {
729 continue;
730 }
731
732 if schema.types.contains_key(target_model.as_str()) {
734 continue;
735 }
736
737 if !schema.models.contains_key(target_model.as_str()) {
739 continue;
740 }
741
742 let attrs = field.extract_attributes();
743
744 let relation_type = if field.is_list() {
745 RelationType::OneToMany
747 } else {
748 RelationType::ManyToOne
750 };
751
752 let mut relation = Relation::new(
753 model.name(),
754 field.name(),
755 target_model.as_str(),
756 relation_type,
757 );
758
759 if let Some(rel_attr) = &attrs.relation {
760 if let Some(name) = &rel_attr.name {
761 relation = relation.with_name(name.as_str());
762 }
763 if !rel_attr.fields.is_empty() {
764 relation = relation.with_from_fields(rel_attr.fields.clone());
765 }
766 if !rel_attr.references.is_empty() {
767 relation = relation.with_to_fields(rel_attr.references.clone());
768 }
769 if let Some(action) = rel_attr.on_delete {
770 relation = relation.with_on_delete(action);
771 }
772 if let Some(action) = rel_attr.on_update {
773 relation = relation.with_on_update(action);
774 }
775 if let Some(map) = &rel_attr.map {
776 relation = relation.with_map(map.as_str());
777 }
778 }
779
780 relations.push(relation);
781 }
782 }
783 }
784
785 relations
786 }
787}
788
789pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
791 let schema = crate::parser::parse_schema(input)?;
792 let mut validator = Validator::new();
793 validator.validate(schema)
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn test_validate_simple_model() {
802 let schema = validate_schema(
803 r#"
804 model User {
805 id Int @id @auto
806 email String @unique
807 }
808 "#,
809 )
810 .unwrap();
811
812 assert_eq!(schema.models.len(), 1);
813 }
814
815 #[test]
816 fn test_validate_model_missing_id() {
817 let result = validate_schema(
818 r#"
819 model User {
820 email String
821 name String
822 }
823 "#,
824 );
825
826 assert!(result.is_err());
827 let err = result.unwrap_err();
828 assert!(matches!(err, SchemaError::ValidationFailed { .. }));
829 }
830
831 #[test]
832 fn test_validate_model_with_composite_id() {
833 let schema = validate_schema(
834 r#"
835 model PostTag {
836 post_id Int
837 tag_id Int
838
839 @@id([post_id, tag_id])
840 }
841 "#,
842 )
843 .unwrap();
844
845 assert_eq!(schema.models.len(), 1);
846 }
847
848 #[test]
849 fn test_validate_unknown_type_reference() {
850 let result = validate_schema(
851 r#"
852 model User {
853 id Int @id @auto
854 profile UnknownType
855 }
856 "#,
857 );
858
859 assert!(result.is_err());
860 }
861
862 #[test]
863 fn test_validate_enum_reference() {
864 let schema = validate_schema(
865 r#"
866 enum Role {
867 User
868 Admin
869 }
870
871 model User {
872 id Int @id @auto
873 role Role @default(User)
874 }
875 "#,
876 )
877 .unwrap();
878
879 assert_eq!(schema.models.len(), 1);
880 assert_eq!(schema.enums.len(), 1);
881 }
882
883 #[test]
884 fn test_validate_invalid_enum_default() {
885 let result = validate_schema(
886 r#"
887 enum Role {
888 User
889 Admin
890 }
891
892 model User {
893 id Int @id @auto
894 role Role @default(Unknown)
895 }
896 "#,
897 );
898
899 assert!(result.is_err());
900 }
901
902 #[test]
903 fn test_validate_auto_on_non_int() {
904 let result = validate_schema(
905 r#"
906 model User {
907 id String @id @auto
908 email String
909 }
910 "#,
911 );
912
913 assert!(result.is_err());
914 }
915
916 #[test]
917 fn test_validate_updated_at_on_non_datetime() {
918 let result = validate_schema(
919 r#"
920 model User {
921 id Int @id @auto
922 updated_at String @updated_at
923 }
924 "#,
925 );
926
927 assert!(result.is_err());
928 }
929
930 #[test]
931 fn test_validate_empty_enum() {
932 let result = validate_schema(
933 r#"
934 enum Empty {
935 }
936
937 model User {
938 id Int @id @auto
939 }
940 "#,
941 );
942
943 assert!(result.is_err());
944 }
945
946 #[test]
947 fn test_validate_duplicate_model_names() {
948 let result = validate_schema(
949 r#"
950 model User {
951 id Int @id @auto
952 }
953
954 model User {
955 id Int @id @auto
956 }
957 "#,
958 );
959
960 assert!(result.is_ok() || result.is_err());
963 }
964
965 #[test]
966 fn test_validate_relation() {
967 let schema = validate_schema(
968 r#"
969 model User {
970 id Int @id @auto
971 posts Post[]
972 }
973
974 model Post {
975 id Int @id @auto
976 author_id Int
977 author User @relation(fields: [author_id], references: [id])
978 }
979 "#,
980 )
981 .unwrap();
982
983 assert_eq!(schema.models.len(), 2);
984 assert!(!schema.relations.is_empty());
985 }
986
987 #[test]
988 fn test_validate_index_with_invalid_field() {
989 let result = validate_schema(
990 r#"
991 model User {
992 id Int @id @auto
993 email String
994
995 @@index([nonexistent])
996 }
997 "#,
998 );
999
1000 assert!(result.is_err());
1001 }
1002
1003 #[test]
1004 fn test_validate_search_on_non_string_field() {
1005 let result = validate_schema(
1006 r#"
1007 model Post {
1008 id Int @id @auto
1009 views Int
1010
1011 @@search([views])
1012 }
1013 "#,
1014 );
1015
1016 assert!(result.is_err());
1017 }
1018
1019 #[test]
1020 fn test_validate_composite_type() {
1021 let schema = validate_schema(
1022 r#"
1023 type Address {
1024 street String
1025 city String
1026 country String @default("US")
1027 }
1028
1029 model User {
1030 id Int @id @auto
1031 address Address
1032 }
1033 "#,
1034 );
1035
1036 assert!(schema.is_ok() || schema.is_err());
1038 }
1039
1040 #[test]
1043 fn test_validate_server_group_basic() {
1044 let schema = validate_schema(
1045 r#"
1046 model User {
1047 id Int @id @auto
1048 }
1049
1050 serverGroup MainCluster {
1051 server primary {
1052 url = "postgres://localhost/db"
1053 role = "primary"
1054 }
1055 }
1056 "#,
1057 )
1058 .unwrap();
1059
1060 assert_eq!(schema.server_groups.len(), 1);
1061 }
1062
1063 #[test]
1064 fn test_validate_server_group_empty_servers() {
1065 let result = validate_schema(
1066 r#"
1067 model User {
1068 id Int @id @auto
1069 }
1070
1071 serverGroup EmptyCluster {
1072 }
1073 "#,
1074 );
1075
1076 assert!(result.is_err());
1077 }
1078
1079 #[test]
1080 fn test_validate_server_group_missing_url() {
1081 let result = validate_schema(
1082 r#"
1083 model User {
1084 id Int @id @auto
1085 }
1086
1087 serverGroup Cluster {
1088 server db {
1089 role = "primary"
1090 }
1091 }
1092 "#,
1093 );
1094
1095 assert!(result.is_err());
1096 }
1097
1098 #[test]
1099 fn test_validate_server_group_invalid_strategy() {
1100 let result = validate_schema(
1101 r#"
1102 model User {
1103 id Int @id @auto
1104 }
1105
1106 serverGroup Cluster {
1107 @@strategy(InvalidStrategy)
1108
1109 server db {
1110 url = "postgres://localhost/db"
1111 }
1112 }
1113 "#,
1114 );
1115
1116 assert!(result.is_err());
1117 }
1118
1119 #[test]
1120 fn test_validate_server_group_valid_strategy() {
1121 let schema = validate_schema(
1122 r#"
1123 model User {
1124 id Int @id @auto
1125 }
1126
1127 serverGroup Cluster {
1128 @@strategy(ReadReplica)
1129 @@loadBalance(RoundRobin)
1130
1131 server primary {
1132 url = "postgres://localhost/db"
1133 role = "primary"
1134 }
1135 }
1136 "#,
1137 )
1138 .unwrap();
1139
1140 assert_eq!(schema.server_groups.len(), 1);
1141 }
1142
1143 #[test]
1144 fn test_validate_server_group_read_replica_needs_primary() {
1145 let result = validate_schema(
1146 r#"
1147 model User {
1148 id Int @id @auto
1149 }
1150
1151 serverGroup Cluster {
1152 @@strategy(ReadReplica)
1153
1154 server replica1 {
1155 url = "postgres://localhost/db"
1156 role = "replica"
1157 }
1158 }
1159 "#,
1160 );
1161
1162 assert!(result.is_err());
1163 }
1164
1165 #[test]
1166 fn test_validate_server_group_with_replicas() {
1167 let schema = validate_schema(
1168 r#"
1169 model User {
1170 id Int @id @auto
1171 }
1172
1173 serverGroup Cluster {
1174 @@strategy(ReadReplica)
1175
1176 server primary {
1177 url = "postgres://primary/db"
1178 role = "primary"
1179 weight = 1
1180 }
1181
1182 server replica1 {
1183 url = "postgres://replica1/db"
1184 role = "replica"
1185 weight = 2
1186 }
1187
1188 server replica2 {
1189 url = "postgres://replica2/db"
1190 role = "replica"
1191 weight = 2
1192 region = "us-west-1"
1193 }
1194 }
1195 "#,
1196 )
1197 .unwrap();
1198
1199 let cluster = schema.get_server_group("Cluster").unwrap();
1200 assert_eq!(cluster.servers.len(), 3);
1201 }
1202
1203 #[test]
1204 fn test_validate_server_group_zero_weight() {
1205 let result = validate_schema(
1206 r#"
1207 model User {
1208 id Int @id @auto
1209 }
1210
1211 serverGroup Cluster {
1212 server db {
1213 url = "postgres://localhost/db"
1214 weight = 0
1215 }
1216 }
1217 "#,
1218 );
1219
1220 assert!(result.is_err());
1221 }
1222
1223 #[test]
1224 fn test_validate_server_group_invalid_load_balance() {
1225 let result = validate_schema(
1226 r#"
1227 model User {
1228 id Int @id @auto
1229 }
1230
1231 serverGroup Cluster {
1232 @@loadBalance(InvalidStrategy)
1233
1234 server db {
1235 url = "postgres://localhost/db"
1236 }
1237 }
1238 "#,
1239 );
1240
1241 assert!(result.is_err());
1242 }
1243}