Skip to main content

fraiseql_server/
validation.rs

1//! GraphQL request validation module.
2//!
3//! Provides validation for GraphQL queries including:
4//! - Query depth validation (prevent deeply nested queries)
5//! - Query complexity scoring (prevent complex queries)
6//! - Variable type validation (ensure variable types match schema)
7//!
8//! # Security
9//!
10//! Uses AST-based validation via `graphql-parser` to correctly handle:
11//! - Fragment spreads (which expand to arbitrary depth)
12//! - Inline fragments
13//! - Aliases and multiple operations
14//! - Pagination arguments that multiply result cardinality
15
16use graphql_parser::query::{
17    Definition, Document, FragmentDefinition, OperationDefinition, Selection, SelectionSet,
18};
19use serde_json::Value as JsonValue;
20use thiserror::Error;
21
22/// Validation error types.
23#[derive(Debug, Error, Clone)]
24pub enum ValidationError {
25    /// Query exceeds maximum allowed depth.
26    #[error("Query exceeds maximum depth of {max_depth}: depth = {actual_depth}")]
27    QueryTooDeep {
28        /// Maximum allowed depth
29        max_depth:    usize,
30        /// Actual query depth
31        actual_depth: usize,
32    },
33
34    /// Query exceeds maximum complexity score.
35    #[error("Query exceeds maximum complexity of {max_complexity}: score = {actual_complexity}")]
36    QueryTooComplex {
37        /// Maximum allowed complexity
38        max_complexity:    usize,
39        /// Actual query complexity
40        actual_complexity: usize,
41    },
42
43    /// Invalid query variables.
44    #[error("Invalid variables: {0}")]
45    InvalidVariables(String),
46
47    /// Malformed GraphQL query.
48    #[error("Malformed GraphQL query: {0}")]
49    MalformedQuery(String),
50}
51
52/// GraphQL request validator.
53#[derive(Debug, Clone)]
54pub struct RequestValidator {
55    /// Maximum query depth allowed.
56    max_depth:           usize,
57    /// Maximum query complexity score allowed.
58    max_complexity:      usize,
59    /// Enable query depth validation.
60    validate_depth:      bool,
61    /// Enable query complexity validation.
62    validate_complexity: bool,
63}
64
65impl RequestValidator {
66    /// Create a new validator with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Set maximum query depth.
73    #[must_use]
74    pub fn with_max_depth(mut self, max_depth: usize) -> Self {
75        self.max_depth = max_depth;
76        self
77    }
78
79    /// Set maximum query complexity.
80    #[must_use]
81    pub fn with_max_complexity(mut self, max_complexity: usize) -> Self {
82        self.max_complexity = max_complexity;
83        self
84    }
85
86    /// Enable/disable depth validation.
87    #[must_use]
88    pub fn with_depth_validation(mut self, enabled: bool) -> Self {
89        self.validate_depth = enabled;
90        self
91    }
92
93    /// Enable/disable complexity validation.
94    #[must_use]
95    pub fn with_complexity_validation(mut self, enabled: bool) -> Self {
96        self.validate_complexity = enabled;
97        self
98    }
99
100    /// Validate a GraphQL query string.
101    ///
102    /// # Errors
103    ///
104    /// Returns `ValidationError` if the query violates any validation rules.
105    pub fn validate_query(&self, query: &str) -> Result<(), ValidationError> {
106        // Validate query is not empty
107        if query.trim().is_empty() {
108            return Err(ValidationError::MalformedQuery("Empty query".to_string()));
109        }
110
111        // Skip AST parsing if both validations are disabled
112        if !self.validate_depth && !self.validate_complexity {
113            return Ok(());
114        }
115
116        // Parse the GraphQL query into an AST
117        let document = graphql_parser::parse_query::<String>(query)
118            .map_err(|e| ValidationError::MalformedQuery(format!("{e}")))?;
119
120        // Collect fragment definitions for resolving fragment spreads
121        let fragments: Vec<&FragmentDefinition<String>> = document
122            .definitions
123            .iter()
124            .filter_map(|def| {
125                if let Definition::Fragment(f) = def {
126                    Some(f)
127                } else {
128                    None
129                }
130            })
131            .collect();
132
133        // Check depth if enabled
134        if self.validate_depth {
135            let depth = self.calculate_depth_ast(&document, &fragments);
136            if depth > self.max_depth {
137                return Err(ValidationError::QueryTooDeep {
138                    max_depth:    self.max_depth,
139                    actual_depth: depth,
140                });
141            }
142        }
143
144        // Check complexity if enabled
145        if self.validate_complexity {
146            let complexity = self.calculate_complexity_ast(&document, &fragments);
147            if complexity > self.max_complexity {
148                return Err(ValidationError::QueryTooComplex {
149                    max_complexity:    self.max_complexity,
150                    actual_complexity: complexity,
151                });
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Validate variables JSON.
159    ///
160    /// # Errors
161    ///
162    /// Returns `ValidationError` if variables are invalid.
163    pub fn validate_variables(&self, variables: Option<&JsonValue>) -> Result<(), ValidationError> {
164        if let Some(vars) = variables {
165            if !vars.is_object() {
166                return Err(ValidationError::InvalidVariables(
167                    "Variables must be an object".to_string(),
168                ));
169            }
170        }
171
172        Ok(())
173    }
174
175    /// Calculate query depth using AST walking.
176    ///
177    /// Correctly handles fragment spreads, inline fragments, and nested selections.
178    fn calculate_depth_ast(
179        &self,
180        document: &Document<String>,
181        fragments: &[&FragmentDefinition<String>],
182    ) -> usize {
183        let mut max_depth = 0;
184
185        for definition in &document.definitions {
186            let depth = match definition {
187                Definition::Operation(op) => match op {
188                    OperationDefinition::Query(q) => {
189                        self.selection_set_depth(&q.selection_set, fragments, 0)
190                    },
191                    OperationDefinition::Mutation(m) => {
192                        self.selection_set_depth(&m.selection_set, fragments, 0)
193                    },
194                    OperationDefinition::Subscription(s) => {
195                        self.selection_set_depth(&s.selection_set, fragments, 0)
196                    },
197                    OperationDefinition::SelectionSet(ss) => {
198                        self.selection_set_depth(ss, fragments, 0)
199                    },
200                },
201                Definition::Fragment(f) => {
202                    // Fragment definitions are walked when referenced
203                    self.selection_set_depth(&f.selection_set, fragments, 0)
204                },
205            };
206            max_depth = max_depth.max(depth);
207        }
208
209        max_depth
210    }
211
212    /// Recursively calculate depth of a selection set.
213    fn selection_set_depth(
214        &self,
215        selection_set: &SelectionSet<String>,
216        fragments: &[&FragmentDefinition<String>],
217        recursion_depth: usize,
218    ) -> usize {
219        // Prevent infinite recursion from circular fragment references
220        if recursion_depth > 32 {
221            return self.max_depth + 1;
222        }
223
224        if selection_set.items.is_empty() {
225            return 0;
226        }
227
228        let mut max_child_depth = 0;
229
230        for selection in &selection_set.items {
231            let child_depth = match selection {
232                Selection::Field(field) => {
233                    if field.selection_set.items.is_empty() {
234                        0
235                    } else {
236                        self.selection_set_depth(&field.selection_set, fragments, recursion_depth)
237                    }
238                },
239                Selection::InlineFragment(inline) => {
240                    self.selection_set_depth(&inline.selection_set, fragments, recursion_depth)
241                },
242                Selection::FragmentSpread(spread) => {
243                    // Find the fragment definition and calculate its depth
244                    if let Some(frag) = fragments.iter().find(|f| f.name == spread.fragment_name) {
245                        self.selection_set_depth(
246                            &frag.selection_set,
247                            fragments,
248                            recursion_depth + 1,
249                        )
250                    } else {
251                        // Unknown fragment: be conservative
252                        self.max_depth
253                    }
254                },
255            };
256            max_child_depth = max_child_depth.max(child_depth);
257        }
258
259        1 + max_child_depth
260    }
261
262    /// Calculate query complexity using AST walking.
263    ///
264    /// Each field adds 1 to complexity. Fields with nested selections (list fields)
265    /// multiply the nested cost. Fragment spreads are resolved and counted.
266    fn calculate_complexity_ast(
267        &self,
268        document: &Document<String>,
269        fragments: &[&FragmentDefinition<String>],
270    ) -> usize {
271        let mut total = 0;
272
273        for definition in &document.definitions {
274            let cost = match definition {
275                Definition::Operation(op) => match op {
276                    OperationDefinition::Query(q) => {
277                        self.selection_set_complexity(&q.selection_set, fragments, 0)
278                    },
279                    OperationDefinition::Mutation(m) => {
280                        self.selection_set_complexity(&m.selection_set, fragments, 0)
281                    },
282                    OperationDefinition::Subscription(s) => {
283                        self.selection_set_complexity(&s.selection_set, fragments, 0)
284                    },
285                    OperationDefinition::SelectionSet(ss) => {
286                        self.selection_set_complexity(ss, fragments, 0)
287                    },
288                },
289                Definition::Fragment(_) => 0, // Only counted when referenced
290            };
291            total += cost;
292        }
293
294        total
295    }
296
297    /// Recursively calculate complexity of a selection set.
298    ///
299    /// Each field costs 1. Fields with sub-selections cost 1 + nested cost.
300    /// Arguments like `first`, `limit`, `take` act as multipliers.
301    fn selection_set_complexity(
302        &self,
303        selection_set: &SelectionSet<String>,
304        fragments: &[&FragmentDefinition<String>],
305        recursion_depth: usize,
306    ) -> usize {
307        if recursion_depth > 32 {
308            return self.max_complexity + 1;
309        }
310
311        let mut total = 0;
312
313        for selection in &selection_set.items {
314            total += match selection {
315                Selection::Field(field) => {
316                    let multiplier = Self::extract_limit_multiplier(&field.arguments);
317                    if field.selection_set.items.is_empty() {
318                        // Leaf field
319                        1
320                    } else {
321                        // Field with sub-selections: base cost + nested * multiplier
322                        let nested = self.selection_set_complexity(
323                            &field.selection_set,
324                            fragments,
325                            recursion_depth,
326                        );
327                        1 + nested * multiplier
328                    }
329                },
330                Selection::InlineFragment(inline) => {
331                    self.selection_set_complexity(&inline.selection_set, fragments, recursion_depth)
332                },
333                Selection::FragmentSpread(spread) => {
334                    if let Some(frag) = fragments.iter().find(|f| f.name == spread.fragment_name) {
335                        self.selection_set_complexity(
336                            &frag.selection_set,
337                            fragments,
338                            recursion_depth + 1,
339                        )
340                    } else {
341                        10 // Unknown fragment: conservative estimate
342                    }
343                },
344            };
345        }
346
347        total
348    }
349
350    /// Extract pagination limit from field arguments to use as a cost multiplier.
351    ///
352    /// Looks for `first`, `limit`, `take`, or `last` arguments. Clamps the value
353    /// to prevent absurdly high multipliers.
354    fn extract_limit_multiplier(
355        arguments: &[(String, graphql_parser::query::Value<String>)],
356    ) -> usize {
357        for (name, value) in arguments {
358            if matches!(name.as_str(), "first" | "limit" | "take" | "last") {
359                if let graphql_parser::query::Value::Int(n) = value {
360                    let limit = n.as_i64().unwrap_or(10) as usize;
361                    // Clamp: treat anything > 100 as 100 for cost purposes
362                    return limit.clamp(1, 100);
363                }
364            }
365        }
366        // Default multiplier for fields without explicit limits
367        1
368    }
369}
370
371impl Default for RequestValidator {
372    fn default() -> Self {
373        Self {
374            max_depth:           10,
375            max_complexity:      100,
376            validate_depth:      true,
377            validate_complexity: true,
378        }
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_empty_query_validation() {
388        let validator = RequestValidator::new();
389        assert!(validator.validate_query("").is_err());
390        assert!(validator.validate_query("   ").is_err());
391    }
392
393    #[test]
394    fn test_query_depth_validation() {
395        let validator = RequestValidator::new().with_max_depth(3);
396
397        // Shallow query should pass (depth = 2)
398        let shallow = "{ user { id } }";
399        assert!(validator.validate_query(shallow).is_ok());
400
401        // Deep query should fail (depth = 4)
402        let deep = "{ user { profile { settings { theme } } } }";
403        assert!(validator.validate_query(deep).is_err());
404    }
405
406    #[test]
407    fn test_query_complexity_validation() {
408        let validator = RequestValidator::new().with_max_complexity(5);
409
410        // Simple query should pass (complexity = 3: root + user + id)
411        let simple = "{ user { id name } }";
412        assert!(validator.validate_query(simple).is_ok());
413    }
414
415    #[test]
416    fn test_variables_validation() {
417        let validator = RequestValidator::new();
418
419        // Valid variables object
420        let valid = serde_json::json!({"id": "123", "name": "John"});
421        assert!(validator.validate_variables(Some(&valid)).is_ok());
422
423        // No variables
424        assert!(validator.validate_variables(None).is_ok());
425
426        // Invalid: variables is not an object
427        let invalid = serde_json::json!([1, 2, 3]);
428        assert!(validator.validate_variables(Some(&invalid)).is_err());
429    }
430
431    #[test]
432    fn test_disable_validation() {
433        let validator = RequestValidator::new()
434            .with_depth_validation(false)
435            .with_complexity_validation(false)
436            .with_max_depth(1)
437            .with_max_complexity(1);
438
439        // Even very deep query should pass when validation is disabled
440        let deep = "{ a { b { c { d { e { f } } } } } }";
441        assert!(validator.validate_query(deep).is_ok());
442    }
443
444    // SECURITY: Fragment-based depth bypass tests (VULN #5)
445
446    #[test]
447    fn test_fragment_depth_bypass_blocked() {
448        let validator = RequestValidator::new().with_max_depth(3);
449
450        // Fragment that expands to depth > 3
451        let query = "
452            fragment Deep on User {
453                a { b { c { d { e } } } }
454            }
455            query { ...Deep }
456        ";
457        let result = validator.validate_query(query);
458        assert!(result.is_err(), "Fragment depth bypass must be blocked");
459    }
460
461    #[test]
462    fn test_inline_fragment_depth_counted() {
463        let validator = RequestValidator::new().with_max_depth(3);
464
465        let query = "
466            query {
467                ... on User { a { b { c { d } } } }
468            }
469        ";
470        let result = validator.validate_query(query);
471        assert!(result.is_err(), "Inline fragment depth must be counted correctly");
472    }
473
474    #[test]
475    fn test_multiple_fragments_depth() {
476        let validator = RequestValidator::new().with_max_depth(4);
477
478        // Fragment A references Fragment B, total depth > 4
479        let query = "
480            fragment B on Type { x { y { z } } }
481            fragment A on Type { inner { ...B } }
482            query { ...A }
483        ";
484        let result = validator.validate_query(query);
485        assert!(result.is_err(), "Chained fragment depth must be detected");
486    }
487
488    #[test]
489    fn test_shallow_fragment_allowed() {
490        let validator = RequestValidator::new().with_max_depth(5);
491
492        let query = "
493            fragment UserFields on User { id name email }
494            query { user { ...UserFields } }
495        ";
496        assert!(validator.validate_query(query).is_ok(), "Shallow fragments should be allowed");
497    }
498
499    // SECURITY: Complexity scoring with multipliers (VULN #6)
500
501    #[test]
502    fn test_pagination_limit_multiplier() {
503        let validator = RequestValidator::new().with_max_complexity(50);
504
505        // This query has a high multiplier from the limit argument
506        // users(first: 100) { id name } => 1 + (1 + 1) * 100 = 201
507        let query = "query { users(first: 100) { id name } }";
508        let result = validator.validate_query(query);
509        assert!(result.is_err(), "High pagination limits must increase complexity");
510    }
511
512    #[test]
513    fn test_nested_list_multiplier() {
514        let validator = RequestValidator::new().with_max_complexity(50);
515
516        // Nested lists should compound: users(first:10) { friends(first:10) { id } }
517        // = 1 + (1 + (1)*10)*10 = 1 + 110 = 111
518        let query = "query { users(first: 10) { friends(first: 10) { id } } }";
519        let result = validator.validate_query(query);
520        assert!(result.is_err(), "Nested list multipliers must compound");
521    }
522
523    #[test]
524    fn test_simple_query_low_complexity() {
525        let validator = RequestValidator::new().with_max_complexity(20);
526
527        let query = "query { user { id name email } }";
528        assert!(
529            validator.validate_query(query).is_ok(),
530            "Simple queries should have low complexity"
531        );
532    }
533
534    #[test]
535    fn test_malformed_query_rejected() {
536        let validator = RequestValidator::new();
537        let result = validator.validate_query("{ invalid query {{}}");
538        assert!(result.is_err(), "Malformed queries must be rejected");
539    }
540}