use std::collections::HashMap;
use graphql_tools::{
ast::OperationVisitorContext,
static_graphql::query::Definition,
validation::{
rules::ValidationRule,
utils::{ValidationError, ValidationErrorContext},
},
};
use hive_router_config::limits::MaxDirectivesRuleConfig;
use crate::pipeline::validation::shared::{CountableNode, VisitedFragment};
pub struct MaxDirectivesRule {
pub config: MaxDirectivesRuleConfig,
}
impl ValidationRule for MaxDirectivesRule {
fn error_code<'a>(&self) -> &'a str {
"MAX_DIRECTIVES_EXCEEDED"
}
fn validate(
&self,
ctx: &mut OperationVisitorContext<'_>,
error_collector: &mut ValidationErrorContext,
) {
let mut visitor = MaxDirectivesVisitor {
config: &self.config,
visited_fragments: HashMap::with_capacity(ctx.known_fragments.len()),
ctx,
};
for definition in &ctx.operation.definitions {
let Definition::Operation(op) = definition else {
continue;
};
if let Err(err) = visitor.count_directives(op.into()) {
error_collector.report_error(err);
}
}
}
}
struct MaxDirectivesVisitor<'a, 'b> {
config: &'b MaxDirectivesRuleConfig,
visited_fragments: HashMap<&'a str, VisitedFragment>,
ctx: &'b OperationVisitorContext<'a>,
}
impl<'a> MaxDirectivesVisitor<'a, '_> {
fn check_limit(&self, count: usize) -> Result<usize, ValidationError> {
if count > self.config.n {
Err(ValidationError {
locations: vec![],
message: "Directives limit exceeded.".to_string(),
error_code: "MAX_DIRECTIVES_EXCEEDED",
})
} else {
Ok(count)
}
}
fn count_directives(
&mut self,
countable_node: CountableNode<'a>,
) -> Result<usize, ValidationError> {
let mut directive_count: usize = 0;
if let Some(directives) = countable_node.get_directives() {
directive_count = self.check_limit(directive_count + directives.len())?;
}
if let Some(selection_set) = countable_node.selection_set() {
for selection in &selection_set.items {
let countable_node: CountableNode<'a> = selection.into();
let child_directives = self.count_directives(countable_node)?;
directive_count = self.check_limit(directive_count + child_directives)?;
}
}
if let CountableNode::FragmentSpread(node) = countable_node {
let fragment_name = node.fragment_name.as_str();
match self.visited_fragments.get(fragment_name) {
Some(VisitedFragment::Counted(num)) => {
return self.check_limit(directive_count + num);
}
Some(VisitedFragment::Visiting) => return Ok(directive_count),
None => {}
}
self.visited_fragments
.insert(fragment_name, VisitedFragment::Visiting);
if let Some(fragment_def) = self.ctx.known_fragments.get(fragment_name) {
let countable_node: CountableNode<'a> = fragment_def.into();
let fragment_directive_count = self.count_directives(countable_node)?;
self.visited_fragments.insert(
fragment_name,
VisitedFragment::Counted(fragment_directive_count),
);
directive_count = self.check_limit(directive_count + fragment_directive_count)?;
}
}
Ok(directive_count)
}
}
#[cfg(test)]
mod tests {
use graphql_tools::parser::{parse_query, parse_schema};
use graphql_tools::validation::validate::{validate, ValidationPlan};
use hive_router_config::limits::MaxDirectivesRuleConfig;
use crate::pipeline::validation::max_directives_rule::MaxDirectivesRule;
const TYPE_DEFS: &'static str = r#"
type Book {
title: String
author: String
}
type Query {
books: [Book]
}
"#;
const QUERY: &'static str = r#"
query {
__typename @a @a @a @a
}
"#;
#[test]
fn works() {
let schema = parse_schema(TYPE_DEFS)
.expect("Failed to parse schema")
.into_static();
let query = parse_query(QUERY)
.expect("Failed to parse query")
.into_static();
let validation_plan = ValidationPlan::from(vec![Box::new(MaxDirectivesRule {
config: MaxDirectivesRuleConfig { n: 5 },
})]);
let errors = validate(&schema, &query, &validation_plan);
assert!(errors.is_empty());
}
#[test]
fn rejects_query_exceeding_max_directives() {
let schema = parse_schema(TYPE_DEFS)
.expect("Failed to parse schema")
.into_static();
let query = parse_query(QUERY)
.expect("Failed to parse query")
.into_static();
let validation_plan = ValidationPlan::from(vec![Box::new(MaxDirectivesRule {
config: MaxDirectivesRuleConfig { n: 3 },
})]);
let errors = validate(&schema, &query, &validation_plan);
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].message, "Directives limit exceeded.");
}
#[test]
fn works_on_fragment() {
let schema = parse_schema(TYPE_DEFS)
.expect("Failed to parse schema")
.into_static();
let query = parse_query(
r#"
query {
...DirectivesFragment
}
fragment DirectivesFragment on Query {
__typename @a @a @a @a
}
"#,
)
.expect("Failed to parse query")
.into_static();
let validation_plan = ValidationPlan::from(vec![Box::new(MaxDirectivesRule {
config: MaxDirectivesRuleConfig { n: 3 },
})]);
let errors = validate(&schema, &query, &validation_plan);
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].message, "Directives limit exceeded.");
}
#[test]
fn not_crash_on_recursive_fragment() {
let schema = parse_schema(TYPE_DEFS)
.expect("Failed to parse schema")
.into_static();
let query = parse_query(
r#"
query {
...A
}
fragment A on Query {
...B
}
fragment B on Query {
...A
}
"#,
)
.expect("Failed to parse query")
.into_static();
let validation_plan = ValidationPlan::from(vec![Box::new(MaxDirectivesRule {
config: MaxDirectivesRuleConfig { n: 5 },
})]);
let errors = validate(&schema, &query, &validation_plan);
assert!(errors.is_empty());
}
#[test]
fn count_directives_on_recursive_fragment_spreads() {
let schema = parse_schema(TYPE_DEFS)
.expect("Failed to parse schema")
.into_static();
let query = parse_query(
r#"
query {
...A
}
fragment A on Query {
...B @directive1 @directive2
}
fragment B on Query {
...A @directive3 @directive4
}
"#,
)
.expect("Failed to parse query")
.into_static();
let validation_plan = ValidationPlan::from(vec![Box::new(MaxDirectivesRule {
config: MaxDirectivesRuleConfig { n: 1 },
})]);
let errors = validate(&schema, &query, &validation_plan);
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].message, "Directives limit exceeded.");
}
}