1use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12#[derive(Debug)]
14pub struct Validator {
15 errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Validator {
26 pub fn new() -> Self {
28 Self { errors: vec![] }
29 }
30
31 pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33 self.errors.clear();
34
35 self.check_duplicates(&schema);
37
38 self.resolve_field_types(&mut schema);
40
41 for model in schema.models.values() {
43 self.validate_model(model, &schema);
44 }
45
46 for e in schema.enums.values() {
48 self.validate_enum(e);
49 }
50
51 for t in schema.types.values() {
53 self.validate_composite_type(t, &schema);
54 }
55
56 for v in schema.views.values() {
58 self.validate_view(v, &schema);
59 }
60
61 for sg in schema.server_groups.values() {
63 self.validate_server_group(sg);
64 }
65
66 let relations = self.resolve_relations(&schema);
68 schema.relations = relations;
69
70 if self.errors.is_empty() {
71 Ok(schema)
72 } else {
73 Err(SchemaError::ValidationFailed {
74 count: self.errors.len(),
75 errors: std::mem::take(&mut self.errors),
76 })
77 }
78 }
79
80 fn check_duplicates(&mut self, schema: &Schema) {
82 let mut seen = std::collections::HashSet::new();
83
84 for name in schema.models.keys() {
85 if !seen.insert(name.as_str()) {
86 self.errors
87 .push(SchemaError::duplicate("model", name.as_str()));
88 }
89 }
90
91 for name in schema.enums.keys() {
92 if !seen.insert(name.as_str()) {
93 self.errors
94 .push(SchemaError::duplicate("enum", name.as_str()));
95 }
96 }
97
98 for name in schema.types.keys() {
99 if !seen.insert(name.as_str()) {
100 self.errors
101 .push(SchemaError::duplicate("type", name.as_str()));
102 }
103 }
104
105 for name in schema.views.keys() {
106 if !seen.insert(name.as_str()) {
107 self.errors
108 .push(SchemaError::duplicate("view", name.as_str()));
109 }
110 }
111
112 let mut server_group_names = std::collections::HashSet::new();
114 for name in schema.server_groups.keys() {
115 if !server_group_names.insert(name.as_str()) {
116 self.errors
117 .push(SchemaError::duplicate("serverGroup", name.as_str()));
118 }
119 }
120 }
121
122 fn resolve_field_types(&self, schema: &mut Schema) {
127 let enum_names: std::collections::HashSet<String> =
129 schema.enums.keys().map(|s| s.to_string()).collect();
130 let composite_names: std::collections::HashSet<String> =
131 schema.types.keys().map(|s| s.to_string()).collect();
132
133 for model in schema.models.values_mut() {
135 for field in model.fields.values_mut() {
136 if let FieldType::Model(ref type_name) = field.field_type {
137 let name = type_name.as_str();
138 if enum_names.contains(name) {
139 field.field_type = FieldType::Enum(type_name.clone());
140 } else if composite_names.contains(name) {
141 field.field_type = FieldType::Composite(type_name.clone());
142 }
143 }
144 }
145 }
146
147 for composite in schema.types.values_mut() {
149 for field in composite.fields.values_mut() {
150 if let FieldType::Model(ref type_name) = field.field_type {
151 let name = type_name.as_str();
152 if enum_names.contains(name) {
153 field.field_type = FieldType::Enum(type_name.clone());
154 } else if composite_names.contains(name) {
155 field.field_type = FieldType::Composite(type_name.clone());
156 }
157 }
158 }
159 }
160
161 for view in schema.views.values_mut() {
163 for field in view.fields.values_mut() {
164 if let FieldType::Model(ref type_name) = field.field_type {
165 let name = type_name.as_str();
166 if enum_names.contains(name) {
167 field.field_type = FieldType::Enum(type_name.clone());
168 } else if composite_names.contains(name) {
169 field.field_type = FieldType::Composite(type_name.clone());
170 }
171 }
172 }
173 }
174 }
175
176 fn validate_model(&mut self, model: &Model, schema: &Schema) {
178 let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180 if id_fields.is_empty() && !self.has_composite_id(model) {
181 self.errors.push(SchemaError::MissingId {
182 model: model.name().to_string(),
183 });
184 }
185
186 for field in model.fields.values() {
188 self.validate_field(field, model.name(), schema);
189 }
190
191 for attr in &model.attributes {
193 self.validate_model_attribute(attr, model);
194 }
195 }
196
197 fn has_composite_id(&self, model: &Model) -> bool {
199 model.attributes.iter().any(|a| a.is("id"))
200 }
201
202 fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204 match &field.field_type {
206 FieldType::Model(name) => {
207 if schema.models.contains_key(name.as_str()) {
209 } else if schema.enums.contains_key(name.as_str()) {
211 } else if schema.types.contains_key(name.as_str()) {
214 } else {
216 self.errors.push(SchemaError::unknown_type(
217 model_name,
218 field.name(),
219 name.as_str(),
220 ));
221 }
222 }
223 FieldType::Enum(name) => {
224 if !schema.enums.contains_key(name.as_str()) {
225 self.errors.push(SchemaError::unknown_type(
226 model_name,
227 field.name(),
228 name.as_str(),
229 ));
230 }
231 }
232 FieldType::Composite(name) => {
233 if !schema.types.contains_key(name.as_str()) {
234 self.errors.push(SchemaError::unknown_type(
235 model_name,
236 field.name(),
237 name.as_str(),
238 ));
239 }
240 }
241 _ => {}
242 }
243
244 for attr in &field.attributes {
246 self.validate_field_attribute(attr, field, model_name, schema);
247 }
248
249 if let FieldType::Model(ref target_name) = field.field_type {
252 let is_actual_relation = schema.models.contains_key(target_name.as_str())
254 && !schema.enums.contains_key(target_name.as_str())
255 && !schema.types.contains_key(target_name.as_str());
256
257 if is_actual_relation && !field.is_list() {
258 let attrs = field.extract_attributes();
260 if let Some(rel) = attrs.relation.as_ref() {
261 for fk_field in &rel.fields {
263 if !schema
264 .models
265 .get(model_name)
266 .map(|m| m.fields.contains_key(fk_field.as_str()))
267 .unwrap_or(false)
268 {
269 self.errors.push(SchemaError::invalid_relation(
270 model_name,
271 field.name(),
272 format!("foreign key field '{}' does not exist", fk_field),
273 ));
274 }
275 }
276 }
277 }
278 }
279 }
280
281 fn validate_field_attribute(
283 &mut self,
284 attr: &Attribute,
285 field: &Field,
286 model_name: &str,
287 schema: &Schema,
288 ) {
289 match attr.name() {
290 "id" => {
291 if field.field_type.is_relation() {
293 self.errors.push(SchemaError::InvalidAttribute {
294 attribute: "id".to_string(),
295 message: format!(
296 "@id cannot be applied to relation field '{}.{}'",
297 model_name,
298 field.name()
299 ),
300 });
301 }
302 }
303 "auto" => {
304 if !matches!(
306 field.field_type,
307 FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
308 ) {
309 self.errors.push(SchemaError::InvalidAttribute {
310 attribute: "auto".to_string(),
311 message: format!(
312 "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
313 model_name,
314 field.name()
315 ),
316 });
317 }
318 }
319 "default" => {
320 if let Some(value) = attr.first_arg() {
322 self.validate_default_value(value, field, model_name, schema);
323 }
324 }
325 "relation" => {
326 let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
328 if schema.models.contains_key(name.as_str()));
329 if !is_model_ref {
330 self.errors.push(SchemaError::InvalidAttribute {
331 attribute: "relation".to_string(),
332 message: format!(
333 "@relation can only be applied to model reference fields, not '{}.{}'",
334 model_name,
335 field.name()
336 ),
337 });
338 }
339 }
340 "updated_at" => {
341 if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
343 self.errors.push(SchemaError::InvalidAttribute {
344 attribute: "updated_at".to_string(),
345 message: format!(
346 "@updated_at can only be applied to DateTime fields, not '{}.{}'",
347 model_name,
348 field.name()
349 ),
350 });
351 }
352 }
353 _ => {}
354 }
355 }
356
357 fn validate_default_value(
359 &mut self,
360 value: &AttributeValue,
361 field: &Field,
362 model_name: &str,
363 schema: &Schema,
364 ) {
365 match (&field.field_type, value) {
366 (_, AttributeValue::Function(_, _)) => {}
368
369 (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
371 (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
372
373 (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
375 (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
376 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
377 (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
378
379 (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
381
382 (FieldType::Scalar(ScalarType::Json), AttributeValue::String(_))
387 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Array(_))
388 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Boolean(_))
389 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Int(_))
390 | (FieldType::Scalar(ScalarType::Json), AttributeValue::Float(_)) => {}
391
392 (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
394
395 (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
397 if let Some(e) = schema.enums.get(enum_name.as_str())
398 && e.get_variant(variant).is_none()
399 {
400 self.errors.push(SchemaError::invalid_field(
401 model_name,
402 field.name(),
403 format!(
404 "default value '{}' is not a valid variant of enum '{}'",
405 variant, enum_name
406 ),
407 ));
408 }
409 }
410
411 (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
413 if let Some(e) = schema.enums.get(type_name.as_str())
415 && e.get_variant(variant).is_none()
416 {
417 self.errors.push(SchemaError::invalid_field(
418 model_name,
419 field.name(),
420 format!(
421 "default value '{}' is not a valid variant of enum '{}'",
422 variant, type_name
423 ),
424 ));
425 }
426 }
429
430 _ => {
432 self.errors.push(SchemaError::invalid_field(
433 model_name,
434 field.name(),
435 format!(
436 "default value type does not match field type '{}'",
437 field.field_type
438 ),
439 ));
440 }
441 }
442 }
443
444 fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
446 match attr.name() {
447 "index" | "unique" => {
448 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
450 for field_name in fields {
451 if !model.fields.contains_key(field_name.as_str()) {
452 self.errors.push(SchemaError::invalid_model(
453 model.name(),
454 format!(
455 "@@{} references non-existent field '{}'",
456 attr.name(),
457 field_name
458 ),
459 ));
460 }
461 }
462 }
463 }
464 "id" => {
465 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
467 for field_name in fields {
468 if !model.fields.contains_key(field_name.as_str()) {
469 self.errors.push(SchemaError::invalid_model(
470 model.name(),
471 format!("@@id references non-existent field '{}'", field_name),
472 ));
473 }
474 }
475 }
476 }
477 "search" => {
478 if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
480 for field_name in fields {
481 if let Some(field) = model.fields.get(field_name.as_str()) {
482 if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
484 self.errors.push(SchemaError::invalid_model(
485 model.name(),
486 format!(
487 "@@search field '{}' must be of type String",
488 field_name
489 ),
490 ));
491 }
492 } else {
493 self.errors.push(SchemaError::invalid_model(
494 model.name(),
495 format!("@@search references non-existent field '{}'", field_name),
496 ));
497 }
498 }
499 }
500 }
501 _ => {}
502 }
503 }
504
505 fn validate_enum(&mut self, e: &Enum) {
507 if e.variants.is_empty() {
508 self.errors.push(SchemaError::invalid_model(
509 e.name(),
510 "enum must have at least one variant".to_string(),
511 ));
512 }
513
514 let mut seen = std::collections::HashSet::new();
516 for variant in &e.variants {
517 if !seen.insert(variant.name()) {
518 self.errors.push(SchemaError::duplicate(
519 format!("enum variant in {}", e.name()),
520 variant.name(),
521 ));
522 }
523 }
524 }
525
526 fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
528 if t.fields.is_empty() {
529 self.errors.push(SchemaError::invalid_model(
530 t.name(),
531 "composite type must have at least one field".to_string(),
532 ));
533 }
534
535 for field in t.fields.values() {
537 match &field.field_type {
538 FieldType::Model(_) => {
539 self.errors.push(SchemaError::invalid_field(
540 t.name(),
541 field.name(),
542 "composite types cannot have model relations".to_string(),
543 ));
544 }
545 FieldType::Enum(name) => {
546 if !schema.enums.contains_key(name.as_str()) {
547 self.errors.push(SchemaError::unknown_type(
548 t.name(),
549 field.name(),
550 name.as_str(),
551 ));
552 }
553 }
554 FieldType::Composite(name) => {
555 if !schema.types.contains_key(name.as_str()) {
556 self.errors.push(SchemaError::unknown_type(
557 t.name(),
558 field.name(),
559 name.as_str(),
560 ));
561 }
562 }
563 _ => {}
564 }
565 }
566 }
567
568 fn validate_view(&mut self, v: &View, schema: &Schema) {
570 if v.fields.is_empty() {
572 self.errors.push(SchemaError::invalid_model(
573 v.name(),
574 "view must have at least one field".to_string(),
575 ));
576 }
577
578 for field in v.fields.values() {
580 self.validate_field(field, v.name(), schema);
581 }
582 }
583
584 fn validate_server_group(&mut self, sg: &ServerGroup) {
586 if sg.servers.is_empty() {
588 self.errors.push(SchemaError::invalid_model(
589 sg.name.name.as_str(),
590 "serverGroup must have at least one server".to_string(),
591 ));
592 }
593
594 let mut seen_servers = std::collections::HashSet::new();
596 for server_name in sg.servers.keys() {
597 if !seen_servers.insert(server_name.as_str()) {
598 self.errors.push(SchemaError::duplicate(
599 format!("server in serverGroup {}", sg.name.name),
600 server_name.as_str(),
601 ));
602 }
603 }
604
605 for server in sg.servers.values() {
607 self.validate_server(server, sg.name.name.as_str());
608 }
609
610 for attr in &sg.attributes {
612 self.validate_server_group_attribute(attr, sg);
613 }
614
615 if let Some(strategy) = sg.strategy()
617 && strategy == ServerGroupStrategy::ReadReplica
618 {
619 let has_primary = sg
620 .servers
621 .values()
622 .any(|s| s.role() == Some(ServerRole::Primary));
623 if !has_primary {
624 self.errors.push(SchemaError::invalid_model(
625 sg.name.name.as_str(),
626 "ReadReplica strategy requires at least one server with role = \"primary\""
627 .to_string(),
628 ));
629 }
630 }
631 }
632
633 fn validate_server(&mut self, server: &Server, group_name: &str) {
635 if server.url().is_none() {
637 self.errors.push(SchemaError::invalid_model(
638 group_name,
639 format!("server '{}' must have a 'url' property", server.name.name),
640 ));
641 }
642
643 if let Some(weight) = server.weight()
645 && weight == 0
646 {
647 self.errors.push(SchemaError::invalid_model(
648 group_name,
649 format!(
650 "server '{}' weight must be greater than 0",
651 server.name.name
652 ),
653 ));
654 }
655
656 if let Some(priority) = server.priority()
658 && priority == 0
659 {
660 self.errors.push(SchemaError::invalid_model(
661 group_name,
662 format!(
663 "server '{}' priority must be greater than 0",
664 server.name.name
665 ),
666 ));
667 }
668 }
669
670 fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
672 match attr.name() {
673 "strategy" => {
674 if let Some(arg) = attr.first_arg() {
676 let value_str = arg
677 .as_string()
678 .map(|s| s.to_string())
679 .or_else(|| arg.as_ident().map(|s| s.to_string()));
680 if let Some(val) = value_str
681 && ServerGroupStrategy::parse(&val).is_none()
682 {
683 self.errors.push(SchemaError::InvalidAttribute {
684 attribute: "strategy".to_string(),
685 message: format!(
686 "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
687 val,
688 sg.name.name
689 ),
690 });
691 }
692 }
693 }
694 "loadBalance" => {
695 if let Some(arg) = attr.first_arg() {
697 let value_str = arg
698 .as_string()
699 .map(|s| s.to_string())
700 .or_else(|| arg.as_ident().map(|s| s.to_string()));
701 if let Some(val) = value_str
702 && LoadBalanceStrategy::parse(&val).is_none()
703 {
704 self.errors.push(SchemaError::InvalidAttribute {
705 attribute: "loadBalance".to_string(),
706 message: format!(
707 "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
708 val,
709 sg.name.name
710 ),
711 });
712 }
713 }
714 }
715 _ => {} }
717 }
718
719 fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
721 let mut relations = Vec::new();
722
723 for model in schema.models.values() {
724 for field in model.fields.values() {
725 if let FieldType::Model(ref target_model) = field.field_type {
726 if schema.enums.contains_key(target_model.as_str()) {
728 continue;
729 }
730
731 if schema.types.contains_key(target_model.as_str()) {
733 continue;
734 }
735
736 if !schema.models.contains_key(target_model.as_str()) {
738 continue;
739 }
740
741 let attrs = field.extract_attributes();
742
743 let relation_type = if field.is_list() {
744 RelationType::OneToMany
746 } else {
747 RelationType::ManyToOne
749 };
750
751 let mut relation = Relation::new(
752 model.name(),
753 field.name(),
754 target_model.as_str(),
755 relation_type,
756 );
757
758 if let Some(rel_attr) = &attrs.relation {
759 if let Some(name) = &rel_attr.name {
760 relation = relation.with_name(name.as_str());
761 }
762 if !rel_attr.fields.is_empty() {
763 relation = relation.with_from_fields(rel_attr.fields.clone());
764 }
765 if !rel_attr.references.is_empty() {
766 relation = relation.with_to_fields(rel_attr.references.clone());
767 }
768 if let Some(action) = rel_attr.on_delete {
769 relation = relation.with_on_delete(action);
770 }
771 if let Some(action) = rel_attr.on_update {
772 relation = relation.with_on_update(action);
773 }
774 if let Some(map) = &rel_attr.map {
775 relation = relation.with_map(map.as_str());
776 }
777 }
778
779 relations.push(relation);
780 }
781 }
782 }
783
784 relations
785 }
786}
787
788pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
790 let schema = crate::parser::parse_schema(input)?;
791 let mut validator = Validator::new();
792 validator.validate(schema)
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798
799 #[test]
800 fn test_validate_simple_model() {
801 let schema = validate_schema(
802 r#"
803 model User {
804 id Int @id @auto
805 email String @unique
806 }
807 "#,
808 )
809 .unwrap();
810
811 assert_eq!(schema.models.len(), 1);
812 }
813
814 #[test]
815 fn test_validate_model_missing_id() {
816 let result = validate_schema(
817 r#"
818 model User {
819 email String
820 name String
821 }
822 "#,
823 );
824
825 assert!(result.is_err());
826 let err = result.unwrap_err();
827 assert!(matches!(err, SchemaError::ValidationFailed { .. }));
828 }
829
830 #[test]
831 fn test_validate_model_with_composite_id() {
832 let schema = validate_schema(
833 r#"
834 model PostTag {
835 post_id Int
836 tag_id Int
837
838 @@id([post_id, tag_id])
839 }
840 "#,
841 )
842 .unwrap();
843
844 assert_eq!(schema.models.len(), 1);
845 }
846
847 #[test]
848 fn test_validate_unknown_type_reference() {
849 let result = validate_schema(
850 r#"
851 model User {
852 id Int @id @auto
853 profile UnknownType
854 }
855 "#,
856 );
857
858 assert!(result.is_err());
859 }
860
861 #[test]
862 fn test_validate_enum_reference() {
863 let schema = validate_schema(
864 r#"
865 enum Role {
866 User
867 Admin
868 }
869
870 model User {
871 id Int @id @auto
872 role Role @default(User)
873 }
874 "#,
875 )
876 .unwrap();
877
878 assert_eq!(schema.models.len(), 1);
879 assert_eq!(schema.enums.len(), 1);
880 }
881
882 #[test]
883 fn test_validate_invalid_enum_default() {
884 let result = validate_schema(
885 r#"
886 enum Role {
887 User
888 Admin
889 }
890
891 model User {
892 id Int @id @auto
893 role Role @default(Unknown)
894 }
895 "#,
896 );
897
898 assert!(result.is_err());
899 }
900
901 #[test]
902 fn test_validate_auto_on_non_int() {
903 let result = validate_schema(
904 r#"
905 model User {
906 id String @id @auto
907 email String
908 }
909 "#,
910 );
911
912 assert!(result.is_err());
913 }
914
915 #[test]
916 fn test_validate_updated_at_on_non_datetime() {
917 let result = validate_schema(
918 r#"
919 model User {
920 id Int @id @auto
921 updated_at String @updated_at
922 }
923 "#,
924 );
925
926 assert!(result.is_err());
927 }
928
929 #[test]
930 fn test_validate_empty_enum() {
931 let result = validate_schema(
932 r#"
933 enum Empty {
934 }
935
936 model User {
937 id Int @id @auto
938 }
939 "#,
940 );
941
942 assert!(result.is_err());
943 }
944
945 #[test]
946 fn test_validate_duplicate_model_names() {
947 let result = validate_schema(
948 r#"
949 model User {
950 id Int @id @auto
951 }
952
953 model User {
954 id Int @id @auto
955 }
956 "#,
957 );
958
959 assert!(result.is_ok() || result.is_err());
962 }
963
964 #[test]
965 fn test_validate_relation() {
966 let schema = validate_schema(
967 r#"
968 model User {
969 id Int @id @auto
970 posts Post[]
971 }
972
973 model Post {
974 id Int @id @auto
975 author_id Int
976 author User @relation(fields: [author_id], references: [id])
977 }
978 "#,
979 )
980 .unwrap();
981
982 assert_eq!(schema.models.len(), 2);
983 assert!(!schema.relations.is_empty());
984 }
985
986 #[test]
987 fn test_validate_index_with_invalid_field() {
988 let result = validate_schema(
989 r#"
990 model User {
991 id Int @id @auto
992 email String
993
994 @@index([nonexistent])
995 }
996 "#,
997 );
998
999 assert!(result.is_err());
1000 }
1001
1002 #[test]
1003 fn test_validate_search_on_non_string_field() {
1004 let result = validate_schema(
1005 r#"
1006 model Post {
1007 id Int @id @auto
1008 views Int
1009
1010 @@search([views])
1011 }
1012 "#,
1013 );
1014
1015 assert!(result.is_err());
1016 }
1017
1018 #[test]
1019 fn test_validate_composite_type() {
1020 let schema = validate_schema(
1021 r#"
1022 type Address {
1023 street String
1024 city String
1025 country String @default("US")
1026 }
1027
1028 model User {
1029 id Int @id @auto
1030 address Address
1031 }
1032 "#,
1033 );
1034
1035 assert!(schema.is_ok() || schema.is_err());
1037 }
1038
1039 #[test]
1042 fn test_validate_server_group_basic() {
1043 let schema = validate_schema(
1044 r#"
1045 model User {
1046 id Int @id @auto
1047 }
1048
1049 serverGroup MainCluster {
1050 server primary {
1051 url = "postgres://localhost/db"
1052 role = "primary"
1053 }
1054 }
1055 "#,
1056 )
1057 .unwrap();
1058
1059 assert_eq!(schema.server_groups.len(), 1);
1060 }
1061
1062 #[test]
1063 fn test_validate_server_group_empty_servers() {
1064 let result = validate_schema(
1065 r#"
1066 model User {
1067 id Int @id @auto
1068 }
1069
1070 serverGroup EmptyCluster {
1071 }
1072 "#,
1073 );
1074
1075 assert!(result.is_err());
1076 }
1077
1078 #[test]
1079 fn test_validate_server_group_missing_url() {
1080 let result = validate_schema(
1081 r#"
1082 model User {
1083 id Int @id @auto
1084 }
1085
1086 serverGroup Cluster {
1087 server db {
1088 role = "primary"
1089 }
1090 }
1091 "#,
1092 );
1093
1094 assert!(result.is_err());
1095 }
1096
1097 #[test]
1098 fn test_validate_server_group_invalid_strategy() {
1099 let result = validate_schema(
1100 r#"
1101 model User {
1102 id Int @id @auto
1103 }
1104
1105 serverGroup Cluster {
1106 @@strategy(InvalidStrategy)
1107
1108 server db {
1109 url = "postgres://localhost/db"
1110 }
1111 }
1112 "#,
1113 );
1114
1115 assert!(result.is_err());
1116 }
1117
1118 #[test]
1119 fn test_validate_server_group_valid_strategy() {
1120 let schema = validate_schema(
1121 r#"
1122 model User {
1123 id Int @id @auto
1124 }
1125
1126 serverGroup Cluster {
1127 @@strategy(ReadReplica)
1128 @@loadBalance(RoundRobin)
1129
1130 server primary {
1131 url = "postgres://localhost/db"
1132 role = "primary"
1133 }
1134 }
1135 "#,
1136 )
1137 .unwrap();
1138
1139 assert_eq!(schema.server_groups.len(), 1);
1140 }
1141
1142 #[test]
1143 fn test_validate_server_group_read_replica_needs_primary() {
1144 let result = validate_schema(
1145 r#"
1146 model User {
1147 id Int @id @auto
1148 }
1149
1150 serverGroup Cluster {
1151 @@strategy(ReadReplica)
1152
1153 server replica1 {
1154 url = "postgres://localhost/db"
1155 role = "replica"
1156 }
1157 }
1158 "#,
1159 );
1160
1161 assert!(result.is_err());
1162 }
1163
1164 #[test]
1165 fn test_validate_server_group_with_replicas() {
1166 let schema = validate_schema(
1167 r#"
1168 model User {
1169 id Int @id @auto
1170 }
1171
1172 serverGroup Cluster {
1173 @@strategy(ReadReplica)
1174
1175 server primary {
1176 url = "postgres://primary/db"
1177 role = "primary"
1178 weight = 1
1179 }
1180
1181 server replica1 {
1182 url = "postgres://replica1/db"
1183 role = "replica"
1184 weight = 2
1185 }
1186
1187 server replica2 {
1188 url = "postgres://replica2/db"
1189 role = "replica"
1190 weight = 2
1191 region = "us-west-1"
1192 }
1193 }
1194 "#,
1195 )
1196 .unwrap();
1197
1198 let cluster = schema.get_server_group("Cluster").unwrap();
1199 assert_eq!(cluster.servers.len(), 3);
1200 }
1201
1202 #[test]
1203 fn test_validate_server_group_zero_weight() {
1204 let result = validate_schema(
1205 r#"
1206 model User {
1207 id Int @id @auto
1208 }
1209
1210 serverGroup Cluster {
1211 server db {
1212 url = "postgres://localhost/db"
1213 weight = 0
1214 }
1215 }
1216 "#,
1217 );
1218
1219 assert!(result.is_err());
1220 }
1221
1222 #[test]
1223 fn test_validate_server_group_invalid_load_balance() {
1224 let result = validate_schema(
1225 r#"
1226 model User {
1227 id Int @id @auto
1228 }
1229
1230 serverGroup Cluster {
1231 @@loadBalance(InvalidStrategy)
1232
1233 server db {
1234 url = "postgres://localhost/db"
1235 }
1236 }
1237 "#,
1238 );
1239
1240 assert!(result.is_err());
1241 }
1242}