1use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12#[derive(Debug)]
14pub struct Validator {
15 errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Validator {
26 pub fn new() -> Self {
28 Self { errors: vec![] }
29 }
30
31 pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33 self.errors.clear();
34
35 self.check_duplicates(&schema);
37
38 self.resolve_field_types(&mut schema);
40
41 for model in schema.models.values() {
43 self.validate_model(model, &schema);
44 }
45
46 for e in schema.enums.values() {
48 self.validate_enum(e);
49 }
50
51 for t in schema.types.values() {
53 self.validate_composite_type(t, &schema);
54 }
55
56 for v in schema.views.values() {
58 self.validate_view(v, &schema);
59 }
60
61 for sg in schema.server_groups.values() {
63 self.validate_server_group(sg);
64 }
65
66 let relations = self.resolve_relations(&schema);
68 schema.relations = relations;
69
70 if self.errors.is_empty() {
71 Ok(schema)
72 } else {
73 Err(SchemaError::ValidationFailed {
74 count: self.errors.len(),
75 errors: std::mem::take(&mut self.errors),
76 })
77 }
78 }
79
80 fn check_duplicates(&mut self, schema: &Schema) {
82 let mut seen = std::collections::HashSet::new();
83
84 for name in schema.models.keys() {
85 if !seen.insert(name.as_str()) {
86 self.errors
87 .push(SchemaError::duplicate("model", name.as_str()));
88 }
89 }
90
91 for name in schema.enums.keys() {
92 if !seen.insert(name.as_str()) {
93 self.errors
94 .push(SchemaError::duplicate("enum", name.as_str()));
95 }
96 }
97
98 for name in schema.types.keys() {
99 if !seen.insert(name.as_str()) {
100 self.errors
101 .push(SchemaError::duplicate("type", name.as_str()));
102 }
103 }
104
105 for name in schema.views.keys() {
106 if !seen.insert(name.as_str()) {
107 self.errors
108 .push(SchemaError::duplicate("view", name.as_str()));
109 }
110 }
111
112 let mut server_group_names = std::collections::HashSet::new();
114 for name in schema.server_groups.keys() {
115 if !server_group_names.insert(name.as_str()) {
116 self.errors
117 .push(SchemaError::duplicate("serverGroup", name.as_str()));
118 }
119 }
120 }
121
122 fn resolve_field_types(&self, schema: &mut Schema) {
127 let enum_names: std::collections::HashSet<String> =
129 schema.enums.keys().map(|s| s.to_string()).collect();
130 let composite_names: std::collections::HashSet<String> =
131 schema.types.keys().map(|s| s.to_string()).collect();
132
133 for model in schema.models.values_mut() {
135 for field in model.fields.values_mut() {
136 if let FieldType::Model(ref type_name) = field.field_type {
137 let name = type_name.as_str();
138 if enum_names.contains(name) {
139 field.field_type = FieldType::Enum(type_name.clone());
140 } else if composite_names.contains(name) {
141 field.field_type = FieldType::Composite(type_name.clone());
142 }
143 }
144 }
145 }
146
147 for composite in schema.types.values_mut() {
149 for field in composite.fields.values_mut() {
150 if let FieldType::Model(ref type_name) = field.field_type {
151 let name = type_name.as_str();
152 if enum_names.contains(name) {
153 field.field_type = FieldType::Enum(type_name.clone());
154 } else if composite_names.contains(name) {
155 field.field_type = FieldType::Composite(type_name.clone());
156 }
157 }
158 }
159 }
160
161 for view in schema.views.values_mut() {
163 for field in view.fields.values_mut() {
164 if let FieldType::Model(ref type_name) = field.field_type {
165 let name = type_name.as_str();
166 if enum_names.contains(name) {
167 field.field_type = FieldType::Enum(type_name.clone());
168 } else if composite_names.contains(name) {
169 field.field_type = FieldType::Composite(type_name.clone());
170 }
171 }
172 }
173 }
174 }
175
176 fn validate_model(&mut self, model: &Model, schema: &Schema) {
178 let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180 if id_fields.is_empty() && !self.has_composite_id(model) {
181 self.errors.push(SchemaError::MissingId {
182 model: model.name().to_string(),
183 });
184 }
185
186 for field in model.fields.values() {
188 self.validate_field(field, model.name(), schema);
189 }
190
191 for attr in &model.attributes {
193 self.validate_model_attribute(attr, model);
194 }
195 }
196
197 fn has_composite_id(&self, model: &Model) -> bool {
199 model.attributes.iter().any(|a| a.is("id"))
200 }
201
202 fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204 match &field.field_type {
206 FieldType::Model(name) => {
207 if schema.models.contains_key(name.as_str()) {
209 } else if schema.enums.contains_key(name.as_str()) {
211 } else if schema.types.contains_key(name.as_str()) {
214 } else {
216 self.errors.push(SchemaError::unknown_type(
217 model_name,
218 field.name(),
219 name.as_str(),
220 ));
221 }
222 }
223 FieldType::Enum(name) => {
224 if !schema.enums.contains_key(name.as_str()) {
225 self.errors.push(SchemaError::unknown_type(
226 model_name,
227 field.name(),
228 name.as_str(),
229 ));
230 }
231 }
232 FieldType::Composite(name) => {
233 if !schema.types.contains_key(name.as_str()) {
234 self.errors.push(SchemaError::unknown_type(
235 model_name,
236 field.name(),
237 name.as_str(),
238 ));
239 }
240 }
241 _ => {}
242 }
243
244 for attr in &field.attributes {
246 self.validate_field_attribute(attr, field, model_name, schema);
247 }
248
249 if let FieldType::Model(ref target_name) = field.field_type {
252 let is_actual_relation = schema.models.contains_key(target_name.as_str())
254 && !schema.enums.contains_key(target_name.as_str())
255 && !schema.types.contains_key(target_name.as_str());
256
257 if is_actual_relation && !field.is_list() {
258 let attrs = field.extract_attributes();
260 if let Some(rel) = attrs.relation.as_ref() {
261 for fk_field in &rel.fields {
263 if !schema
264 .models
265 .get(model_name)
266 .map(|m| m.fields.contains_key(fk_field.as_str()))
267 .unwrap_or(false)
268 {
269 self.errors.push(SchemaError::invalid_relation(
270 model_name,
271 field.name(),
272 format!("foreign key field '{}' does not exist", fk_field),
273 ));
274 }
275 }
276 }
277 }
278 }
279 }
280
281 fn validate_field_attribute(
283 &mut self,
284 attr: &Attribute,
285 field: &Field,
286 model_name: &str,
287 schema: &Schema,
288 ) {
289 match attr.name() {
290 "id" => {
291 if field.field_type.is_relation() {
293 self.errors.push(SchemaError::InvalidAttribute {
294 attribute: "id".to_string(),
295 message: format!(
296 "@id cannot be applied to relation field '{}.{}'",
297 model_name,
298 field.name()
299 ),
300 });
301 }
302 }
303 "auto" => {
304 if !matches!(
306 field.field_type,
307 FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
308 ) {
309 self.errors.push(SchemaError::InvalidAttribute {
310 attribute: "auto".to_string(),
311 message: format!(
312 "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
313 model_name,
314 field.name()
315 ),
316 });
317 }
318 }
319 "default" => {
320 if let Some(value) = attr.first_arg() {
322 self.validate_default_value(value, field, model_name, schema);
323 }
324 }
325 "relation" => {
326 let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
328 if schema.models.contains_key(name.as_str()));
329 if !is_model_ref {
330 self.errors.push(SchemaError::InvalidAttribute {
331 attribute: "relation".to_string(),
332 message: format!(
333 "@relation can only be applied to model reference fields, not '{}.{}'",
334 model_name,
335 field.name()
336 ),
337 });
338 }
339 }
340 "updated_at" => {
341 if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
343 self.errors.push(SchemaError::InvalidAttribute {
344 attribute: "updated_at".to_string(),
345 message: format!(
346 "@updated_at can only be applied to DateTime fields, not '{}.{}'",
347 model_name,
348 field.name()
349 ),
350 });
351 }
352 }
353 _ => {}
354 }
355 }
356
357 fn validate_default_value(
359 &mut self,
360 value: &AttributeValue,
361 field: &Field,
362 model_name: &str,
363 schema: &Schema,
364 ) {
365 match (&field.field_type, value) {
366 (_, AttributeValue::Function(_, _)) => {}
368
369 (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
371 (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
372
373 (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
375 (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
376 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
377 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
378
379 (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
381
382 (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
384
385 (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
387 if let Some(e) = schema.enums.get(enum_name.as_str())
388 && e.get_variant(variant).is_none()
389 {
390 self.errors.push(SchemaError::invalid_field(
391 model_name,
392 field.name(),
393 format!(
394 "default value '{}' is not a valid variant of enum '{}'",
395 variant, enum_name
396 ),
397 ));
398 }
399 }
400
401 (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
403 if let Some(e) = schema.enums.get(type_name.as_str())
405 && e.get_variant(variant).is_none()
406 {
407 self.errors.push(SchemaError::invalid_field(
408 model_name,
409 field.name(),
410 format!(
411 "default value '{}' is not a valid variant of enum '{}'",
412 variant, type_name
413 ),
414 ));
415 }
416 }
419
420 _ => {
422 self.errors.push(SchemaError::invalid_field(
423 model_name,
424 field.name(),
425 format!(
426 "default value type does not match field type '{}'",
427 field.field_type
428 ),
429 ));
430 }
431 }
432 }
433
434 fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
436 match attr.name() {
437 "index" | "unique" => {
438 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
440 for field_name in fields {
441 if !model.fields.contains_key(field_name.as_str()) {
442 self.errors.push(SchemaError::invalid_model(
443 model.name(),
444 format!(
445 "@@{} references non-existent field '{}'",
446 attr.name(),
447 field_name
448 ),
449 ));
450 }
451 }
452 }
453 }
454 "id" => {
455 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
457 for field_name in fields {
458 if !model.fields.contains_key(field_name.as_str()) {
459 self.errors.push(SchemaError::invalid_model(
460 model.name(),
461 format!("@@id references non-existent field '{}'", field_name),
462 ));
463 }
464 }
465 }
466 }
467 "search" => {
468 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
470 for field_name in fields {
471 if let Some(field) = model.fields.get(field_name.as_str()) {
472 if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
474 self.errors.push(SchemaError::invalid_model(
475 model.name(),
476 format!(
477 "@@search field '{}' must be of type String",
478 field_name
479 ),
480 ));
481 }
482 } else {
483 self.errors.push(SchemaError::invalid_model(
484 model.name(),
485 format!("@@search references non-existent field '{}'", field_name),
486 ));
487 }
488 }
489 }
490 }
491 _ => {}
492 }
493 }
494
495 fn validate_enum(&mut self, e: &Enum) {
497 if e.variants.is_empty() {
498 self.errors.push(SchemaError::invalid_model(
499 e.name(),
500 "enum must have at least one variant".to_string(),
501 ));
502 }
503
504 let mut seen = std::collections::HashSet::new();
506 for variant in &e.variants {
507 if !seen.insert(variant.name()) {
508 self.errors.push(SchemaError::duplicate(
509 format!("enum variant in {}", e.name()),
510 variant.name(),
511 ));
512 }
513 }
514 }
515
516 fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
518 if t.fields.is_empty() {
519 self.errors.push(SchemaError::invalid_model(
520 t.name(),
521 "composite type must have at least one field".to_string(),
522 ));
523 }
524
525 for field in t.fields.values() {
527 match &field.field_type {
528 FieldType::Model(_) => {
529 self.errors.push(SchemaError::invalid_field(
530 t.name(),
531 field.name(),
532 "composite types cannot have model relations".to_string(),
533 ));
534 }
535 FieldType::Enum(name) => {
536 if !schema.enums.contains_key(name.as_str()) {
537 self.errors.push(SchemaError::unknown_type(
538 t.name(),
539 field.name(),
540 name.as_str(),
541 ));
542 }
543 }
544 FieldType::Composite(name) => {
545 if !schema.types.contains_key(name.as_str()) {
546 self.errors.push(SchemaError::unknown_type(
547 t.name(),
548 field.name(),
549 name.as_str(),
550 ));
551 }
552 }
553 _ => {}
554 }
555 }
556 }
557
558 fn validate_view(&mut self, v: &View, schema: &Schema) {
560 if v.fields.is_empty() {
562 self.errors.push(SchemaError::invalid_model(
563 v.name(),
564 "view must have at least one field".to_string(),
565 ));
566 }
567
568 for field in v.fields.values() {
570 self.validate_field(field, v.name(), schema);
571 }
572 }
573
574 fn validate_server_group(&mut self, sg: &ServerGroup) {
576 if sg.servers.is_empty() {
578 self.errors.push(SchemaError::invalid_model(
579 sg.name.name.as_str(),
580 "serverGroup must have at least one server".to_string(),
581 ));
582 }
583
584 let mut seen_servers = std::collections::HashSet::new();
586 for server_name in sg.servers.keys() {
587 if !seen_servers.insert(server_name.as_str()) {
588 self.errors.push(SchemaError::duplicate(
589 format!("server in serverGroup {}", sg.name.name),
590 server_name.as_str(),
591 ));
592 }
593 }
594
595 for server in sg.servers.values() {
597 self.validate_server(server, sg.name.name.as_str());
598 }
599
600 for attr in &sg.attributes {
602 self.validate_server_group_attribute(attr, sg);
603 }
604
605 if let Some(strategy) = sg.strategy()
607 && strategy == ServerGroupStrategy::ReadReplica
608 {
609 let has_primary = sg
610 .servers
611 .values()
612 .any(|s| s.role() == Some(ServerRole::Primary));
613 if !has_primary {
614 self.errors.push(SchemaError::invalid_model(
615 sg.name.name.as_str(),
616 "ReadReplica strategy requires at least one server with role = \"primary\""
617 .to_string(),
618 ));
619 }
620 }
621 }
622
623 fn validate_server(&mut self, server: &Server, group_name: &str) {
625 if server.url().is_none() {
627 self.errors.push(SchemaError::invalid_model(
628 group_name,
629 format!("server '{}' must have a 'url' property", server.name.name),
630 ));
631 }
632
633 if let Some(weight) = server.weight()
635 && weight == 0
636 {
637 self.errors.push(SchemaError::invalid_model(
638 group_name,
639 format!(
640 "server '{}' weight must be greater than 0",
641 server.name.name
642 ),
643 ));
644 }
645
646 if let Some(priority) = server.priority()
648 && priority == 0
649 {
650 self.errors.push(SchemaError::invalid_model(
651 group_name,
652 format!(
653 "server '{}' priority must be greater than 0",
654 server.name.name
655 ),
656 ));
657 }
658 }
659
660 fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
662 match attr.name() {
663 "strategy" => {
664 if let Some(arg) = attr.first_arg() {
666 let value_str = arg
667 .as_string()
668 .map(|s| s.to_string())
669 .or_else(|| arg.as_ident().map(|s| s.to_string()));
670 if let Some(val) = value_str
671 && ServerGroupStrategy::parse(&val).is_none()
672 {
673 self.errors.push(SchemaError::InvalidAttribute {
674 attribute: "strategy".to_string(),
675 message: format!(
676 "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
677 val,
678 sg.name.name
679 ),
680 });
681 }
682 }
683 }
684 "loadBalance" => {
685 if let Some(arg) = attr.first_arg() {
687 let value_str = arg
688 .as_string()
689 .map(|s| s.to_string())
690 .or_else(|| arg.as_ident().map(|s| s.to_string()));
691 if let Some(val) = value_str
692 && LoadBalanceStrategy::parse(&val).is_none()
693 {
694 self.errors.push(SchemaError::InvalidAttribute {
695 attribute: "loadBalance".to_string(),
696 message: format!(
697 "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
698 val,
699 sg.name.name
700 ),
701 });
702 }
703 }
704 }
705 _ => {} }
707 }
708
709 fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
711 let mut relations = Vec::new();
712
713 for model in schema.models.values() {
714 for field in model.fields.values() {
715 if let FieldType::Model(ref target_model) = field.field_type {
716 if schema.enums.contains_key(target_model.as_str()) {
718 continue;
719 }
720
721 if schema.types.contains_key(target_model.as_str()) {
723 continue;
724 }
725
726 if !schema.models.contains_key(target_model.as_str()) {
728 continue;
729 }
730
731 let attrs = field.extract_attributes();
732
733 let relation_type = if field.is_list() {
734 RelationType::OneToMany
736 } else {
737 RelationType::ManyToOne
739 };
740
741 let mut relation = Relation::new(
742 model.name(),
743 field.name(),
744 target_model.as_str(),
745 relation_type,
746 );
747
748 if let Some(rel_attr) = &attrs.relation {
749 if let Some(name) = &rel_attr.name {
750 relation = relation.with_name(name.as_str());
751 }
752 if !rel_attr.fields.is_empty() {
753 relation = relation.with_from_fields(rel_attr.fields.clone());
754 }
755 if !rel_attr.references.is_empty() {
756 relation = relation.with_to_fields(rel_attr.references.clone());
757 }
758 if let Some(action) = rel_attr.on_delete {
759 relation = relation.with_on_delete(action);
760 }
761 if let Some(action) = rel_attr.on_update {
762 relation = relation.with_on_update(action);
763 }
764 if let Some(map) = &rel_attr.map {
765 relation = relation.with_map(map.as_str());
766 }
767 }
768
769 relations.push(relation);
770 }
771 }
772 }
773
774 relations
775 }
776}
777
778pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
780 let schema = crate::parser::parse_schema(input)?;
781 let mut validator = Validator::new();
782 validator.validate(schema)
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 #[test]
790 fn test_validate_simple_model() {
791 let schema = validate_schema(
792 r#"
793 model User {
794 id Int @id @auto
795 email String @unique
796 }
797 "#,
798 )
799 .unwrap();
800
801 assert_eq!(schema.models.len(), 1);
802 }
803
804 #[test]
805 fn test_validate_model_missing_id() {
806 let result = validate_schema(
807 r#"
808 model User {
809 email String
810 name String
811 }
812 "#,
813 );
814
815 assert!(result.is_err());
816 let err = result.unwrap_err();
817 assert!(matches!(err, SchemaError::ValidationFailed { .. }));
818 }
819
820 #[test]
821 fn test_validate_model_with_composite_id() {
822 let schema = validate_schema(
823 r#"
824 model PostTag {
825 post_id Int
826 tag_id Int
827
828 @@id([post_id, tag_id])
829 }
830 "#,
831 )
832 .unwrap();
833
834 assert_eq!(schema.models.len(), 1);
835 }
836
837 #[test]
838 fn test_validate_unknown_type_reference() {
839 let result = validate_schema(
840 r#"
841 model User {
842 id Int @id @auto
843 profile UnknownType
844 }
845 "#,
846 );
847
848 assert!(result.is_err());
849 }
850
851 #[test]
852 fn test_validate_enum_reference() {
853 let schema = validate_schema(
854 r#"
855 enum Role {
856 User
857 Admin
858 }
859
860 model User {
861 id Int @id @auto
862 role Role @default(User)
863 }
864 "#,
865 )
866 .unwrap();
867
868 assert_eq!(schema.models.len(), 1);
869 assert_eq!(schema.enums.len(), 1);
870 }
871
872 #[test]
873 fn test_validate_invalid_enum_default() {
874 let result = validate_schema(
875 r#"
876 enum Role {
877 User
878 Admin
879 }
880
881 model User {
882 id Int @id @auto
883 role Role @default(Unknown)
884 }
885 "#,
886 );
887
888 assert!(result.is_err());
889 }
890
891 #[test]
892 fn test_validate_auto_on_non_int() {
893 let result = validate_schema(
894 r#"
895 model User {
896 id String @id @auto
897 email String
898 }
899 "#,
900 );
901
902 assert!(result.is_err());
903 }
904
905 #[test]
906 fn test_validate_updated_at_on_non_datetime() {
907 let result = validate_schema(
908 r#"
909 model User {
910 id Int @id @auto
911 updated_at String @updated_at
912 }
913 "#,
914 );
915
916 assert!(result.is_err());
917 }
918
919 #[test]
920 fn test_validate_empty_enum() {
921 let result = validate_schema(
922 r#"
923 enum Empty {
924 }
925
926 model User {
927 id Int @id @auto
928 }
929 "#,
930 );
931
932 assert!(result.is_err());
933 }
934
935 #[test]
936 fn test_validate_duplicate_model_names() {
937 let result = validate_schema(
938 r#"
939 model User {
940 id Int @id @auto
941 }
942
943 model User {
944 id Int @id @auto
945 }
946 "#,
947 );
948
949 assert!(result.is_ok() || result.is_err());
952 }
953
954 #[test]
955 fn test_validate_relation() {
956 let schema = validate_schema(
957 r#"
958 model User {
959 id Int @id @auto
960 posts Post[]
961 }
962
963 model Post {
964 id Int @id @auto
965 author_id Int
966 author User @relation(fields: [author_id], references: [id])
967 }
968 "#,
969 )
970 .unwrap();
971
972 assert_eq!(schema.models.len(), 2);
973 assert!(!schema.relations.is_empty());
974 }
975
976 #[test]
977 fn test_validate_index_with_invalid_field() {
978 let result = validate_schema(
979 r#"
980 model User {
981 id Int @id @auto
982 email String
983
984 @@index([nonexistent])
985 }
986 "#,
987 );
988
989 assert!(result.is_err());
990 }
991
992 #[test]
993 fn test_validate_search_on_non_string_field() {
994 let result = validate_schema(
995 r#"
996 model Post {
997 id Int @id @auto
998 views Int
999
1000 @@search([views])
1001 }
1002 "#,
1003 );
1004
1005 assert!(result.is_err());
1006 }
1007
1008 #[test]
1009 fn test_validate_composite_type() {
1010 let schema = validate_schema(
1011 r#"
1012 type Address {
1013 street String
1014 city String
1015 country String @default("US")
1016 }
1017
1018 model User {
1019 id Int @id @auto
1020 address Address
1021 }
1022 "#,
1023 );
1024
1025 assert!(schema.is_ok() || schema.is_err());
1027 }
1028
1029 #[test]
1032 fn test_validate_server_group_basic() {
1033 let schema = validate_schema(
1034 r#"
1035 model User {
1036 id Int @id @auto
1037 }
1038
1039 serverGroup MainCluster {
1040 server primary {
1041 url = "postgres://localhost/db"
1042 role = "primary"
1043 }
1044 }
1045 "#,
1046 )
1047 .unwrap();
1048
1049 assert_eq!(schema.server_groups.len(), 1);
1050 }
1051
1052 #[test]
1053 fn test_validate_server_group_empty_servers() {
1054 let result = validate_schema(
1055 r#"
1056 model User {
1057 id Int @id @auto
1058 }
1059
1060 serverGroup EmptyCluster {
1061 }
1062 "#,
1063 );
1064
1065 assert!(result.is_err());
1066 }
1067
1068 #[test]
1069 fn test_validate_server_group_missing_url() {
1070 let result = validate_schema(
1071 r#"
1072 model User {
1073 id Int @id @auto
1074 }
1075
1076 serverGroup Cluster {
1077 server db {
1078 role = "primary"
1079 }
1080 }
1081 "#,
1082 );
1083
1084 assert!(result.is_err());
1085 }
1086
1087 #[test]
1088 fn test_validate_server_group_invalid_strategy() {
1089 let result = validate_schema(
1090 r#"
1091 model User {
1092 id Int @id @auto
1093 }
1094
1095 serverGroup Cluster {
1096 @@strategy(InvalidStrategy)
1097
1098 server db {
1099 url = "postgres://localhost/db"
1100 }
1101 }
1102 "#,
1103 );
1104
1105 assert!(result.is_err());
1106 }
1107
1108 #[test]
1109 fn test_validate_server_group_valid_strategy() {
1110 let schema = validate_schema(
1111 r#"
1112 model User {
1113 id Int @id @auto
1114 }
1115
1116 serverGroup Cluster {
1117 @@strategy(ReadReplica)
1118 @@loadBalance(RoundRobin)
1119
1120 server primary {
1121 url = "postgres://localhost/db"
1122 role = "primary"
1123 }
1124 }
1125 "#,
1126 )
1127 .unwrap();
1128
1129 assert_eq!(schema.server_groups.len(), 1);
1130 }
1131
1132 #[test]
1133 fn test_validate_server_group_read_replica_needs_primary() {
1134 let result = validate_schema(
1135 r#"
1136 model User {
1137 id Int @id @auto
1138 }
1139
1140 serverGroup Cluster {
1141 @@strategy(ReadReplica)
1142
1143 server replica1 {
1144 url = "postgres://localhost/db"
1145 role = "replica"
1146 }
1147 }
1148 "#,
1149 );
1150
1151 assert!(result.is_err());
1152 }
1153
1154 #[test]
1155 fn test_validate_server_group_with_replicas() {
1156 let schema = validate_schema(
1157 r#"
1158 model User {
1159 id Int @id @auto
1160 }
1161
1162 serverGroup Cluster {
1163 @@strategy(ReadReplica)
1164
1165 server primary {
1166 url = "postgres://primary/db"
1167 role = "primary"
1168 weight = 1
1169 }
1170
1171 server replica1 {
1172 url = "postgres://replica1/db"
1173 role = "replica"
1174 weight = 2
1175 }
1176
1177 server replica2 {
1178 url = "postgres://replica2/db"
1179 role = "replica"
1180 weight = 2
1181 region = "us-west-1"
1182 }
1183 }
1184 "#,
1185 )
1186 .unwrap();
1187
1188 let cluster = schema.get_server_group("Cluster").unwrap();
1189 assert_eq!(cluster.servers.len(), 3);
1190 }
1191
1192 #[test]
1193 fn test_validate_server_group_zero_weight() {
1194 let result = validate_schema(
1195 r#"
1196 model User {
1197 id Int @id @auto
1198 }
1199
1200 serverGroup Cluster {
1201 server db {
1202 url = "postgres://localhost/db"
1203 weight = 0
1204 }
1205 }
1206 "#,
1207 );
1208
1209 assert!(result.is_err());
1210 }
1211
1212 #[test]
1213 fn test_validate_server_group_invalid_load_balance() {
1214 let result = validate_schema(
1215 r#"
1216 model User {
1217 id Int @id @auto
1218 }
1219
1220 serverGroup Cluster {
1221 @@loadBalance(InvalidStrategy)
1222
1223 server db {
1224 url = "postgres://localhost/db"
1225 }
1226 }
1227 "#,
1228 );
1229
1230 assert!(result.is_err());
1231 }
1232}