1use graphql_parser::query::{
17 Definition, Document, FragmentDefinition, OperationDefinition, Selection, SelectionSet,
18};
19use serde_json::Value as JsonValue;
20use thiserror::Error;
21
22#[derive(Debug, Error, Clone)]
24pub enum ValidationError {
25 #[error("Query exceeds maximum depth of {max_depth}: depth = {actual_depth}")]
27 QueryTooDeep {
28 max_depth: usize,
30 actual_depth: usize,
32 },
33
34 #[error("Query exceeds maximum complexity of {max_complexity}: score = {actual_complexity}")]
36 QueryTooComplex {
37 max_complexity: usize,
39 actual_complexity: usize,
41 },
42
43 #[error("Invalid variables: {0}")]
45 InvalidVariables(String),
46
47 #[error("Malformed GraphQL query: {0}")]
49 MalformedQuery(String),
50}
51
52#[derive(Debug, Clone)]
54pub struct RequestValidator {
55 max_depth: usize,
57 max_complexity: usize,
59 validate_depth: bool,
61 validate_complexity: bool,
63}
64
65impl RequestValidator {
66 #[must_use]
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 #[must_use]
74 pub fn with_max_depth(mut self, max_depth: usize) -> Self {
75 self.max_depth = max_depth;
76 self
77 }
78
79 #[must_use]
81 pub fn with_max_complexity(mut self, max_complexity: usize) -> Self {
82 self.max_complexity = max_complexity;
83 self
84 }
85
86 #[must_use]
88 pub fn with_depth_validation(mut self, enabled: bool) -> Self {
89 self.validate_depth = enabled;
90 self
91 }
92
93 #[must_use]
95 pub fn with_complexity_validation(mut self, enabled: bool) -> Self {
96 self.validate_complexity = enabled;
97 self
98 }
99
100 pub fn validate_query(&self, query: &str) -> Result<(), ValidationError> {
106 if query.trim().is_empty() {
108 return Err(ValidationError::MalformedQuery("Empty query".to_string()));
109 }
110
111 if !self.validate_depth && !self.validate_complexity {
113 return Ok(());
114 }
115
116 let document = graphql_parser::parse_query::<String>(query)
118 .map_err(|e| ValidationError::MalformedQuery(format!("{e}")))?;
119
120 let fragments: Vec<&FragmentDefinition<String>> = document
122 .definitions
123 .iter()
124 .filter_map(|def| {
125 if let Definition::Fragment(f) = def {
126 Some(f)
127 } else {
128 None
129 }
130 })
131 .collect();
132
133 if self.validate_depth {
135 let depth = self.calculate_depth_ast(&document, &fragments);
136 if depth > self.max_depth {
137 return Err(ValidationError::QueryTooDeep {
138 max_depth: self.max_depth,
139 actual_depth: depth,
140 });
141 }
142 }
143
144 if self.validate_complexity {
146 let complexity = self.calculate_complexity_ast(&document, &fragments);
147 if complexity > self.max_complexity {
148 return Err(ValidationError::QueryTooComplex {
149 max_complexity: self.max_complexity,
150 actual_complexity: complexity,
151 });
152 }
153 }
154
155 Ok(())
156 }
157
158 pub fn validate_variables(&self, variables: Option<&JsonValue>) -> Result<(), ValidationError> {
164 if let Some(vars) = variables {
165 if !vars.is_object() {
166 return Err(ValidationError::InvalidVariables(
167 "Variables must be an object".to_string(),
168 ));
169 }
170 }
171
172 Ok(())
173 }
174
175 fn calculate_depth_ast(
179 &self,
180 document: &Document<String>,
181 fragments: &[&FragmentDefinition<String>],
182 ) -> usize {
183 let mut max_depth = 0;
184
185 for definition in &document.definitions {
186 let depth = match definition {
187 Definition::Operation(op) => match op {
188 OperationDefinition::Query(q) => {
189 self.selection_set_depth(&q.selection_set, fragments, 0)
190 },
191 OperationDefinition::Mutation(m) => {
192 self.selection_set_depth(&m.selection_set, fragments, 0)
193 },
194 OperationDefinition::Subscription(s) => {
195 self.selection_set_depth(&s.selection_set, fragments, 0)
196 },
197 OperationDefinition::SelectionSet(ss) => {
198 self.selection_set_depth(ss, fragments, 0)
199 },
200 },
201 Definition::Fragment(f) => {
202 self.selection_set_depth(&f.selection_set, fragments, 0)
204 },
205 };
206 max_depth = max_depth.max(depth);
207 }
208
209 max_depth
210 }
211
212 fn selection_set_depth(
214 &self,
215 selection_set: &SelectionSet<String>,
216 fragments: &[&FragmentDefinition<String>],
217 recursion_depth: usize,
218 ) -> usize {
219 if recursion_depth > 32 {
221 return self.max_depth + 1;
222 }
223
224 if selection_set.items.is_empty() {
225 return 0;
226 }
227
228 let mut max_child_depth = 0;
229
230 for selection in &selection_set.items {
231 let child_depth = match selection {
232 Selection::Field(field) => {
233 if field.selection_set.items.is_empty() {
234 0
235 } else {
236 self.selection_set_depth(&field.selection_set, fragments, recursion_depth)
237 }
238 },
239 Selection::InlineFragment(inline) => {
240 self.selection_set_depth(&inline.selection_set, fragments, recursion_depth)
241 },
242 Selection::FragmentSpread(spread) => {
243 if let Some(frag) = fragments.iter().find(|f| f.name == spread.fragment_name) {
245 self.selection_set_depth(
246 &frag.selection_set,
247 fragments,
248 recursion_depth + 1,
249 )
250 } else {
251 self.max_depth
253 }
254 },
255 };
256 max_child_depth = max_child_depth.max(child_depth);
257 }
258
259 1 + max_child_depth
260 }
261
262 fn calculate_complexity_ast(
267 &self,
268 document: &Document<String>,
269 fragments: &[&FragmentDefinition<String>],
270 ) -> usize {
271 let mut total = 0;
272
273 for definition in &document.definitions {
274 let cost = match definition {
275 Definition::Operation(op) => match op {
276 OperationDefinition::Query(q) => {
277 self.selection_set_complexity(&q.selection_set, fragments, 0)
278 },
279 OperationDefinition::Mutation(m) => {
280 self.selection_set_complexity(&m.selection_set, fragments, 0)
281 },
282 OperationDefinition::Subscription(s) => {
283 self.selection_set_complexity(&s.selection_set, fragments, 0)
284 },
285 OperationDefinition::SelectionSet(ss) => {
286 self.selection_set_complexity(ss, fragments, 0)
287 },
288 },
289 Definition::Fragment(_) => 0, };
291 total += cost;
292 }
293
294 total
295 }
296
297 fn selection_set_complexity(
302 &self,
303 selection_set: &SelectionSet<String>,
304 fragments: &[&FragmentDefinition<String>],
305 recursion_depth: usize,
306 ) -> usize {
307 if recursion_depth > 32 {
308 return self.max_complexity + 1;
309 }
310
311 let mut total = 0;
312
313 for selection in &selection_set.items {
314 total += match selection {
315 Selection::Field(field) => {
316 let multiplier = Self::extract_limit_multiplier(&field.arguments);
317 if field.selection_set.items.is_empty() {
318 1
320 } else {
321 let nested = self.selection_set_complexity(
323 &field.selection_set,
324 fragments,
325 recursion_depth,
326 );
327 1 + nested * multiplier
328 }
329 },
330 Selection::InlineFragment(inline) => {
331 self.selection_set_complexity(&inline.selection_set, fragments, recursion_depth)
332 },
333 Selection::FragmentSpread(spread) => {
334 if let Some(frag) = fragments.iter().find(|f| f.name == spread.fragment_name) {
335 self.selection_set_complexity(
336 &frag.selection_set,
337 fragments,
338 recursion_depth + 1,
339 )
340 } else {
341 10 }
343 },
344 };
345 }
346
347 total
348 }
349
350 fn extract_limit_multiplier(
355 arguments: &[(String, graphql_parser::query::Value<String>)],
356 ) -> usize {
357 for (name, value) in arguments {
358 if matches!(name.as_str(), "first" | "limit" | "take" | "last") {
359 if let graphql_parser::query::Value::Int(n) = value {
360 let limit = n.as_i64().unwrap_or(10) as usize;
361 return limit.clamp(1, 100);
363 }
364 }
365 }
366 1
368 }
369}
370
371impl Default for RequestValidator {
372 fn default() -> Self {
373 Self {
374 max_depth: 10,
375 max_complexity: 100,
376 validate_depth: true,
377 validate_complexity: true,
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_empty_query_validation() {
388 let validator = RequestValidator::new();
389 assert!(validator.validate_query("").is_err());
390 assert!(validator.validate_query(" ").is_err());
391 }
392
393 #[test]
394 fn test_query_depth_validation() {
395 let validator = RequestValidator::new().with_max_depth(3);
396
397 let shallow = "{ user { id } }";
399 assert!(validator.validate_query(shallow).is_ok());
400
401 let deep = "{ user { profile { settings { theme } } } }";
403 assert!(validator.validate_query(deep).is_err());
404 }
405
406 #[test]
407 fn test_query_complexity_validation() {
408 let validator = RequestValidator::new().with_max_complexity(5);
409
410 let simple = "{ user { id name } }";
412 assert!(validator.validate_query(simple).is_ok());
413 }
414
415 #[test]
416 fn test_variables_validation() {
417 let validator = RequestValidator::new();
418
419 let valid = serde_json::json!({"id": "123", "name": "John"});
421 assert!(validator.validate_variables(Some(&valid)).is_ok());
422
423 assert!(validator.validate_variables(None).is_ok());
425
426 let invalid = serde_json::json!([1, 2, 3]);
428 assert!(validator.validate_variables(Some(&invalid)).is_err());
429 }
430
431 #[test]
432 fn test_disable_validation() {
433 let validator = RequestValidator::new()
434 .with_depth_validation(false)
435 .with_complexity_validation(false)
436 .with_max_depth(1)
437 .with_max_complexity(1);
438
439 let deep = "{ a { b { c { d { e { f } } } } } }";
441 assert!(validator.validate_query(deep).is_ok());
442 }
443
444 #[test]
447 fn test_fragment_depth_bypass_blocked() {
448 let validator = RequestValidator::new().with_max_depth(3);
449
450 let query = "
452 fragment Deep on User {
453 a { b { c { d { e } } } }
454 }
455 query { ...Deep }
456 ";
457 let result = validator.validate_query(query);
458 assert!(result.is_err(), "Fragment depth bypass must be blocked");
459 }
460
461 #[test]
462 fn test_inline_fragment_depth_counted() {
463 let validator = RequestValidator::new().with_max_depth(3);
464
465 let query = "
466 query {
467 ... on User { a { b { c { d } } } }
468 }
469 ";
470 let result = validator.validate_query(query);
471 assert!(result.is_err(), "Inline fragment depth must be counted correctly");
472 }
473
474 #[test]
475 fn test_multiple_fragments_depth() {
476 let validator = RequestValidator::new().with_max_depth(4);
477
478 let query = "
480 fragment B on Type { x { y { z } } }
481 fragment A on Type { inner { ...B } }
482 query { ...A }
483 ";
484 let result = validator.validate_query(query);
485 assert!(result.is_err(), "Chained fragment depth must be detected");
486 }
487
488 #[test]
489 fn test_shallow_fragment_allowed() {
490 let validator = RequestValidator::new().with_max_depth(5);
491
492 let query = "
493 fragment UserFields on User { id name email }
494 query { user { ...UserFields } }
495 ";
496 assert!(validator.validate_query(query).is_ok(), "Shallow fragments should be allowed");
497 }
498
499 #[test]
502 fn test_pagination_limit_multiplier() {
503 let validator = RequestValidator::new().with_max_complexity(50);
504
505 let query = "query { users(first: 100) { id name } }";
508 let result = validator.validate_query(query);
509 assert!(result.is_err(), "High pagination limits must increase complexity");
510 }
511
512 #[test]
513 fn test_nested_list_multiplier() {
514 let validator = RequestValidator::new().with_max_complexity(50);
515
516 let query = "query { users(first: 10) { friends(first: 10) { id } } }";
519 let result = validator.validate_query(query);
520 assert!(result.is_err(), "Nested list multipliers must compound");
521 }
522
523 #[test]
524 fn test_simple_query_low_complexity() {
525 let validator = RequestValidator::new().with_max_complexity(20);
526
527 let query = "query { user { id name email } }";
528 assert!(
529 validator.validate_query(query).is_ok(),
530 "Simple queries should have low complexity"
531 );
532 }
533
534 #[test]
535 fn test_malformed_query_rejected() {
536 let validator = RequestValidator::new();
537 let result = validator.validate_query("{ invalid query {{}}");
538 assert!(result.is_err(), "Malformed queries must be rejected");
539 }
540}