use std::collections::HashMap;
use std::collections::HashSet;
use apollo_compiler::ExecutableDocument;
use apollo_compiler::Name;
use apollo_compiler::executable;
use serde::Deserialize;
use serde::Serialize;
use crate::Configuration;
#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
pub(crate) struct OperationLimits<T> {
pub(crate) depth: T,
pub(crate) height: T,
pub(crate) root_fields: T,
pub(crate) aliases: T,
}
impl<A> OperationLimits<A> {
fn map<B>(self, mut f: impl FnMut(A) -> B) -> OperationLimits<B> {
OperationLimits {
depth: f(self.depth),
height: f(self.height),
root_fields: f(self.root_fields),
aliases: f(self.aliases),
}
}
fn combine<B, C>(
self,
other: OperationLimits<B>,
mut f: impl FnMut(&'static str, A, B) -> C,
) -> OperationLimits<C> {
OperationLimits {
depth: f("depth", self.depth, other.depth),
height: f("height", self.height, other.height),
root_fields: f("root_fields", self.root_fields, other.root_fields),
aliases: f("aliases", self.aliases, other.aliases),
}
}
}
impl OperationLimits<bool> {
fn any(&self) -> bool {
let Self {
depth,
height,
root_fields,
aliases,
} = *self;
depth || height || root_fields || aliases
}
}
pub(crate) fn check(
query_metrics_in: &mut OperationLimits<u32>,
configuration: &Configuration,
query: &str,
document: &ExecutableDocument,
operation_name: Option<&str>,
) -> Result<(), OperationLimits<bool>> {
let config_limits = &configuration.limits;
let max = OperationLimits {
depth: config_limits.max_depth,
height: config_limits.max_height,
root_fields: config_limits.max_root_fields,
aliases: config_limits.max_aliases,
};
let Ok(operation) = document.operations.get(operation_name) else {
return Ok(());
};
let mut fragment_cache = HashMap::new();
let measured = count(document, &mut fragment_cache, &operation.selection_set);
*query_metrics_in = measured;
if !max.map(|limit| limit.is_some()).any() {
return Ok(());
}
let exceeded = max.combine(measured, |_, config, measured| {
if let Some(limit) = config {
measured > limit
} else {
false
}
});
if exceeded.any() {
let mut messages = Vec::new();
max.combine(measured, |ident, max, measured| {
if let Some(max) = max {
if measured > max {
messages.push(format!("{ident}: {measured}, max_{ident}: {max}"))
}
}
});
let message = messages.join(", ");
tracing::warn!(
"request exceeded complexity limits: {message}, \
query: {query:?}, operation name: {operation_name:?}"
);
if !config_limits.warn_only {
return Err(exceeded);
}
}
Ok(())
}
enum Computation<T> {
InProgress,
Done(T),
}
fn count<'a>(
document: &'a executable::ExecutableDocument,
fragment_cache: &mut HashMap<&'a Name, Computation<OperationLimits<u32>>>,
selection_set: &'a executable::SelectionSet,
) -> OperationLimits<u32> {
let mut counts = OperationLimits {
depth: 0,
height: 0,
root_fields: 0,
aliases: 0,
};
let mut fields_seen = HashSet::new();
for selection in &selection_set.selections {
match selection {
executable::Selection::Field(field) => {
let nested = count(document, fragment_cache, &field.selection_set);
counts.depth = counts.depth.max(nested.depth.saturating_add(1));
counts.height = counts.height.saturating_add(nested.height);
counts.aliases = counts.aliases.saturating_add(nested.aliases);
let used_name = if let Some(alias) = &field.alias {
counts.aliases = counts.aliases.saturating_add(1);
alias
} else {
&field.name
};
let not_seen_before = fields_seen.insert(used_name);
if not_seen_before {
counts.height = counts.height.saturating_add(1);
counts.root_fields = counts.root_fields.saturating_add(1);
}
}
executable::Selection::InlineFragment(fragment) => {
let nested = count(document, fragment_cache, &fragment.selection_set);
counts.depth = counts.depth.max(nested.depth);
counts.height = counts.height.saturating_add(nested.height);
counts.aliases = counts.aliases.saturating_add(nested.aliases);
}
executable::Selection::FragmentSpread(fragment) => {
let name = &fragment.fragment_name;
let nested;
match fragment_cache.get(name) {
None => {
if let Some(definition) = document.fragments.get(name) {
fragment_cache.insert(name, Computation::InProgress);
nested = count(document, fragment_cache, &definition.selection_set);
fragment_cache.insert(name, Computation::Done(nested));
} else {
continue;
}
}
Some(Computation::InProgress) => {
continue;
}
Some(Computation::Done(cached)) => nested = *cached,
}
counts.depth = counts.depth.max(nested.depth);
counts.height = counts.height.saturating_add(nested.height);
counts.aliases = counts.aliases.saturating_add(nested.aliases);
}
}
}
counts
}