Skip to main content

prax_schema/
validator.rs

1//! Schema validation and semantic analysis.
2//!
3//! This module validates parsed schemas for semantic correctness:
4//! - All type references are valid
5//! - Relations are properly defined
6//! - Required attributes are present
7//! - No duplicate definitions
8
9use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12/// Schema validator for semantic analysis.
13#[derive(Debug)]
14pub struct Validator {
15    /// Collected validation errors.
16    errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl Validator {
26    /// Create a new validator.
27    pub fn new() -> Self {
28        Self { errors: vec![] }
29    }
30
31    /// Validate a schema and return the validated schema or errors.
32    pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33        self.errors.clear();
34
35        // Check for duplicate definitions
36        self.check_duplicates(&schema);
37
38        // Resolve field types (convert Model references to Enum or Composite where appropriate)
39        self.resolve_field_types(&mut schema);
40
41        // Validate each model
42        for model in schema.models.values() {
43            self.validate_model(model, &schema);
44        }
45
46        // Validate each enum
47        for e in schema.enums.values() {
48            self.validate_enum(e);
49        }
50
51        // Validate each composite type
52        for t in schema.types.values() {
53            self.validate_composite_type(t, &schema);
54        }
55
56        // Validate each view
57        for v in schema.views.values() {
58            self.validate_view(v, &schema);
59        }
60
61        // Validate each server group
62        for sg in schema.server_groups.values() {
63            self.validate_server_group(sg);
64        }
65
66        // Resolve relations
67        let relations = self.resolve_relations(&schema);
68        schema.relations = relations;
69
70        if self.errors.is_empty() {
71            Ok(schema)
72        } else {
73            Err(SchemaError::ValidationFailed {
74                count: self.errors.len(),
75                errors: std::mem::take(&mut self.errors),
76            })
77        }
78    }
79
80    /// Check for duplicate model, enum, or type names.
81    fn check_duplicates(&mut self, schema: &Schema) {
82        let mut seen = std::collections::HashSet::new();
83
84        for name in schema.models.keys() {
85            if !seen.insert(name.as_str()) {
86                self.errors
87                    .push(SchemaError::duplicate("model", name.as_str()));
88            }
89        }
90
91        for name in schema.enums.keys() {
92            if !seen.insert(name.as_str()) {
93                self.errors
94                    .push(SchemaError::duplicate("enum", name.as_str()));
95            }
96        }
97
98        for name in schema.types.keys() {
99            if !seen.insert(name.as_str()) {
100                self.errors
101                    .push(SchemaError::duplicate("type", name.as_str()));
102            }
103        }
104
105        for name in schema.views.keys() {
106            if !seen.insert(name.as_str()) {
107                self.errors
108                    .push(SchemaError::duplicate("view", name.as_str()));
109            }
110        }
111
112        // Check server group names (separately, since they don't conflict with types)
113        let mut server_group_names = std::collections::HashSet::new();
114        for name in schema.server_groups.keys() {
115            if !server_group_names.insert(name.as_str()) {
116                self.errors
117                    .push(SchemaError::duplicate("serverGroup", name.as_str()));
118            }
119        }
120    }
121
122    /// Resolve field types to their correct types (Enum or Composite) instead of Model.
123    ///
124    /// The parser initially treats all non-scalar type references as Model references.
125    /// This pass corrects them to Enum or Composite where appropriate.
126    fn resolve_field_types(&self, schema: &mut Schema) {
127        // Collect enum and composite type names into owned strings to avoid borrow conflicts
128        let enum_names: std::collections::HashSet<String> =
129            schema.enums.keys().map(|s| s.to_string()).collect();
130        let composite_names: std::collections::HashSet<String> =
131            schema.types.keys().map(|s| s.to_string()).collect();
132
133        // Update field types in models
134        for model in schema.models.values_mut() {
135            for field in model.fields.values_mut() {
136                if let FieldType::Model(ref type_name) = field.field_type {
137                    let name = type_name.as_str();
138                    if enum_names.contains(name) {
139                        field.field_type = FieldType::Enum(type_name.clone());
140                    } else if composite_names.contains(name) {
141                        field.field_type = FieldType::Composite(type_name.clone());
142                    }
143                }
144            }
145        }
146
147        // Also update field types in composite types
148        for composite in schema.types.values_mut() {
149            for field in composite.fields.values_mut() {
150                if let FieldType::Model(ref type_name) = field.field_type {
151                    let name = type_name.as_str();
152                    if enum_names.contains(name) {
153                        field.field_type = FieldType::Enum(type_name.clone());
154                    } else if composite_names.contains(name) {
155                        field.field_type = FieldType::Composite(type_name.clone());
156                    }
157                }
158            }
159        }
160
161        // Also update field types in views
162        for view in schema.views.values_mut() {
163            for field in view.fields.values_mut() {
164                if let FieldType::Model(ref type_name) = field.field_type {
165                    let name = type_name.as_str();
166                    if enum_names.contains(name) {
167                        field.field_type = FieldType::Enum(type_name.clone());
168                    } else if composite_names.contains(name) {
169                        field.field_type = FieldType::Composite(type_name.clone());
170                    }
171                }
172            }
173        }
174    }
175
176    /// Validate a model definition.
177    fn validate_model(&mut self, model: &Model, schema: &Schema) {
178        // Check for @id field
179        let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
180        if id_fields.is_empty() && !self.has_composite_id(model) {
181            self.errors.push(SchemaError::MissingId {
182                model: model.name().to_string(),
183            });
184        }
185
186        // Validate each field
187        for field in model.fields.values() {
188            self.validate_field(field, model.name(), schema);
189        }
190
191        // Validate model attributes
192        for attr in &model.attributes {
193            self.validate_model_attribute(attr, model);
194        }
195    }
196
197    /// Check if model has a composite ID (@@id attribute).
198    fn has_composite_id(&self, model: &Model) -> bool {
199        model.attributes.iter().any(|a| a.is("id"))
200    }
201
202    /// Validate a field definition.
203    fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
204        // Validate type references
205        match &field.field_type {
206            FieldType::Model(name) => {
207                // Check if it's actually a model, enum, or composite type
208                if schema.models.contains_key(name.as_str()) {
209                    // Valid model reference
210                } else if schema.enums.contains_key(name.as_str()) {
211                    // Parser initially treats non-scalar types as Model references
212                    // This is actually an enum type - we'll handle this during resolution
213                } else if schema.types.contains_key(name.as_str()) {
214                    // This is a composite type
215                } else {
216                    self.errors.push(SchemaError::unknown_type(
217                        model_name,
218                        field.name(),
219                        name.as_str(),
220                    ));
221                }
222            }
223            FieldType::Enum(name) => {
224                if !schema.enums.contains_key(name.as_str()) {
225                    self.errors.push(SchemaError::unknown_type(
226                        model_name,
227                        field.name(),
228                        name.as_str(),
229                    ));
230                }
231            }
232            FieldType::Composite(name) => {
233                if !schema.types.contains_key(name.as_str()) {
234                    self.errors.push(SchemaError::unknown_type(
235                        model_name,
236                        field.name(),
237                        name.as_str(),
238                    ));
239                }
240            }
241            _ => {}
242        }
243
244        // Validate field attributes
245        for attr in &field.attributes {
246            self.validate_field_attribute(attr, field, model_name, schema);
247        }
248
249        // Validate relation fields have @relation or are back-references
250        // Only check actual model relations (not enums or composite types parsed as Model)
251        if let FieldType::Model(ref target_name) = field.field_type {
252            // Skip validation for enum and composite type references
253            let is_actual_relation = schema.models.contains_key(target_name.as_str())
254                && !schema.enums.contains_key(target_name.as_str())
255                && !schema.types.contains_key(target_name.as_str());
256
257            if is_actual_relation && !field.is_list() {
258                // One-side of relation should have foreign key fields
259                let attrs = field.extract_attributes();
260                if attrs.relation.is_some() {
261                    let rel = attrs.relation.as_ref().unwrap();
262                    // Validate foreign key fields exist
263                    for fk_field in &rel.fields {
264                        if !schema
265                            .models
266                            .get(model_name)
267                            .map(|m| m.fields.contains_key(fk_field.as_str()))
268                            .unwrap_or(false)
269                        {
270                            self.errors.push(SchemaError::invalid_relation(
271                                model_name,
272                                field.name(),
273                                format!("foreign key field '{}' does not exist", fk_field),
274                            ));
275                        }
276                    }
277                }
278            }
279        }
280    }
281
282    /// Validate a field attribute.
283    fn validate_field_attribute(
284        &mut self,
285        attr: &Attribute,
286        field: &Field,
287        model_name: &str,
288        schema: &Schema,
289    ) {
290        match attr.name() {
291            "id" => {
292                // @id should be on a scalar or composite type, not a relation
293                if field.field_type.is_relation() {
294                    self.errors.push(SchemaError::InvalidAttribute {
295                        attribute: "id".to_string(),
296                        message: format!(
297                            "@id cannot be applied to relation field '{}.{}'",
298                            model_name,
299                            field.name()
300                        ),
301                    });
302                }
303            }
304            "auto" => {
305                // @auto should only be on Int or BigInt
306                if !matches!(
307                    field.field_type,
308                    FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
309                ) {
310                    self.errors.push(SchemaError::InvalidAttribute {
311                        attribute: "auto".to_string(),
312                        message: format!(
313                            "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
314                            model_name,
315                            field.name()
316                        ),
317                    });
318                }
319            }
320            "default" => {
321                // Validate default value type matches field type
322                if let Some(value) = attr.first_arg() {
323                    self.validate_default_value(value, field, model_name, schema);
324                }
325            }
326            "relation" => {
327                // Validate relation attribute - should only be on actual model references
328                let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
329                    if schema.models.contains_key(name.as_str()));
330                if !is_model_ref {
331                    self.errors.push(SchemaError::InvalidAttribute {
332                        attribute: "relation".to_string(),
333                        message: format!(
334                            "@relation can only be applied to model reference fields, not '{}.{}'",
335                            model_name,
336                            field.name()
337                        ),
338                    });
339                }
340            }
341            "updated_at" => {
342                // @updated_at should only be on DateTime
343                if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
344                    self.errors.push(SchemaError::InvalidAttribute {
345                        attribute: "updated_at".to_string(),
346                        message: format!(
347                            "@updated_at can only be applied to DateTime fields, not '{}.{}'",
348                            model_name,
349                            field.name()
350                        ),
351                    });
352                }
353            }
354            _ => {}
355        }
356    }
357
358    /// Validate a default value matches the field type.
359    fn validate_default_value(
360        &mut self,
361        value: &AttributeValue,
362        field: &Field,
363        model_name: &str,
364        schema: &Schema,
365    ) {
366        match (&field.field_type, value) {
367            // Functions are generally allowed (now(), uuid(), etc.)
368            (_, AttributeValue::Function(_, _)) => {}
369
370            // Int fields should have int defaults
371            (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
372            (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
373
374            // Float fields can have int or float defaults
375            (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
376            (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
377            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
378            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
379
380            // String fields should have string defaults
381            (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
382
383            // Json fields accept any constant as the @default — the payload is
384            // stored as a text literal that the database parses into jsonb.
385            // Prisma writes empty objects/arrays as `@default("[]")` or
386            // `@default("{}")`, so accept string, array, and scalar primitives.
387            (FieldType::Scalar(ScalarType::Json), AttributeValue::String(_))
388            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Array(_))
389            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Boolean(_))
390            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Int(_))
391            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Float(_)) => {}
392
393            // Boolean fields should have boolean defaults
394            (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
395
396            // Enum fields should have ident defaults matching a variant
397            (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
398                if let Some(e) = schema.enums.get(enum_name.as_str()) {
399                    if e.get_variant(variant).is_none() {
400                        self.errors.push(SchemaError::invalid_field(
401                            model_name,
402                            field.name(),
403                            format!(
404                                "default value '{}' is not a valid variant of enum '{}'",
405                                variant, enum_name
406                            ),
407                        ));
408                    }
409                }
410            }
411
412            // Model type might actually be an enum (parser treats non-scalar as Model initially)
413            (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
414                // Check if this is actually an enum reference
415                if let Some(e) = schema.enums.get(type_name.as_str()) {
416                    if e.get_variant(variant).is_none() {
417                        self.errors.push(SchemaError::invalid_field(
418                            model_name,
419                            field.name(),
420                            format!(
421                                "default value '{}' is not a valid variant of enum '{}'",
422                                variant, type_name
423                            ),
424                        ));
425                    }
426                }
427                // If it's a real model reference with an ident default, that's an error
428                // but we skip that here since it's likely a valid enum
429            }
430
431            // Type mismatch
432            _ => {
433                self.errors.push(SchemaError::invalid_field(
434                    model_name,
435                    field.name(),
436                    format!(
437                        "default value type does not match field type '{}'",
438                        field.field_type
439                    ),
440                ));
441            }
442        }
443    }
444
445    /// Validate a model-level attribute.
446    fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
447        match attr.name() {
448            "index" | "unique" => {
449                // Validate referenced fields exist
450                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
451                    for field_name in fields {
452                        if !model.fields.contains_key(field_name.as_str()) {
453                            self.errors.push(SchemaError::invalid_model(
454                                model.name(),
455                                format!(
456                                    "@@{} references non-existent field '{}'",
457                                    attr.name(),
458                                    field_name
459                                ),
460                            ));
461                        }
462                    }
463                }
464            }
465            "id" => {
466                // Composite primary key
467                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
468                    for field_name in fields {
469                        if !model.fields.contains_key(field_name.as_str()) {
470                            self.errors.push(SchemaError::invalid_model(
471                                model.name(),
472                                format!("@@id references non-existent field '{}'", field_name),
473                            ));
474                        }
475                    }
476                }
477            }
478            "search" => {
479                // Full-text search on fields
480                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
481                    for field_name in fields {
482                        if let Some(field) = model.fields.get(field_name.as_str()) {
483                            // Only string fields can be searched
484                            if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
485                                self.errors.push(SchemaError::invalid_model(
486                                    model.name(),
487                                    format!(
488                                        "@@search field '{}' must be of type String",
489                                        field_name
490                                    ),
491                                ));
492                            }
493                        } else {
494                            self.errors.push(SchemaError::invalid_model(
495                                model.name(),
496                                format!("@@search references non-existent field '{}'", field_name),
497                            ));
498                        }
499                    }
500                }
501            }
502            _ => {}
503        }
504    }
505
506    /// Validate an enum definition.
507    fn validate_enum(&mut self, e: &Enum) {
508        if e.variants.is_empty() {
509            self.errors.push(SchemaError::invalid_model(
510                e.name(),
511                "enum must have at least one variant".to_string(),
512            ));
513        }
514
515        // Check for duplicate variant names
516        let mut seen = std::collections::HashSet::new();
517        for variant in &e.variants {
518            if !seen.insert(variant.name()) {
519                self.errors.push(SchemaError::duplicate(
520                    format!("enum variant in {}", e.name()),
521                    variant.name(),
522                ));
523            }
524        }
525    }
526
527    /// Validate a composite type definition.
528    fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
529        if t.fields.is_empty() {
530            self.errors.push(SchemaError::invalid_model(
531                t.name(),
532                "composite type must have at least one field".to_string(),
533            ));
534        }
535
536        // Validate field types
537        for field in t.fields.values() {
538            match &field.field_type {
539                FieldType::Model(_) => {
540                    self.errors.push(SchemaError::invalid_field(
541                        t.name(),
542                        field.name(),
543                        "composite types cannot have model relations".to_string(),
544                    ));
545                }
546                FieldType::Enum(name) => {
547                    if !schema.enums.contains_key(name.as_str()) {
548                        self.errors.push(SchemaError::unknown_type(
549                            t.name(),
550                            field.name(),
551                            name.as_str(),
552                        ));
553                    }
554                }
555                FieldType::Composite(name) => {
556                    if !schema.types.contains_key(name.as_str()) {
557                        self.errors.push(SchemaError::unknown_type(
558                            t.name(),
559                            field.name(),
560                            name.as_str(),
561                        ));
562                    }
563                }
564                _ => {}
565            }
566        }
567    }
568
569    /// Validate a view definition.
570    fn validate_view(&mut self, v: &View, schema: &Schema) {
571        // Views should have at least one field
572        if v.fields.is_empty() {
573            self.errors.push(SchemaError::invalid_model(
574                v.name(),
575                "view must have at least one field".to_string(),
576            ));
577        }
578
579        // Validate field types
580        for field in v.fields.values() {
581            self.validate_field(field, v.name(), schema);
582        }
583    }
584
585    /// Validate a server group definition.
586    fn validate_server_group(&mut self, sg: &ServerGroup) {
587        // Server groups should have at least one server
588        if sg.servers.is_empty() {
589            self.errors.push(SchemaError::invalid_model(
590                sg.name.name.as_str(),
591                "serverGroup must have at least one server".to_string(),
592            ));
593        }
594
595        // Check for duplicate server names within the group
596        let mut seen_servers = std::collections::HashSet::new();
597        for server_name in sg.servers.keys() {
598            if !seen_servers.insert(server_name.as_str()) {
599                self.errors.push(SchemaError::duplicate(
600                    format!("server in serverGroup {}", sg.name.name),
601                    server_name.as_str(),
602                ));
603            }
604        }
605
606        // Validate each server
607        for server in sg.servers.values() {
608            self.validate_server(server, sg.name.name.as_str());
609        }
610
611        // Validate server group attributes
612        for attr in &sg.attributes {
613            self.validate_server_group_attribute(attr, sg);
614        }
615
616        // Check for at least one primary server in read replica strategy
617        if let Some(strategy) = sg.strategy() {
618            if strategy == ServerGroupStrategy::ReadReplica {
619                let has_primary = sg
620                    .servers
621                    .values()
622                    .any(|s| s.role() == Some(ServerRole::Primary));
623                if !has_primary {
624                    self.errors.push(SchemaError::invalid_model(
625                        sg.name.name.as_str(),
626                        "ReadReplica strategy requires at least one server with role = \"primary\""
627                            .to_string(),
628                    ));
629                }
630            }
631        }
632    }
633
634    /// Validate an individual server definition.
635    fn validate_server(&mut self, server: &Server, group_name: &str) {
636        // Server should have a URL property
637        if server.url().is_none() {
638            self.errors.push(SchemaError::invalid_model(
639                group_name,
640                format!("server '{}' must have a 'url' property", server.name.name),
641            ));
642        }
643
644        // Validate weight is positive if specified
645        if let Some(weight) = server.weight() {
646            if weight == 0 {
647                self.errors.push(SchemaError::invalid_model(
648                    group_name,
649                    format!(
650                        "server '{}' weight must be greater than 0",
651                        server.name.name
652                    ),
653                ));
654            }
655        }
656
657        // Validate priority is positive if specified
658        if let Some(priority) = server.priority() {
659            if priority == 0 {
660                self.errors.push(SchemaError::invalid_model(
661                    group_name,
662                    format!(
663                        "server '{}' priority must be greater than 0",
664                        server.name.name
665                    ),
666                ));
667            }
668        }
669    }
670
671    /// Validate a server group attribute.
672    fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
673        match attr.name() {
674            "strategy" => {
675                // Validate strategy value
676                if let Some(arg) = attr.first_arg() {
677                    let value_str = arg
678                        .as_string()
679                        .map(|s| s.to_string())
680                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
681                    if let Some(val) = value_str {
682                        if ServerGroupStrategy::parse(&val).is_none() {
683                            self.errors.push(SchemaError::InvalidAttribute {
684                                attribute: "strategy".to_string(),
685                                message: format!(
686                                    "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
687                                    val,
688                                    sg.name.name
689                                ),
690                            });
691                        }
692                    }
693                }
694            }
695            "loadBalance" => {
696                // Validate load balance value
697                if let Some(arg) = attr.first_arg() {
698                    let value_str = arg
699                        .as_string()
700                        .map(|s| s.to_string())
701                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
702                    if let Some(val) = value_str {
703                        if LoadBalanceStrategy::parse(&val).is_none() {
704                            self.errors.push(SchemaError::InvalidAttribute {
705                                attribute: "loadBalance".to_string(),
706                                message: format!(
707                                    "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
708                                    val,
709                                    sg.name.name
710                                ),
711                            });
712                        }
713                    }
714                }
715            }
716            _ => {} // Other attributes are allowed
717        }
718    }
719
720    /// Resolve all relations in the schema.
721    fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
722        let mut relations = Vec::new();
723
724        for model in schema.models.values() {
725            for field in model.fields.values() {
726                if let FieldType::Model(ref target_model) = field.field_type {
727                    // Skip if this is actually an enum reference (parser treats non-scalar as Model initially)
728                    if schema.enums.contains_key(target_model.as_str()) {
729                        continue;
730                    }
731
732                    // Skip if this is actually a composite type reference
733                    if schema.types.contains_key(target_model.as_str()) {
734                        continue;
735                    }
736
737                    // Skip if the target model doesn't exist (error was already reported)
738                    if !schema.models.contains_key(target_model.as_str()) {
739                        continue;
740                    }
741
742                    let attrs = field.extract_attributes();
743
744                    let relation_type = if field.is_list() {
745                        // This model has many of target
746                        RelationType::OneToMany
747                    } else {
748                        // This model has one of target
749                        RelationType::ManyToOne
750                    };
751
752                    let mut relation = Relation::new(
753                        model.name(),
754                        field.name(),
755                        target_model.as_str(),
756                        relation_type,
757                    );
758
759                    if let Some(rel_attr) = &attrs.relation {
760                        if let Some(name) = &rel_attr.name {
761                            relation = relation.with_name(name.as_str());
762                        }
763                        if !rel_attr.fields.is_empty() {
764                            relation = relation.with_from_fields(rel_attr.fields.clone());
765                        }
766                        if !rel_attr.references.is_empty() {
767                            relation = relation.with_to_fields(rel_attr.references.clone());
768                        }
769                        if let Some(action) = rel_attr.on_delete {
770                            relation = relation.with_on_delete(action);
771                        }
772                        if let Some(action) = rel_attr.on_update {
773                            relation = relation.with_on_update(action);
774                        }
775                        if let Some(map) = &rel_attr.map {
776                            relation = relation.with_map(map.as_str());
777                        }
778                    }
779
780                    relations.push(relation);
781                }
782            }
783        }
784
785        relations
786    }
787}
788
789/// Validate a schema string and return the validated schema.
790pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
791    let schema = crate::parser::parse_schema(input)?;
792    let mut validator = Validator::new();
793    validator.validate(schema)
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799
800    #[test]
801    fn test_validate_simple_model() {
802        let schema = validate_schema(
803            r#"
804            model User {
805                id    Int    @id @auto
806                email String @unique
807            }
808        "#,
809        )
810        .unwrap();
811
812        assert_eq!(schema.models.len(), 1);
813    }
814
815    #[test]
816    fn test_validate_model_missing_id() {
817        let result = validate_schema(
818            r#"
819            model User {
820                email String
821                name  String
822            }
823        "#,
824        );
825
826        assert!(result.is_err());
827        let err = result.unwrap_err();
828        assert!(matches!(err, SchemaError::ValidationFailed { .. }));
829    }
830
831    #[test]
832    fn test_validate_model_with_composite_id() {
833        let schema = validate_schema(
834            r#"
835            model PostTag {
836                post_id Int
837                tag_id  Int
838
839                @@id([post_id, tag_id])
840            }
841        "#,
842        )
843        .unwrap();
844
845        assert_eq!(schema.models.len(), 1);
846    }
847
848    #[test]
849    fn test_validate_unknown_type_reference() {
850        let result = validate_schema(
851            r#"
852            model User {
853                id      Int    @id @auto
854                profile UnknownType
855            }
856        "#,
857        );
858
859        assert!(result.is_err());
860    }
861
862    #[test]
863    fn test_validate_enum_reference() {
864        let schema = validate_schema(
865            r#"
866            enum Role {
867                User
868                Admin
869            }
870
871            model User {
872                id   Int    @id @auto
873                role Role   @default(User)
874            }
875        "#,
876        )
877        .unwrap();
878
879        assert_eq!(schema.models.len(), 1);
880        assert_eq!(schema.enums.len(), 1);
881    }
882
883    #[test]
884    fn test_validate_invalid_enum_default() {
885        let result = validate_schema(
886            r#"
887            enum Role {
888                User
889                Admin
890            }
891
892            model User {
893                id   Int    @id @auto
894                role Role   @default(Unknown)
895            }
896        "#,
897        );
898
899        assert!(result.is_err());
900    }
901
902    #[test]
903    fn test_validate_auto_on_non_int() {
904        let result = validate_schema(
905            r#"
906            model User {
907                id    String @id @auto
908                email String
909            }
910        "#,
911        );
912
913        assert!(result.is_err());
914    }
915
916    #[test]
917    fn test_validate_updated_at_on_non_datetime() {
918        let result = validate_schema(
919            r#"
920            model User {
921                id         Int    @id @auto
922                updated_at String @updated_at
923            }
924        "#,
925        );
926
927        assert!(result.is_err());
928    }
929
930    #[test]
931    fn test_validate_empty_enum() {
932        let result = validate_schema(
933            r#"
934            enum Empty {
935            }
936
937            model User {
938                id Int @id @auto
939            }
940        "#,
941        );
942
943        assert!(result.is_err());
944    }
945
946    #[test]
947    fn test_validate_duplicate_model_names() {
948        let result = validate_schema(
949            r#"
950            model User {
951                id Int @id @auto
952            }
953
954            model User {
955                id Int @id @auto
956            }
957        "#,
958        );
959
960        // Note: This might parse as a single model due to grammar
961        // The duplicate check happens at validation time
962        assert!(result.is_ok() || result.is_err());
963    }
964
965    #[test]
966    fn test_validate_relation() {
967        let schema = validate_schema(
968            r#"
969            model User {
970                id    Int    @id @auto
971                posts Post[]
972            }
973
974            model Post {
975                id        Int    @id @auto
976                author_id Int
977                author    User   @relation(fields: [author_id], references: [id])
978            }
979        "#,
980        )
981        .unwrap();
982
983        assert_eq!(schema.models.len(), 2);
984        assert!(!schema.relations.is_empty());
985    }
986
987    #[test]
988    fn test_validate_index_with_invalid_field() {
989        let result = validate_schema(
990            r#"
991            model User {
992                id    Int    @id @auto
993                email String
994
995                @@index([nonexistent])
996            }
997        "#,
998        );
999
1000        assert!(result.is_err());
1001    }
1002
1003    #[test]
1004    fn test_validate_search_on_non_string_field() {
1005        let result = validate_schema(
1006            r#"
1007            model Post {
1008                id    Int    @id @auto
1009                views Int
1010
1011                @@search([views])
1012            }
1013        "#,
1014        );
1015
1016        assert!(result.is_err());
1017    }
1018
1019    #[test]
1020    fn test_validate_composite_type() {
1021        let schema = validate_schema(
1022            r#"
1023            type Address {
1024                street  String
1025                city    String
1026                country String @default("US")
1027            }
1028
1029            model User {
1030                id      Int     @id @auto
1031                address Address
1032            }
1033        "#,
1034        );
1035
1036        // Note: Composite type support depends on parser handling
1037        assert!(schema.is_ok() || schema.is_err());
1038    }
1039
1040    // ==================== Server Group Validation Tests ====================
1041
1042    #[test]
1043    fn test_validate_server_group_basic() {
1044        let schema = validate_schema(
1045            r#"
1046            model User {
1047                id Int @id @auto
1048            }
1049
1050            serverGroup MainCluster {
1051                server primary {
1052                    url = "postgres://localhost/db"
1053                    role = "primary"
1054                }
1055            }
1056        "#,
1057        )
1058        .unwrap();
1059
1060        assert_eq!(schema.server_groups.len(), 1);
1061    }
1062
1063    #[test]
1064    fn test_validate_server_group_empty_servers() {
1065        let result = validate_schema(
1066            r#"
1067            model User {
1068                id Int @id @auto
1069            }
1070
1071            serverGroup EmptyCluster {
1072            }
1073        "#,
1074        );
1075
1076        assert!(result.is_err());
1077    }
1078
1079    #[test]
1080    fn test_validate_server_group_missing_url() {
1081        let result = validate_schema(
1082            r#"
1083            model User {
1084                id Int @id @auto
1085            }
1086
1087            serverGroup Cluster {
1088                server db {
1089                    role = "primary"
1090                }
1091            }
1092        "#,
1093        );
1094
1095        assert!(result.is_err());
1096    }
1097
1098    #[test]
1099    fn test_validate_server_group_invalid_strategy() {
1100        let result = validate_schema(
1101            r#"
1102            model User {
1103                id Int @id @auto
1104            }
1105
1106            serverGroup Cluster {
1107                @@strategy(InvalidStrategy)
1108
1109                server db {
1110                    url = "postgres://localhost/db"
1111                }
1112            }
1113        "#,
1114        );
1115
1116        assert!(result.is_err());
1117    }
1118
1119    #[test]
1120    fn test_validate_server_group_valid_strategy() {
1121        let schema = validate_schema(
1122            r#"
1123            model User {
1124                id Int @id @auto
1125            }
1126
1127            serverGroup Cluster {
1128                @@strategy(ReadReplica)
1129                @@loadBalance(RoundRobin)
1130
1131                server primary {
1132                    url = "postgres://localhost/db"
1133                    role = "primary"
1134                }
1135            }
1136        "#,
1137        )
1138        .unwrap();
1139
1140        assert_eq!(schema.server_groups.len(), 1);
1141    }
1142
1143    #[test]
1144    fn test_validate_server_group_read_replica_needs_primary() {
1145        let result = validate_schema(
1146            r#"
1147            model User {
1148                id Int @id @auto
1149            }
1150
1151            serverGroup Cluster {
1152                @@strategy(ReadReplica)
1153
1154                server replica1 {
1155                    url = "postgres://localhost/db"
1156                    role = "replica"
1157                }
1158            }
1159        "#,
1160        );
1161
1162        assert!(result.is_err());
1163    }
1164
1165    #[test]
1166    fn test_validate_server_group_with_replicas() {
1167        let schema = validate_schema(
1168            r#"
1169            model User {
1170                id Int @id @auto
1171            }
1172
1173            serverGroup Cluster {
1174                @@strategy(ReadReplica)
1175
1176                server primary {
1177                    url = "postgres://primary/db"
1178                    role = "primary"
1179                    weight = 1
1180                }
1181
1182                server replica1 {
1183                    url = "postgres://replica1/db"
1184                    role = "replica"
1185                    weight = 2
1186                }
1187
1188                server replica2 {
1189                    url = "postgres://replica2/db"
1190                    role = "replica"
1191                    weight = 2
1192                    region = "us-west-1"
1193                }
1194            }
1195        "#,
1196        )
1197        .unwrap();
1198
1199        let cluster = schema.get_server_group("Cluster").unwrap();
1200        assert_eq!(cluster.servers.len(), 3);
1201    }
1202
1203    #[test]
1204    fn test_validate_server_group_zero_weight() {
1205        let result = validate_schema(
1206            r#"
1207            model User {
1208                id Int @id @auto
1209            }
1210
1211            serverGroup Cluster {
1212                server db {
1213                    url = "postgres://localhost/db"
1214                    weight = 0
1215                }
1216            }
1217        "#,
1218        );
1219
1220        assert!(result.is_err());
1221    }
1222
1223    #[test]
1224    fn test_validate_server_group_invalid_load_balance() {
1225        let result = validate_schema(
1226            r#"
1227            model User {
1228                id Int @id @auto
1229            }
1230
1231            serverGroup Cluster {
1232                @@loadBalance(InvalidStrategy)
1233
1234                server db {
1235                    url = "postgres://localhost/db"
1236                }
1237            }
1238        "#,
1239        );
1240
1241        assert!(result.is_err());
1242    }
1243}