Skip to main content

heliosdb_proxy/graphql/
validation.rs

1//! Query Validation
2//!
3//! Validates GraphQL queries against schema and complexity limits.
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::{
9    GraphQLConfig, GraphQLSchema, GraphQLError, ErrorCode,
10    engine::ParsedDocument, engine::ParsedSelection,
11};
12
13/// Query validator
14#[derive(Debug)]
15pub struct QueryValidator {
16    /// Configuration
17    config: Arc<GraphQLConfig>,
18}
19
20impl QueryValidator {
21    /// Create a new validator
22    pub fn new(config: Arc<GraphQLConfig>) -> Self {
23        Self { config }
24    }
25
26    /// Validate a parsed document
27    pub fn validate(
28        &self,
29        document: &ParsedDocument,
30        schema: &GraphQLSchema,
31    ) -> Result<(), GraphQLError> {
32        // Check complexity
33        let complexity = self.calculate_complexity(document, schema)?;
34        if complexity.total > self.config.limits.max_complexity {
35            return Err(GraphQLError::new(
36                format!(
37                    "Query complexity {} exceeds maximum allowed {}",
38                    complexity.total, self.config.limits.max_complexity
39                ),
40                ErrorCode::QueryTooComplex,
41            ));
42        }
43
44        // Check depth
45        if complexity.max_depth > self.config.limits.max_depth {
46            return Err(GraphQLError::new(
47                format!(
48                    "Query depth {} exceeds maximum allowed {}",
49                    complexity.max_depth, self.config.limits.max_depth
50                ),
51                ErrorCode::QueryTooComplex,
52            ));
53        }
54
55        // Check alias count
56        if complexity.alias_count > self.config.limits.max_aliases {
57            return Err(GraphQLError::new(
58                format!(
59                    "Query has {} aliases, maximum allowed is {}",
60                    complexity.alias_count, self.config.limits.max_aliases
61                ),
62                ErrorCode::QueryTooComplex,
63            ));
64        }
65
66        // Check root fields
67        if document.selections.len() > self.config.limits.max_root_fields as usize {
68            return Err(GraphQLError::new(
69                format!(
70                    "Query has {} root fields, maximum allowed is {}",
71                    document.selections.len(), self.config.limits.max_root_fields
72                ),
73                ErrorCode::QueryTooComplex,
74            ));
75        }
76
77        // Validate fields against schema
78        self.validate_fields(&document.selections, "Query", schema)?;
79
80        Ok(())
81    }
82
83    /// Calculate query complexity
84    pub fn calculate_complexity(
85        &self,
86        document: &ParsedDocument,
87        _schema: &GraphQLSchema,
88    ) -> Result<ComplexityResult, GraphQLError> {
89        let mut result = ComplexityResult::default();
90
91        for selection in &document.selections {
92            self.calculate_selection_complexity(selection, 1, &mut result)?;
93        }
94
95        Ok(result)
96    }
97
98    /// Calculate complexity for a selection
99    fn calculate_selection_complexity(
100        &self,
101        selection: &ParsedSelection,
102        depth: u32,
103        result: &mut ComplexityResult,
104    ) -> Result<(), GraphQLError> {
105        // Update max depth
106        result.max_depth = result.max_depth.max(depth);
107
108        // Base cost per field
109        let mut field_cost = 1u32;
110
111        // Count alias
112        if selection.alias.is_some() {
113            result.alias_count += 1;
114        }
115
116        // Multiplier for list fields with limit
117        if let Some(limit) = selection.arguments.get("limit") {
118            if let Some(l) = limit.as_u64() {
119                field_cost = field_cost.saturating_mul(l.min(100) as u32);
120            }
121        }
122
123        // Add to total
124        result.total = result.total.saturating_add(field_cost);
125        result.field_count += 1;
126
127        // Recurse into nested selections
128        for nested in &selection.selections {
129            self.calculate_selection_complexity(nested, depth + 1, result)?;
130        }
131
132        Ok(())
133    }
134
135    /// Validate fields against schema
136    fn validate_fields(
137        &self,
138        selections: &[ParsedSelection],
139        type_name: &str,
140        schema: &GraphQLSchema,
141    ) -> Result<(), GraphQLError> {
142        // Get type from schema
143        let type_def = schema.get_type(type_name);
144
145        for selection in selections {
146            // Skip __typename and introspection fields
147            if selection.name.starts_with("__") {
148                continue;
149            }
150
151            // Check if field exists (for non-Query types)
152            if type_name != "Query" && type_name != "Mutation" {
153                if let Some(ref type_def) = type_def {
154                    let field_exists = type_def.get_field(&selection.name).is_some();
155                    let rel_exists = schema.get_relationships_for(type_name)
156                        .iter()
157                        .any(|r| r.field_name == selection.name);
158
159                    if !field_exists && !rel_exists {
160                        return Err(GraphQLError::validation_error(format!(
161                            "Field '{}' does not exist on type '{}'",
162                            selection.name, type_name
163                        )));
164                    }
165                }
166            }
167
168            // Validate nested selections
169            if !selection.selections.is_empty() {
170                // Determine the type of this field
171                let nested_type = self.get_field_type(&selection.name, type_name, schema);
172                if let Some(nested_type) = nested_type {
173                    self.validate_fields(&selection.selections, &nested_type, schema)?;
174                }
175            }
176        }
177
178        Ok(())
179    }
180
181    /// Get the type of a field
182    fn get_field_type(&self, field_name: &str, parent_type: &str, schema: &GraphQLSchema) -> Option<String> {
183        // Check direct fields
184        if let Some(type_def) = schema.get_type(parent_type) {
185            if let Some(field) = type_def.get_field(field_name) {
186                return Some(self.extract_type_name(&field.graphql_type));
187            }
188        }
189
190        // Check relationships
191        for rel in schema.get_relationships_for(parent_type) {
192            if rel.field_name == field_name {
193                return Some(rel.to_type.clone());
194            }
195        }
196
197        // For queries, the field name is often the type name
198        Some(super::to_pascal_case(field_name))
199    }
200
201    /// Extract the base type name from a field type
202    fn extract_type_name(&self, field_type: &super::introspector::FieldType) -> String {
203        use super::introspector::FieldType;
204
205        match field_type {
206            FieldType::Scalar(s) => s.to_sdl().to_string(),
207            FieldType::Object(name) => name.clone(),
208            FieldType::List(inner) => self.extract_type_name(inner),
209            FieldType::NonNull(inner) => self.extract_type_name(inner),
210        }
211    }
212}
213
214/// Complexity calculation result
215#[derive(Debug, Clone, Default)]
216pub struct ComplexityResult {
217    /// Total complexity score
218    pub total: u32,
219    /// Maximum depth reached
220    pub max_depth: u32,
221    /// Number of aliases used
222    pub alias_count: u32,
223    /// Total field count
224    pub field_count: u32,
225}
226
227impl ComplexityResult {
228    /// Check if complexity is within limits
229    pub fn is_within_limits(&self, config: &GraphQLConfig) -> bool {
230        self.total <= config.limits.max_complexity
231            && self.max_depth <= config.limits.max_depth
232            && self.alias_count <= config.limits.max_aliases
233    }
234}
235
236/// Validation error
237#[derive(Debug, Clone)]
238pub struct ValidationError {
239    /// Error message
240    pub message: String,
241    /// Error locations
242    pub locations: Vec<(u32, u32)>,
243    /// Validation rule that failed
244    pub rule: ValidationRule,
245}
246
247impl ValidationError {
248    /// Create a new validation error
249    pub fn new(message: impl Into<String>, rule: ValidationRule) -> Self {
250        Self {
251            message: message.into(),
252            locations: Vec::new(),
253            rule,
254        }
255    }
256
257    /// Add a location
258    pub fn at(mut self, line: u32, column: u32) -> Self {
259        self.locations.push((line, column));
260        self
261    }
262}
263
264impl std::fmt::Display for ValidationError {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        write!(f, "{}", self.message)
267    }
268}
269
270impl std::error::Error for ValidationError {}
271
272/// Validation rules
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum ValidationRule {
275    /// Unknown field
276    UnknownField,
277    /// Unknown type
278    UnknownType,
279    /// Unknown argument
280    UnknownArgument,
281    /// Missing required argument
282    MissingArgument,
283    /// Invalid argument type
284    InvalidArgumentType,
285    /// Query too complex
286    QueryTooComplex,
287    /// Query too deep
288    QueryTooDeep,
289    /// Too many aliases
290    TooManyAliases,
291    /// Duplicate field
292    DuplicateField,
293    /// Fragment cycle
294    FragmentCycle,
295    /// Unknown fragment
296    UnknownFragment,
297    /// Invalid fragment spread
298    InvalidFragmentSpread,
299}
300
301/// Rule validator trait
302pub trait RuleValidator: Send + Sync {
303    /// Validate the rule
304    fn validate(
305        &self,
306        document: &ParsedDocument,
307        schema: &GraphQLSchema,
308    ) -> Result<(), ValidationError>;
309
310    /// Get the rule type
311    fn rule(&self) -> ValidationRule;
312}
313
314/// Unknown field validator
315pub struct UnknownFieldValidator;
316
317impl RuleValidator for UnknownFieldValidator {
318    fn validate(
319        &self,
320        document: &ParsedDocument,
321        schema: &GraphQLSchema,
322    ) -> Result<(), ValidationError> {
323        fn check_selections(
324            selections: &[ParsedSelection],
325            type_name: &str,
326            schema: &GraphQLSchema,
327        ) -> Result<(), ValidationError> {
328            for selection in selections {
329                if selection.name.starts_with("__") {
330                    continue;
331                }
332
333                if type_name != "Query" && type_name != "Mutation" {
334                    if let Some(type_def) = schema.get_type(type_name) {
335                        if type_def.get_field(&selection.name).is_none() {
336                            return Err(ValidationError::new(
337                                format!("Unknown field '{}' on type '{}'", selection.name, type_name),
338                                ValidationRule::UnknownField,
339                            ));
340                        }
341                    }
342                }
343
344                // Recurse (would need proper type resolution)
345                if !selection.selections.is_empty() {
346                    check_selections(&selection.selections, &selection.name, schema)?;
347                }
348            }
349            Ok(())
350        }
351
352        check_selections(&document.selections, "Query", schema)
353    }
354
355    fn rule(&self) -> ValidationRule {
356        ValidationRule::UnknownField
357    }
358}
359
360/// Depth validator
361pub struct DepthValidator {
362    max_depth: u32,
363}
364
365impl DepthValidator {
366    pub fn new(max_depth: u32) -> Self {
367        Self { max_depth }
368    }
369}
370
371impl RuleValidator for DepthValidator {
372    fn validate(
373        &self,
374        document: &ParsedDocument,
375        _schema: &GraphQLSchema,
376    ) -> Result<(), ValidationError> {
377        fn check_depth(selections: &[ParsedSelection], current_depth: u32, max_depth: u32) -> Result<(), ValidationError> {
378            if current_depth > max_depth {
379                return Err(ValidationError::new(
380                    format!("Query depth {} exceeds maximum {}", current_depth, max_depth),
381                    ValidationRule::QueryTooDeep,
382                ));
383            }
384
385            for selection in selections {
386                check_depth(&selection.selections, current_depth + 1, max_depth)?;
387            }
388
389            Ok(())
390        }
391
392        check_depth(&document.selections, 1, self.max_depth)
393    }
394
395    fn rule(&self) -> ValidationRule {
396        ValidationRule::QueryTooDeep
397    }
398}
399
400/// Alias count validator
401pub struct AliasValidator {
402    max_aliases: u32,
403}
404
405impl AliasValidator {
406    pub fn new(max_aliases: u32) -> Self {
407        Self { max_aliases }
408    }
409}
410
411impl RuleValidator for AliasValidator {
412    fn validate(
413        &self,
414        document: &ParsedDocument,
415        _schema: &GraphQLSchema,
416    ) -> Result<(), ValidationError> {
417        fn count_aliases(selections: &[ParsedSelection]) -> u32 {
418            let mut count = 0;
419            for selection in selections {
420                if selection.alias.is_some() {
421                    count += 1;
422                }
423                count += count_aliases(&selection.selections);
424            }
425            count
426        }
427
428        let alias_count = count_aliases(&document.selections);
429        if alias_count > self.max_aliases {
430            return Err(ValidationError::new(
431                format!("Query has {} aliases, maximum is {}", alias_count, self.max_aliases),
432                ValidationRule::TooManyAliases,
433            ));
434        }
435
436        Ok(())
437    }
438
439    fn rule(&self) -> ValidationRule {
440        ValidationRule::TooManyAliases
441    }
442}
443
444/// Validation pipeline
445#[derive(Default)]
446pub struct ValidationPipeline {
447    validators: Vec<Box<dyn RuleValidator>>,
448}
449
450impl ValidationPipeline {
451    /// Create a new pipeline
452    pub fn new() -> Self {
453        Self::default()
454    }
455
456    /// Add a validator
457    pub fn add(mut self, validator: impl RuleValidator + 'static) -> Self {
458        self.validators.push(Box::new(validator));
459        self
460    }
461
462    /// Create default validation pipeline
463    pub fn default_pipeline(config: &GraphQLConfig) -> Self {
464        Self::new()
465            .add(UnknownFieldValidator)
466            .add(DepthValidator::new(config.limits.max_depth))
467            .add(AliasValidator::new(config.limits.max_aliases))
468    }
469
470    /// Run all validators
471    pub fn validate(
472        &self,
473        document: &ParsedDocument,
474        schema: &GraphQLSchema,
475    ) -> Result<(), Vec<ValidationError>> {
476        let mut errors = Vec::new();
477
478        for validator in &self.validators {
479            if let Err(e) = validator.validate(document, schema) {
480                errors.push(e);
481            }
482        }
483
484        if errors.is_empty() {
485            Ok(())
486        } else {
487            Err(errors)
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::graphql::{OperationType, introspector::*};
496
497    fn create_test_schema() -> GraphQLSchema {
498        let mut schema = GraphQLSchema::new();
499
500        let mut user_type = GraphQLType::new("User");
501        user_type.add_field(GraphQLField::new("id", FieldType::scalar(crate::graphql::GraphQLScalar::ID)));
502        user_type.add_field(GraphQLField::new("name", FieldType::scalar(crate::graphql::GraphQLScalar::String)));
503        schema.add_type(user_type);
504
505        schema
506    }
507
508    fn create_test_document(selections: Vec<ParsedSelection>) -> ParsedDocument {
509        ParsedDocument {
510            operation_type: OperationType::Query,
511            operation_name: None,
512            selections,
513            variable_definitions: Vec::new(),
514            fragments: HashMap::new(),
515        }
516    }
517
518    #[test]
519    fn test_complexity_calculation() {
520        let config = Arc::new(GraphQLConfig::default());
521        let validator = QueryValidator::new(config);
522        let schema = create_test_schema();
523
524        let document = create_test_document(vec![
525            ParsedSelection {
526                name: "users".to_string(),
527                alias: None,
528                arguments: HashMap::new(),
529                selections: vec![
530                    ParsedSelection {
531                        name: "id".to_string(),
532                        alias: None,
533                        arguments: HashMap::new(),
534                        selections: vec![],
535                        directives: vec![],
536                    },
537                    ParsedSelection {
538                        name: "name".to_string(),
539                        alias: None,
540                        arguments: HashMap::new(),
541                        selections: vec![],
542                        directives: vec![],
543                    },
544                ],
545                directives: vec![],
546            },
547        ]);
548
549        let result = validator.calculate_complexity(&document, &schema).unwrap();
550
551        assert_eq!(result.field_count, 3); // users, id, name
552        assert_eq!(result.max_depth, 2);
553        assert_eq!(result.alias_count, 0);
554    }
555
556    #[test]
557    fn test_complexity_with_limit() {
558        let config = Arc::new(GraphQLConfig::default());
559        let validator = QueryValidator::new(config);
560        let schema = create_test_schema();
561
562        let mut args = HashMap::new();
563        args.insert("limit".to_string(), serde_json::json!(10));
564
565        let document = create_test_document(vec![
566            ParsedSelection {
567                name: "users".to_string(),
568                alias: None,
569                arguments: args,
570                selections: vec![],
571                directives: vec![],
572            },
573        ]);
574
575        let result = validator.calculate_complexity(&document, &schema).unwrap();
576
577        // With limit of 10, complexity should be multiplied
578        assert_eq!(result.total, 10);
579    }
580
581    #[test]
582    fn test_alias_counting() {
583        let config = Arc::new(GraphQLConfig::default());
584        let validator = QueryValidator::new(config);
585        let schema = create_test_schema();
586
587        let document = create_test_document(vec![
588            ParsedSelection {
589                name: "users".to_string(),
590                alias: Some("allUsers".to_string()),
591                arguments: HashMap::new(),
592                selections: vec![
593                    ParsedSelection {
594                        name: "id".to_string(),
595                        alias: Some("userId".to_string()),
596                        arguments: HashMap::new(),
597                        selections: vec![],
598                        directives: vec![],
599                    },
600                ],
601                directives: vec![],
602            },
603        ]);
604
605        let result = validator.calculate_complexity(&document, &schema).unwrap();
606
607        assert_eq!(result.alias_count, 2);
608    }
609
610    #[test]
611    fn test_depth_validator() {
612        let validator = DepthValidator::new(2);
613        let schema = create_test_schema();
614
615        // Depth 1 - should pass
616        let shallow = create_test_document(vec![
617            ParsedSelection {
618                name: "users".to_string(),
619                alias: None,
620                arguments: HashMap::new(),
621                selections: vec![],
622                directives: vec![],
623            },
624        ]);
625        assert!(validator.validate(&shallow, &schema).is_ok());
626
627        // Depth 3 - should fail
628        let deep = create_test_document(vec![
629            ParsedSelection {
630                name: "users".to_string(),
631                alias: None,
632                arguments: HashMap::new(),
633                selections: vec![
634                    ParsedSelection {
635                        name: "posts".to_string(),
636                        alias: None,
637                        arguments: HashMap::new(),
638                        selections: vec![
639                            ParsedSelection {
640                                name: "comments".to_string(),
641                                alias: None,
642                                arguments: HashMap::new(),
643                                selections: vec![],
644                                directives: vec![],
645                            },
646                        ],
647                        directives: vec![],
648                    },
649                ],
650                directives: vec![],
651            },
652        ]);
653        assert!(validator.validate(&deep, &schema).is_err());
654    }
655
656    #[test]
657    fn test_alias_validator() {
658        let validator = AliasValidator::new(2);
659        let schema = create_test_schema();
660
661        // 2 aliases - should pass
662        let within_limit = create_test_document(vec![
663            ParsedSelection {
664                name: "users".to_string(),
665                alias: Some("a1".to_string()),
666                arguments: HashMap::new(),
667                selections: vec![
668                    ParsedSelection {
669                        name: "id".to_string(),
670                        alias: Some("a2".to_string()),
671                        arguments: HashMap::new(),
672                        selections: vec![],
673                        directives: vec![],
674                    },
675                ],
676                directives: vec![],
677            },
678        ]);
679        assert!(validator.validate(&within_limit, &schema).is_ok());
680
681        // 3 aliases - should fail
682        let exceeds_limit = create_test_document(vec![
683            ParsedSelection {
684                name: "users".to_string(),
685                alias: Some("a1".to_string()),
686                arguments: HashMap::new(),
687                selections: vec![
688                    ParsedSelection {
689                        name: "id".to_string(),
690                        alias: Some("a2".to_string()),
691                        arguments: HashMap::new(),
692                        selections: vec![],
693                        directives: vec![],
694                    },
695                    ParsedSelection {
696                        name: "name".to_string(),
697                        alias: Some("a3".to_string()),
698                        arguments: HashMap::new(),
699                        selections: vec![],
700                        directives: vec![],
701                    },
702                ],
703                directives: vec![],
704            },
705        ]);
706        assert!(validator.validate(&exceeds_limit, &schema).is_err());
707    }
708
709    #[test]
710    fn test_validation_pipeline() {
711        let config = GraphQLConfig::default();
712        let pipeline = ValidationPipeline::default_pipeline(&config);
713        let schema = create_test_schema();
714
715        let document = create_test_document(vec![
716            ParsedSelection {
717                name: "users".to_string(),
718                alias: None,
719                arguments: HashMap::new(),
720                selections: vec![],
721                directives: vec![],
722            },
723        ]);
724
725        assert!(pipeline.validate(&document, &schema).is_ok());
726    }
727
728    #[test]
729    fn test_complexity_result_within_limits() {
730        let config = GraphQLConfig::default();
731
732        let within = ComplexityResult {
733            total: 100,
734            max_depth: 5,
735            alias_count: 2,
736            field_count: 10,
737        };
738        assert!(within.is_within_limits(&config));
739
740        let exceeds = ComplexityResult {
741            total: 10000,
742            max_depth: 5,
743            alias_count: 2,
744            field_count: 10,
745        };
746        assert!(!exceeds.is_within_limits(&config));
747    }
748
749    #[test]
750    fn test_validation_error() {
751        let err = ValidationError::new("Test error", ValidationRule::UnknownField)
752            .at(1, 10)
753            .at(2, 5);
754
755        assert_eq!(err.message, "Test error");
756        assert_eq!(err.locations.len(), 2);
757        assert_eq!(err.rule, ValidationRule::UnknownField);
758    }
759}