Skip to main content

oxirs_gql/
validation.rs

1//! GraphQL query validation and security features
2//!
3//! This module provides comprehensive validation for GraphQL queries, including
4//! security features like depth limiting, complexity analysis, and schema validation.
5
6use crate::ast::{
7    Definition, Document, Field, OperationDefinition, Selection, SelectionSet, Value,
8    VariableDefinition,
9};
10use crate::types::{GraphQLType, Schema};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::time::Duration;
14
15/// Configuration for query validation and security
16#[derive(Debug, Clone)]
17pub struct ValidationConfig {
18    /// Maximum allowed query depth
19    pub max_depth: usize,
20    /// Maximum allowed query complexity score
21    pub max_complexity: usize,
22    /// Maximum number of aliases allowed
23    pub max_aliases: usize,
24    /// Maximum number of root fields
25    pub max_root_fields: usize,
26    /// Enable query timeout
27    pub query_timeout: Option<Duration>,
28    /// Disabled introspection queries
29    pub disable_introspection: bool,
30    /// Maximum number of fragments
31    pub max_fragments: usize,
32    /// Whitelist of allowed operation names
33    pub allowed_operations: Option<HashSet<String>>,
34    /// Blacklist of forbidden field names
35    pub forbidden_fields: HashSet<String>,
36    /// Enable cost analysis
37    pub enable_cost_analysis: bool,
38}
39
40impl Default for ValidationConfig {
41    fn default() -> Self {
42        Self {
43            max_depth: 10,
44            max_complexity: 1000,
45            max_aliases: 50,
46            max_root_fields: 20,
47            query_timeout: Some(Duration::from_secs(30)),
48            disable_introspection: false,
49            max_fragments: 50,
50            allowed_operations: None,
51            forbidden_fields: HashSet::new(),
52            enable_cost_analysis: true,
53        }
54    }
55}
56
57/// Result of query validation
58#[derive(Debug, Clone)]
59pub struct ValidationResult {
60    pub is_valid: bool,
61    pub errors: Vec<ValidationError>,
62    pub warnings: Vec<ValidationWarning>,
63    pub complexity_score: usize,
64    pub depth: usize,
65}
66
67impl Default for ValidationResult {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl ValidationResult {
74    pub fn new() -> Self {
75        Self {
76            is_valid: true,
77            errors: Vec::new(),
78            warnings: Vec::new(),
79            complexity_score: 0,
80            depth: 0,
81        }
82    }
83
84    pub fn with_error(mut self, error: ValidationError) -> Self {
85        self.is_valid = false;
86        self.errors.push(error);
87        self
88    }
89
90    pub fn with_warning(mut self, warning: ValidationWarning) -> Self {
91        self.warnings.push(warning);
92        self
93    }
94
95    pub fn with_complexity(mut self, complexity: usize) -> Self {
96        self.complexity_score = complexity;
97        self
98    }
99
100    pub fn with_depth(mut self, depth: usize) -> Self {
101        self.depth = depth;
102        self
103    }
104}
105
106/// Validation error details
107#[derive(Debug, Clone)]
108pub struct ValidationError {
109    pub message: String,
110    pub path: Vec<String>,
111    pub rule: ValidationRule,
112}
113
114impl ValidationError {
115    pub fn new(message: String, rule: ValidationRule) -> Self {
116        Self {
117            message,
118            path: Vec::new(),
119            rule,
120        }
121    }
122
123    pub fn with_path(mut self, path: Vec<String>) -> Self {
124        self.path = path;
125        self
126    }
127}
128
129/// Validation warning details
130#[derive(Debug, Clone)]
131pub struct ValidationWarning {
132    pub message: String,
133    pub suggestion: Option<String>,
134}
135
136impl ValidationWarning {
137    pub fn new(message: String) -> Self {
138        Self {
139            message,
140            suggestion: None,
141        }
142    }
143
144    pub fn with_suggestion(mut self, suggestion: String) -> Self {
145        self.suggestion = Some(suggestion);
146        self
147    }
148}
149
150/// Types of validation rules
151#[derive(Debug, Clone, PartialEq)]
152pub enum ValidationRule {
153    MaxDepth,
154    MaxComplexity,
155    MaxAliases,
156    MaxRootFields,
157    MaxFragments,
158    FieldValidation,
159    TypeValidation,
160    VariableValidation,
161    FragmentValidation,
162    IntrospectionDisabled,
163    OperationNotAllowed,
164    ForbiddenField,
165    CostAnalysis,
166}
167
168/// Main query validator
169pub struct QueryValidator {
170    config: ValidationConfig,
171    schema: Schema,
172}
173
174impl QueryValidator {
175    pub fn new(config: ValidationConfig, schema: Schema) -> Self {
176        Self { config, schema }
177    }
178
179    /// Validate a GraphQL document
180    pub fn validate(&self, document: &Document) -> Result<ValidationResult> {
181        let mut result = ValidationResult::new();
182        let mut validation_context = ValidationContext::new(&self.schema);
183
184        // Collect fragments first
185        for definition in &document.definitions {
186            if let Definition::Fragment(fragment) = definition {
187                validation_context.add_fragment(fragment.name.clone(), fragment.clone());
188            }
189        }
190
191        // Validate each operation
192        for definition in &document.definitions {
193            if let Definition::Operation(operation) = definition {
194                result = self.validate_operation(operation, &validation_context, result)?;
195            }
196        }
197
198        // Global validations
199        result = self.validate_fragments(&validation_context, result)?;
200        result = self.validate_global_limits(document, result)?;
201
202        Ok(result)
203    }
204
205    fn validate_operation(
206        &self,
207        operation: &OperationDefinition,
208        context: &ValidationContext,
209        mut result: ValidationResult,
210    ) -> Result<ValidationResult> {
211        // Check operation name whitelist
212        if let Some(ref allowed_ops) = self.config.allowed_operations {
213            if let Some(ref op_name) = operation.name {
214                if !allowed_ops.contains(op_name) {
215                    return Ok(result.with_error(ValidationError::new(
216                        format!("Operation '{op_name}' is not allowed"),
217                        ValidationRule::OperationNotAllowed,
218                    )));
219                }
220            }
221        }
222
223        // Validate variables
224        result = self.validate_variables(&operation.variable_definitions, context, result)?;
225
226        // Get root type based on operation type
227        let root_type_name = match operation.operation_type {
228            crate::ast::OperationType::Query => self
229                .schema
230                .query_type
231                .as_ref()
232                .ok_or_else(|| anyhow!("Schema has no query type"))?,
233            crate::ast::OperationType::Mutation => self
234                .schema
235                .mutation_type
236                .as_ref()
237                .ok_or_else(|| anyhow!("Schema has no mutation type"))?,
238            crate::ast::OperationType::Subscription => self
239                .schema
240                .subscription_type
241                .as_ref()
242                .ok_or_else(|| anyhow!("Schema has no subscription type"))?,
243        };
244
245        // Validate selection set
246        let (depth, complexity) = self.validate_selection_set(
247            &operation.selection_set,
248            root_type_name,
249            context,
250            0,
251            Vec::new(),
252        )?;
253
254        result.depth = depth.max(result.depth);
255        result.complexity_score += complexity;
256
257        // Check depth limit
258        if result.depth > self.config.max_depth {
259            let current_depth = result.depth;
260            result = result.with_error(ValidationError::new(
261                format!(
262                    "Query depth {} exceeds maximum allowed depth {}",
263                    current_depth, self.config.max_depth
264                ),
265                ValidationRule::MaxDepth,
266            ));
267        }
268
269        // Check complexity limit
270        if result.complexity_score > self.config.max_complexity {
271            let current_complexity = result.complexity_score;
272            result = result.with_error(ValidationError::new(
273                format!(
274                    "Query complexity {} exceeds maximum allowed complexity {}",
275                    current_complexity, self.config.max_complexity
276                ),
277                ValidationRule::MaxComplexity,
278            ));
279        }
280
281        Ok(result)
282    }
283
284    fn validate_selection_set(
285        &self,
286        selection_set: &SelectionSet,
287        parent_type_name: &str,
288        context: &ValidationContext,
289        current_depth: usize,
290        path: Vec<String>,
291    ) -> Result<(usize, usize)> {
292        let mut max_depth = current_depth;
293        let mut total_complexity = 0;
294        let mut alias_count = 0;
295
296        let parent_type = self
297            .schema
298            .get_type(parent_type_name)
299            .ok_or_else(|| anyhow!("Type '{}' not found in schema", parent_type_name))?;
300
301        for selection in &selection_set.selections {
302            match selection {
303                Selection::Field(field) => {
304                    // Check for aliases
305                    if field.alias.is_some() {
306                        alias_count += 1;
307                    }
308
309                    // Check forbidden fields
310                    if self.config.forbidden_fields.contains(&field.name) {
311                        return Err(anyhow!("Field '{}' is forbidden", field.name));
312                    }
313
314                    // Check introspection fields
315                    if self.config.disable_introspection && field.name.starts_with("__") {
316                        return Err(anyhow!("Introspection is disabled"));
317                    }
318
319                    // Validate field exists on type
320                    let field_type = self.get_field_type(parent_type, &field.name)?;
321
322                    let mut field_path = path.clone();
323                    field_path.push(field.alias.as_ref().unwrap_or(&field.name).clone());
324
325                    // Calculate field complexity
326                    let field_complexity = self.calculate_field_complexity(field);
327                    total_complexity += field_complexity;
328
329                    // Recurse into nested selection sets
330                    if let Some(ref nested_selection_set) = field.selection_set {
331                        let inner_type_name = self.get_inner_type_name(field_type);
332                        let (nested_depth, nested_complexity) = self.validate_selection_set(
333                            nested_selection_set,
334                            &inner_type_name,
335                            context,
336                            current_depth + 1,
337                            field_path,
338                        )?;
339                        max_depth = max_depth.max(nested_depth);
340                        total_complexity += nested_complexity;
341                    }
342                }
343                Selection::InlineFragment(inline_fragment) => {
344                    let fragment_type =
345                        if let Some(ref type_condition) = inline_fragment.type_condition {
346                            type_condition
347                        } else {
348                            parent_type_name
349                        };
350
351                    let (nested_depth, nested_complexity) = self.validate_selection_set(
352                        &inline_fragment.selection_set,
353                        fragment_type,
354                        context,
355                        current_depth,
356                        path.clone(),
357                    )?;
358                    max_depth = max_depth.max(nested_depth);
359                    total_complexity += nested_complexity;
360                }
361                Selection::FragmentSpread(fragment_spread) => {
362                    if let Some(fragment_def) = context.get_fragment(&fragment_spread.fragment_name)
363                    {
364                        let (nested_depth, nested_complexity) = self.validate_selection_set(
365                            &fragment_def.selection_set,
366                            &fragment_def.type_condition,
367                            context,
368                            current_depth,
369                            path.clone(),
370                        )?;
371                        max_depth = max_depth.max(nested_depth);
372                        total_complexity += nested_complexity;
373                    } else {
374                        return Err(anyhow!(
375                            "Fragment '{}' not found",
376                            fragment_spread.fragment_name
377                        ));
378                    }
379                }
380            }
381        }
382
383        // Check alias limit
384        if alias_count > self.config.max_aliases {
385            return Err(anyhow!(
386                "Too many aliases: {} exceeds limit {}",
387                alias_count,
388                self.config.max_aliases
389            ));
390        }
391
392        Ok((max_depth, total_complexity))
393    }
394
395    fn validate_variables(
396        &self,
397        variable_definitions: &[VariableDefinition],
398        _context: &ValidationContext,
399        mut result: ValidationResult,
400    ) -> Result<ValidationResult> {
401        for var_def in variable_definitions {
402            // Validate variable type exists in schema
403            if !self.type_exists(&var_def.type_) {
404                result = result.with_error(ValidationError::new(
405                    format!(
406                        "Variable type '{}' not found in schema",
407                        var_def.type_.name()
408                    ),
409                    ValidationRule::VariableValidation,
410                ));
411            }
412
413            // Validate default value compatibility
414            if let Some(ref default_value) = var_def.default_value {
415                if !self.is_value_compatible_with_type(default_value, &var_def.type_) {
416                    result = result.with_error(ValidationError::new(
417                        format!(
418                            "Default value for variable '{}' is not compatible with type '{}'",
419                            var_def.variable.name,
420                            var_def.type_.name()
421                        ),
422                        ValidationRule::VariableValidation,
423                    ));
424                }
425            }
426        }
427
428        Ok(result)
429    }
430
431    fn validate_fragments(
432        &self,
433        context: &ValidationContext,
434        mut result: ValidationResult,
435    ) -> Result<ValidationResult> {
436        if context.fragments.len() > self.config.max_fragments {
437            result = result.with_error(ValidationError::new(
438                format!(
439                    "Too many fragments: {} exceeds limit {}",
440                    context.fragments.len(),
441                    self.config.max_fragments
442                ),
443                ValidationRule::MaxFragments,
444            ));
445        }
446
447        // Validate fragment type conditions
448        for (fragment_name, fragment) in &context.fragments {
449            if !self.schema.types.contains_key(&fragment.type_condition) {
450                result = result.with_error(ValidationError::new(
451                    format!(
452                        "Fragment '{}' has unknown type condition '{}'",
453                        fragment_name, fragment.type_condition
454                    ),
455                    ValidationRule::FragmentValidation,
456                ));
457            }
458        }
459
460        Ok(result)
461    }
462
463    fn validate_global_limits(
464        &self,
465        document: &Document,
466        mut result: ValidationResult,
467    ) -> Result<ValidationResult> {
468        let mut root_field_count = 0;
469
470        for definition in &document.definitions {
471            if let Definition::Operation(operation) = definition {
472                root_field_count += operation.selection_set.selections.len();
473            }
474        }
475
476        if root_field_count > self.config.max_root_fields {
477            result = result.with_error(ValidationError::new(
478                format!(
479                    "Too many root fields: {} exceeds limit {}",
480                    root_field_count, self.config.max_root_fields
481                ),
482                ValidationRule::MaxRootFields,
483            ));
484        }
485
486        Ok(result)
487    }
488
489    fn get_field_type<'a>(
490        &self,
491        parent_type: &'a GraphQLType,
492        field_name: &str,
493    ) -> Result<&'a GraphQLType> {
494        match parent_type {
495            GraphQLType::Object(obj) => obj
496                .fields
497                .get(field_name)
498                .map(|field| &field.field_type)
499                .ok_or_else(|| {
500                    anyhow!(
501                        "Field '{}' not found on object type '{}'",
502                        field_name,
503                        obj.name
504                    )
505                }),
506            GraphQLType::Interface(iface) => iface
507                .fields
508                .get(field_name)
509                .map(|field| &field.field_type)
510                .ok_or_else(|| {
511                    anyhow!(
512                        "Field '{}' not found on interface type '{}'",
513                        field_name,
514                        iface.name
515                    )
516                }),
517            _ => Err(anyhow!(
518                "Cannot select field '{}' on non-composite type",
519                field_name
520            )),
521        }
522    }
523
524    #[allow(clippy::only_used_in_recursion)]
525    fn get_inner_type_name(&self, graphql_type: &GraphQLType) -> String {
526        match graphql_type {
527            GraphQLType::NonNull(inner) => self.get_inner_type_name(inner),
528            GraphQLType::List(inner) => self.get_inner_type_name(inner),
529            _ => graphql_type.name().to_string(),
530        }
531    }
532
533    fn calculate_field_complexity(&self, field: &Field) -> usize {
534        if !self.config.enable_cost_analysis {
535            return 1;
536        }
537
538        let mut complexity = 1;
539
540        // Add complexity for arguments
541        complexity += field.arguments.len();
542
543        // Add complexity for nested selections
544        if let Some(ref selection_set) = field.selection_set {
545            complexity += selection_set.selections.len();
546        }
547
548        // Special cases for expensive operations
549        match field.name.as_str() {
550            "sparql" => complexity *= 10, // Raw SPARQL queries are expensive
551            name if name.contains("search") => complexity *= 3,
552            name if name.contains("aggregate") => complexity *= 5,
553            _ => {}
554        }
555
556        complexity
557    }
558
559    fn type_exists(&self, ast_type: &crate::ast::Type) -> bool {
560        match ast_type {
561            crate::ast::Type::NamedType(name) => {
562                self.schema.types.contains_key(name)
563                    || matches!(name.as_str(), "String" | "Int" | "Float" | "Boolean" | "ID")
564            }
565            crate::ast::Type::ListType(inner) => self.type_exists(inner),
566            crate::ast::Type::NonNullType(inner) => self.type_exists(inner),
567        }
568    }
569
570    #[allow(clippy::only_used_in_recursion)]
571    fn is_value_compatible_with_type(&self, value: &Value, ast_type: &crate::ast::Type) -> bool {
572        match (value, ast_type) {
573            (Value::NullValue, crate::ast::Type::NonNullType(_)) => false,
574            (Value::NullValue, _) => true,
575            (Value::StringValue(_), crate::ast::Type::NamedType(name)) => {
576                matches!(name.as_str(), "String" | "ID")
577            }
578            (Value::IntValue(_), crate::ast::Type::NamedType(name)) => {
579                matches!(name.as_str(), "Int" | "ID")
580            }
581            (Value::FloatValue(_), crate::ast::Type::NamedType(name)) => {
582                matches!(name.as_str(), "Float")
583            }
584            (Value::BooleanValue(_), crate::ast::Type::NamedType(name)) => {
585                matches!(name.as_str(), "Boolean")
586            }
587            (Value::ListValue(list), crate::ast::Type::ListType(inner_type)) => list
588                .iter()
589                .all(|item| self.is_value_compatible_with_type(item, inner_type)),
590            (_, crate::ast::Type::NonNullType(inner)) => {
591                self.is_value_compatible_with_type(value, inner)
592            }
593            _ => false,
594        }
595    }
596}
597
598/// Context for validation operations
599struct ValidationContext {
600    #[allow(dead_code)]
601    schema: Schema,
602    fragments: HashMap<String, crate::ast::FragmentDefinition>,
603}
604
605impl ValidationContext {
606    fn new(schema: &Schema) -> Self {
607        Self {
608            schema: schema.clone(),
609            fragments: HashMap::new(),
610        }
611    }
612
613    fn add_fragment(&mut self, name: String, fragment: crate::ast::FragmentDefinition) {
614        self.fragments.insert(name, fragment);
615    }
616
617    fn get_fragment(&self, name: &str) -> Option<&crate::ast::FragmentDefinition> {
618        self.fragments.get(name)
619    }
620}
621
622/// Rate limiting for query validation
623#[derive(Debug, Clone)]
624pub struct RateLimitConfig {
625    /// Maximum queries per minute per client
626    pub max_queries_per_minute: usize,
627    /// Maximum complexity per minute per client
628    pub max_complexity_per_minute: usize,
629    /// Time window for rate limiting
630    pub window_duration: Duration,
631}
632
633impl Default for RateLimitConfig {
634    fn default() -> Self {
635        Self {
636            max_queries_per_minute: 60,
637            max_complexity_per_minute: 10000,
638            window_duration: Duration::from_secs(60),
639        }
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use crate::types::{BuiltinScalars, FieldType, ObjectType};
647
648    fn create_test_schema() -> Schema {
649        let mut schema = Schema::new();
650
651        let query_type = ObjectType::new("Query".to_string())
652            .with_field(
653                "hello".to_string(),
654                FieldType::new(
655                    "hello".to_string(),
656                    GraphQLType::Scalar(BuiltinScalars::string()),
657                ),
658            )
659            .with_field(
660                "__schema".to_string(),
661                FieldType::new(
662                    "__schema".to_string(),
663                    GraphQLType::Scalar(BuiltinScalars::string()),
664                ),
665            );
666
667        schema.add_type(GraphQLType::Object(query_type));
668        schema.set_query_type("Query".to_string());
669
670        schema
671    }
672
673    #[test]
674    fn test_validation_config_default() {
675        let config = ValidationConfig::default();
676        assert_eq!(config.max_depth, 10);
677        assert_eq!(config.max_complexity, 1000);
678        assert!(!config.disable_introspection);
679    }
680
681    #[test]
682    fn test_validation_result_creation() {
683        let result = ValidationResult::new()
684            .with_error(ValidationError::new(
685                "Test error".to_string(),
686                ValidationRule::MaxDepth,
687            ))
688            .with_warning(ValidationWarning::new("Test warning".to_string()))
689            .with_complexity(100)
690            .with_depth(5);
691
692        assert!(!result.is_valid);
693        assert_eq!(result.errors.len(), 1);
694        assert_eq!(result.warnings.len(), 1);
695        assert_eq!(result.complexity_score, 100);
696        assert_eq!(result.depth, 5);
697    }
698
699    #[test]
700    fn test_query_validator_creation() {
701        let config = ValidationConfig::default();
702        let schema = create_test_schema();
703        let validator = QueryValidator::new(config, schema);
704
705        // Validator should be created successfully
706        assert_eq!(validator.config.max_depth, 10);
707    }
708
709    #[test]
710    fn test_validation_error_with_path() {
711        let error = ValidationError::new("Test error".to_string(), ValidationRule::FieldValidation)
712            .with_path(vec![
713                "query".to_string(),
714                "user".to_string(),
715                "name".to_string(),
716            ]);
717
718        assert_eq!(error.message, "Test error");
719        assert_eq!(error.path, vec!["query", "user", "name"]);
720        assert_eq!(error.rule, ValidationRule::FieldValidation);
721    }
722
723    #[test]
724    fn test_validation_warning_with_suggestion() {
725        let warning = ValidationWarning::new("Performance warning".to_string())
726            .with_suggestion("Consider using pagination".to_string());
727
728        assert_eq!(warning.message, "Performance warning");
729        assert_eq!(
730            warning.suggestion,
731            Some("Consider using pagination".to_string())
732        );
733    }
734
735    #[test]
736    fn test_rate_limit_config_default() {
737        let config = RateLimitConfig::default();
738        assert_eq!(config.max_queries_per_minute, 60);
739        assert_eq!(config.max_complexity_per_minute, 10000);
740        assert_eq!(config.window_duration, Duration::from_secs(60));
741    }
742
743    #[test]
744    fn test_validation_context() {
745        let schema = create_test_schema();
746        let mut context = ValidationContext::new(&schema);
747
748        let fragment = crate::ast::FragmentDefinition {
749            name: "TestFragment".to_string(),
750            type_condition: "Query".to_string(),
751            selection_set: crate::ast::SelectionSet { selections: vec![] },
752            directives: vec![],
753        };
754
755        context.add_fragment("TestFragment".to_string(), fragment.clone());
756
757        let retrieved = context.get_fragment("TestFragment");
758        assert!(retrieved.is_some());
759        assert_eq!(retrieved.expect("should succeed").name, "TestFragment");
760    }
761}