1use std::sync::Arc;
6
7use super::{
8 engine::ParsedDocument, engine::ParsedSelection, ErrorCode, GraphQLConfig, GraphQLError,
9 GraphQLSchema,
10};
11
12#[derive(Debug)]
14pub struct QueryValidator {
15 config: Arc<GraphQLConfig>,
17}
18
19impl QueryValidator {
20 pub fn new(config: Arc<GraphQLConfig>) -> Self {
22 Self { config }
23 }
24
25 #[allow(clippy::result_large_err)]
27 pub fn validate(
28 &self,
29 document: &ParsedDocument,
30 schema: &GraphQLSchema,
31 ) -> Result<(), GraphQLError> {
32 let complexity = self.calculate_complexity(document, schema)?;
34 if complexity.total > self.config.limits.max_complexity {
35 return Err(GraphQLError::new(
36 format!(
37 "Query complexity {} exceeds maximum allowed {}",
38 complexity.total, self.config.limits.max_complexity
39 ),
40 ErrorCode::QueryTooComplex,
41 ));
42 }
43
44 if complexity.max_depth > self.config.limits.max_depth {
46 return Err(GraphQLError::new(
47 format!(
48 "Query depth {} exceeds maximum allowed {}",
49 complexity.max_depth, self.config.limits.max_depth
50 ),
51 ErrorCode::QueryTooComplex,
52 ));
53 }
54
55 if complexity.alias_count > self.config.limits.max_aliases {
57 return Err(GraphQLError::new(
58 format!(
59 "Query has {} aliases, maximum allowed is {}",
60 complexity.alias_count, self.config.limits.max_aliases
61 ),
62 ErrorCode::QueryTooComplex,
63 ));
64 }
65
66 if document.selections.len() > self.config.limits.max_root_fields as usize {
68 return Err(GraphQLError::new(
69 format!(
70 "Query has {} root fields, maximum allowed is {}",
71 document.selections.len(),
72 self.config.limits.max_root_fields
73 ),
74 ErrorCode::QueryTooComplex,
75 ));
76 }
77
78 self.validate_fields(&document.selections, "Query", schema)?;
80
81 Ok(())
82 }
83
84 #[allow(clippy::result_large_err)]
86 pub fn calculate_complexity(
87 &self,
88 document: &ParsedDocument,
89 _schema: &GraphQLSchema,
90 ) -> Result<ComplexityResult, GraphQLError> {
91 let mut result = ComplexityResult::default();
92
93 for selection in &document.selections {
94 self.calculate_selection_complexity(selection, 1, &mut result)?;
95 }
96
97 Ok(result)
98 }
99
100 #[allow(clippy::result_large_err)]
102 fn calculate_selection_complexity(
103 &self,
104 selection: &ParsedSelection,
105 depth: u32,
106 result: &mut ComplexityResult,
107 ) -> Result<(), GraphQLError> {
108 result.max_depth = result.max_depth.max(depth);
110
111 let mut field_cost = 1u32;
113
114 if selection.alias.is_some() {
116 result.alias_count += 1;
117 }
118
119 if let Some(limit) = selection.arguments.get("limit") {
121 if let Some(l) = limit.as_u64() {
122 field_cost = field_cost.saturating_mul(l.min(100) as u32);
123 }
124 }
125
126 result.total = result.total.saturating_add(field_cost);
128 result.field_count += 1;
129
130 for nested in &selection.selections {
132 self.calculate_selection_complexity(nested, depth + 1, result)?;
133 }
134
135 Ok(())
136 }
137
138 #[allow(clippy::result_large_err)]
140 fn validate_fields(
141 &self,
142 selections: &[ParsedSelection],
143 type_name: &str,
144 schema: &GraphQLSchema,
145 ) -> Result<(), GraphQLError> {
146 let type_def = schema.get_type(type_name);
148
149 for selection in selections {
150 if selection.name.starts_with("__") {
152 continue;
153 }
154
155 if type_name != "Query" && type_name != "Mutation" {
157 if let Some(type_def) = type_def {
158 let field_exists = type_def.get_field(&selection.name).is_some();
159 let rel_exists = schema
160 .get_relationships_for(type_name)
161 .iter()
162 .any(|r| r.field_name == selection.name);
163
164 if !field_exists && !rel_exists {
165 return Err(GraphQLError::validation_error(format!(
166 "Field '{}' does not exist on type '{}'",
167 selection.name, type_name
168 )));
169 }
170 }
171 }
172
173 if !selection.selections.is_empty() {
175 let nested_type = self.get_field_type(&selection.name, type_name, schema);
177 if let Some(nested_type) = nested_type {
178 self.validate_fields(&selection.selections, &nested_type, schema)?;
179 }
180 }
181 }
182
183 Ok(())
184 }
185
186 fn get_field_type(
188 &self,
189 field_name: &str,
190 parent_type: &str,
191 schema: &GraphQLSchema,
192 ) -> Option<String> {
193 if let Some(type_def) = schema.get_type(parent_type) {
195 if let Some(field) = type_def.get_field(field_name) {
196 return Some(self.extract_type_name(&field.graphql_type));
197 }
198 }
199
200 for rel in schema.get_relationships_for(parent_type) {
202 if rel.field_name == field_name {
203 return Some(rel.to_type.clone());
204 }
205 }
206
207 Some(super::to_pascal_case(field_name))
209 }
210
211 fn extract_type_name(&self, field_type: &super::introspector::FieldType) -> String {
213 use super::introspector::FieldType;
214
215 match field_type {
216 FieldType::Scalar(s) => s.to_sdl().to_string(),
217 FieldType::Object(name) => name.clone(),
218 FieldType::List(inner) => self.extract_type_name(inner),
219 FieldType::NonNull(inner) => self.extract_type_name(inner),
220 }
221 }
222}
223
224#[derive(Debug, Clone, Default)]
226pub struct ComplexityResult {
227 pub total: u32,
229 pub max_depth: u32,
231 pub alias_count: u32,
233 pub field_count: u32,
235}
236
237impl ComplexityResult {
238 pub fn is_within_limits(&self, config: &GraphQLConfig) -> bool {
240 self.total <= config.limits.max_complexity
241 && self.max_depth <= config.limits.max_depth
242 && self.alias_count <= config.limits.max_aliases
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct ValidationError {
249 pub message: String,
251 pub locations: Vec<(u32, u32)>,
253 pub rule: ValidationRule,
255}
256
257impl ValidationError {
258 pub fn new(message: impl Into<String>, rule: ValidationRule) -> Self {
260 Self {
261 message: message.into(),
262 locations: Vec::new(),
263 rule,
264 }
265 }
266
267 pub fn at(mut self, line: u32, column: u32) -> Self {
269 self.locations.push((line, column));
270 self
271 }
272}
273
274impl std::fmt::Display for ValidationError {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 write!(f, "{}", self.message)
277 }
278}
279
280impl std::error::Error for ValidationError {}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
284pub enum ValidationRule {
285 UnknownField,
287 UnknownType,
289 UnknownArgument,
291 MissingArgument,
293 InvalidArgumentType,
295 QueryTooComplex,
297 QueryTooDeep,
299 TooManyAliases,
301 DuplicateField,
303 FragmentCycle,
305 UnknownFragment,
307 InvalidFragmentSpread,
309}
310
311pub trait RuleValidator: Send + Sync {
313 fn validate(
315 &self,
316 document: &ParsedDocument,
317 schema: &GraphQLSchema,
318 ) -> Result<(), ValidationError>;
319
320 fn rule(&self) -> ValidationRule;
322}
323
324pub struct UnknownFieldValidator;
326
327impl RuleValidator for UnknownFieldValidator {
328 fn validate(
329 &self,
330 document: &ParsedDocument,
331 schema: &GraphQLSchema,
332 ) -> Result<(), ValidationError> {
333 fn check_selections(
334 selections: &[ParsedSelection],
335 type_name: &str,
336 schema: &GraphQLSchema,
337 ) -> Result<(), ValidationError> {
338 for selection in selections {
339 if selection.name.starts_with("__") {
340 continue;
341 }
342
343 if type_name != "Query" && type_name != "Mutation" {
344 if let Some(type_def) = schema.get_type(type_name) {
345 if type_def.get_field(&selection.name).is_none() {
346 return Err(ValidationError::new(
347 format!(
348 "Unknown field '{}' on type '{}'",
349 selection.name, type_name
350 ),
351 ValidationRule::UnknownField,
352 ));
353 }
354 }
355 }
356
357 if !selection.selections.is_empty() {
359 check_selections(&selection.selections, &selection.name, schema)?;
360 }
361 }
362 Ok(())
363 }
364
365 check_selections(&document.selections, "Query", schema)
366 }
367
368 fn rule(&self) -> ValidationRule {
369 ValidationRule::UnknownField
370 }
371}
372
373pub struct DepthValidator {
375 max_depth: u32,
376}
377
378impl DepthValidator {
379 pub fn new(max_depth: u32) -> Self {
380 Self { max_depth }
381 }
382}
383
384impl RuleValidator for DepthValidator {
385 fn validate(
386 &self,
387 document: &ParsedDocument,
388 _schema: &GraphQLSchema,
389 ) -> Result<(), ValidationError> {
390 fn check_depth(
391 selections: &[ParsedSelection],
392 current_depth: u32,
393 max_depth: u32,
394 ) -> Result<(), ValidationError> {
395 if current_depth > max_depth {
396 return Err(ValidationError::new(
397 format!(
398 "Query depth {} exceeds maximum {}",
399 current_depth, max_depth
400 ),
401 ValidationRule::QueryTooDeep,
402 ));
403 }
404
405 for selection in selections {
406 check_depth(&selection.selections, current_depth + 1, max_depth)?;
407 }
408
409 Ok(())
410 }
411
412 check_depth(&document.selections, 1, self.max_depth)
413 }
414
415 fn rule(&self) -> ValidationRule {
416 ValidationRule::QueryTooDeep
417 }
418}
419
420pub struct AliasValidator {
422 max_aliases: u32,
423}
424
425impl AliasValidator {
426 pub fn new(max_aliases: u32) -> Self {
427 Self { max_aliases }
428 }
429}
430
431impl RuleValidator for AliasValidator {
432 fn validate(
433 &self,
434 document: &ParsedDocument,
435 _schema: &GraphQLSchema,
436 ) -> Result<(), ValidationError> {
437 fn count_aliases(selections: &[ParsedSelection]) -> u32 {
438 let mut count = 0;
439 for selection in selections {
440 if selection.alias.is_some() {
441 count += 1;
442 }
443 count += count_aliases(&selection.selections);
444 }
445 count
446 }
447
448 let alias_count = count_aliases(&document.selections);
449 if alias_count > self.max_aliases {
450 return Err(ValidationError::new(
451 format!(
452 "Query has {} aliases, maximum is {}",
453 alias_count, self.max_aliases
454 ),
455 ValidationRule::TooManyAliases,
456 ));
457 }
458
459 Ok(())
460 }
461
462 fn rule(&self) -> ValidationRule {
463 ValidationRule::TooManyAliases
464 }
465}
466
467#[derive(Default)]
469pub struct ValidationPipeline {
470 validators: Vec<Box<dyn RuleValidator>>,
471}
472
473impl ValidationPipeline {
474 pub fn new() -> Self {
476 Self::default()
477 }
478
479 #[allow(clippy::should_implement_trait)]
481 pub fn add(mut self, validator: impl RuleValidator + 'static) -> Self {
482 self.validators.push(Box::new(validator));
483 self
484 }
485
486 pub fn default_pipeline(config: &GraphQLConfig) -> Self {
488 Self::new()
489 .add(UnknownFieldValidator)
490 .add(DepthValidator::new(config.limits.max_depth))
491 .add(AliasValidator::new(config.limits.max_aliases))
492 }
493
494 pub fn validate(
496 &self,
497 document: &ParsedDocument,
498 schema: &GraphQLSchema,
499 ) -> Result<(), Vec<ValidationError>> {
500 let mut errors = Vec::new();
501
502 for validator in &self.validators {
503 if let Err(e) = validator.validate(document, schema) {
504 errors.push(e);
505 }
506 }
507
508 if errors.is_empty() {
509 Ok(())
510 } else {
511 Err(errors)
512 }
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::graphql::{introspector::*, OperationType};
520 use std::collections::HashMap;
521
522 fn create_test_schema() -> GraphQLSchema {
523 let mut schema = GraphQLSchema::new();
524
525 let mut user_type = GraphQLType::new("User");
526 user_type.add_field(GraphQLField::new(
527 "id",
528 FieldType::scalar(crate::graphql::GraphQLScalar::ID),
529 ));
530 user_type.add_field(GraphQLField::new(
531 "name",
532 FieldType::scalar(crate::graphql::GraphQLScalar::String),
533 ));
534 schema.add_type(user_type);
535
536 schema
537 }
538
539 fn create_test_document(selections: Vec<ParsedSelection>) -> ParsedDocument {
540 ParsedDocument {
541 operation_type: OperationType::Query,
542 operation_name: None,
543 selections,
544 variable_definitions: Vec::new(),
545 fragments: HashMap::new(),
546 }
547 }
548
549 #[test]
550 fn test_complexity_calculation() {
551 let config = Arc::new(GraphQLConfig::default());
552 let validator = QueryValidator::new(config);
553 let schema = create_test_schema();
554
555 let document = create_test_document(vec![ParsedSelection {
556 name: "users".to_string(),
557 alias: None,
558 arguments: HashMap::new(),
559 selections: vec![
560 ParsedSelection {
561 name: "id".to_string(),
562 alias: None,
563 arguments: HashMap::new(),
564 selections: vec![],
565 directives: vec![],
566 },
567 ParsedSelection {
568 name: "name".to_string(),
569 alias: None,
570 arguments: HashMap::new(),
571 selections: vec![],
572 directives: vec![],
573 },
574 ],
575 directives: vec![],
576 }]);
577
578 let result = validator.calculate_complexity(&document, &schema).unwrap();
579
580 assert_eq!(result.field_count, 3); assert_eq!(result.max_depth, 2);
582 assert_eq!(result.alias_count, 0);
583 }
584
585 #[test]
586 fn test_complexity_with_limit() {
587 let config = Arc::new(GraphQLConfig::default());
588 let validator = QueryValidator::new(config);
589 let schema = create_test_schema();
590
591 let mut args = HashMap::new();
592 args.insert("limit".to_string(), serde_json::json!(10));
593
594 let document = create_test_document(vec![ParsedSelection {
595 name: "users".to_string(),
596 alias: None,
597 arguments: args,
598 selections: vec![],
599 directives: vec![],
600 }]);
601
602 let result = validator.calculate_complexity(&document, &schema).unwrap();
603
604 assert_eq!(result.total, 10);
606 }
607
608 #[test]
609 fn test_alias_counting() {
610 let config = Arc::new(GraphQLConfig::default());
611 let validator = QueryValidator::new(config);
612 let schema = create_test_schema();
613
614 let document = create_test_document(vec![ParsedSelection {
615 name: "users".to_string(),
616 alias: Some("allUsers".to_string()),
617 arguments: HashMap::new(),
618 selections: vec![ParsedSelection {
619 name: "id".to_string(),
620 alias: Some("userId".to_string()),
621 arguments: HashMap::new(),
622 selections: vec![],
623 directives: vec![],
624 }],
625 directives: vec![],
626 }]);
627
628 let result = validator.calculate_complexity(&document, &schema).unwrap();
629
630 assert_eq!(result.alias_count, 2);
631 }
632
633 #[test]
634 fn test_depth_validator() {
635 let validator = DepthValidator::new(2);
636 let schema = create_test_schema();
637
638 let shallow = create_test_document(vec![ParsedSelection {
640 name: "users".to_string(),
641 alias: None,
642 arguments: HashMap::new(),
643 selections: vec![],
644 directives: vec![],
645 }]);
646 assert!(validator.validate(&shallow, &schema).is_ok());
647
648 let deep = create_test_document(vec![ParsedSelection {
650 name: "users".to_string(),
651 alias: None,
652 arguments: HashMap::new(),
653 selections: vec![ParsedSelection {
654 name: "posts".to_string(),
655 alias: None,
656 arguments: HashMap::new(),
657 selections: vec![ParsedSelection {
658 name: "comments".to_string(),
659 alias: None,
660 arguments: HashMap::new(),
661 selections: vec![],
662 directives: vec![],
663 }],
664 directives: vec![],
665 }],
666 directives: vec![],
667 }]);
668 assert!(validator.validate(&deep, &schema).is_err());
669 }
670
671 #[test]
672 fn test_alias_validator() {
673 let validator = AliasValidator::new(2);
674 let schema = create_test_schema();
675
676 let within_limit = create_test_document(vec![ParsedSelection {
678 name: "users".to_string(),
679 alias: Some("a1".to_string()),
680 arguments: HashMap::new(),
681 selections: vec![ParsedSelection {
682 name: "id".to_string(),
683 alias: Some("a2".to_string()),
684 arguments: HashMap::new(),
685 selections: vec![],
686 directives: vec![],
687 }],
688 directives: vec![],
689 }]);
690 assert!(validator.validate(&within_limit, &schema).is_ok());
691
692 let exceeds_limit = create_test_document(vec![ParsedSelection {
694 name: "users".to_string(),
695 alias: Some("a1".to_string()),
696 arguments: HashMap::new(),
697 selections: vec![
698 ParsedSelection {
699 name: "id".to_string(),
700 alias: Some("a2".to_string()),
701 arguments: HashMap::new(),
702 selections: vec![],
703 directives: vec![],
704 },
705 ParsedSelection {
706 name: "name".to_string(),
707 alias: Some("a3".to_string()),
708 arguments: HashMap::new(),
709 selections: vec![],
710 directives: vec![],
711 },
712 ],
713 directives: vec![],
714 }]);
715 assert!(validator.validate(&exceeds_limit, &schema).is_err());
716 }
717
718 #[test]
719 fn test_validation_pipeline() {
720 let config = GraphQLConfig::default();
721 let pipeline = ValidationPipeline::default_pipeline(&config);
722 let schema = create_test_schema();
723
724 let document = create_test_document(vec![ParsedSelection {
725 name: "users".to_string(),
726 alias: None,
727 arguments: HashMap::new(),
728 selections: vec![],
729 directives: vec![],
730 }]);
731
732 assert!(pipeline.validate(&document, &schema).is_ok());
733 }
734
735 #[test]
736 fn test_complexity_result_within_limits() {
737 let config = GraphQLConfig::default();
738
739 let within = ComplexityResult {
740 total: 100,
741 max_depth: 5,
742 alias_count: 2,
743 field_count: 10,
744 };
745 assert!(within.is_within_limits(&config));
746
747 let exceeds = ComplexityResult {
748 total: 10000,
749 max_depth: 5,
750 alias_count: 2,
751 field_count: 10,
752 };
753 assert!(!exceeds.is_within_limits(&config));
754 }
755
756 #[test]
757 fn test_validation_error() {
758 let err = ValidationError::new("Test error", ValidationRule::UnknownField)
759 .at(1, 10)
760 .at(2, 5);
761
762 assert_eq!(err.message, "Test error");
763 assert_eq!(err.locations.len(), 2);
764 assert_eq!(err.rule, ValidationRule::UnknownField);
765 }
766}