Skip to main content

prax_schema/
validator.rs

1//! Schema validation and semantic analysis.
2//!
3//! This module validates parsed schemas for semantic correctness:
4//! - All type references are valid
5//! - Relations are properly defined
6//! - Required attributes are present
7//! - No duplicate definitions
8
9use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12/// Schema validator for semantic analysis.
13#[derive(Debug)]
14pub struct Validator {
15    /// Collected validation errors.
16    errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl Validator {
26    /// Create a new validator.
27    pub fn new() -> Self {
28        Self { errors: vec![] }
29    }
30
31    /// Validate a schema and return the validated schema or errors.
32    pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33        self.errors.clear();
34
35        // Check for duplicate definitions
36        self.check_duplicates(&schema);
37
38        // Resolve field types (convert Model references to Enum or Composite where appropriate)
39        self.resolve_field_types(&mut schema);
40
41        // Validate each model
42        for model in schema.models.values() {
43            self.validate_model(model, &schema);
44        }
45
46        // Validate each enum
47        for e in schema.enums.values() {
48            self.validate_enum(e);
49        }
50
51        // Validate each composite type
52        for t in schema.types.values() {
53            self.validate_composite_type(t, &schema);
54        }
55
56        // Validate each view
57        for v in schema.views.values() {
58            self.validate_view(v, &schema);
59        }
60
61        // Validate each server group
62        for sg in schema.server_groups.values() {
63            self.validate_server_group(sg);
64        }
65
66        // Resolve relations
67        let relations = self.resolve_relations(&schema);
68        schema.relations = relations;
69
70        if self.errors.is_empty() {
71            Ok(schema)
72        } else {
73            Err(SchemaError::ValidationFailed {
74                count: self.errors.len(),
75                errors: std::mem::take(&mut self.errors),
76            })
77        }
78    }
79
80    /// Check for duplicate model, enum, or type names.
81    fn check_duplicates(&mut self, schema: &Schema) {
82        let mut seen = std::collections::HashSet::new();
83
84        for name in schema.models.keys() {
85            if !seen.insert(name.as_str()) {
86                self.errors
87                    .push(SchemaError::duplicate("model", name.as_str()));
88            }
89        }
90
91        for name in schema.enums.keys() {
92            if !seen.insert(name.as_str()) {
93                self.errors
94                    .push(SchemaError::duplicate("enum", name.as_str()));
95            }
96        }
97
98        for name in schema.types.keys() {
99            if !seen.insert(name.as_str()) {
100                self.errors
101                    .push(SchemaError::duplicate("type", name.as_str()));
102            }
103        }
104
105        for name in schema.views.keys() {
106            if !seen.insert(name.as_str()) {
107                self.errors
108                    .push(SchemaError::duplicate("view", name.as_str()));
109            }
110        }
111
112        // Check server group names (separately, since they don't conflict with types)
113        let mut server_group_names = std::collections::HashSet::new();
114        for name in schema.server_groups.keys() {
115            if !server_group_names.insert(name.as_str()) {
116                self.errors
117                    .push(SchemaError::duplicate("serverGroup", name.as_str()));
118            }
119        }
120    }
121
122    /// Resolve field types to their correct types (Enum or Composite) instead of Model.
123    ///
124    /// The parser initially treats all non-scalar type references as Model references.
125    /// This pass corrects them to Enum or Composite where appropriate.
126    fn resolve_field_types(&self, schema: &mut Schema) {
127        // Collect enum and composite type names into owned strings to avoid borrow conflicts
128        let enum_names: std::collections::HashSet<String> =
129            schema.enums.keys().map(|s| s.to_string()).collect();
130        let composite_names: std::collections::HashSet<String> =
131            schema.types.keys().map(|s| s.to_string()).collect();
132
133        // Update field types in models
134        for model in schema.models.values_mut() {
135            for field in model.fields.values_mut() {
136                if let FieldType::Model(ref type_name) = field.field_type {
137                    let name = type_name.as_str();
138                    if enum_names.contains(name) {
139                        field.field_type = FieldType::Enum(type_name.clone());
140                    } else if composite_names.contains(name) {
141                        field.field_type = FieldType::Composite(type_name.clone());
142                    }
143                }
144            }
145        }
146
147        // Also update field types in composite types
148        for composite in schema.types.values_mut() {
149            for field in composite.fields.values_mut() {
150                if let FieldType::Model(ref type_name) = field.field_type {
151                    let name = type_name.as_str();
152                    if enum_names.contains(name) {
153                        field.field_type = FieldType::Enum(type_name.clone());
154                    } else if composite_names.contains(name) {
155                        field.field_type = FieldType::Composite(type_name.clone());
156                    }
157                }
158            }
159        }
160
161        // Also update field types in views
162        for view in schema.views.values_mut() {
163            for field in view.fields.values_mut() {
164                if let FieldType::Model(ref type_name) = field.field_type {
165                    let name = type_name.as_str();
166                    if enum_names.contains(name) {
167                        field.field_type = FieldType::Enum(type_name.clone());
168                    } else if composite_names.contains(name) {
169                        field.field_type = FieldType::Composite(type_name.clone());
170                    }
171                }
172            }
173        }
174    }
175
176    /// Validate a model definition.
177    fn validate_model(&mut self, model: &Model, schema: &Schema) {
178        // Check for @id field
179        let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180        if id_fields.is_empty() && !self.has_composite_id(model) {
181            self.errors.push(SchemaError::MissingId {
182                model: model.name().to_string(),
183            });
184        }
185
186        // Validate each field
187        for field in model.fields.values() {
188            self.validate_field(field, model.name(), schema);
189        }
190
191        // Validate model attributes
192        for attr in &model.attributes {
193            self.validate_model_attribute(attr, model);
194        }
195    }
196
197    /// Check if model has a composite ID (@@id attribute).
198    fn has_composite_id(&self, model: &Model) -> bool {
199        model.attributes.iter().any(|a| a.is("id"))
200    }
201
202    /// Validate a field definition.
203    fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204        // Validate type references
205        match &field.field_type {
206            FieldType::Model(name) => {
207                // Check if it's actually a model, enum, or composite type
208                if schema.models.contains_key(name.as_str()) {
209                    // Valid model reference
210                } else if schema.enums.contains_key(name.as_str()) {
211                    // Parser initially treats non-scalar types as Model references
212                    // This is actually an enum type - we'll handle this during resolution
213                } else if schema.types.contains_key(name.as_str()) {
214                    // This is a composite type
215                } else {
216                    self.errors.push(SchemaError::unknown_type(
217                        model_name,
218                        field.name(),
219                        name.as_str(),
220                    ));
221                }
222            }
223            FieldType::Enum(name) => {
224                if !schema.enums.contains_key(name.as_str()) {
225                    self.errors.push(SchemaError::unknown_type(
226                        model_name,
227                        field.name(),
228                        name.as_str(),
229                    ));
230                }
231            }
232            FieldType::Composite(name) => {
233                if !schema.types.contains_key(name.as_str()) {
234                    self.errors.push(SchemaError::unknown_type(
235                        model_name,
236                        field.name(),
237                        name.as_str(),
238                    ));
239                }
240            }
241            _ => {}
242        }
243
244        // Validate field attributes
245        for attr in &field.attributes {
246            self.validate_field_attribute(attr, field, model_name, schema);
247        }
248
249        // Validate relation fields have @relation or are back-references
250        // Only check actual model relations (not enums or composite types parsed as Model)
251        if let FieldType::Model(ref target_name) = field.field_type {
252            // Skip validation for enum and composite type references
253            let is_actual_relation = schema.models.contains_key(target_name.as_str())
254                && !schema.enums.contains_key(target_name.as_str())
255                && !schema.types.contains_key(target_name.as_str());
256
257            if is_actual_relation && !field.is_list() {
258                // One-side of relation should have foreign key fields
259                let attrs = field.extract_attributes();
260                if let Some(rel) = attrs.relation.as_ref() {
261                    // Validate foreign key fields exist
262                    for fk_field in &rel.fields {
263                        if !schema
264                            .models
265                            .get(model_name)
266                            .map(|m| m.fields.contains_key(fk_field.as_str()))
267                            .unwrap_or(false)
268                        {
269                            self.errors.push(SchemaError::invalid_relation(
270                                model_name,
271                                field.name(),
272                                format!("foreign key field '{}' does not exist", fk_field),
273                            ));
274                        }
275                    }
276                }
277            }
278        }
279    }
280
281    /// Validate a field attribute.
282    fn validate_field_attribute(
283        &mut self,
284        attr: &Attribute,
285        field: &Field,
286        model_name: &str,
287        schema: &Schema,
288    ) {
289        match attr.name() {
290            "id" => {
291                // @id should be on a scalar or composite type, not a relation
292                if field.field_type.is_relation() {
293                    self.errors.push(SchemaError::InvalidAttribute {
294                        attribute: "id".to_string(),
295                        message: format!(
296                            "@id cannot be applied to relation field '{}.{}'",
297                            model_name,
298                            field.name()
299                        ),
300                    });
301                }
302            }
303            "auto" => {
304                // @auto should only be on Int or BigInt
305                if !matches!(
306                    field.field_type,
307                    FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
308                ) {
309                    self.errors.push(SchemaError::InvalidAttribute {
310                        attribute: "auto".to_string(),
311                        message: format!(
312                            "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
313                            model_name,
314                            field.name()
315                        ),
316                    });
317                }
318            }
319            "default" => {
320                // Validate default value type matches field type
321                if let Some(value) = attr.first_arg() {
322                    self.validate_default_value(value, field, model_name, schema);
323                }
324            }
325            "relation" => {
326                // Validate relation attribute - should only be on actual model references
327                let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
328                    if schema.models.contains_key(name.as_str()));
329                if !is_model_ref {
330                    self.errors.push(SchemaError::InvalidAttribute {
331                        attribute: "relation".to_string(),
332                        message: format!(
333                            "@relation can only be applied to model reference fields, not '{}.{}'",
334                            model_name,
335                            field.name()
336                        ),
337                    });
338                }
339            }
340            "updated_at" => {
341                // @updated_at should only be on DateTime
342                if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
343                    self.errors.push(SchemaError::InvalidAttribute {
344                        attribute: "updated_at".to_string(),
345                        message: format!(
346                            "@updated_at can only be applied to DateTime fields, not '{}.{}'",
347                            model_name,
348                            field.name()
349                        ),
350                    });
351                }
352            }
353            _ => {}
354        }
355    }
356
357    /// Validate a default value matches the field type.
358    fn validate_default_value(
359        &mut self,
360        value: &AttributeValue,
361        field: &Field,
362        model_name: &str,
363        schema: &Schema,
364    ) {
365        match (&field.field_type, value) {
366            // Functions are generally allowed (now(), uuid(), etc.)
367            (_, AttributeValue::Function(_, _)) => {}
368
369            // Int fields should have int defaults
370            (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
371            (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
372
373            // Float fields can have int or float defaults
374            (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
375            (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
376            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
377            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
378
379            // String fields should have string defaults
380            (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
381
382            // Json fields accept any constant as the @default — the payload is
383            // stored as a text literal that the database parses into jsonb.
384            // Prisma writes empty objects/arrays as `@default("[]")` or
385            // `@default("{}")`, so accept string, array, and scalar primitives.
386            (FieldType::Scalar(ScalarType::Json), AttributeValue::String(_))
387            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Array(_))
388            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Boolean(_))
389            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Int(_))
390            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Float(_)) => {}
391
392            // Boolean fields should have boolean defaults
393            (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
394
395            // Enum fields should have ident defaults matching a variant
396            (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
397                if let Some(e) = schema.enums.get(enum_name.as_str())
398                    && e.get_variant(variant).is_none()
399                {
400                    self.errors.push(SchemaError::invalid_field(
401                        model_name,
402                        field.name(),
403                        format!(
404                            "default value '{}' is not a valid variant of enum '{}'",
405                            variant, enum_name
406                        ),
407                    ));
408                }
409            }
410
411            // Model type might actually be an enum (parser treats non-scalar as Model initially)
412            (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
413                // Check if this is actually an enum reference
414                if let Some(e) = schema.enums.get(type_name.as_str())
415                    && e.get_variant(variant).is_none()
416                {
417                    self.errors.push(SchemaError::invalid_field(
418                        model_name,
419                        field.name(),
420                        format!(
421                            "default value '{}' is not a valid variant of enum '{}'",
422                            variant, type_name
423                        ),
424                    ));
425                }
426                // If it's a real model reference with an ident default, that's an error
427                // but we skip that here since it's likely a valid enum
428            }
429
430            // Type mismatch
431            _ => {
432                self.errors.push(SchemaError::invalid_field(
433                    model_name,
434                    field.name(),
435                    format!(
436                        "default value type does not match field type '{}'",
437                        field.field_type
438                    ),
439                ));
440            }
441        }
442    }
443
444    /// Validate a model-level attribute.
445    fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
446        match attr.name() {
447            "index" | "unique" => {
448                // Validate referenced fields exist
449                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
450                    for field_name in fields {
451                        if !model.fields.contains_key(field_name.as_str()) {
452                            self.errors.push(SchemaError::invalid_model(
453                                model.name(),
454                                format!(
455                                    "@@{} references non-existent field '{}'",
456                                    attr.name(),
457                                    field_name
458                                ),
459                            ));
460                        }
461                    }
462                }
463            }
464            "id" => {
465                // Composite primary key
466                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
467                    for field_name in fields {
468                        if !model.fields.contains_key(field_name.as_str()) {
469                            self.errors.push(SchemaError::invalid_model(
470                                model.name(),
471                                format!("@@id references non-existent field '{}'", field_name),
472                            ));
473                        }
474                    }
475                }
476            }
477            "search" => {
478                // Full-text search on fields
479                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
480                    for field_name in fields {
481                        if let Some(field) = model.fields.get(field_name.as_str()) {
482                            // Only string fields can be searched
483                            if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
484                                self.errors.push(SchemaError::invalid_model(
485                                    model.name(),
486                                    format!(
487                                        "@@search field '{}' must be of type String",
488                                        field_name
489                                    ),
490                                ));
491                            }
492                        } else {
493                            self.errors.push(SchemaError::invalid_model(
494                                model.name(),
495                                format!("@@search references non-existent field '{}'", field_name),
496                            ));
497                        }
498                    }
499                }
500            }
501            _ => {}
502        }
503    }
504
505    /// Validate an enum definition.
506    fn validate_enum(&mut self, e: &Enum) {
507        if e.variants.is_empty() {
508            self.errors.push(SchemaError::invalid_model(
509                e.name(),
510                "enum must have at least one variant".to_string(),
511            ));
512        }
513
514        // Check for duplicate variant names
515        let mut seen = std::collections::HashSet::new();
516        for variant in &e.variants {
517            if !seen.insert(variant.name()) {
518                self.errors.push(SchemaError::duplicate(
519                    format!("enum variant in {}", e.name()),
520                    variant.name(),
521                ));
522            }
523        }
524    }
525
526    /// Validate a composite type definition.
527    fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
528        if t.fields.is_empty() {
529            self.errors.push(SchemaError::invalid_model(
530                t.name(),
531                "composite type must have at least one field".to_string(),
532            ));
533        }
534
535        // Validate field types
536        for field in t.fields.values() {
537            match &field.field_type {
538                FieldType::Model(_) => {
539                    self.errors.push(SchemaError::invalid_field(
540                        t.name(),
541                        field.name(),
542                        "composite types cannot have model relations".to_string(),
543                    ));
544                }
545                FieldType::Enum(name) => {
546                    if !schema.enums.contains_key(name.as_str()) {
547                        self.errors.push(SchemaError::unknown_type(
548                            t.name(),
549                            field.name(),
550                            name.as_str(),
551                        ));
552                    }
553                }
554                FieldType::Composite(name) => {
555                    if !schema.types.contains_key(name.as_str()) {
556                        self.errors.push(SchemaError::unknown_type(
557                            t.name(),
558                            field.name(),
559                            name.as_str(),
560                        ));
561                    }
562                }
563                _ => {}
564            }
565        }
566    }
567
568    /// Validate a view definition.
569    fn validate_view(&mut self, v: &View, schema: &Schema) {
570        // Views should have at least one field
571        if v.fields.is_empty() {
572            self.errors.push(SchemaError::invalid_model(
573                v.name(),
574                "view must have at least one field".to_string(),
575            ));
576        }
577
578        // Validate field types
579        for field in v.fields.values() {
580            self.validate_field(field, v.name(), schema);
581        }
582    }
583
584    /// Validate a server group definition.
585    fn validate_server_group(&mut self, sg: &ServerGroup) {
586        // Server groups should have at least one server
587        if sg.servers.is_empty() {
588            self.errors.push(SchemaError::invalid_model(
589                sg.name.name.as_str(),
590                "serverGroup must have at least one server".to_string(),
591            ));
592        }
593
594        // Check for duplicate server names within the group
595        let mut seen_servers = std::collections::HashSet::new();
596        for server_name in sg.servers.keys() {
597            if !seen_servers.insert(server_name.as_str()) {
598                self.errors.push(SchemaError::duplicate(
599                    format!("server in serverGroup {}", sg.name.name),
600                    server_name.as_str(),
601                ));
602            }
603        }
604
605        // Validate each server
606        for server in sg.servers.values() {
607            self.validate_server(server, sg.name.name.as_str());
608        }
609
610        // Validate server group attributes
611        for attr in &sg.attributes {
612            self.validate_server_group_attribute(attr, sg);
613        }
614
615        // Check for at least one primary server in read replica strategy
616        if let Some(strategy) = sg.strategy()
617            && strategy == ServerGroupStrategy::ReadReplica
618        {
619            let has_primary = sg
620                .servers
621                .values()
622                .any(|s| s.role() == Some(ServerRole::Primary));
623            if !has_primary {
624                self.errors.push(SchemaError::invalid_model(
625                    sg.name.name.as_str(),
626                    "ReadReplica strategy requires at least one server with role = \"primary\""
627                        .to_string(),
628                ));
629            }
630        }
631    }
632
633    /// Validate an individual server definition.
634    fn validate_server(&mut self, server: &Server, group_name: &str) {
635        // Server should have a URL property
636        if server.url().is_none() {
637            self.errors.push(SchemaError::invalid_model(
638                group_name,
639                format!("server '{}' must have a 'url' property", server.name.name),
640            ));
641        }
642
643        // Validate weight is positive if specified
644        if let Some(weight) = server.weight()
645            && weight == 0
646        {
647            self.errors.push(SchemaError::invalid_model(
648                group_name,
649                format!(
650                    "server '{}' weight must be greater than 0",
651                    server.name.name
652                ),
653            ));
654        }
655
656        // Validate priority is positive if specified
657        if let Some(priority) = server.priority()
658            && priority == 0
659        {
660            self.errors.push(SchemaError::invalid_model(
661                group_name,
662                format!(
663                    "server '{}' priority must be greater than 0",
664                    server.name.name
665                ),
666            ));
667        }
668    }
669
670    /// Validate a server group attribute.
671    fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
672        match attr.name() {
673            "strategy" => {
674                // Validate strategy value
675                if let Some(arg) = attr.first_arg() {
676                    let value_str = arg
677                        .as_string()
678                        .map(|s| s.to_string())
679                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
680                    if let Some(val) = value_str
681                        && ServerGroupStrategy::parse(&val).is_none()
682                    {
683                        self.errors.push(SchemaError::InvalidAttribute {
684                                attribute: "strategy".to_string(),
685                                message: format!(
686                                    "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
687                                    val,
688                                    sg.name.name
689                                ),
690                            });
691                    }
692                }
693            }
694            "loadBalance" => {
695                // Validate load balance value
696                if let Some(arg) = attr.first_arg() {
697                    let value_str = arg
698                        .as_string()
699                        .map(|s| s.to_string())
700                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
701                    if let Some(val) = value_str
702                        && LoadBalanceStrategy::parse(&val).is_none()
703                    {
704                        self.errors.push(SchemaError::InvalidAttribute {
705                                attribute: "loadBalance".to_string(),
706                                message: format!(
707                                    "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
708                                    val,
709                                    sg.name.name
710                                ),
711                            });
712                    }
713                }
714            }
715            _ => {} // Other attributes are allowed
716        }
717    }
718
719    /// Resolve all relations in the schema.
720    fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
721        let mut relations = Vec::new();
722
723        for model in schema.models.values() {
724            for field in model.fields.values() {
725                if let FieldType::Model(ref target_model) = field.field_type {
726                    // Skip if this is actually an enum reference (parser treats non-scalar as Model initially)
727                    if schema.enums.contains_key(target_model.as_str()) {
728                        continue;
729                    }
730
731                    // Skip if this is actually a composite type reference
732                    if schema.types.contains_key(target_model.as_str()) {
733                        continue;
734                    }
735
736                    // Skip if the target model doesn't exist (error was already reported)
737                    if !schema.models.contains_key(target_model.as_str()) {
738                        continue;
739                    }
740
741                    let attrs = field.extract_attributes();
742
743                    let relation_type = if field.is_list() {
744                        // This model has many of target
745                        RelationType::OneToMany
746                    } else {
747                        // This model has one of target
748                        RelationType::ManyToOne
749                    };
750
751                    let mut relation = Relation::new(
752                        model.name(),
753                        field.name(),
754                        target_model.as_str(),
755                        relation_type,
756                    );
757
758                    if let Some(rel_attr) = &attrs.relation {
759                        if let Some(name) = &rel_attr.name {
760                            relation = relation.with_name(name.as_str());
761                        }
762                        if !rel_attr.fields.is_empty() {
763                            relation = relation.with_from_fields(rel_attr.fields.clone());
764                        }
765                        if !rel_attr.references.is_empty() {
766                            relation = relation.with_to_fields(rel_attr.references.clone());
767                        }
768                        if let Some(action) = rel_attr.on_delete {
769                            relation = relation.with_on_delete(action);
770                        }
771                        if let Some(action) = rel_attr.on_update {
772                            relation = relation.with_on_update(action);
773                        }
774                        if let Some(map) = &rel_attr.map {
775                            relation = relation.with_map(map.as_str());
776                        }
777                    }
778
779                    relations.push(relation);
780                }
781            }
782        }
783
784        relations
785    }
786}
787
788/// Validate a schema string and return the validated schema.
789pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
790    let schema = crate::parser::parse_schema(input)?;
791    let mut validator = Validator::new();
792    validator.validate(schema)
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798
799    #[test]
800    fn test_validate_simple_model() {
801        let schema = validate_schema(
802            r#"
803            model User {
804                id    Int    @id @auto
805                email String @unique
806            }
807        "#,
808        )
809        .unwrap();
810
811        assert_eq!(schema.models.len(), 1);
812    }
813
814    #[test]
815    fn test_validate_model_missing_id() {
816        let result = validate_schema(
817            r#"
818            model User {
819                email String
820                name  String
821            }
822        "#,
823        );
824
825        assert!(result.is_err());
826        let err = result.unwrap_err();
827        assert!(matches!(err, SchemaError::ValidationFailed { .. }));
828    }
829
830    #[test]
831    fn test_validate_model_with_composite_id() {
832        let schema = validate_schema(
833            r#"
834            model PostTag {
835                post_id Int
836                tag_id  Int
837
838                @@id([post_id, tag_id])
839            }
840        "#,
841        )
842        .unwrap();
843
844        assert_eq!(schema.models.len(), 1);
845    }
846
847    #[test]
848    fn test_validate_unknown_type_reference() {
849        let result = validate_schema(
850            r#"
851            model User {
852                id      Int    @id @auto
853                profile UnknownType
854            }
855        "#,
856        );
857
858        assert!(result.is_err());
859    }
860
861    #[test]
862    fn test_validate_enum_reference() {
863        let schema = validate_schema(
864            r#"
865            enum Role {
866                User
867                Admin
868            }
869
870            model User {
871                id   Int    @id @auto
872                role Role   @default(User)
873            }
874        "#,
875        )
876        .unwrap();
877
878        assert_eq!(schema.models.len(), 1);
879        assert_eq!(schema.enums.len(), 1);
880    }
881
882    #[test]
883    fn test_validate_invalid_enum_default() {
884        let result = validate_schema(
885            r#"
886            enum Role {
887                User
888                Admin
889            }
890
891            model User {
892                id   Int    @id @auto
893                role Role   @default(Unknown)
894            }
895        "#,
896        );
897
898        assert!(result.is_err());
899    }
900
901    #[test]
902    fn test_validate_auto_on_non_int() {
903        let result = validate_schema(
904            r#"
905            model User {
906                id    String @id @auto
907                email String
908            }
909        "#,
910        );
911
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn test_validate_updated_at_on_non_datetime() {
917        let result = validate_schema(
918            r#"
919            model User {
920                id         Int    @id @auto
921                updated_at String @updated_at
922            }
923        "#,
924        );
925
926        assert!(result.is_err());
927    }
928
929    #[test]
930    fn test_validate_empty_enum() {
931        let result = validate_schema(
932            r#"
933            enum Empty {
934            }
935
936            model User {
937                id Int @id @auto
938            }
939        "#,
940        );
941
942        assert!(result.is_err());
943    }
944
945    #[test]
946    fn test_validate_duplicate_model_names() {
947        let result = validate_schema(
948            r#"
949            model User {
950                id Int @id @auto
951            }
952
953            model User {
954                id Int @id @auto
955            }
956        "#,
957        );
958
959        // Note: This might parse as a single model due to grammar
960        // The duplicate check happens at validation time
961        assert!(result.is_ok() || result.is_err());
962    }
963
964    #[test]
965    fn test_validate_relation() {
966        let schema = validate_schema(
967            r#"
968            model User {
969                id    Int    @id @auto
970                posts Post[]
971            }
972
973            model Post {
974                id        Int    @id @auto
975                author_id Int
976                author    User   @relation(fields: [author_id], references: [id])
977            }
978        "#,
979        )
980        .unwrap();
981
982        assert_eq!(schema.models.len(), 2);
983        assert!(!schema.relations.is_empty());
984    }
985
986    #[test]
987    fn test_validate_index_with_invalid_field() {
988        let result = validate_schema(
989            r#"
990            model User {
991                id    Int    @id @auto
992                email String
993
994                @@index([nonexistent])
995            }
996        "#,
997        );
998
999        assert!(result.is_err());
1000    }
1001
1002    #[test]
1003    fn test_validate_search_on_non_string_field() {
1004        let result = validate_schema(
1005            r#"
1006            model Post {
1007                id    Int    @id @auto
1008                views Int
1009
1010                @@search([views])
1011            }
1012        "#,
1013        );
1014
1015        assert!(result.is_err());
1016    }
1017
1018    #[test]
1019    fn test_validate_composite_type() {
1020        let schema = validate_schema(
1021            r#"
1022            type Address {
1023                street  String
1024                city    String
1025                country String @default("US")
1026            }
1027
1028            model User {
1029                id      Int     @id @auto
1030                address Address
1031            }
1032        "#,
1033        );
1034
1035        // Note: Composite type support depends on parser handling
1036        assert!(schema.is_ok() || schema.is_err());
1037    }
1038
1039    // ==================== Server Group Validation Tests ====================
1040
1041    #[test]
1042    fn test_validate_server_group_basic() {
1043        let schema = validate_schema(
1044            r#"
1045            model User {
1046                id Int @id @auto
1047            }
1048
1049            serverGroup MainCluster {
1050                server primary {
1051                    url = "postgres://localhost/db"
1052                    role = "primary"
1053                }
1054            }
1055        "#,
1056        )
1057        .unwrap();
1058
1059        assert_eq!(schema.server_groups.len(), 1);
1060    }
1061
1062    #[test]
1063    fn test_validate_server_group_empty_servers() {
1064        let result = validate_schema(
1065            r#"
1066            model User {
1067                id Int @id @auto
1068            }
1069
1070            serverGroup EmptyCluster {
1071            }
1072        "#,
1073        );
1074
1075        assert!(result.is_err());
1076    }
1077
1078    #[test]
1079    fn test_validate_server_group_missing_url() {
1080        let result = validate_schema(
1081            r#"
1082            model User {
1083                id Int @id @auto
1084            }
1085
1086            serverGroup Cluster {
1087                server db {
1088                    role = "primary"
1089                }
1090            }
1091        "#,
1092        );
1093
1094        assert!(result.is_err());
1095    }
1096
1097    #[test]
1098    fn test_validate_server_group_invalid_strategy() {
1099        let result = validate_schema(
1100            r#"
1101            model User {
1102                id Int @id @auto
1103            }
1104
1105            serverGroup Cluster {
1106                @@strategy(InvalidStrategy)
1107
1108                server db {
1109                    url = "postgres://localhost/db"
1110                }
1111            }
1112        "#,
1113        );
1114
1115        assert!(result.is_err());
1116    }
1117
1118    #[test]
1119    fn test_validate_server_group_valid_strategy() {
1120        let schema = validate_schema(
1121            r#"
1122            model User {
1123                id Int @id @auto
1124            }
1125
1126            serverGroup Cluster {
1127                @@strategy(ReadReplica)
1128                @@loadBalance(RoundRobin)
1129
1130                server primary {
1131                    url = "postgres://localhost/db"
1132                    role = "primary"
1133                }
1134            }
1135        "#,
1136        )
1137        .unwrap();
1138
1139        assert_eq!(schema.server_groups.len(), 1);
1140    }
1141
1142    #[test]
1143    fn test_validate_server_group_read_replica_needs_primary() {
1144        let result = validate_schema(
1145            r#"
1146            model User {
1147                id Int @id @auto
1148            }
1149
1150            serverGroup Cluster {
1151                @@strategy(ReadReplica)
1152
1153                server replica1 {
1154                    url = "postgres://localhost/db"
1155                    role = "replica"
1156                }
1157            }
1158        "#,
1159        );
1160
1161        assert!(result.is_err());
1162    }
1163
1164    #[test]
1165    fn test_validate_server_group_with_replicas() {
1166        let schema = validate_schema(
1167            r#"
1168            model User {
1169                id Int @id @auto
1170            }
1171
1172            serverGroup Cluster {
1173                @@strategy(ReadReplica)
1174
1175                server primary {
1176                    url = "postgres://primary/db"
1177                    role = "primary"
1178                    weight = 1
1179                }
1180
1181                server replica1 {
1182                    url = "postgres://replica1/db"
1183                    role = "replica"
1184                    weight = 2
1185                }
1186
1187                server replica2 {
1188                    url = "postgres://replica2/db"
1189                    role = "replica"
1190                    weight = 2
1191                    region = "us-west-1"
1192                }
1193            }
1194        "#,
1195        )
1196        .unwrap();
1197
1198        let cluster = schema.get_server_group("Cluster").unwrap();
1199        assert_eq!(cluster.servers.len(), 3);
1200    }
1201
1202    #[test]
1203    fn test_validate_server_group_zero_weight() {
1204        let result = validate_schema(
1205            r#"
1206            model User {
1207                id Int @id @auto
1208            }
1209
1210            serverGroup Cluster {
1211                server db {
1212                    url = "postgres://localhost/db"
1213                    weight = 0
1214                }
1215            }
1216        "#,
1217        );
1218
1219        assert!(result.is_err());
1220    }
1221
1222    #[test]
1223    fn test_validate_server_group_invalid_load_balance() {
1224        let result = validate_schema(
1225            r#"
1226            model User {
1227                id Int @id @auto
1228            }
1229
1230            serverGroup Cluster {
1231                @@loadBalance(InvalidStrategy)
1232
1233                server db {
1234                    url = "postgres://localhost/db"
1235                }
1236            }
1237        "#,
1238        );
1239
1240        assert!(result.is_err());
1241    }
1242}