Skip to main content

heliosdb_proxy/graphql/
validation.rs

1//! Query Validation
2//!
3//! Validates GraphQL queries against schema and complexity limits.
4
5use std::sync::Arc;
6
7use super::{
8    engine::ParsedDocument, engine::ParsedSelection, ErrorCode, GraphQLConfig, GraphQLError,
9    GraphQLSchema,
10};
11
12/// Query validator
13#[derive(Debug)]
14pub struct QueryValidator {
15    /// Configuration
16    config: Arc<GraphQLConfig>,
17}
18
19impl QueryValidator {
20    /// Create a new validator
21    pub fn new(config: Arc<GraphQLConfig>) -> Self {
22        Self { config }
23    }
24
25    /// Validate a parsed document
26    #[allow(clippy::result_large_err)]
27    pub fn validate(
28        &self,
29        document: &ParsedDocument,
30        schema: &GraphQLSchema,
31    ) -> Result<(), GraphQLError> {
32        // Check complexity
33        let complexity = self.calculate_complexity(document, schema)?;
34        if complexity.total > self.config.limits.max_complexity {
35            return Err(GraphQLError::new(
36                format!(
37                    "Query complexity {} exceeds maximum allowed {}",
38                    complexity.total, self.config.limits.max_complexity
39                ),
40                ErrorCode::QueryTooComplex,
41            ));
42        }
43
44        // Check depth
45        if complexity.max_depth > self.config.limits.max_depth {
46            return Err(GraphQLError::new(
47                format!(
48                    "Query depth {} exceeds maximum allowed {}",
49                    complexity.max_depth, self.config.limits.max_depth
50                ),
51                ErrorCode::QueryTooComplex,
52            ));
53        }
54
55        // Check alias count
56        if complexity.alias_count > self.config.limits.max_aliases {
57            return Err(GraphQLError::new(
58                format!(
59                    "Query has {} aliases, maximum allowed is {}",
60                    complexity.alias_count, self.config.limits.max_aliases
61                ),
62                ErrorCode::QueryTooComplex,
63            ));
64        }
65
66        // Check root fields
67        if document.selections.len() > self.config.limits.max_root_fields as usize {
68            return Err(GraphQLError::new(
69                format!(
70                    "Query has {} root fields, maximum allowed is {}",
71                    document.selections.len(),
72                    self.config.limits.max_root_fields
73                ),
74                ErrorCode::QueryTooComplex,
75            ));
76        }
77
78        // Validate fields against schema
79        self.validate_fields(&document.selections, "Query", schema)?;
80
81        Ok(())
82    }
83
84    /// Calculate query complexity
85    #[allow(clippy::result_large_err)]
86    pub fn calculate_complexity(
87        &self,
88        document: &ParsedDocument,
89        _schema: &GraphQLSchema,
90    ) -> Result<ComplexityResult, GraphQLError> {
91        let mut result = ComplexityResult::default();
92
93        for selection in &document.selections {
94            self.calculate_selection_complexity(selection, 1, &mut result)?;
95        }
96
97        Ok(result)
98    }
99
100    /// Calculate complexity for a selection
101    #[allow(clippy::result_large_err)]
102    fn calculate_selection_complexity(
103        &self,
104        selection: &ParsedSelection,
105        depth: u32,
106        result: &mut ComplexityResult,
107    ) -> Result<(), GraphQLError> {
108        // Update max depth
109        result.max_depth = result.max_depth.max(depth);
110
111        // Base cost per field
112        let mut field_cost = 1u32;
113
114        // Count alias
115        if selection.alias.is_some() {
116            result.alias_count += 1;
117        }
118
119        // Multiplier for list fields with limit
120        if let Some(limit) = selection.arguments.get("limit") {
121            if let Some(l) = limit.as_u64() {
122                field_cost = field_cost.saturating_mul(l.min(100) as u32);
123            }
124        }
125
126        // Add to total
127        result.total = result.total.saturating_add(field_cost);
128        result.field_count += 1;
129
130        // Recurse into nested selections
131        for nested in &selection.selections {
132            self.calculate_selection_complexity(nested, depth + 1, result)?;
133        }
134
135        Ok(())
136    }
137
138    /// Validate fields against schema
139    #[allow(clippy::result_large_err)]
140    fn validate_fields(
141        &self,
142        selections: &[ParsedSelection],
143        type_name: &str,
144        schema: &GraphQLSchema,
145    ) -> Result<(), GraphQLError> {
146        // Get type from schema
147        let type_def = schema.get_type(type_name);
148
149        for selection in selections {
150            // Skip __typename and introspection fields
151            if selection.name.starts_with("__") {
152                continue;
153            }
154
155            // Check if field exists (for non-Query types)
156            if type_name != "Query" && type_name != "Mutation" {
157                if let Some(type_def) = type_def {
158                    let field_exists = type_def.get_field(&selection.name).is_some();
159                    let rel_exists = schema
160                        .get_relationships_for(type_name)
161                        .iter()
162                        .any(|r| r.field_name == selection.name);
163
164                    if !field_exists && !rel_exists {
165                        return Err(GraphQLError::validation_error(format!(
166                            "Field '{}' does not exist on type '{}'",
167                            selection.name, type_name
168                        )));
169                    }
170                }
171            }
172
173            // Validate nested selections
174            if !selection.selections.is_empty() {
175                // Determine the type of this field
176                let nested_type = self.get_field_type(&selection.name, type_name, schema);
177                if let Some(nested_type) = nested_type {
178                    self.validate_fields(&selection.selections, &nested_type, schema)?;
179                }
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Get the type of a field
187    fn get_field_type(
188        &self,
189        field_name: &str,
190        parent_type: &str,
191        schema: &GraphQLSchema,
192    ) -> Option<String> {
193        // Check direct fields
194        if let Some(type_def) = schema.get_type(parent_type) {
195            if let Some(field) = type_def.get_field(field_name) {
196                return Some(self.extract_type_name(&field.graphql_type));
197            }
198        }
199
200        // Check relationships
201        for rel in schema.get_relationships_for(parent_type) {
202            if rel.field_name == field_name {
203                return Some(rel.to_type.clone());
204            }
205        }
206
207        // For queries, the field name is often the type name
208        Some(super::to_pascal_case(field_name))
209    }
210
211    /// Extract the base type name from a field type
212    fn extract_type_name(&self, field_type: &super::introspector::FieldType) -> String {
213        use super::introspector::FieldType;
214
215        match field_type {
216            FieldType::Scalar(s) => s.to_sdl().to_string(),
217            FieldType::Object(name) => name.clone(),
218            FieldType::List(inner) => self.extract_type_name(inner),
219            FieldType::NonNull(inner) => self.extract_type_name(inner),
220        }
221    }
222}
223
224/// Complexity calculation result
225#[derive(Debug, Clone, Default)]
226pub struct ComplexityResult {
227    /// Total complexity score
228    pub total: u32,
229    /// Maximum depth reached
230    pub max_depth: u32,
231    /// Number of aliases used
232    pub alias_count: u32,
233    /// Total field count
234    pub field_count: u32,
235}
236
237impl ComplexityResult {
238    /// Check if complexity is within limits
239    pub fn is_within_limits(&self, config: &GraphQLConfig) -> bool {
240        self.total <= config.limits.max_complexity
241            && self.max_depth <= config.limits.max_depth
242            && self.alias_count <= config.limits.max_aliases
243    }
244}
245
246/// Validation error
247#[derive(Debug, Clone)]
248pub struct ValidationError {
249    /// Error message
250    pub message: String,
251    /// Error locations
252    pub locations: Vec<(u32, u32)>,
253    /// Validation rule that failed
254    pub rule: ValidationRule,
255}
256
257impl ValidationError {
258    /// Create a new validation error
259    pub fn new(message: impl Into<String>, rule: ValidationRule) -> Self {
260        Self {
261            message: message.into(),
262            locations: Vec::new(),
263            rule,
264        }
265    }
266
267    /// Add a location
268    pub fn at(mut self, line: u32, column: u32) -> Self {
269        self.locations.push((line, column));
270        self
271    }
272}
273
274impl std::fmt::Display for ValidationError {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        write!(f, "{}", self.message)
277    }
278}
279
280impl std::error::Error for ValidationError {}
281
282/// Validation rules
283#[derive(Debug, Clone, Copy, PartialEq, Eq)]
284pub enum ValidationRule {
285    /// Unknown field
286    UnknownField,
287    /// Unknown type
288    UnknownType,
289    /// Unknown argument
290    UnknownArgument,
291    /// Missing required argument
292    MissingArgument,
293    /// Invalid argument type
294    InvalidArgumentType,
295    /// Query too complex
296    QueryTooComplex,
297    /// Query too deep
298    QueryTooDeep,
299    /// Too many aliases
300    TooManyAliases,
301    /// Duplicate field
302    DuplicateField,
303    /// Fragment cycle
304    FragmentCycle,
305    /// Unknown fragment
306    UnknownFragment,
307    /// Invalid fragment spread
308    InvalidFragmentSpread,
309}
310
311/// Rule validator trait
312pub trait RuleValidator: Send + Sync {
313    /// Validate the rule
314    fn validate(
315        &self,
316        document: &ParsedDocument,
317        schema: &GraphQLSchema,
318    ) -> Result<(), ValidationError>;
319
320    /// Get the rule type
321    fn rule(&self) -> ValidationRule;
322}
323
324/// Unknown field validator
325pub struct UnknownFieldValidator;
326
327impl RuleValidator for UnknownFieldValidator {
328    fn validate(
329        &self,
330        document: &ParsedDocument,
331        schema: &GraphQLSchema,
332    ) -> Result<(), ValidationError> {
333        fn check_selections(
334            selections: &[ParsedSelection],
335            type_name: &str,
336            schema: &GraphQLSchema,
337        ) -> Result<(), ValidationError> {
338            for selection in selections {
339                if selection.name.starts_with("__") {
340                    continue;
341                }
342
343                if type_name != "Query" && type_name != "Mutation" {
344                    if let Some(type_def) = schema.get_type(type_name) {
345                        if type_def.get_field(&selection.name).is_none() {
346                            return Err(ValidationError::new(
347                                format!(
348                                    "Unknown field '{}' on type '{}'",
349                                    selection.name, type_name
350                                ),
351                                ValidationRule::UnknownField,
352                            ));
353                        }
354                    }
355                }
356
357                // Recurse (would need proper type resolution)
358                if !selection.selections.is_empty() {
359                    check_selections(&selection.selections, &selection.name, schema)?;
360                }
361            }
362            Ok(())
363        }
364
365        check_selections(&document.selections, "Query", schema)
366    }
367
368    fn rule(&self) -> ValidationRule {
369        ValidationRule::UnknownField
370    }
371}
372
373/// Depth validator
374pub struct DepthValidator {
375    max_depth: u32,
376}
377
378impl DepthValidator {
379    pub fn new(max_depth: u32) -> Self {
380        Self { max_depth }
381    }
382}
383
384impl RuleValidator for DepthValidator {
385    fn validate(
386        &self,
387        document: &ParsedDocument,
388        _schema: &GraphQLSchema,
389    ) -> Result<(), ValidationError> {
390        fn check_depth(
391            selections: &[ParsedSelection],
392            current_depth: u32,
393            max_depth: u32,
394        ) -> Result<(), ValidationError> {
395            if current_depth > max_depth {
396                return Err(ValidationError::new(
397                    format!(
398                        "Query depth {} exceeds maximum {}",
399                        current_depth, max_depth
400                    ),
401                    ValidationRule::QueryTooDeep,
402                ));
403            }
404
405            for selection in selections {
406                check_depth(&selection.selections, current_depth + 1, max_depth)?;
407            }
408
409            Ok(())
410        }
411
412        check_depth(&document.selections, 1, self.max_depth)
413    }
414
415    fn rule(&self) -> ValidationRule {
416        ValidationRule::QueryTooDeep
417    }
418}
419
420/// Alias count validator
421pub struct AliasValidator {
422    max_aliases: u32,
423}
424
425impl AliasValidator {
426    pub fn new(max_aliases: u32) -> Self {
427        Self { max_aliases }
428    }
429}
430
431impl RuleValidator for AliasValidator {
432    fn validate(
433        &self,
434        document: &ParsedDocument,
435        _schema: &GraphQLSchema,
436    ) -> Result<(), ValidationError> {
437        fn count_aliases(selections: &[ParsedSelection]) -> u32 {
438            let mut count = 0;
439            for selection in selections {
440                if selection.alias.is_some() {
441                    count += 1;
442                }
443                count += count_aliases(&selection.selections);
444            }
445            count
446        }
447
448        let alias_count = count_aliases(&document.selections);
449        if alias_count > self.max_aliases {
450            return Err(ValidationError::new(
451                format!(
452                    "Query has {} aliases, maximum is {}",
453                    alias_count, self.max_aliases
454                ),
455                ValidationRule::TooManyAliases,
456            ));
457        }
458
459        Ok(())
460    }
461
462    fn rule(&self) -> ValidationRule {
463        ValidationRule::TooManyAliases
464    }
465}
466
467/// Validation pipeline
468#[derive(Default)]
469pub struct ValidationPipeline {
470    validators: Vec<Box<dyn RuleValidator>>,
471}
472
473impl ValidationPipeline {
474    /// Create a new pipeline
475    pub fn new() -> Self {
476        Self::default()
477    }
478
479    /// Add a validator
480    #[allow(clippy::should_implement_trait)]
481    pub fn add(mut self, validator: impl RuleValidator + 'static) -> Self {
482        self.validators.push(Box::new(validator));
483        self
484    }
485
486    /// Create default validation pipeline
487    pub fn default_pipeline(config: &GraphQLConfig) -> Self {
488        Self::new()
489            .add(UnknownFieldValidator)
490            .add(DepthValidator::new(config.limits.max_depth))
491            .add(AliasValidator::new(config.limits.max_aliases))
492    }
493
494    /// Run all validators
495    pub fn validate(
496        &self,
497        document: &ParsedDocument,
498        schema: &GraphQLSchema,
499    ) -> Result<(), Vec<ValidationError>> {
500        let mut errors = Vec::new();
501
502        for validator in &self.validators {
503            if let Err(e) = validator.validate(document, schema) {
504                errors.push(e);
505            }
506        }
507
508        if errors.is_empty() {
509            Ok(())
510        } else {
511            Err(errors)
512        }
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use crate::graphql::{introspector::*, OperationType};
520    use std::collections::HashMap;
521
522    fn create_test_schema() -> GraphQLSchema {
523        let mut schema = GraphQLSchema::new();
524
525        let mut user_type = GraphQLType::new("User");
526        user_type.add_field(GraphQLField::new(
527            "id",
528            FieldType::scalar(crate::graphql::GraphQLScalar::ID),
529        ));
530        user_type.add_field(GraphQLField::new(
531            "name",
532            FieldType::scalar(crate::graphql::GraphQLScalar::String),
533        ));
534        schema.add_type(user_type);
535
536        schema
537    }
538
539    fn create_test_document(selections: Vec<ParsedSelection>) -> ParsedDocument {
540        ParsedDocument {
541            operation_type: OperationType::Query,
542            operation_name: None,
543            selections,
544            variable_definitions: Vec::new(),
545            fragments: HashMap::new(),
546        }
547    }
548
549    #[test]
550    fn test_complexity_calculation() {
551        let config = Arc::new(GraphQLConfig::default());
552        let validator = QueryValidator::new(config);
553        let schema = create_test_schema();
554
555        let document = create_test_document(vec![ParsedSelection {
556            name: "users".to_string(),
557            alias: None,
558            arguments: HashMap::new(),
559            selections: vec![
560                ParsedSelection {
561                    name: "id".to_string(),
562                    alias: None,
563                    arguments: HashMap::new(),
564                    selections: vec![],
565                    directives: vec![],
566                },
567                ParsedSelection {
568                    name: "name".to_string(),
569                    alias: None,
570                    arguments: HashMap::new(),
571                    selections: vec![],
572                    directives: vec![],
573                },
574            ],
575            directives: vec![],
576        }]);
577
578        let result = validator.calculate_complexity(&document, &schema).unwrap();
579
580        assert_eq!(result.field_count, 3); // users, id, name
581        assert_eq!(result.max_depth, 2);
582        assert_eq!(result.alias_count, 0);
583    }
584
585    #[test]
586    fn test_complexity_with_limit() {
587        let config = Arc::new(GraphQLConfig::default());
588        let validator = QueryValidator::new(config);
589        let schema = create_test_schema();
590
591        let mut args = HashMap::new();
592        args.insert("limit".to_string(), serde_json::json!(10));
593
594        let document = create_test_document(vec![ParsedSelection {
595            name: "users".to_string(),
596            alias: None,
597            arguments: args,
598            selections: vec![],
599            directives: vec![],
600        }]);
601
602        let result = validator.calculate_complexity(&document, &schema).unwrap();
603
604        // With limit of 10, complexity should be multiplied
605        assert_eq!(result.total, 10);
606    }
607
608    #[test]
609    fn test_alias_counting() {
610        let config = Arc::new(GraphQLConfig::default());
611        let validator = QueryValidator::new(config);
612        let schema = create_test_schema();
613
614        let document = create_test_document(vec![ParsedSelection {
615            name: "users".to_string(),
616            alias: Some("allUsers".to_string()),
617            arguments: HashMap::new(),
618            selections: vec![ParsedSelection {
619                name: "id".to_string(),
620                alias: Some("userId".to_string()),
621                arguments: HashMap::new(),
622                selections: vec![],
623                directives: vec![],
624            }],
625            directives: vec![],
626        }]);
627
628        let result = validator.calculate_complexity(&document, &schema).unwrap();
629
630        assert_eq!(result.alias_count, 2);
631    }
632
633    #[test]
634    fn test_depth_validator() {
635        let validator = DepthValidator::new(2);
636        let schema = create_test_schema();
637
638        // Depth 1 - should pass
639        let shallow = create_test_document(vec![ParsedSelection {
640            name: "users".to_string(),
641            alias: None,
642            arguments: HashMap::new(),
643            selections: vec![],
644            directives: vec![],
645        }]);
646        assert!(validator.validate(&shallow, &schema).is_ok());
647
648        // Depth 3 - should fail
649        let deep = create_test_document(vec![ParsedSelection {
650            name: "users".to_string(),
651            alias: None,
652            arguments: HashMap::new(),
653            selections: vec![ParsedSelection {
654                name: "posts".to_string(),
655                alias: None,
656                arguments: HashMap::new(),
657                selections: vec![ParsedSelection {
658                    name: "comments".to_string(),
659                    alias: None,
660                    arguments: HashMap::new(),
661                    selections: vec![],
662                    directives: vec![],
663                }],
664                directives: vec![],
665            }],
666            directives: vec![],
667        }]);
668        assert!(validator.validate(&deep, &schema).is_err());
669    }
670
671    #[test]
672    fn test_alias_validator() {
673        let validator = AliasValidator::new(2);
674        let schema = create_test_schema();
675
676        // 2 aliases - should pass
677        let within_limit = create_test_document(vec![ParsedSelection {
678            name: "users".to_string(),
679            alias: Some("a1".to_string()),
680            arguments: HashMap::new(),
681            selections: vec![ParsedSelection {
682                name: "id".to_string(),
683                alias: Some("a2".to_string()),
684                arguments: HashMap::new(),
685                selections: vec![],
686                directives: vec![],
687            }],
688            directives: vec![],
689        }]);
690        assert!(validator.validate(&within_limit, &schema).is_ok());
691
692        // 3 aliases - should fail
693        let exceeds_limit = create_test_document(vec![ParsedSelection {
694            name: "users".to_string(),
695            alias: Some("a1".to_string()),
696            arguments: HashMap::new(),
697            selections: vec![
698                ParsedSelection {
699                    name: "id".to_string(),
700                    alias: Some("a2".to_string()),
701                    arguments: HashMap::new(),
702                    selections: vec![],
703                    directives: vec![],
704                },
705                ParsedSelection {
706                    name: "name".to_string(),
707                    alias: Some("a3".to_string()),
708                    arguments: HashMap::new(),
709                    selections: vec![],
710                    directives: vec![],
711                },
712            ],
713            directives: vec![],
714        }]);
715        assert!(validator.validate(&exceeds_limit, &schema).is_err());
716    }
717
718    #[test]
719    fn test_validation_pipeline() {
720        let config = GraphQLConfig::default();
721        let pipeline = ValidationPipeline::default_pipeline(&config);
722        let schema = create_test_schema();
723
724        let document = create_test_document(vec![ParsedSelection {
725            name: "users".to_string(),
726            alias: None,
727            arguments: HashMap::new(),
728            selections: vec![],
729            directives: vec![],
730        }]);
731
732        assert!(pipeline.validate(&document, &schema).is_ok());
733    }
734
735    #[test]
736    fn test_complexity_result_within_limits() {
737        let config = GraphQLConfig::default();
738
739        let within = ComplexityResult {
740            total: 100,
741            max_depth: 5,
742            alias_count: 2,
743            field_count: 10,
744        };
745        assert!(within.is_within_limits(&config));
746
747        let exceeds = ComplexityResult {
748            total: 10000,
749            max_depth: 5,
750            alias_count: 2,
751            field_count: 10,
752        };
753        assert!(!exceeds.is_within_limits(&config));
754    }
755
756    #[test]
757    fn test_validation_error() {
758        let err = ValidationError::new("Test error", ValidationRule::UnknownField)
759            .at(1, 10)
760            .at(2, 5);
761
762        assert_eq!(err.message, "Test error");
763        assert_eq!(err.locations.len(), 2);
764        assert_eq!(err.rule, ValidationRule::UnknownField);
765    }
766}