1use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::{
9 GraphQLConfig, GraphQLSchema, GraphQLError, ErrorCode,
10 engine::ParsedDocument, engine::ParsedSelection,
11};
12
13#[derive(Debug)]
15pub struct QueryValidator {
16 config: Arc<GraphQLConfig>,
18}
19
20impl QueryValidator {
21 pub fn new(config: Arc<GraphQLConfig>) -> Self {
23 Self { config }
24 }
25
26 pub fn validate(
28 &self,
29 document: &ParsedDocument,
30 schema: &GraphQLSchema,
31 ) -> Result<(), GraphQLError> {
32 let complexity = self.calculate_complexity(document, schema)?;
34 if complexity.total > self.config.limits.max_complexity {
35 return Err(GraphQLError::new(
36 format!(
37 "Query complexity {} exceeds maximum allowed {}",
38 complexity.total, self.config.limits.max_complexity
39 ),
40 ErrorCode::QueryTooComplex,
41 ));
42 }
43
44 if complexity.max_depth > self.config.limits.max_depth {
46 return Err(GraphQLError::new(
47 format!(
48 "Query depth {} exceeds maximum allowed {}",
49 complexity.max_depth, self.config.limits.max_depth
50 ),
51 ErrorCode::QueryTooComplex,
52 ));
53 }
54
55 if complexity.alias_count > self.config.limits.max_aliases {
57 return Err(GraphQLError::new(
58 format!(
59 "Query has {} aliases, maximum allowed is {}",
60 complexity.alias_count, self.config.limits.max_aliases
61 ),
62 ErrorCode::QueryTooComplex,
63 ));
64 }
65
66 if document.selections.len() > self.config.limits.max_root_fields as usize {
68 return Err(GraphQLError::new(
69 format!(
70 "Query has {} root fields, maximum allowed is {}",
71 document.selections.len(), self.config.limits.max_root_fields
72 ),
73 ErrorCode::QueryTooComplex,
74 ));
75 }
76
77 self.validate_fields(&document.selections, "Query", schema)?;
79
80 Ok(())
81 }
82
83 pub fn calculate_complexity(
85 &self,
86 document: &ParsedDocument,
87 _schema: &GraphQLSchema,
88 ) -> Result<ComplexityResult, GraphQLError> {
89 let mut result = ComplexityResult::default();
90
91 for selection in &document.selections {
92 self.calculate_selection_complexity(selection, 1, &mut result)?;
93 }
94
95 Ok(result)
96 }
97
98 fn calculate_selection_complexity(
100 &self,
101 selection: &ParsedSelection,
102 depth: u32,
103 result: &mut ComplexityResult,
104 ) -> Result<(), GraphQLError> {
105 result.max_depth = result.max_depth.max(depth);
107
108 let mut field_cost = 1u32;
110
111 if selection.alias.is_some() {
113 result.alias_count += 1;
114 }
115
116 if let Some(limit) = selection.arguments.get("limit") {
118 if let Some(l) = limit.as_u64() {
119 field_cost = field_cost.saturating_mul(l.min(100) as u32);
120 }
121 }
122
123 result.total = result.total.saturating_add(field_cost);
125 result.field_count += 1;
126
127 for nested in &selection.selections {
129 self.calculate_selection_complexity(nested, depth + 1, result)?;
130 }
131
132 Ok(())
133 }
134
135 fn validate_fields(
137 &self,
138 selections: &[ParsedSelection],
139 type_name: &str,
140 schema: &GraphQLSchema,
141 ) -> Result<(), GraphQLError> {
142 let type_def = schema.get_type(type_name);
144
145 for selection in selections {
146 if selection.name.starts_with("__") {
148 continue;
149 }
150
151 if type_name != "Query" && type_name != "Mutation" {
153 if let Some(ref type_def) = type_def {
154 let field_exists = type_def.get_field(&selection.name).is_some();
155 let rel_exists = schema.get_relationships_for(type_name)
156 .iter()
157 .any(|r| r.field_name == selection.name);
158
159 if !field_exists && !rel_exists {
160 return Err(GraphQLError::validation_error(format!(
161 "Field '{}' does not exist on type '{}'",
162 selection.name, type_name
163 )));
164 }
165 }
166 }
167
168 if !selection.selections.is_empty() {
170 let nested_type = self.get_field_type(&selection.name, type_name, schema);
172 if let Some(nested_type) = nested_type {
173 self.validate_fields(&selection.selections, &nested_type, schema)?;
174 }
175 }
176 }
177
178 Ok(())
179 }
180
181 fn get_field_type(&self, field_name: &str, parent_type: &str, schema: &GraphQLSchema) -> Option<String> {
183 if let Some(type_def) = schema.get_type(parent_type) {
185 if let Some(field) = type_def.get_field(field_name) {
186 return Some(self.extract_type_name(&field.graphql_type));
187 }
188 }
189
190 for rel in schema.get_relationships_for(parent_type) {
192 if rel.field_name == field_name {
193 return Some(rel.to_type.clone());
194 }
195 }
196
197 Some(super::to_pascal_case(field_name))
199 }
200
201 fn extract_type_name(&self, field_type: &super::introspector::FieldType) -> String {
203 use super::introspector::FieldType;
204
205 match field_type {
206 FieldType::Scalar(s) => s.to_sdl().to_string(),
207 FieldType::Object(name) => name.clone(),
208 FieldType::List(inner) => self.extract_type_name(inner),
209 FieldType::NonNull(inner) => self.extract_type_name(inner),
210 }
211 }
212}
213
214#[derive(Debug, Clone, Default)]
216pub struct ComplexityResult {
217 pub total: u32,
219 pub max_depth: u32,
221 pub alias_count: u32,
223 pub field_count: u32,
225}
226
227impl ComplexityResult {
228 pub fn is_within_limits(&self, config: &GraphQLConfig) -> bool {
230 self.total <= config.limits.max_complexity
231 && self.max_depth <= config.limits.max_depth
232 && self.alias_count <= config.limits.max_aliases
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct ValidationError {
239 pub message: String,
241 pub locations: Vec<(u32, u32)>,
243 pub rule: ValidationRule,
245}
246
247impl ValidationError {
248 pub fn new(message: impl Into<String>, rule: ValidationRule) -> Self {
250 Self {
251 message: message.into(),
252 locations: Vec::new(),
253 rule,
254 }
255 }
256
257 pub fn at(mut self, line: u32, column: u32) -> Self {
259 self.locations.push((line, column));
260 self
261 }
262}
263
264impl std::fmt::Display for ValidationError {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 write!(f, "{}", self.message)
267 }
268}
269
270impl std::error::Error for ValidationError {}
271
272#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum ValidationRule {
275 UnknownField,
277 UnknownType,
279 UnknownArgument,
281 MissingArgument,
283 InvalidArgumentType,
285 QueryTooComplex,
287 QueryTooDeep,
289 TooManyAliases,
291 DuplicateField,
293 FragmentCycle,
295 UnknownFragment,
297 InvalidFragmentSpread,
299}
300
301pub trait RuleValidator: Send + Sync {
303 fn validate(
305 &self,
306 document: &ParsedDocument,
307 schema: &GraphQLSchema,
308 ) -> Result<(), ValidationError>;
309
310 fn rule(&self) -> ValidationRule;
312}
313
314pub struct UnknownFieldValidator;
316
317impl RuleValidator for UnknownFieldValidator {
318 fn validate(
319 &self,
320 document: &ParsedDocument,
321 schema: &GraphQLSchema,
322 ) -> Result<(), ValidationError> {
323 fn check_selections(
324 selections: &[ParsedSelection],
325 type_name: &str,
326 schema: &GraphQLSchema,
327 ) -> Result<(), ValidationError> {
328 for selection in selections {
329 if selection.name.starts_with("__") {
330 continue;
331 }
332
333 if type_name != "Query" && type_name != "Mutation" {
334 if let Some(type_def) = schema.get_type(type_name) {
335 if type_def.get_field(&selection.name).is_none() {
336 return Err(ValidationError::new(
337 format!("Unknown field '{}' on type '{}'", selection.name, type_name),
338 ValidationRule::UnknownField,
339 ));
340 }
341 }
342 }
343
344 if !selection.selections.is_empty() {
346 check_selections(&selection.selections, &selection.name, schema)?;
347 }
348 }
349 Ok(())
350 }
351
352 check_selections(&document.selections, "Query", schema)
353 }
354
355 fn rule(&self) -> ValidationRule {
356 ValidationRule::UnknownField
357 }
358}
359
360pub struct DepthValidator {
362 max_depth: u32,
363}
364
365impl DepthValidator {
366 pub fn new(max_depth: u32) -> Self {
367 Self { max_depth }
368 }
369}
370
371impl RuleValidator for DepthValidator {
372 fn validate(
373 &self,
374 document: &ParsedDocument,
375 _schema: &GraphQLSchema,
376 ) -> Result<(), ValidationError> {
377 fn check_depth(selections: &[ParsedSelection], current_depth: u32, max_depth: u32) -> Result<(), ValidationError> {
378 if current_depth > max_depth {
379 return Err(ValidationError::new(
380 format!("Query depth {} exceeds maximum {}", current_depth, max_depth),
381 ValidationRule::QueryTooDeep,
382 ));
383 }
384
385 for selection in selections {
386 check_depth(&selection.selections, current_depth + 1, max_depth)?;
387 }
388
389 Ok(())
390 }
391
392 check_depth(&document.selections, 1, self.max_depth)
393 }
394
395 fn rule(&self) -> ValidationRule {
396 ValidationRule::QueryTooDeep
397 }
398}
399
400pub struct AliasValidator {
402 max_aliases: u32,
403}
404
405impl AliasValidator {
406 pub fn new(max_aliases: u32) -> Self {
407 Self { max_aliases }
408 }
409}
410
411impl RuleValidator for AliasValidator {
412 fn validate(
413 &self,
414 document: &ParsedDocument,
415 _schema: &GraphQLSchema,
416 ) -> Result<(), ValidationError> {
417 fn count_aliases(selections: &[ParsedSelection]) -> u32 {
418 let mut count = 0;
419 for selection in selections {
420 if selection.alias.is_some() {
421 count += 1;
422 }
423 count += count_aliases(&selection.selections);
424 }
425 count
426 }
427
428 let alias_count = count_aliases(&document.selections);
429 if alias_count > self.max_aliases {
430 return Err(ValidationError::new(
431 format!("Query has {} aliases, maximum is {}", alias_count, self.max_aliases),
432 ValidationRule::TooManyAliases,
433 ));
434 }
435
436 Ok(())
437 }
438
439 fn rule(&self) -> ValidationRule {
440 ValidationRule::TooManyAliases
441 }
442}
443
444#[derive(Default)]
446pub struct ValidationPipeline {
447 validators: Vec<Box<dyn RuleValidator>>,
448}
449
450impl ValidationPipeline {
451 pub fn new() -> Self {
453 Self::default()
454 }
455
456 pub fn add(mut self, validator: impl RuleValidator + 'static) -> Self {
458 self.validators.push(Box::new(validator));
459 self
460 }
461
462 pub fn default_pipeline(config: &GraphQLConfig) -> Self {
464 Self::new()
465 .add(UnknownFieldValidator)
466 .add(DepthValidator::new(config.limits.max_depth))
467 .add(AliasValidator::new(config.limits.max_aliases))
468 }
469
470 pub fn validate(
472 &self,
473 document: &ParsedDocument,
474 schema: &GraphQLSchema,
475 ) -> Result<(), Vec<ValidationError>> {
476 let mut errors = Vec::new();
477
478 for validator in &self.validators {
479 if let Err(e) = validator.validate(document, schema) {
480 errors.push(e);
481 }
482 }
483
484 if errors.is_empty() {
485 Ok(())
486 } else {
487 Err(errors)
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::graphql::{OperationType, introspector::*};
496
497 fn create_test_schema() -> GraphQLSchema {
498 let mut schema = GraphQLSchema::new();
499
500 let mut user_type = GraphQLType::new("User");
501 user_type.add_field(GraphQLField::new("id", FieldType::scalar(crate::graphql::GraphQLScalar::ID)));
502 user_type.add_field(GraphQLField::new("name", FieldType::scalar(crate::graphql::GraphQLScalar::String)));
503 schema.add_type(user_type);
504
505 schema
506 }
507
508 fn create_test_document(selections: Vec<ParsedSelection>) -> ParsedDocument {
509 ParsedDocument {
510 operation_type: OperationType::Query,
511 operation_name: None,
512 selections,
513 variable_definitions: Vec::new(),
514 fragments: HashMap::new(),
515 }
516 }
517
518 #[test]
519 fn test_complexity_calculation() {
520 let config = Arc::new(GraphQLConfig::default());
521 let validator = QueryValidator::new(config);
522 let schema = create_test_schema();
523
524 let document = create_test_document(vec![
525 ParsedSelection {
526 name: "users".to_string(),
527 alias: None,
528 arguments: HashMap::new(),
529 selections: vec![
530 ParsedSelection {
531 name: "id".to_string(),
532 alias: None,
533 arguments: HashMap::new(),
534 selections: vec![],
535 directives: vec![],
536 },
537 ParsedSelection {
538 name: "name".to_string(),
539 alias: None,
540 arguments: HashMap::new(),
541 selections: vec![],
542 directives: vec![],
543 },
544 ],
545 directives: vec![],
546 },
547 ]);
548
549 let result = validator.calculate_complexity(&document, &schema).unwrap();
550
551 assert_eq!(result.field_count, 3); assert_eq!(result.max_depth, 2);
553 assert_eq!(result.alias_count, 0);
554 }
555
556 #[test]
557 fn test_complexity_with_limit() {
558 let config = Arc::new(GraphQLConfig::default());
559 let validator = QueryValidator::new(config);
560 let schema = create_test_schema();
561
562 let mut args = HashMap::new();
563 args.insert("limit".to_string(), serde_json::json!(10));
564
565 let document = create_test_document(vec![
566 ParsedSelection {
567 name: "users".to_string(),
568 alias: None,
569 arguments: args,
570 selections: vec![],
571 directives: vec![],
572 },
573 ]);
574
575 let result = validator.calculate_complexity(&document, &schema).unwrap();
576
577 assert_eq!(result.total, 10);
579 }
580
581 #[test]
582 fn test_alias_counting() {
583 let config = Arc::new(GraphQLConfig::default());
584 let validator = QueryValidator::new(config);
585 let schema = create_test_schema();
586
587 let document = create_test_document(vec![
588 ParsedSelection {
589 name: "users".to_string(),
590 alias: Some("allUsers".to_string()),
591 arguments: HashMap::new(),
592 selections: vec![
593 ParsedSelection {
594 name: "id".to_string(),
595 alias: Some("userId".to_string()),
596 arguments: HashMap::new(),
597 selections: vec![],
598 directives: vec![],
599 },
600 ],
601 directives: vec![],
602 },
603 ]);
604
605 let result = validator.calculate_complexity(&document, &schema).unwrap();
606
607 assert_eq!(result.alias_count, 2);
608 }
609
610 #[test]
611 fn test_depth_validator() {
612 let validator = DepthValidator::new(2);
613 let schema = create_test_schema();
614
615 let shallow = create_test_document(vec![
617 ParsedSelection {
618 name: "users".to_string(),
619 alias: None,
620 arguments: HashMap::new(),
621 selections: vec![],
622 directives: vec![],
623 },
624 ]);
625 assert!(validator.validate(&shallow, &schema).is_ok());
626
627 let deep = create_test_document(vec![
629 ParsedSelection {
630 name: "users".to_string(),
631 alias: None,
632 arguments: HashMap::new(),
633 selections: vec![
634 ParsedSelection {
635 name: "posts".to_string(),
636 alias: None,
637 arguments: HashMap::new(),
638 selections: vec![
639 ParsedSelection {
640 name: "comments".to_string(),
641 alias: None,
642 arguments: HashMap::new(),
643 selections: vec![],
644 directives: vec![],
645 },
646 ],
647 directives: vec![],
648 },
649 ],
650 directives: vec![],
651 },
652 ]);
653 assert!(validator.validate(&deep, &schema).is_err());
654 }
655
656 #[test]
657 fn test_alias_validator() {
658 let validator = AliasValidator::new(2);
659 let schema = create_test_schema();
660
661 let within_limit = create_test_document(vec![
663 ParsedSelection {
664 name: "users".to_string(),
665 alias: Some("a1".to_string()),
666 arguments: HashMap::new(),
667 selections: vec![
668 ParsedSelection {
669 name: "id".to_string(),
670 alias: Some("a2".to_string()),
671 arguments: HashMap::new(),
672 selections: vec![],
673 directives: vec![],
674 },
675 ],
676 directives: vec![],
677 },
678 ]);
679 assert!(validator.validate(&within_limit, &schema).is_ok());
680
681 let exceeds_limit = create_test_document(vec![
683 ParsedSelection {
684 name: "users".to_string(),
685 alias: Some("a1".to_string()),
686 arguments: HashMap::new(),
687 selections: vec![
688 ParsedSelection {
689 name: "id".to_string(),
690 alias: Some("a2".to_string()),
691 arguments: HashMap::new(),
692 selections: vec![],
693 directives: vec![],
694 },
695 ParsedSelection {
696 name: "name".to_string(),
697 alias: Some("a3".to_string()),
698 arguments: HashMap::new(),
699 selections: vec![],
700 directives: vec![],
701 },
702 ],
703 directives: vec![],
704 },
705 ]);
706 assert!(validator.validate(&exceeds_limit, &schema).is_err());
707 }
708
709 #[test]
710 fn test_validation_pipeline() {
711 let config = GraphQLConfig::default();
712 let pipeline = ValidationPipeline::default_pipeline(&config);
713 let schema = create_test_schema();
714
715 let document = create_test_document(vec![
716 ParsedSelection {
717 name: "users".to_string(),
718 alias: None,
719 arguments: HashMap::new(),
720 selections: vec![],
721 directives: vec![],
722 },
723 ]);
724
725 assert!(pipeline.validate(&document, &schema).is_ok());
726 }
727
728 #[test]
729 fn test_complexity_result_within_limits() {
730 let config = GraphQLConfig::default();
731
732 let within = ComplexityResult {
733 total: 100,
734 max_depth: 5,
735 alias_count: 2,
736 field_count: 10,
737 };
738 assert!(within.is_within_limits(&config));
739
740 let exceeds = ComplexityResult {
741 total: 10000,
742 max_depth: 5,
743 alias_count: 2,
744 field_count: 10,
745 };
746 assert!(!exceeds.is_within_limits(&config));
747 }
748
749 #[test]
750 fn test_validation_error() {
751 let err = ValidationError::new("Test error", ValidationRule::UnknownField)
752 .at(1, 10)
753 .at(2, 5);
754
755 assert_eq!(err.message, "Test error");
756 assert_eq!(err.locations.len(), 2);
757 assert_eq!(err.rule, ValidationRule::UnknownField);
758 }
759}