use async_graphql::{
ServerError,
ServerResult,
ValidationResult,
Variables,
extensions::{
Extension,
ExtensionContext,
ExtensionFactory,
NextParseQuery,
NextValidation,
},
parser::types::ExecutableDocument,
};
use recursion_finder::RecursionFinder;
use std::sync::{
Arc,
Mutex,
};
use visitor::{
RuleError,
VisitorContext,
visit,
};
mod recursion_finder;
mod visitor;
pub(crate) struct ValidationExtension {
recursion_limit: usize,
}
impl ValidationExtension {
pub fn new(recursion_limit: usize) -> Self {
Self { recursion_limit }
}
}
impl ExtensionFactory for ValidationExtension {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(ValidationInner::new(self.recursion_limit))
}
}
struct ValidationInner {
recursion_limit: usize,
errors: Mutex<Vec<RuleError>>,
}
impl ValidationInner {
fn new(recursion_limit: usize) -> Self {
Self {
recursion_limit,
errors: Default::default(),
}
}
}
#[async_trait::async_trait]
impl Extension for ValidationInner {
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
next: NextParseQuery<'_>,
) -> ServerResult<ExecutableDocument> {
let result = next.run(ctx, query, variables).await?;
let registry = &ctx.schema_env.registry;
let mut visitor = VisitorContext::new(registry, &result, Some(variables));
visit(
&mut RecursionFinder::new(self.recursion_limit),
&mut visitor,
&result,
);
let errors = visitor.errors;
if !errors.is_empty() {
let mut store = self
.errors
.lock()
.expect("Only one instance owns `ValidationInner`; qed");
store.extend(errors);
}
Ok(result)
}
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextValidation<'_>,
) -> async_graphql::Result<ValidationResult, Vec<ServerError>> {
{
let mut errors = self
.errors
.lock()
.expect("Only one instance owns `ValidationInner`; qed");
if !errors.is_empty() {
return Err(errors.drain(..).map(Into::into).collect())
}
}
next.run(ctx).await
}
}