1use crate::ast::{
7 Definition, Document, Field, OperationDefinition, Selection, SelectionSet, Value,
8 VariableDefinition,
9};
10use crate::types::{GraphQLType, Schema};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::time::Duration;
14
15#[derive(Debug, Clone)]
17pub struct ValidationConfig {
18 pub max_depth: usize,
20 pub max_complexity: usize,
22 pub max_aliases: usize,
24 pub max_root_fields: usize,
26 pub query_timeout: Option<Duration>,
28 pub disable_introspection: bool,
30 pub max_fragments: usize,
32 pub allowed_operations: Option<HashSet<String>>,
34 pub forbidden_fields: HashSet<String>,
36 pub enable_cost_analysis: bool,
38}
39
40impl Default for ValidationConfig {
41 fn default() -> Self {
42 Self {
43 max_depth: 10,
44 max_complexity: 1000,
45 max_aliases: 50,
46 max_root_fields: 20,
47 query_timeout: Some(Duration::from_secs(30)),
48 disable_introspection: false,
49 max_fragments: 50,
50 allowed_operations: None,
51 forbidden_fields: HashSet::new(),
52 enable_cost_analysis: true,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ValidationResult {
60 pub is_valid: bool,
61 pub errors: Vec<ValidationError>,
62 pub warnings: Vec<ValidationWarning>,
63 pub complexity_score: usize,
64 pub depth: usize,
65}
66
67impl Default for ValidationResult {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl ValidationResult {
74 pub fn new() -> Self {
75 Self {
76 is_valid: true,
77 errors: Vec::new(),
78 warnings: Vec::new(),
79 complexity_score: 0,
80 depth: 0,
81 }
82 }
83
84 pub fn with_error(mut self, error: ValidationError) -> Self {
85 self.is_valid = false;
86 self.errors.push(error);
87 self
88 }
89
90 pub fn with_warning(mut self, warning: ValidationWarning) -> Self {
91 self.warnings.push(warning);
92 self
93 }
94
95 pub fn with_complexity(mut self, complexity: usize) -> Self {
96 self.complexity_score = complexity;
97 self
98 }
99
100 pub fn with_depth(mut self, depth: usize) -> Self {
101 self.depth = depth;
102 self
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ValidationError {
109 pub message: String,
110 pub path: Vec<String>,
111 pub rule: ValidationRule,
112}
113
114impl ValidationError {
115 pub fn new(message: String, rule: ValidationRule) -> Self {
116 Self {
117 message,
118 path: Vec::new(),
119 rule,
120 }
121 }
122
123 pub fn with_path(mut self, path: Vec<String>) -> Self {
124 self.path = path;
125 self
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct ValidationWarning {
132 pub message: String,
133 pub suggestion: Option<String>,
134}
135
136impl ValidationWarning {
137 pub fn new(message: String) -> Self {
138 Self {
139 message,
140 suggestion: None,
141 }
142 }
143
144 pub fn with_suggestion(mut self, suggestion: String) -> Self {
145 self.suggestion = Some(suggestion);
146 self
147 }
148}
149
150#[derive(Debug, Clone, PartialEq)]
152pub enum ValidationRule {
153 MaxDepth,
154 MaxComplexity,
155 MaxAliases,
156 MaxRootFields,
157 MaxFragments,
158 FieldValidation,
159 TypeValidation,
160 VariableValidation,
161 FragmentValidation,
162 IntrospectionDisabled,
163 OperationNotAllowed,
164 ForbiddenField,
165 CostAnalysis,
166}
167
168pub struct QueryValidator {
170 config: ValidationConfig,
171 schema: Schema,
172}
173
174impl QueryValidator {
175 pub fn new(config: ValidationConfig, schema: Schema) -> Self {
176 Self { config, schema }
177 }
178
179 pub fn validate(&self, document: &Document) -> Result<ValidationResult> {
181 let mut result = ValidationResult::new();
182 let mut validation_context = ValidationContext::new(&self.schema);
183
184 for definition in &document.definitions {
186 if let Definition::Fragment(fragment) = definition {
187 validation_context.add_fragment(fragment.name.clone(), fragment.clone());
188 }
189 }
190
191 for definition in &document.definitions {
193 if let Definition::Operation(operation) = definition {
194 result = self.validate_operation(operation, &validation_context, result)?;
195 }
196 }
197
198 result = self.validate_fragments(&validation_context, result)?;
200 result = self.validate_global_limits(document, result)?;
201
202 Ok(result)
203 }
204
205 fn validate_operation(
206 &self,
207 operation: &OperationDefinition,
208 context: &ValidationContext,
209 mut result: ValidationResult,
210 ) -> Result<ValidationResult> {
211 if let Some(ref allowed_ops) = self.config.allowed_operations {
213 if let Some(ref op_name) = operation.name {
214 if !allowed_ops.contains(op_name) {
215 return Ok(result.with_error(ValidationError::new(
216 format!("Operation '{op_name}' is not allowed"),
217 ValidationRule::OperationNotAllowed,
218 )));
219 }
220 }
221 }
222
223 result = self.validate_variables(&operation.variable_definitions, context, result)?;
225
226 let root_type_name = match operation.operation_type {
228 crate::ast::OperationType::Query => self
229 .schema
230 .query_type
231 .as_ref()
232 .ok_or_else(|| anyhow!("Schema has no query type"))?,
233 crate::ast::OperationType::Mutation => self
234 .schema
235 .mutation_type
236 .as_ref()
237 .ok_or_else(|| anyhow!("Schema has no mutation type"))?,
238 crate::ast::OperationType::Subscription => self
239 .schema
240 .subscription_type
241 .as_ref()
242 .ok_or_else(|| anyhow!("Schema has no subscription type"))?,
243 };
244
245 let (depth, complexity) = self.validate_selection_set(
247 &operation.selection_set,
248 root_type_name,
249 context,
250 0,
251 Vec::new(),
252 )?;
253
254 result.depth = depth.max(result.depth);
255 result.complexity_score += complexity;
256
257 if result.depth > self.config.max_depth {
259 let current_depth = result.depth;
260 result = result.with_error(ValidationError::new(
261 format!(
262 "Query depth {} exceeds maximum allowed depth {}",
263 current_depth, self.config.max_depth
264 ),
265 ValidationRule::MaxDepth,
266 ));
267 }
268
269 if result.complexity_score > self.config.max_complexity {
271 let current_complexity = result.complexity_score;
272 result = result.with_error(ValidationError::new(
273 format!(
274 "Query complexity {} exceeds maximum allowed complexity {}",
275 current_complexity, self.config.max_complexity
276 ),
277 ValidationRule::MaxComplexity,
278 ));
279 }
280
281 Ok(result)
282 }
283
284 fn validate_selection_set(
285 &self,
286 selection_set: &SelectionSet,
287 parent_type_name: &str,
288 context: &ValidationContext,
289 current_depth: usize,
290 path: Vec<String>,
291 ) -> Result<(usize, usize)> {
292 let mut max_depth = current_depth;
293 let mut total_complexity = 0;
294 let mut alias_count = 0;
295
296 let parent_type = self
297 .schema
298 .get_type(parent_type_name)
299 .ok_or_else(|| anyhow!("Type '{}' not found in schema", parent_type_name))?;
300
301 for selection in &selection_set.selections {
302 match selection {
303 Selection::Field(field) => {
304 if field.alias.is_some() {
306 alias_count += 1;
307 }
308
309 if self.config.forbidden_fields.contains(&field.name) {
311 return Err(anyhow!("Field '{}' is forbidden", field.name));
312 }
313
314 if self.config.disable_introspection && field.name.starts_with("__") {
316 return Err(anyhow!("Introspection is disabled"));
317 }
318
319 let field_type = self.get_field_type(parent_type, &field.name)?;
321
322 let mut field_path = path.clone();
323 field_path.push(field.alias.as_ref().unwrap_or(&field.name).clone());
324
325 let field_complexity = self.calculate_field_complexity(field);
327 total_complexity += field_complexity;
328
329 if let Some(ref nested_selection_set) = field.selection_set {
331 let inner_type_name = self.get_inner_type_name(field_type);
332 let (nested_depth, nested_complexity) = self.validate_selection_set(
333 nested_selection_set,
334 &inner_type_name,
335 context,
336 current_depth + 1,
337 field_path,
338 )?;
339 max_depth = max_depth.max(nested_depth);
340 total_complexity += nested_complexity;
341 }
342 }
343 Selection::InlineFragment(inline_fragment) => {
344 let fragment_type =
345 if let Some(ref type_condition) = inline_fragment.type_condition {
346 type_condition
347 } else {
348 parent_type_name
349 };
350
351 let (nested_depth, nested_complexity) = self.validate_selection_set(
352 &inline_fragment.selection_set,
353 fragment_type,
354 context,
355 current_depth,
356 path.clone(),
357 )?;
358 max_depth = max_depth.max(nested_depth);
359 total_complexity += nested_complexity;
360 }
361 Selection::FragmentSpread(fragment_spread) => {
362 if let Some(fragment_def) = context.get_fragment(&fragment_spread.fragment_name)
363 {
364 let (nested_depth, nested_complexity) = self.validate_selection_set(
365 &fragment_def.selection_set,
366 &fragment_def.type_condition,
367 context,
368 current_depth,
369 path.clone(),
370 )?;
371 max_depth = max_depth.max(nested_depth);
372 total_complexity += nested_complexity;
373 } else {
374 return Err(anyhow!(
375 "Fragment '{}' not found",
376 fragment_spread.fragment_name
377 ));
378 }
379 }
380 }
381 }
382
383 if alias_count > self.config.max_aliases {
385 return Err(anyhow!(
386 "Too many aliases: {} exceeds limit {}",
387 alias_count,
388 self.config.max_aliases
389 ));
390 }
391
392 Ok((max_depth, total_complexity))
393 }
394
395 fn validate_variables(
396 &self,
397 variable_definitions: &[VariableDefinition],
398 _context: &ValidationContext,
399 mut result: ValidationResult,
400 ) -> Result<ValidationResult> {
401 for var_def in variable_definitions {
402 if !self.type_exists(&var_def.type_) {
404 result = result.with_error(ValidationError::new(
405 format!(
406 "Variable type '{}' not found in schema",
407 var_def.type_.name()
408 ),
409 ValidationRule::VariableValidation,
410 ));
411 }
412
413 if let Some(ref default_value) = var_def.default_value {
415 if !self.is_value_compatible_with_type(default_value, &var_def.type_) {
416 result = result.with_error(ValidationError::new(
417 format!(
418 "Default value for variable '{}' is not compatible with type '{}'",
419 var_def.variable.name,
420 var_def.type_.name()
421 ),
422 ValidationRule::VariableValidation,
423 ));
424 }
425 }
426 }
427
428 Ok(result)
429 }
430
431 fn validate_fragments(
432 &self,
433 context: &ValidationContext,
434 mut result: ValidationResult,
435 ) -> Result<ValidationResult> {
436 if context.fragments.len() > self.config.max_fragments {
437 result = result.with_error(ValidationError::new(
438 format!(
439 "Too many fragments: {} exceeds limit {}",
440 context.fragments.len(),
441 self.config.max_fragments
442 ),
443 ValidationRule::MaxFragments,
444 ));
445 }
446
447 for (fragment_name, fragment) in &context.fragments {
449 if !self.schema.types.contains_key(&fragment.type_condition) {
450 result = result.with_error(ValidationError::new(
451 format!(
452 "Fragment '{}' has unknown type condition '{}'",
453 fragment_name, fragment.type_condition
454 ),
455 ValidationRule::FragmentValidation,
456 ));
457 }
458 }
459
460 Ok(result)
461 }
462
463 fn validate_global_limits(
464 &self,
465 document: &Document,
466 mut result: ValidationResult,
467 ) -> Result<ValidationResult> {
468 let mut root_field_count = 0;
469
470 for definition in &document.definitions {
471 if let Definition::Operation(operation) = definition {
472 root_field_count += operation.selection_set.selections.len();
473 }
474 }
475
476 if root_field_count > self.config.max_root_fields {
477 result = result.with_error(ValidationError::new(
478 format!(
479 "Too many root fields: {} exceeds limit {}",
480 root_field_count, self.config.max_root_fields
481 ),
482 ValidationRule::MaxRootFields,
483 ));
484 }
485
486 Ok(result)
487 }
488
489 fn get_field_type<'a>(
490 &self,
491 parent_type: &'a GraphQLType,
492 field_name: &str,
493 ) -> Result<&'a GraphQLType> {
494 match parent_type {
495 GraphQLType::Object(obj) => obj
496 .fields
497 .get(field_name)
498 .map(|field| &field.field_type)
499 .ok_or_else(|| {
500 anyhow!(
501 "Field '{}' not found on object type '{}'",
502 field_name,
503 obj.name
504 )
505 }),
506 GraphQLType::Interface(iface) => iface
507 .fields
508 .get(field_name)
509 .map(|field| &field.field_type)
510 .ok_or_else(|| {
511 anyhow!(
512 "Field '{}' not found on interface type '{}'",
513 field_name,
514 iface.name
515 )
516 }),
517 _ => Err(anyhow!(
518 "Cannot select field '{}' on non-composite type",
519 field_name
520 )),
521 }
522 }
523
524 #[allow(clippy::only_used_in_recursion)]
525 fn get_inner_type_name(&self, graphql_type: &GraphQLType) -> String {
526 match graphql_type {
527 GraphQLType::NonNull(inner) => self.get_inner_type_name(inner),
528 GraphQLType::List(inner) => self.get_inner_type_name(inner),
529 _ => graphql_type.name().to_string(),
530 }
531 }
532
533 fn calculate_field_complexity(&self, field: &Field) -> usize {
534 if !self.config.enable_cost_analysis {
535 return 1;
536 }
537
538 let mut complexity = 1;
539
540 complexity += field.arguments.len();
542
543 if let Some(ref selection_set) = field.selection_set {
545 complexity += selection_set.selections.len();
546 }
547
548 match field.name.as_str() {
550 "sparql" => complexity *= 10, name if name.contains("search") => complexity *= 3,
552 name if name.contains("aggregate") => complexity *= 5,
553 _ => {}
554 }
555
556 complexity
557 }
558
559 fn type_exists(&self, ast_type: &crate::ast::Type) -> bool {
560 match ast_type {
561 crate::ast::Type::NamedType(name) => {
562 self.schema.types.contains_key(name)
563 || matches!(name.as_str(), "String" | "Int" | "Float" | "Boolean" | "ID")
564 }
565 crate::ast::Type::ListType(inner) => self.type_exists(inner),
566 crate::ast::Type::NonNullType(inner) => self.type_exists(inner),
567 }
568 }
569
570 #[allow(clippy::only_used_in_recursion)]
571 fn is_value_compatible_with_type(&self, value: &Value, ast_type: &crate::ast::Type) -> bool {
572 match (value, ast_type) {
573 (Value::NullValue, crate::ast::Type::NonNullType(_)) => false,
574 (Value::NullValue, _) => true,
575 (Value::StringValue(_), crate::ast::Type::NamedType(name)) => {
576 matches!(name.as_str(), "String" | "ID")
577 }
578 (Value::IntValue(_), crate::ast::Type::NamedType(name)) => {
579 matches!(name.as_str(), "Int" | "ID")
580 }
581 (Value::FloatValue(_), crate::ast::Type::NamedType(name)) => {
582 matches!(name.as_str(), "Float")
583 }
584 (Value::BooleanValue(_), crate::ast::Type::NamedType(name)) => {
585 matches!(name.as_str(), "Boolean")
586 }
587 (Value::ListValue(list), crate::ast::Type::ListType(inner_type)) => list
588 .iter()
589 .all(|item| self.is_value_compatible_with_type(item, inner_type)),
590 (_, crate::ast::Type::NonNullType(inner)) => {
591 self.is_value_compatible_with_type(value, inner)
592 }
593 _ => false,
594 }
595 }
596}
597
598struct ValidationContext {
600 #[allow(dead_code)]
601 schema: Schema,
602 fragments: HashMap<String, crate::ast::FragmentDefinition>,
603}
604
605impl ValidationContext {
606 fn new(schema: &Schema) -> Self {
607 Self {
608 schema: schema.clone(),
609 fragments: HashMap::new(),
610 }
611 }
612
613 fn add_fragment(&mut self, name: String, fragment: crate::ast::FragmentDefinition) {
614 self.fragments.insert(name, fragment);
615 }
616
617 fn get_fragment(&self, name: &str) -> Option<&crate::ast::FragmentDefinition> {
618 self.fragments.get(name)
619 }
620}
621
622#[derive(Debug, Clone)]
624pub struct RateLimitConfig {
625 pub max_queries_per_minute: usize,
627 pub max_complexity_per_minute: usize,
629 pub window_duration: Duration,
631}
632
633impl Default for RateLimitConfig {
634 fn default() -> Self {
635 Self {
636 max_queries_per_minute: 60,
637 max_complexity_per_minute: 10000,
638 window_duration: Duration::from_secs(60),
639 }
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646 use crate::types::{BuiltinScalars, FieldType, ObjectType};
647
648 fn create_test_schema() -> Schema {
649 let mut schema = Schema::new();
650
651 let query_type = ObjectType::new("Query".to_string())
652 .with_field(
653 "hello".to_string(),
654 FieldType::new(
655 "hello".to_string(),
656 GraphQLType::Scalar(BuiltinScalars::string()),
657 ),
658 )
659 .with_field(
660 "__schema".to_string(),
661 FieldType::new(
662 "__schema".to_string(),
663 GraphQLType::Scalar(BuiltinScalars::string()),
664 ),
665 );
666
667 schema.add_type(GraphQLType::Object(query_type));
668 schema.set_query_type("Query".to_string());
669
670 schema
671 }
672
673 #[test]
674 fn test_validation_config_default() {
675 let config = ValidationConfig::default();
676 assert_eq!(config.max_depth, 10);
677 assert_eq!(config.max_complexity, 1000);
678 assert!(!config.disable_introspection);
679 }
680
681 #[test]
682 fn test_validation_result_creation() {
683 let result = ValidationResult::new()
684 .with_error(ValidationError::new(
685 "Test error".to_string(),
686 ValidationRule::MaxDepth,
687 ))
688 .with_warning(ValidationWarning::new("Test warning".to_string()))
689 .with_complexity(100)
690 .with_depth(5);
691
692 assert!(!result.is_valid);
693 assert_eq!(result.errors.len(), 1);
694 assert_eq!(result.warnings.len(), 1);
695 assert_eq!(result.complexity_score, 100);
696 assert_eq!(result.depth, 5);
697 }
698
699 #[test]
700 fn test_query_validator_creation() {
701 let config = ValidationConfig::default();
702 let schema = create_test_schema();
703 let validator = QueryValidator::new(config, schema);
704
705 assert_eq!(validator.config.max_depth, 10);
707 }
708
709 #[test]
710 fn test_validation_error_with_path() {
711 let error = ValidationError::new("Test error".to_string(), ValidationRule::FieldValidation)
712 .with_path(vec![
713 "query".to_string(),
714 "user".to_string(),
715 "name".to_string(),
716 ]);
717
718 assert_eq!(error.message, "Test error");
719 assert_eq!(error.path, vec!["query", "user", "name"]);
720 assert_eq!(error.rule, ValidationRule::FieldValidation);
721 }
722
723 #[test]
724 fn test_validation_warning_with_suggestion() {
725 let warning = ValidationWarning::new("Performance warning".to_string())
726 .with_suggestion("Consider using pagination".to_string());
727
728 assert_eq!(warning.message, "Performance warning");
729 assert_eq!(
730 warning.suggestion,
731 Some("Consider using pagination".to_string())
732 );
733 }
734
735 #[test]
736 fn test_rate_limit_config_default() {
737 let config = RateLimitConfig::default();
738 assert_eq!(config.max_queries_per_minute, 60);
739 assert_eq!(config.max_complexity_per_minute, 10000);
740 assert_eq!(config.window_duration, Duration::from_secs(60));
741 }
742
743 #[test]
744 fn test_validation_context() {
745 let schema = create_test_schema();
746 let mut context = ValidationContext::new(&schema);
747
748 let fragment = crate::ast::FragmentDefinition {
749 name: "TestFragment".to_string(),
750 type_condition: "Query".to_string(),
751 selection_set: crate::ast::SelectionSet { selections: vec![] },
752 directives: vec![],
753 };
754
755 context.add_fragment("TestFragment".to_string(), fragment.clone());
756
757 let retrieved = context.get_fragment("TestFragment");
758 assert!(retrieved.is_some());
759 assert_eq!(retrieved.expect("should succeed").name, "TestFragment");
760 }
761}