Skip to main content

prax_schema/
validator.rs

1//! Schema validation and semantic analysis.
2//!
3//! This module validates parsed schemas for semantic correctness:
4//! - All type references are valid
5//! - Relations are properly defined
6//! - Required attributes are present
7//! - No duplicate definitions
8
9use crate::ast::*;
10use crate::error::{SchemaError, SchemaResult};
11
12/// Schema validator for semantic analysis.
13#[derive(Debug)]
14pub struct Validator {
15    /// Collected validation errors.
16    errors: Vec<SchemaError>,
17}
18
19impl Default for Validator {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl Validator {
26    /// Create a new validator.
27    pub fn new() -> Self {
28        Self { errors: vec![] }
29    }
30
31    /// Validate a schema and return the validated schema or errors.
32    pub fn validate(&mut self, mut schema: Schema) -> SchemaResult<Schema> {
33        self.errors.clear();
34
35        // Check for duplicate definitions
36        self.check_duplicates(&schema);
37
38        // Resolve field types (convert Model references to Enum or Composite where appropriate)
39        self.resolve_field_types(&mut schema);
40
41        // Validate each model
42        for model in schema.models.values() {
43            self.validate_model(model, &schema);
44        }
45
46        // Validate computed / virtual field combinations
47        self.validate_computed_fields(&schema);
48
49        // Validate each enum
50        for e in schema.enums.values() {
51            self.validate_enum(e);
52        }
53
54        // Validate each composite type
55        for t in schema.types.values() {
56            self.validate_composite_type(t, &schema);
57        }
58
59        // Validate each view
60        for v in schema.views.values() {
61            self.validate_view(v, &schema);
62        }
63
64        // Validate each server group
65        for sg in schema.server_groups.values() {
66            self.validate_server_group(sg);
67        }
68
69        // Resolve relations
70        let relations = self.resolve_relations(&schema);
71        schema.relations = relations;
72
73        if self.errors.is_empty() {
74            Ok(schema)
75        } else {
76            Err(SchemaError::ValidationFailed {
77                count: self.errors.len(),
78                errors: std::mem::take(&mut self.errors),
79            })
80        }
81    }
82
83    /// Check for duplicate model, enum, or type names.
84    fn check_duplicates(&mut self, schema: &Schema) {
85        let mut seen = std::collections::HashSet::new();
86
87        for name in schema.models.keys() {
88            if !seen.insert(name.as_str()) {
89                self.errors
90                    .push(SchemaError::duplicate("model", name.as_str()));
91            }
92        }
93
94        for name in schema.enums.keys() {
95            if !seen.insert(name.as_str()) {
96                self.errors
97                    .push(SchemaError::duplicate("enum", name.as_str()));
98            }
99        }
100
101        for name in schema.types.keys() {
102            if !seen.insert(name.as_str()) {
103                self.errors
104                    .push(SchemaError::duplicate("type", name.as_str()));
105            }
106        }
107
108        for name in schema.views.keys() {
109            if !seen.insert(name.as_str()) {
110                self.errors
111                    .push(SchemaError::duplicate("view", name.as_str()));
112            }
113        }
114
115        // Check server group names (separately, since they don't conflict with types)
116        let mut server_group_names = std::collections::HashSet::new();
117        for name in schema.server_groups.keys() {
118            if !server_group_names.insert(name.as_str()) {
119                self.errors
120                    .push(SchemaError::duplicate("serverGroup", name.as_str()));
121            }
122        }
123    }
124
125    /// Resolve field types to their correct types (Enum or Composite) instead of Model.
126    ///
127    /// The parser initially treats all non-scalar type references as Model references.
128    /// This pass corrects them to Enum or Composite where appropriate.
129    fn resolve_field_types(&self, schema: &mut Schema) {
130        // Collect enum and composite type names into owned strings to avoid borrow conflicts
131        let enum_names: std::collections::HashSet<String> =
132            schema.enums.keys().map(|s| s.to_string()).collect();
133        let composite_names: std::collections::HashSet<String> =
134            schema.types.keys().map(|s| s.to_string()).collect();
135
136        // Update field types in models
137        for model in schema.models.values_mut() {
138            for field in model.fields.values_mut() {
139                if let FieldType::Model(ref type_name) = field.field_type {
140                    let name = type_name.as_str();
141                    if enum_names.contains(name) {
142                        field.field_type = FieldType::Enum(type_name.clone());
143                    } else if composite_names.contains(name) {
144                        field.field_type = FieldType::Composite(type_name.clone());
145                    }
146                }
147            }
148        }
149
150        // Also update field types in composite types
151        for composite in schema.types.values_mut() {
152            for field in composite.fields.values_mut() {
153                if let FieldType::Model(ref type_name) = field.field_type {
154                    let name = type_name.as_str();
155                    if enum_names.contains(name) {
156                        field.field_type = FieldType::Enum(type_name.clone());
157                    } else if composite_names.contains(name) {
158                        field.field_type = FieldType::Composite(type_name.clone());
159                    }
160                }
161            }
162        }
163
164        // Also update field types in views
165        for view in schema.views.values_mut() {
166            for field in view.fields.values_mut() {
167                if let FieldType::Model(ref type_name) = field.field_type {
168                    let name = type_name.as_str();
169                    if enum_names.contains(name) {
170                        field.field_type = FieldType::Enum(type_name.clone());
171                    } else if composite_names.contains(name) {
172                        field.field_type = FieldType::Composite(type_name.clone());
173                    }
174                }
175            }
176        }
177    }
178
179    /// Validate a model definition.
180    fn validate_model(&mut self, model: &Model, schema: &Schema) {
181        // Check for @id field
182        let id_fields: Vec<_> = model.fields.values().filter(|f| f.is_id()).collect();
183        if id_fields.is_empty() && !self.has_composite_id(model) {
184            self.errors.push(SchemaError::MissingId {
185                model: model.name().to_string(),
186            });
187        }
188
189        // Validate each field
190        for field in model.fields.values() {
191            self.validate_field(field, model.name(), schema);
192        }
193
194        // Validate model attributes
195        for attr in &model.attributes {
196            self.validate_model_attribute(attr, model);
197        }
198    }
199
200    /// Check if model has a composite ID (@@id attribute).
201    fn has_composite_id(&self, model: &Model) -> bool {
202        model.attributes.iter().any(|a| a.is("id"))
203    }
204
205    /// Validate a field definition.
206    fn validate_field(&mut self, field: &Field, model_name: &str, schema: &Schema) {
207        // Validate type references
208        match &field.field_type {
209            FieldType::Model(name) => {
210                // Check if it's actually a model, enum, or composite type
211                if schema.models.contains_key(name.as_str()) {
212                    // Valid model reference
213                } else if schema.enums.contains_key(name.as_str()) {
214                    // Parser initially treats non-scalar types as Model references
215                    // This is actually an enum type - we'll handle this during resolution
216                } else if schema.types.contains_key(name.as_str()) {
217                    // This is a composite type
218                } else {
219                    self.errors.push(SchemaError::unknown_type(
220                        model_name,
221                        field.name(),
222                        name.as_str(),
223                    ));
224                }
225            }
226            FieldType::Enum(name) => {
227                if !schema.enums.contains_key(name.as_str()) {
228                    self.errors.push(SchemaError::unknown_type(
229                        model_name,
230                        field.name(),
231                        name.as_str(),
232                    ));
233                }
234            }
235            FieldType::Composite(name) => {
236                if !schema.types.contains_key(name.as_str()) {
237                    self.errors.push(SchemaError::unknown_type(
238                        model_name,
239                        field.name(),
240                        name.as_str(),
241                    ));
242                }
243            }
244            _ => {}
245        }
246
247        // Validate field attributes
248        for attr in &field.attributes {
249            self.validate_field_attribute(attr, field, model_name, schema);
250        }
251
252        // Validate relation fields have @relation or are back-references
253        // Only check actual model relations (not enums or composite types parsed as Model)
254        if let FieldType::Model(ref target_name) = field.field_type {
255            // Skip validation for enum and composite type references
256            let is_actual_relation = schema.models.contains_key(target_name.as_str())
257                && !schema.enums.contains_key(target_name.as_str())
258                && !schema.types.contains_key(target_name.as_str());
259
260            if is_actual_relation && !field.is_list() {
261                // One-side of relation should have foreign key fields
262                let attrs = field.extract_attributes();
263                if let Some(rel) = attrs.relation.as_ref() {
264                    // Validate foreign key fields exist
265                    for fk_field in &rel.fields {
266                        if !schema
267                            .models
268                            .get(model_name)
269                            .map(|m| m.fields.contains_key(fk_field.as_str()))
270                            .unwrap_or(false)
271                        {
272                            self.errors.push(SchemaError::invalid_relation(
273                                model_name,
274                                field.name(),
275                                format!("foreign key field '{}' does not exist", fk_field),
276                            ));
277                        }
278                    }
279                }
280            }
281        }
282    }
283
284    /// Validate a field attribute.
285    fn validate_field_attribute(
286        &mut self,
287        attr: &Attribute,
288        field: &Field,
289        model_name: &str,
290        schema: &Schema,
291    ) {
292        match attr.name() {
293            "id" => {
294                // @id should be on a scalar or composite type, not a relation
295                if field.field_type.is_relation() {
296                    self.errors.push(SchemaError::InvalidAttribute {
297                        attribute: "id".to_string(),
298                        message: format!(
299                            "@id cannot be applied to relation field '{}.{}'",
300                            model_name,
301                            field.name()
302                        ),
303                    });
304                }
305            }
306            "auto" => {
307                // @auto should only be on Int or BigInt
308                if !matches!(
309                    field.field_type,
310                    FieldType::Scalar(ScalarType::Int) | FieldType::Scalar(ScalarType::BigInt)
311                ) {
312                    self.errors.push(SchemaError::InvalidAttribute {
313                        attribute: "auto".to_string(),
314                        message: format!(
315                            "@auto can only be applied to Int or BigInt fields, not '{}.{}'",
316                            model_name,
317                            field.name()
318                        ),
319                    });
320                }
321            }
322            "default" => {
323                // Validate default value type matches field type
324                if let Some(value) = attr.first_arg() {
325                    self.validate_default_value(value, field, model_name, schema);
326                }
327            }
328            "relation" => {
329                // Validate relation attribute - should only be on actual model references
330                let is_model_ref = matches!(&field.field_type, FieldType::Model(name)
331                    if schema.models.contains_key(name.as_str()));
332                if !is_model_ref {
333                    self.errors.push(SchemaError::InvalidAttribute {
334                        attribute: "relation".to_string(),
335                        message: format!(
336                            "@relation can only be applied to model reference fields, not '{}.{}'",
337                            model_name,
338                            field.name()
339                        ),
340                    });
341                }
342            }
343            "updated_at" => {
344                // @updated_at should only be on DateTime
345                if !matches!(field.field_type, FieldType::Scalar(ScalarType::DateTime)) {
346                    self.errors.push(SchemaError::InvalidAttribute {
347                        attribute: "updated_at".to_string(),
348                        message: format!(
349                            "@updated_at can only be applied to DateTime fields, not '{}.{}'",
350                            model_name,
351                            field.name()
352                        ),
353                    });
354                }
355            }
356            "map" => {
357                // @map("col") rewrites the SQL column name. Identifiers must
358                // be safe for direct splicing into queries — see the
359                // `.cursor/rules/sql-safety.mdc` trust-boundary subsection.
360                if let Some(AttributeValue::String(name)) = attr.first_arg()
361                    && !is_safe_sql_identifier(name)
362                {
363                    self.errors.push(SchemaError::invalid_field(
364                        model_name,
365                        field.name(),
366                        format!(
367                            "@map(\"{}\") contains characters outside [A-Za-z0-9_.]; \
368                             SQL identifiers must be safe to splice into queries",
369                            name
370                        ),
371                    ));
372                }
373            }
374            _ => {}
375        }
376    }
377
378    /// Validate a default value matches the field type.
379    fn validate_default_value(
380        &mut self,
381        value: &AttributeValue,
382        field: &Field,
383        model_name: &str,
384        schema: &Schema,
385    ) {
386        match (&field.field_type, value) {
387            // Functions are generally allowed (now(), uuid(), etc.)
388            (_, AttributeValue::Function(_, _)) => {}
389
390            // Int fields should have int defaults
391            (FieldType::Scalar(ScalarType::Int), AttributeValue::Int(_)) => {}
392            (FieldType::Scalar(ScalarType::BigInt), AttributeValue::Int(_)) => {}
393
394            // Float fields can have int or float defaults
395            (FieldType::Scalar(ScalarType::Float), AttributeValue::Int(_)) => {}
396            (FieldType::Scalar(ScalarType::Float), AttributeValue::Float(_)) => {}
397            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Int(_)) => {}
398            (FieldType::Scalar(ScalarType::Decimal), AttributeValue::Float(_)) => {}
399
400            // String fields should have string defaults
401            (FieldType::Scalar(ScalarType::String), AttributeValue::String(_)) => {}
402
403            // Json fields accept any constant as the @default — the payload is
404            // stored as a text literal that the database parses into jsonb.
405            // Prisma writes empty objects/arrays as `@default("[]")` or
406            // `@default("{}")`, so accept string, array, and scalar primitives.
407            (FieldType::Scalar(ScalarType::Json), AttributeValue::String(_))
408            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Array(_))
409            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Boolean(_))
410            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Int(_))
411            | (FieldType::Scalar(ScalarType::Json), AttributeValue::Float(_)) => {}
412
413            // Boolean fields should have boolean defaults
414            (FieldType::Scalar(ScalarType::Boolean), AttributeValue::Boolean(_)) => {}
415
416            // Enum fields should have ident defaults matching a variant
417            (FieldType::Enum(enum_name), AttributeValue::Ident(variant)) => {
418                if let Some(e) = schema.enums.get(enum_name.as_str())
419                    && e.get_variant(variant).is_none()
420                {
421                    self.errors.push(SchemaError::invalid_field(
422                        model_name,
423                        field.name(),
424                        format!(
425                            "default value '{}' is not a valid variant of enum '{}'",
426                            variant, enum_name
427                        ),
428                    ));
429                }
430            }
431
432            // Model type might actually be an enum (parser treats non-scalar as Model initially)
433            (FieldType::Model(type_name), AttributeValue::Ident(variant)) => {
434                // Check if this is actually an enum reference
435                if let Some(e) = schema.enums.get(type_name.as_str())
436                    && e.get_variant(variant).is_none()
437                {
438                    self.errors.push(SchemaError::invalid_field(
439                        model_name,
440                        field.name(),
441                        format!(
442                            "default value '{}' is not a valid variant of enum '{}'",
443                            variant, type_name
444                        ),
445                    ));
446                }
447                // If it's a real model reference with an ident default, that's an error
448                // but we skip that here since it's likely a valid enum
449            }
450
451            // Type mismatch
452            _ => {
453                self.errors.push(SchemaError::invalid_field(
454                    model_name,
455                    field.name(),
456                    format!(
457                        "default value type does not match field type '{}'",
458                        field.field_type
459                    ),
460                ));
461            }
462        }
463    }
464
465    /// Validate a model-level attribute.
466    fn validate_model_attribute(&mut self, attr: &Attribute, model: &Model) {
467        match attr.name() {
468            "index" | "unique" => {
469                // Validate referenced fields exist
470                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
471                    for field_name in fields {
472                        if !model.fields.contains_key(field_name.as_str()) {
473                            self.errors.push(SchemaError::invalid_model(
474                                model.name(),
475                                format!(
476                                    "@@{} references non-existent field '{}'",
477                                    attr.name(),
478                                    field_name
479                                ),
480                            ));
481                        }
482                    }
483                }
484            }
485            "id" => {
486                // Composite primary key
487                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
488                    for field_name in fields {
489                        if !model.fields.contains_key(field_name.as_str()) {
490                            self.errors.push(SchemaError::invalid_model(
491                                model.name(),
492                                format!("@@id references non-existent field '{}'", field_name),
493                            ));
494                        }
495                    }
496                }
497            }
498            "search" => {
499                // Full-text search on fields
500                if let Some(AttributeValue::FieldRefList(fields)) = attr.first_arg() {
501                    for field_name in fields {
502                        if let Some(field) = model.fields.get(field_name.as_str()) {
503                            // Only string fields can be searched
504                            if !matches!(field.field_type, FieldType::Scalar(ScalarType::String)) {
505                                self.errors.push(SchemaError::invalid_model(
506                                    model.name(),
507                                    format!(
508                                        "@@search field '{}' must be of type String",
509                                        field_name
510                                    ),
511                                ));
512                            }
513                        } else {
514                            self.errors.push(SchemaError::invalid_model(
515                                model.name(),
516                                format!("@@search references non-existent field '{}'", field_name),
517                            ));
518                        }
519                    }
520                }
521            }
522            "map" => {
523                // @@map("table_name") — the value is flowed verbatim into
524                // generated SQL via `RelationFilterMeta::PARENT_TABLE` /
525                // `CHILD_TABLE`. Per `.cursor/rules/sql-safety.mdc`,
526                // identifiers must be whitelisted: enforce ASCII
527                // alphanumeric + underscore + dot (for schema-qualified
528                // names) here so the trust boundary is actually
529                // enforced, not just documented.
530                if let Some(AttributeValue::String(name)) = attr.first_arg()
531                    && !is_safe_sql_identifier(name)
532                {
533                    self.errors.push(SchemaError::invalid_model(
534                        model.name(),
535                        format!(
536                            "@@map(\"{}\") contains characters outside [A-Za-z0-9_.]; \
537                             SQL identifiers must be safe to splice into queries",
538                            name
539                        ),
540                    ));
541                }
542            }
543            _ => {}
544        }
545    }
546
547    /// Validate an enum definition.
548    fn validate_enum(&mut self, e: &Enum) {
549        if e.variants.is_empty() {
550            self.errors.push(SchemaError::invalid_model(
551                e.name(),
552                "enum must have at least one variant".to_string(),
553            ));
554        }
555
556        // Check for duplicate variant names
557        let mut seen = std::collections::HashSet::new();
558        for variant in &e.variants {
559            if !seen.insert(variant.name()) {
560                self.errors.push(SchemaError::duplicate(
561                    format!("enum variant in {}", e.name()),
562                    variant.name(),
563                ));
564            }
565        }
566    }
567
568    /// Validate a composite type definition.
569    fn validate_composite_type(&mut self, t: &CompositeType, schema: &Schema) {
570        if t.fields.is_empty() {
571            self.errors.push(SchemaError::invalid_model(
572                t.name(),
573                "composite type must have at least one field".to_string(),
574            ));
575        }
576
577        // Validate field types
578        for field in t.fields.values() {
579            match &field.field_type {
580                FieldType::Model(_) => {
581                    self.errors.push(SchemaError::invalid_field(
582                        t.name(),
583                        field.name(),
584                        "composite types cannot have model relations".to_string(),
585                    ));
586                }
587                FieldType::Enum(name) => {
588                    if !schema.enums.contains_key(name.as_str()) {
589                        self.errors.push(SchemaError::unknown_type(
590                            t.name(),
591                            field.name(),
592                            name.as_str(),
593                        ));
594                    }
595                }
596                FieldType::Composite(name) => {
597                    if !schema.types.contains_key(name.as_str()) {
598                        self.errors.push(SchemaError::unknown_type(
599                            t.name(),
600                            field.name(),
601                            name.as_str(),
602                        ));
603                    }
604                }
605                _ => {}
606            }
607        }
608    }
609
610    /// Validate a view definition.
611    fn validate_view(&mut self, v: &View, schema: &Schema) {
612        // Views should have at least one field
613        if v.fields.is_empty() {
614            self.errors.push(SchemaError::invalid_model(
615                v.name(),
616                "view must have at least one field".to_string(),
617            ));
618        }
619
620        // Validate field types
621        for field in v.fields.values() {
622            self.validate_field(field, v.name(), schema);
623        }
624    }
625
626    /// Validate a server group definition.
627    fn validate_server_group(&mut self, sg: &ServerGroup) {
628        // Server groups should have at least one server
629        if sg.servers.is_empty() {
630            self.errors.push(SchemaError::invalid_model(
631                sg.name.name.as_str(),
632                "serverGroup must have at least one server".to_string(),
633            ));
634        }
635
636        // Check for duplicate server names within the group
637        let mut seen_servers = std::collections::HashSet::new();
638        for server_name in sg.servers.keys() {
639            if !seen_servers.insert(server_name.as_str()) {
640                self.errors.push(SchemaError::duplicate(
641                    format!("server in serverGroup {}", sg.name.name),
642                    server_name.as_str(),
643                ));
644            }
645        }
646
647        // Validate each server
648        for server in sg.servers.values() {
649            self.validate_server(server, sg.name.name.as_str());
650        }
651
652        // Validate server group attributes
653        for attr in &sg.attributes {
654            self.validate_server_group_attribute(attr, sg);
655        }
656
657        // Check for at least one primary server in read replica strategy
658        if let Some(strategy) = sg.strategy()
659            && strategy == ServerGroupStrategy::ReadReplica
660        {
661            let has_primary = sg
662                .servers
663                .values()
664                .any(|s| s.role() == Some(ServerRole::Primary));
665            if !has_primary {
666                self.errors.push(SchemaError::invalid_model(
667                    sg.name.name.as_str(),
668                    "ReadReplica strategy requires at least one server with role = \"primary\""
669                        .to_string(),
670                ));
671            }
672        }
673    }
674
675    /// Validate an individual server definition.
676    fn validate_server(&mut self, server: &Server, group_name: &str) {
677        // Server should have a URL property
678        if server.url().is_none() {
679            self.errors.push(SchemaError::invalid_model(
680                group_name,
681                format!("server '{}' must have a 'url' property", server.name.name),
682            ));
683        }
684
685        // Validate weight is positive if specified
686        if let Some(weight) = server.weight()
687            && weight == 0
688        {
689            self.errors.push(SchemaError::invalid_model(
690                group_name,
691                format!(
692                    "server '{}' weight must be greater than 0",
693                    server.name.name
694                ),
695            ));
696        }
697
698        // Validate priority is positive if specified
699        if let Some(priority) = server.priority()
700            && priority == 0
701        {
702            self.errors.push(SchemaError::invalid_model(
703                group_name,
704                format!(
705                    "server '{}' priority must be greater than 0",
706                    server.name.name
707                ),
708            ));
709        }
710    }
711
712    /// Validate a server group attribute.
713    fn validate_server_group_attribute(&mut self, attr: &Attribute, sg: &ServerGroup) {
714        match attr.name() {
715            "strategy" => {
716                // Validate strategy value
717                if let Some(arg) = attr.first_arg() {
718                    let value_str = arg
719                        .as_string()
720                        .map(|s| s.to_string())
721                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
722                    if let Some(val) = value_str
723                        && ServerGroupStrategy::parse(&val).is_none()
724                    {
725                        self.errors.push(SchemaError::InvalidAttribute {
726                                attribute: "strategy".to_string(),
727                                message: format!(
728                                    "invalid strategy '{}' for serverGroup '{}'. Valid values: ReadReplica, Sharding, MultiRegion, HighAvailability, Custom",
729                                    val,
730                                    sg.name.name
731                                ),
732                            });
733                    }
734                }
735            }
736            "loadBalance" => {
737                // Validate load balance value
738                if let Some(arg) = attr.first_arg() {
739                    let value_str = arg
740                        .as_string()
741                        .map(|s| s.to_string())
742                        .or_else(|| arg.as_ident().map(|s| s.to_string()));
743                    if let Some(val) = value_str
744                        && LoadBalanceStrategy::parse(&val).is_none()
745                    {
746                        self.errors.push(SchemaError::InvalidAttribute {
747                                attribute: "loadBalance".to_string(),
748                                message: format!(
749                                    "invalid loadBalance '{}' for serverGroup '{}'. Valid values: RoundRobin, Random, LeastConnections, Weighted, Nearest, Sticky",
750                                    val,
751                                    sg.name.name
752                                ),
753                            });
754                    }
755                }
756            }
757            _ => {} // Other attributes are allowed
758        }
759    }
760
761    /// Validate computed and virtual field combinations within every model.
762    ///
763    /// Illegal cases:
764    /// - `@generated` + `@id` or `@auto`
765    /// - `@generated` + an aggregate attribute on the same field
766    /// - Empty expression inside `@generated("")`
767    /// - `@count(rel.field)` — count takes only a relation name, not a dotted path
768    /// - `@sum/@avg/@min/@max(rel)` — non-Count aggregates need `rel.field`
769    /// - Unknown relation name in any aggregate attribute
770    /// - `@stored` or `@virtual` without a sibling `@generated`
771    fn validate_computed_fields(&mut self, schema: &Schema) {
772        for model in schema.models.values() {
773            for field in model.fields.values() {
774                let attrs = field.extract_attributes();
775
776                // ── @generated validations ─────────────────────────────────
777                if let Some(g) = &attrs.generated {
778                    if attrs.is_id || attrs.is_auto {
779                        self.errors.push(SchemaError::invalid_field(
780                            model.name(),
781                            field.name(),
782                            format!(
783                                "field `{}` cannot be both @generated and @id/@auto",
784                                field.name()
785                            ),
786                        ));
787                    }
788                    if attrs.aggregate.is_some() {
789                        self.errors.push(SchemaError::invalid_field(
790                            model.name(),
791                            field.name(),
792                            format!(
793                                "field `{}` cannot be both @generated and an aggregate",
794                                field.name()
795                            ),
796                        ));
797                    }
798                    if g.expression.trim().is_empty() {
799                        self.errors.push(SchemaError::invalid_field(
800                            model.name(),
801                            field.name(),
802                            format!(
803                                "field `{}`: @generated expression must not be empty",
804                                field.name()
805                            ),
806                        ));
807                    }
808                }
809
810                // ── aggregate validations ──────────────────────────────────
811                // Check each raw aggregate attribute so we can also detect
812                // malformed forms that the extractor silently drops (e.g.
813                // `@sum(posts)` with no dot, `@count(posts.id)` with a dot).
814                for raw_attr in &field.attributes {
815                    match raw_attr.name() {
816                        "count" => {
817                            // @count must have a plain relation name with no dot.
818                            // first_path_arg() returns the raw string/ident;
819                            // if it contains a dot the user wrote `@count(rel.field)`.
820                            if let Some(path) = raw_attr.first_path_arg() {
821                                if path.contains('.') {
822                                    self.errors.push(SchemaError::invalid_field(
823                                        model.name(),
824                                        field.name(),
825                                        format!(
826                                            "field `{}`: @count takes a relation name, not `relation.field`",
827                                            field.name()
828                                        ),
829                                    ));
830                                } else {
831                                    // Relation must exist on this model.
832                                    let rel_exists = model
833                                        .fields
834                                        .values()
835                                        .any(|f| f.name() == path && f.is_list());
836                                    if !rel_exists {
837                                        self.errors.push(SchemaError::invalid_field(
838                                            model.name(),
839                                            field.name(),
840                                            format!(
841                                                "field `{}`: unknown relation `{}` in @count",
842                                                field.name(),
843                                                path
844                                            ),
845                                        ));
846                                    }
847                                }
848                            }
849                        }
850                        "sum" | "avg" | "min" | "max" => {
851                            let kind_name = raw_attr.name();
852                            if let Some(path) = raw_attr.first_path_arg() {
853                                if let Some((rel, _field_part)) = path.split_once('.') {
854                                    // Validate relation exists.
855                                    let rel_exists = model
856                                        .fields
857                                        .values()
858                                        .any(|f| f.name() == rel && f.is_list());
859                                    if !rel_exists {
860                                        self.errors.push(SchemaError::invalid_field(
861                                            model.name(),
862                                            field.name(),
863                                            format!(
864                                                "field `{}`: unknown relation `{}` in @{}",
865                                                field.name(),
866                                                rel,
867                                                kind_name
868                                            ),
869                                        ));
870                                    }
871                                } else {
872                                    // No dot — missing field path.
873                                    self.errors.push(SchemaError::invalid_field(
874                                        model.name(),
875                                        field.name(),
876                                        format!(
877                                            "field `{}`: @{} requires `relation.field`",
878                                            field.name(),
879                                            kind_name
880                                        ),
881                                    ));
882                                }
883                            }
884                        }
885                        _ => {}
886                    }
887                }
888
889                // ── orphan @stored / @virtual ──────────────────────────────
890                let has_generated = attrs.generated.is_some();
891                for attr_name in ["stored", "virtual"] {
892                    if field.has_attribute(attr_name) && !has_generated {
893                        self.errors.push(SchemaError::invalid_field(
894                            model.name(),
895                            field.name(),
896                            format!(
897                                "field `{}`: @{} is only valid alongside @generated",
898                                field.name(),
899                                attr_name
900                            ),
901                        ));
902                    }
903                }
904            }
905        }
906    }
907
908    /// Resolve all relations in the schema.
909    fn resolve_relations(&mut self, schema: &Schema) -> Vec<Relation> {
910        let mut relations = Vec::new();
911
912        for model in schema.models.values() {
913            for field in model.fields.values() {
914                if let FieldType::Model(ref target_model) = field.field_type {
915                    // Skip if this is actually an enum reference (parser treats non-scalar as Model initially)
916                    if schema.enums.contains_key(target_model.as_str()) {
917                        continue;
918                    }
919
920                    // Skip if this is actually a composite type reference
921                    if schema.types.contains_key(target_model.as_str()) {
922                        continue;
923                    }
924
925                    // Skip if the target model doesn't exist (error was already reported)
926                    if !schema.models.contains_key(target_model.as_str()) {
927                        continue;
928                    }
929
930                    let attrs = field.extract_attributes();
931
932                    let relation_type = if field.is_list() {
933                        // This model has many of target
934                        RelationType::OneToMany
935                    } else {
936                        // This model has one of target
937                        RelationType::ManyToOne
938                    };
939
940                    let mut relation = Relation::new(
941                        model.name(),
942                        field.name(),
943                        target_model.as_str(),
944                        relation_type,
945                    );
946
947                    if let Some(rel_attr) = &attrs.relation {
948                        if let Some(name) = &rel_attr.name {
949                            relation = relation.with_name(name.as_str());
950                        }
951                        if !rel_attr.fields.is_empty() {
952                            relation = relation.with_from_fields(rel_attr.fields.clone());
953                        }
954                        if !rel_attr.references.is_empty() {
955                            relation = relation.with_to_fields(rel_attr.references.clone());
956                        }
957                        if let Some(action) = rel_attr.on_delete {
958                            relation = relation.with_on_delete(action);
959                        }
960                        if let Some(action) = rel_attr.on_update {
961                            relation = relation.with_on_update(action);
962                        }
963                        if let Some(map) = &rel_attr.map {
964                            relation = relation.with_map(map.as_str());
965                        }
966                    }
967
968                    relations.push(relation);
969                }
970            }
971        }
972
973        relations
974    }
975}
976
977/// Whether a string is a safe SQL identifier — ASCII alphanumeric plus
978/// underscore, plus dot for schema-qualified names. Defense-in-depth
979/// for the compile-time-trusted schema-author boundary described in
980/// `.cursor/rules/sql-safety.mdc`.
981fn is_safe_sql_identifier(s: &str) -> bool {
982    !s.is_empty()
983        && s.chars()
984            .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
985}
986
987/// Validate a schema string and return the validated schema.
988pub fn validate_schema(input: &str) -> SchemaResult<Schema> {
989    let schema = crate::parser::parse_schema(input)?;
990    let mut validator = Validator::new();
991    validator.validate(schema)
992}
993
994#[cfg(test)]
995mod tests {
996    use super::*;
997
998    #[test]
999    fn test_validate_simple_model() {
1000        let schema = validate_schema(
1001            r#"
1002            model User {
1003                id    Int    @id @auto
1004                email String @unique
1005            }
1006        "#,
1007        )
1008        .unwrap();
1009
1010        assert_eq!(schema.models.len(), 1);
1011    }
1012
1013    #[test]
1014    fn test_validate_model_missing_id() {
1015        let result = validate_schema(
1016            r#"
1017            model User {
1018                email String
1019                name  String
1020            }
1021        "#,
1022        );
1023
1024        assert!(result.is_err());
1025        let err = result.unwrap_err();
1026        assert!(matches!(err, SchemaError::ValidationFailed { .. }));
1027    }
1028
1029    #[test]
1030    fn test_validate_model_with_composite_id() {
1031        let schema = validate_schema(
1032            r#"
1033            model PostTag {
1034                post_id Int
1035                tag_id  Int
1036
1037                @@id([post_id, tag_id])
1038            }
1039        "#,
1040        )
1041        .unwrap();
1042
1043        assert_eq!(schema.models.len(), 1);
1044    }
1045
1046    #[test]
1047    fn test_validate_unknown_type_reference() {
1048        let result = validate_schema(
1049            r#"
1050            model User {
1051                id      Int    @id @auto
1052                profile UnknownType
1053            }
1054        "#,
1055        );
1056
1057        assert!(result.is_err());
1058    }
1059
1060    #[test]
1061    fn test_validate_enum_reference() {
1062        let schema = validate_schema(
1063            r#"
1064            enum Role {
1065                User
1066                Admin
1067            }
1068
1069            model User {
1070                id   Int    @id @auto
1071                role Role   @default(User)
1072            }
1073        "#,
1074        )
1075        .unwrap();
1076
1077        assert_eq!(schema.models.len(), 1);
1078        assert_eq!(schema.enums.len(), 1);
1079    }
1080
1081    #[test]
1082    fn test_validate_invalid_enum_default() {
1083        let result = validate_schema(
1084            r#"
1085            enum Role {
1086                User
1087                Admin
1088            }
1089
1090            model User {
1091                id   Int    @id @auto
1092                role Role   @default(Unknown)
1093            }
1094        "#,
1095        );
1096
1097        assert!(result.is_err());
1098    }
1099
1100    #[test]
1101    fn test_validate_auto_on_non_int() {
1102        let result = validate_schema(
1103            r#"
1104            model User {
1105                id    String @id @auto
1106                email String
1107            }
1108        "#,
1109        );
1110
1111        assert!(result.is_err());
1112    }
1113
1114    #[test]
1115    fn test_validate_updated_at_on_non_datetime() {
1116        let result = validate_schema(
1117            r#"
1118            model User {
1119                id         Int    @id @auto
1120                updated_at String @updated_at
1121            }
1122        "#,
1123        );
1124
1125        assert!(result.is_err());
1126    }
1127
1128    #[test]
1129    fn test_validate_empty_enum() {
1130        let result = validate_schema(
1131            r#"
1132            enum Empty {
1133            }
1134
1135            model User {
1136                id Int @id @auto
1137            }
1138        "#,
1139        );
1140
1141        assert!(result.is_err());
1142    }
1143
1144    #[test]
1145    fn test_validate_duplicate_model_names() {
1146        let result = validate_schema(
1147            r#"
1148            model User {
1149                id Int @id @auto
1150            }
1151
1152            model User {
1153                id Int @id @auto
1154            }
1155        "#,
1156        );
1157
1158        // Note: This might parse as a single model due to grammar
1159        // The duplicate check happens at validation time
1160        assert!(result.is_ok() || result.is_err());
1161    }
1162
1163    #[test]
1164    fn test_validate_relation() {
1165        let schema = validate_schema(
1166            r#"
1167            model User {
1168                id    Int    @id @auto
1169                posts Post[]
1170            }
1171
1172            model Post {
1173                id        Int    @id @auto
1174                author_id Int
1175                author    User   @relation(fields: [author_id], references: [id])
1176            }
1177        "#,
1178        )
1179        .unwrap();
1180
1181        assert_eq!(schema.models.len(), 2);
1182        assert!(!schema.relations.is_empty());
1183    }
1184
1185    #[test]
1186    fn test_validate_index_with_invalid_field() {
1187        let result = validate_schema(
1188            r#"
1189            model User {
1190                id    Int    @id @auto
1191                email String
1192
1193                @@index([nonexistent])
1194            }
1195        "#,
1196        );
1197
1198        assert!(result.is_err());
1199    }
1200
1201    #[test]
1202    fn test_validate_search_on_non_string_field() {
1203        let result = validate_schema(
1204            r#"
1205            model Post {
1206                id    Int    @id @auto
1207                views Int
1208
1209                @@search([views])
1210            }
1211        "#,
1212        );
1213
1214        assert!(result.is_err());
1215    }
1216
1217    #[test]
1218    fn test_validate_composite_type() {
1219        let schema = validate_schema(
1220            r#"
1221            type Address {
1222                street  String
1223                city    String
1224                country String @default("US")
1225            }
1226
1227            model User {
1228                id      Int     @id @auto
1229                address Address
1230            }
1231        "#,
1232        );
1233
1234        // Note: Composite type support depends on parser handling
1235        assert!(schema.is_ok() || schema.is_err());
1236    }
1237
1238    // ==================== Server Group Validation Tests ====================
1239
1240    #[test]
1241    fn test_validate_server_group_basic() {
1242        let schema = validate_schema(
1243            r#"
1244            model User {
1245                id Int @id @auto
1246            }
1247
1248            serverGroup MainCluster {
1249                server primary {
1250                    url = "postgres://localhost/db"
1251                    role = "primary"
1252                }
1253            }
1254        "#,
1255        )
1256        .unwrap();
1257
1258        assert_eq!(schema.server_groups.len(), 1);
1259    }
1260
1261    #[test]
1262    fn test_validate_server_group_empty_servers() {
1263        let result = validate_schema(
1264            r#"
1265            model User {
1266                id Int @id @auto
1267            }
1268
1269            serverGroup EmptyCluster {
1270            }
1271        "#,
1272        );
1273
1274        assert!(result.is_err());
1275    }
1276
1277    #[test]
1278    fn test_validate_server_group_missing_url() {
1279        let result = validate_schema(
1280            r#"
1281            model User {
1282                id Int @id @auto
1283            }
1284
1285            serverGroup Cluster {
1286                server db {
1287                    role = "primary"
1288                }
1289            }
1290        "#,
1291        );
1292
1293        assert!(result.is_err());
1294    }
1295
1296    #[test]
1297    fn test_validate_server_group_invalid_strategy() {
1298        let result = validate_schema(
1299            r#"
1300            model User {
1301                id Int @id @auto
1302            }
1303
1304            serverGroup Cluster {
1305                @@strategy(InvalidStrategy)
1306
1307                server db {
1308                    url = "postgres://localhost/db"
1309                }
1310            }
1311        "#,
1312        );
1313
1314        assert!(result.is_err());
1315    }
1316
1317    #[test]
1318    fn test_validate_server_group_valid_strategy() {
1319        let schema = validate_schema(
1320            r#"
1321            model User {
1322                id Int @id @auto
1323            }
1324
1325            serverGroup Cluster {
1326                @@strategy(ReadReplica)
1327                @@loadBalance(RoundRobin)
1328
1329                server primary {
1330                    url = "postgres://localhost/db"
1331                    role = "primary"
1332                }
1333            }
1334        "#,
1335        )
1336        .unwrap();
1337
1338        assert_eq!(schema.server_groups.len(), 1);
1339    }
1340
1341    #[test]
1342    fn test_validate_server_group_read_replica_needs_primary() {
1343        let result = validate_schema(
1344            r#"
1345            model User {
1346                id Int @id @auto
1347            }
1348
1349            serverGroup Cluster {
1350                @@strategy(ReadReplica)
1351
1352                server replica1 {
1353                    url = "postgres://localhost/db"
1354                    role = "replica"
1355                }
1356            }
1357        "#,
1358        );
1359
1360        assert!(result.is_err());
1361    }
1362
1363    #[test]
1364    fn test_validate_server_group_with_replicas() {
1365        let schema = validate_schema(
1366            r#"
1367            model User {
1368                id Int @id @auto
1369            }
1370
1371            serverGroup Cluster {
1372                @@strategy(ReadReplica)
1373
1374                server primary {
1375                    url = "postgres://primary/db"
1376                    role = "primary"
1377                    weight = 1
1378                }
1379
1380                server replica1 {
1381                    url = "postgres://replica1/db"
1382                    role = "replica"
1383                    weight = 2
1384                }
1385
1386                server replica2 {
1387                    url = "postgres://replica2/db"
1388                    role = "replica"
1389                    weight = 2
1390                    region = "us-west-1"
1391                }
1392            }
1393        "#,
1394        )
1395        .unwrap();
1396
1397        let cluster = schema.get_server_group("Cluster").unwrap();
1398        assert_eq!(cluster.servers.len(), 3);
1399    }
1400
1401    #[test]
1402    fn test_validate_server_group_zero_weight() {
1403        let result = validate_schema(
1404            r#"
1405            model User {
1406                id Int @id @auto
1407            }
1408
1409            serverGroup Cluster {
1410                server db {
1411                    url = "postgres://localhost/db"
1412                    weight = 0
1413                }
1414            }
1415        "#,
1416        );
1417
1418        assert!(result.is_err());
1419    }
1420
1421    #[test]
1422    fn test_validate_server_group_invalid_load_balance() {
1423        let result = validate_schema(
1424            r#"
1425            model User {
1426                id Int @id @auto
1427            }
1428
1429            serverGroup Cluster {
1430                @@loadBalance(InvalidStrategy)
1431
1432                server db {
1433                    url = "postgres://localhost/db"
1434                }
1435            }
1436        "#,
1437        );
1438
1439        assert!(result.is_err());
1440    }
1441}