1use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12#[derive(Debug)]
14pub struct Validator {
15 errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Validator {
26 pub fn new() -> Self {
28 Self { errors: vec![] }
29 }
30
31 pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33 self.errors.clear();
34
35 self.check_duplicates(&schema);
37
38 self.resolve_field_types(&mut schema);
40
41 for model in schema.models.values() {
43 self.validate_model(model, &schema);
44 }
45
46 for e in schema.enums.values() {
48 self.validate_enum(e);
49 }
50
51 for t in schema.types.values() {
53 self.validate_composite_type(t, &schema);
54 }
55
56 for v in schema.views.values() {
58 self.validate_view(v, &schema);
59 }
60
61 for sg in schema.server_groups.values() {
63 self.validate_server_group(sg);
64 }
65
66 let relations = self.resolve_relations(&schema);
68 schema.relations = relations;
69
70 if self.errors.is_empty() {
71 Ok(schema)
72 } else {
73 Err(SchemaError::ValidationFailed {
74 count: self.errors.len(),
75 errors: std::mem::take(&mut self.errors),
76 })
77 }
78 }
79
80 fn check_duplicates(&mut self, schema: &Schema) {
82 let mut seen = std::collections::HashSet::new();
83
84 for name in schema.models.keys() {
85 if !seen.insert(name.as_str()) {
86 self.errors
87 .push(SchemaError::duplicate("model", name.as_str()));
88 }
89 }
90
91 for name in schema.enums.keys() {
92 if !seen.insert(name.as_str()) {
93 self.errors
94 .push(SchemaError::duplicate("enum", name.as_str()));
95 }
96 }
97
98 for name in schema.types.keys() {
99 if !seen.insert(name.as_str()) {
100 self.errors
101 .push(SchemaError::duplicate("type", name.as_str()));
102 }
103 }
104
105 for name in schema.views.keys() {
106 if !seen.insert(name.as_str()) {
107 self.errors
108 .push(SchemaError::duplicate("view", name.as_str()));
109 }
110 }
111
112 let mut server_group_names = std::collections::HashSet::new();
114 for name in schema.server_groups.keys() {
115 if !server_group_names.insert(name.as_str()) {
116 self.errors
117 .push(SchemaError::duplicate("serverGroup", name.as_str()));
118 }
119 }
120 }
121
122 fn resolve_field_types(&self, schema: &mut Schema) {
127 let enum_names: std::collections::HashSet<String> =
129 schema.enums.keys().map(|s| s.to_string()).collect();
130 let composite_names: std::collections::HashSet<String> =
131 schema.types.keys().map(|s| s.to_string()).collect();
132
133 for model in schema.models.values_mut() {
135 for field in model.fields.values_mut() {
136 if let FieldType::Model(ref type_name) = field.field_type {
137 let name = type_name.as_str();
138 if enum_names.contains(name) {
139 field.field_type = FieldType::Enum(type_name.clone());
140 } else if composite_names.contains(name) {
141 field.field_type = FieldType::Composite(type_name.clone());
142 }
143 }
144 }
145 }
146
147 for composite in schema.types.values_mut() {
149 for field in composite.fields.values_mut() {
150 if let FieldType::Model(ref type_name) = field.field_type {
151 let name = type_name.as_str();
152 if enum_names.contains(name) {
153 field.field_type = FieldType::Enum(type_name.clone());
154 } else if composite_names.contains(name) {
155 field.field_type = FieldType::Composite(type_name.clone());
156 }
157 }
158 }
159 }
160
161 for view in schema.views.values_mut() {
163 for field in view.fields.values_mut() {
164 if let FieldType::Model(ref type_name) = field.field_type {
165 let name = type_name.as_str();
166 if enum_names.contains(name) {
167 field.field_type = FieldType::Enum(type_name.clone());
168 } else if composite_names.contains(name) {
169 field.field_type = FieldType::Composite(type_name.clone());
170 }
171 }
172 }
173 }
174 }
175
176 fn validate_model(&mut self, model: &Model, schema: &Schema) {
178 let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180 if id_fields.is_empty() && !self.has_composite_id(model) {
181 self.errors.push(SchemaError::MissingId {
182 model: model.name().to_string(),
183 });
184 }
185
186 for field in model.fields.values() {
188 self.validate_field(field, model.name(), schema);
189 }
190
191 for attr in &model.attributes {
193 self.validate_model_attribute(attr, model);
194 }
195 }
196
197 fn has_composite_id(&self, model: &Model) -> bool {
199 model.attributes.iter().any(|a| a.is("id"))
200 }
201
202 fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204 match &field.field_type {
206 FieldType::Model(name) => {
207 if schema.models.contains_key(name.as_str()) {
209 } else if schema.enums.contains_key(name.as_str()) {
211 } else if schema.types.contains_key(name.as_str()) {
214 } else {
216 self.errors.push(SchemaError::unknown_type(
217 model_name,
218 field.name(),
219 name.as_str(),
220 ));
221 }
222 }
223 FieldType::Enum(name) => {
224 if !schema.enums.contains_key(name.as_str()) {
225 self.errors.push(SchemaError::unknown_type(
226 model_name,
227 field.name(),
228 name.as_str(),
229 ));
230 }
231 }
232 FieldType::Composite(name) => {
233 if !schema.types.contains_key(name.as_str()) {
234 self.errors.push(SchemaError::unknown_type(
235 model_name,
236 field.name(),
237 name.as_str(),
238 ));
239 }
240 }
241 _ => {}
242 }
243
244 for attr in &field.attributes {
246 self.validate_field_attribute(attr, field, model_name, schema);
247 }
248
249 if let FieldType::Model(ref target_name) = field.field_type {
252 let is_actual_relation = schema.models.contains_key(target_name.as_str())
254 && !schema.enums.contains_key(target_name.as_str())
255 && !schema.types.contains_key(target_name.as_str());
256
257 if is_actual_relation && !field.is_list() {
258 let attrs = field.extract_attributes();
260 if attrs.relation.is_some() {
261 let rel = attrs.relation.as_ref().unwrap();
262 for fk_field in &rel.fields {
264 if !schema
265 .models
266 .get(model_name)
267 .map(|m| m.fields.contains_key(fk_field.as_str()))
268 .unwrap_or(false)
269 {
270 self.errors.push(SchemaError::invalid_relation(
271 model_name,
272 field.name(),
273 format!("foreign key field '{}' does not exist", fk_field),
274 ));
275 }
276 }
277 }
278 }
279 }
280 }
281
282 fn validate_field_attribute(
284 &mut self,
285 attr: &Attribute,
286 field: &Field,
287 model_name: &str,
288 schema: &Schema,
289 ) {
290 match attr.name() {
291 "id" => {
292 if field.field_type.is_relation() {
294 self.errors.push(SchemaError::InvalidAttribute {
295 attribute: "id".to_string(),
296 message: format!(
297 "@id cannot be applied to relation field '{}.{}'",
298 model_name,
299 field.name()
300 ),
301 });
302 }
303 }
304 "auto" => {
305 if !matches!(
307 field.field_type,
308 FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
309 ) {
310 self.errors.push(SchemaError::InvalidAttribute {
311 attribute: "auto".to_string(),
312 message: format!(
313 "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
314 model_name,
315 field.name()
316 ),
317 });
318 }
319 }
320 "default" => {
321 if let Some(value) = attr.first_arg() {
323 self.validate_default_value(value, field, model_name, schema);
324 }
325 }
326 "relation" => {
327 let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
329 if schema.models.contains_key(name.as_str()));
330 if !is_model_ref {
331 self.errors.push(SchemaError::InvalidAttribute {
332 attribute: "relation".to_string(),
333 message: format!(
334 "@relation can only be applied to model reference fields, not '{}.{}'",
335 model_name,
336 field.name()
337 ),
338 });
339 }
340 }
341 "updated_at" => {
342 if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
344 self.errors.push(SchemaError::InvalidAttribute {
345 attribute: "updated_at".to_string(),
346 message: format!(
347 "@updated_at can only be applied to DateTime fields, not '{}.{}'",
348 model_name,
349 field.name()
350 ),
351 });
352 }
353 }
354 _ => {}
355 }
356 }
357
358 fn validate_default_value(
360 &mut self,
361 value: &AttributeValue,
362 field: &Field,
363 model_name: &str,
364 schema: &Schema,
365 ) {
366 match (&field.field_type, value) {
367 (_, AttributeValue::Function(_, _)) => {}
369
370 (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
372 (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
373
374 (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
376 (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
377 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
378 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
379
380 (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
382
383 (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
385
386 (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
388 if let Some(e) = schema.enums.get(enum_name.as_str()) {
389 if e.get_variant(variant).is_none() {
390 self.errors.push(SchemaError::invalid_field(
391 model_name,
392 field.name(),
393 format!(
394 "default value '{}' is not a valid variant of enum '{}'",
395 variant, enum_name
396 ),
397 ));
398 }
399 }
400 }
401
402 (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
404 if let Some(e) = schema.enums.get(type_name.as_str()) {
406 if e.get_variant(variant).is_none() {
407 self.errors.push(SchemaError::invalid_field(
408 model_name,
409 field.name(),
410 format!(
411 "default value '{}' is not a valid variant of enum '{}'",
412 variant, type_name
413 ),
414 ));
415 }
416 }
417 }
420
421 _ => {
423 self.errors.push(SchemaError::invalid_field(
424 model_name,
425 field.name(),
426 format!(
427 "default value type does not match field type '{}'",
428 field.field_type
429 ),
430 ));
431 }
432 }
433 }
434
435 fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
437 match attr.name() {
438 "index" | "unique" => {
439 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
441 for field_name in fields {
442 if !model.fields.contains_key(field_name.as_str()) {
443 self.errors.push(SchemaError::invalid_model(
444 model.name(),
445 format!(
446 "@@{} references non-existent field '{}'",
447 attr.name(),
448 field_name
449 ),
450 ));
451 }
452 }
453 }
454 }
455 "id" => {
456 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
458 for field_name in fields {
459 if !model.fields.contains_key(field_name.as_str()) {
460 self.errors.push(SchemaError::invalid_model(
461 model.name(),
462 format!("@@id references non-existent field '{}'", field_name),
463 ));
464 }
465 }
466 }
467 }
468 "search" => {
469 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
471 for field_name in fields {
472 if let Some(field) = model.fields.get(field_name.as_str()) {
473 if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
475 self.errors.push(SchemaError::invalid_model(
476 model.name(),
477 format!(
478 "@@search field '{}' must be of type String",
479 field_name
480 ),
481 ));
482 }
483 } else {
484 self.errors.push(SchemaError::invalid_model(
485 model.name(),
486 format!("@@search references non-existent field '{}'", field_name),
487 ));
488 }
489 }
490 }
491 }
492 _ => {}
493 }
494 }
495
496 fn validate_enum(&mut self, e: &Enum) {
498 if e.variants.is_empty() {
499 self.errors.push(SchemaError::invalid_model(
500 e.name(),
501 "enum must have at least one variant".to_string(),
502 ));
503 }
504
505 let mut seen = std::collections::HashSet::new();
507 for variant in &e.variants {
508 if !seen.insert(variant.name()) {
509 self.errors.push(SchemaError::duplicate(
510 format!("enum variant in {}", e.name()),
511 variant.name(),
512 ));
513 }
514 }
515 }
516
517 fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
519 if t.fields.is_empty() {
520 self.errors.push(SchemaError::invalid_model(
521 t.name(),
522 "composite type must have at least one field".to_string(),
523 ));
524 }
525
526 for field in t.fields.values() {
528 match &field.field_type {
529 FieldType::Model(_) => {
530 self.errors.push(SchemaError::invalid_field(
531 t.name(),
532 field.name(),
533 "composite types cannot have model relations".to_string(),
534 ));
535 }
536 FieldType::Enum(name) => {
537 if !schema.enums.contains_key(name.as_str()) {
538 self.errors.push(SchemaError::unknown_type(
539 t.name(),
540 field.name(),
541 name.as_str(),
542 ));
543 }
544 }
545 FieldType::Composite(name) => {
546 if !schema.types.contains_key(name.as_str()) {
547 self.errors.push(SchemaError::unknown_type(
548 t.name(),
549 field.name(),
550 name.as_str(),
551 ));
552 }
553 }
554 _ => {}
555 }
556 }
557 }
558
559 fn validate_view(&mut self, v: &View, schema: &Schema) {
561 if v.fields.is_empty() {
563 self.errors.push(SchemaError::invalid_model(
564 v.name(),
565 "view must have at least one field".to_string(),
566 ));
567 }
568
569 for field in v.fields.values() {
571 self.validate_field(field, v.name(), schema);
572 }
573 }
574
575 fn validate_server_group(&mut self, sg: &ServerGroup) {
577 if sg.servers.is_empty() {
579 self.errors.push(SchemaError::invalid_model(
580 sg.name.name.as_str(),
581 "serverGroup must have at least one server".to_string(),
582 ));
583 }
584
585 let mut seen_servers = std::collections::HashSet::new();
587 for server_name in sg.servers.keys() {
588 if !seen_servers.insert(server_name.as_str()) {
589 self.errors.push(SchemaError::duplicate(
590 format!("server in serverGroup {}", sg.name.name),
591 server_name.as_str(),
592 ));
593 }
594 }
595
596 for server in sg.servers.values() {
598 self.validate_server(server, sg.name.name.as_str());
599 }
600
601 for attr in &sg.attributes {
603 self.validate_server_group_attribute(attr, sg);
604 }
605
606 if let Some(strategy) = sg.strategy() {
608 if strategy == ServerGroupStrategy::ReadReplica {
609 let has_primary = sg
610 .servers
611 .values()
612 .any(|s| s.role() == Some(ServerRole::Primary));
613 if !has_primary {
614 self.errors.push(SchemaError::invalid_model(
615 sg.name.name.as_str(),
616 "ReadReplica strategy requires at least one server with role = \"primary\""
617 .to_string(),
618 ));
619 }
620 }
621 }
622 }
623
624 fn validate_server(&mut self, server: &Server, group_name: &str) {
626 if server.url().is_none() {
628 self.errors.push(SchemaError::invalid_model(
629 group_name,
630 format!("server '{}' must have a 'url' property", server.name.name),
631 ));
632 }
633
634 if let Some(weight) = server.weight() {
636 if weight == 0 {
637 self.errors.push(SchemaError::invalid_model(
638 group_name,
639 format!(
640 "server '{}' weight must be greater than 0",
641 server.name.name
642 ),
643 ));
644 }
645 }
646
647 if let Some(priority) = server.priority() {
649 if priority == 0 {
650 self.errors.push(SchemaError::invalid_model(
651 group_name,
652 format!(
653 "server '{}' priority must be greater than 0",
654 server.name.name
655 ),
656 ));
657 }
658 }
659 }
660
661 fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
663 match attr.name() {
664 "strategy" => {
665 if let Some(arg) = attr.first_arg() {
667 let value_str = arg
668 .as_string()
669 .map(|s| s.to_string())
670 .or_else(|| arg.as_ident().map(|s| s.to_string()));
671 if let Some(val) = value_str {
672 if ServerGroupStrategy::parse(&val).is_none() {
673 self.errors.push(SchemaError::InvalidAttribute {
674 attribute: "strategy".to_string(),
675 message: format!(
676 "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
677 val,
678 sg.name.name
679 ),
680 });
681 }
682 }
683 }
684 }
685 "loadBalance" => {
686 if let Some(arg) = attr.first_arg() {
688 let value_str = arg
689 .as_string()
690 .map(|s| s.to_string())
691 .or_else(|| arg.as_ident().map(|s| s.to_string()));
692 if let Some(val) = value_str {
693 if LoadBalanceStrategy::parse(&val).is_none() {
694 self.errors.push(SchemaError::InvalidAttribute {
695 attribute: "loadBalance".to_string(),
696 message: format!(
697 "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
698 val,
699 sg.name.name
700 ),
701 });
702 }
703 }
704 }
705 }
706 _ => {} }
708 }
709
710 fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
712 let mut relations = Vec::new();
713
714 for model in schema.models.values() {
715 for field in model.fields.values() {
716 if let FieldType::Model(ref target_model) = field.field_type {
717 if schema.enums.contains_key(target_model.as_str()) {
719 continue;
720 }
721
722 if schema.types.contains_key(target_model.as_str()) {
724 continue;
725 }
726
727 if !schema.models.contains_key(target_model.as_str()) {
729 continue;
730 }
731
732 let attrs = field.extract_attributes();
733
734 let relation_type = if field.is_list() {
735 RelationType::OneToMany
737 } else {
738 RelationType::ManyToOne
740 };
741
742 let mut relation = Relation::new(
743 model.name(),
744 field.name(),
745 target_model.as_str(),
746 relation_type,
747 );
748
749 if let Some(rel_attr) = &attrs.relation {
750 if let Some(name) = &rel_attr.name {
751 relation = relation.with_name(name.as_str());
752 }
753 if !rel_attr.fields.is_empty() {
754 relation = relation.with_from_fields(rel_attr.fields.clone());
755 }
756 if !rel_attr.references.is_empty() {
757 relation = relation.with_to_fields(rel_attr.references.clone());
758 }
759 if let Some(action) = rel_attr.on_delete {
760 relation = relation.with_on_delete(action);
761 }
762 if let Some(action) = rel_attr.on_update {
763 relation = relation.with_on_update(action);
764 }
765 if let Some(map) = &rel_attr.map {
766 relation = relation.with_map(map.as_str());
767 }
768 }
769
770 relations.push(relation);
771 }
772 }
773 }
774
775 relations
776 }
777}
778
779pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
781 let schema = crate::parser::parse_schema(input)?;
782 let mut validator = Validator::new();
783 validator.validate(schema)
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789
790 #[test]
791 fn test_validate_simple_model() {
792 let schema = validate_schema(
793 r#"
794 model User {
795 id Int @id @auto
796 email String @unique
797 }
798 "#,
799 )
800 .unwrap();
801
802 assert_eq!(schema.models.len(), 1);
803 }
804
805 #[test]
806 fn test_validate_model_missing_id() {
807 let result = validate_schema(
808 r#"
809 model User {
810 email String
811 name String
812 }
813 "#,
814 );
815
816 assert!(result.is_err());
817 let err = result.unwrap_err();
818 assert!(matches!(err, SchemaError::ValidationFailed { .. }));
819 }
820
821 #[test]
822 fn test_validate_model_with_composite_id() {
823 let schema = validate_schema(
824 r#"
825 model PostTag {
826 post_id Int
827 tag_id Int
828
829 @@id([post_id, tag_id])
830 }
831 "#,
832 )
833 .unwrap();
834
835 assert_eq!(schema.models.len(), 1);
836 }
837
838 #[test]
839 fn test_validate_unknown_type_reference() {
840 let result = validate_schema(
841 r#"
842 model User {
843 id Int @id @auto
844 profile UnknownType
845 }
846 "#,
847 );
848
849 assert!(result.is_err());
850 }
851
852 #[test]
853 fn test_validate_enum_reference() {
854 let schema = validate_schema(
855 r#"
856 enum Role {
857 User
858 Admin
859 }
860
861 model User {
862 id Int @id @auto
863 role Role @default(User)
864 }
865 "#,
866 )
867 .unwrap();
868
869 assert_eq!(schema.models.len(), 1);
870 assert_eq!(schema.enums.len(), 1);
871 }
872
873 #[test]
874 fn test_validate_invalid_enum_default() {
875 let result = validate_schema(
876 r#"
877 enum Role {
878 User
879 Admin
880 }
881
882 model User {
883 id Int @id @auto
884 role Role @default(Unknown)
885 }
886 "#,
887 );
888
889 assert!(result.is_err());
890 }
891
892 #[test]
893 fn test_validate_auto_on_non_int() {
894 let result = validate_schema(
895 r#"
896 model User {
897 id String @id @auto
898 email String
899 }
900 "#,
901 );
902
903 assert!(result.is_err());
904 }
905
906 #[test]
907 fn test_validate_updated_at_on_non_datetime() {
908 let result = validate_schema(
909 r#"
910 model User {
911 id Int @id @auto
912 updated_at String @updated_at
913 }
914 "#,
915 );
916
917 assert!(result.is_err());
918 }
919
920 #[test]
921 fn test_validate_empty_enum() {
922 let result = validate_schema(
923 r#"
924 enum Empty {
925 }
926
927 model User {
928 id Int @id @auto
929 }
930 "#,
931 );
932
933 assert!(result.is_err());
934 }
935
936 #[test]
937 fn test_validate_duplicate_model_names() {
938 let result = validate_schema(
939 r#"
940 model User {
941 id Int @id @auto
942 }
943
944 model User {
945 id Int @id @auto
946 }
947 "#,
948 );
949
950 assert!(result.is_ok() || result.is_err());
953 }
954
955 #[test]
956 fn test_validate_relation() {
957 let schema = validate_schema(
958 r#"
959 model User {
960 id Int @id @auto
961 posts Post[]
962 }
963
964 model Post {
965 id Int @id @auto
966 author_id Int
967 author User @relation(fields: [author_id], references: [id])
968 }
969 "#,
970 )
971 .unwrap();
972
973 assert_eq!(schema.models.len(), 2);
974 assert!(!schema.relations.is_empty());
975 }
976
977 #[test]
978 fn test_validate_index_with_invalid_field() {
979 let result = validate_schema(
980 r#"
981 model User {
982 id Int @id @auto
983 email String
984
985 @@index([nonexistent])
986 }
987 "#,
988 );
989
990 assert!(result.is_err());
991 }
992
993 #[test]
994 fn test_validate_search_on_non_string_field() {
995 let result = validate_schema(
996 r#"
997 model Post {
998 id Int @id @auto
999 views Int
1000
1001 @@search([views])
1002 }
1003 "#,
1004 );
1005
1006 assert!(result.is_err());
1007 }
1008
1009 #[test]
1010 fn test_validate_composite_type() {
1011 let schema = validate_schema(
1012 r#"
1013 type Address {
1014 street String
1015 city String
1016 country String @default("US")
1017 }
1018
1019 model User {
1020 id Int @id @auto
1021 address Address
1022 }
1023 "#,
1024 );
1025
1026 assert!(schema.is_ok() || schema.is_err());
1028 }
1029
1030 #[test]
1033 fn test_validate_server_group_basic() {
1034 let schema = validate_schema(
1035 r#"
1036 model User {
1037 id Int @id @auto
1038 }
1039
1040 serverGroup MainCluster {
1041 server primary {
1042 url = "postgres://localhost/db"
1043 role = "primary"
1044 }
1045 }
1046 "#,
1047 )
1048 .unwrap();
1049
1050 assert_eq!(schema.server_groups.len(), 1);
1051 }
1052
1053 #[test]
1054 fn test_validate_server_group_empty_servers() {
1055 let result = validate_schema(
1056 r#"
1057 model User {
1058 id Int @id @auto
1059 }
1060
1061 serverGroup EmptyCluster {
1062 }
1063 "#,
1064 );
1065
1066 assert!(result.is_err());
1067 }
1068
1069 #[test]
1070 fn test_validate_server_group_missing_url() {
1071 let result = validate_schema(
1072 r#"
1073 model User {
1074 id Int @id @auto
1075 }
1076
1077 serverGroup Cluster {
1078 server db {
1079 role = "primary"
1080 }
1081 }
1082 "#,
1083 );
1084
1085 assert!(result.is_err());
1086 }
1087
1088 #[test]
1089 fn test_validate_server_group_invalid_strategy() {
1090 let result = validate_schema(
1091 r#"
1092 model User {
1093 id Int @id @auto
1094 }
1095
1096 serverGroup Cluster {
1097 @@strategy(InvalidStrategy)
1098
1099 server db {
1100 url = "postgres://localhost/db"
1101 }
1102 }
1103 "#,
1104 );
1105
1106 assert!(result.is_err());
1107 }
1108
1109 #[test]
1110 fn test_validate_server_group_valid_strategy() {
1111 let schema = validate_schema(
1112 r#"
1113 model User {
1114 id Int @id @auto
1115 }
1116
1117 serverGroup Cluster {
1118 @@strategy(ReadReplica)
1119 @@loadBalance(RoundRobin)
1120
1121 server primary {
1122 url = "postgres://localhost/db"
1123 role = "primary"
1124 }
1125 }
1126 "#,
1127 )
1128 .unwrap();
1129
1130 assert_eq!(schema.server_groups.len(), 1);
1131 }
1132
1133 #[test]
1134 fn test_validate_server_group_read_replica_needs_primary() {
1135 let result = validate_schema(
1136 r#"
1137 model User {
1138 id Int @id @auto
1139 }
1140
1141 serverGroup Cluster {
1142 @@strategy(ReadReplica)
1143
1144 server replica1 {
1145 url = "postgres://localhost/db"
1146 role = "replica"
1147 }
1148 }
1149 "#,
1150 );
1151
1152 assert!(result.is_err());
1153 }
1154
1155 #[test]
1156 fn test_validate_server_group_with_replicas() {
1157 let schema = validate_schema(
1158 r#"
1159 model User {
1160 id Int @id @auto
1161 }
1162
1163 serverGroup Cluster {
1164 @@strategy(ReadReplica)
1165
1166 server primary {
1167 url = "postgres://primary/db"
1168 role = "primary"
1169 weight = 1
1170 }
1171
1172 server replica1 {
1173 url = "postgres://replica1/db"
1174 role = "replica"
1175 weight = 2
1176 }
1177
1178 server replica2 {
1179 url = "postgres://replica2/db"
1180 role = "replica"
1181 weight = 2
1182 region = "us-west-1"
1183 }
1184 }
1185 "#,
1186 )
1187 .unwrap();
1188
1189 let cluster = schema.get_server_group("Cluster").unwrap();
1190 assert_eq!(cluster.servers.len(), 3);
1191 }
1192
1193 #[test]
1194 fn test_validate_server_group_zero_weight() {
1195 let result = validate_schema(
1196 r#"
1197 model User {
1198 id Int @id @auto
1199 }
1200
1201 serverGroup Cluster {
1202 server db {
1203 url = "postgres://localhost/db"
1204 weight = 0
1205 }
1206 }
1207 "#,
1208 );
1209
1210 assert!(result.is_err());
1211 }
1212
1213 #[test]
1214 fn test_validate_server_group_invalid_load_balance() {
1215 let result = validate_schema(
1216 r#"
1217 model User {
1218 id Int @id @auto
1219 }
1220
1221 serverGroup Cluster {
1222 @@loadBalance(InvalidStrategy)
1223
1224 server db {
1225 url = "postgres://localhost/db"
1226 }
1227 }
1228 "#,
1229 );
1230
1231 assert!(result.is_err());
1232 }
1233}