use std::fmt::Display;
use std::sync::Arc;
use apollo_compiler::Name;
use apollo_compiler::Node;
use apollo_compiler::ast::Directive;
use apollo_compiler::collections::IndexMap;
use apollo_compiler::executable::Value;
use indexmap::map::Entry;
use serde::Serialize;
use crate::bail;
use crate::error::FederationError;
use crate::operation::DirectiveList;
use crate::operation::Selection;
use crate::operation::SelectionMap;
use crate::operation::SelectionMapperReturn;
use crate::operation::SelectionSet;
use crate::query_graph::graph_path::operation::OpPathElement;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub(crate) enum ConditionKind {
Skip,
Include,
}
impl ConditionKind {
fn as_str(self) -> &'static str {
match self {
Self::Skip => "skip",
Self::Include => "include",
}
}
}
impl Display for ConditionKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f)
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) enum Conditions {
Variables(VariableConditions),
Boolean(bool),
}
impl Display for Conditions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
match self {
Conditions::Boolean(constant) => write!(f, "{constant:?}")?,
Conditions::Variables(variables) => {
for (index, (name, kind)) in variables.iter().enumerate() {
if index > 0 {
write!(f, " ")?;
}
write!(f, "@{kind}(if: ${name})")?;
}
}
}
write!(f, "]")
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct VariableConditions(
): does it really make sense for this to be an indexmap? we normally only
Arc<IndexMap<Name, ConditionKind>>,
);
impl VariableConditions {
fn new_unchecked(map: IndexMap<Name, ConditionKind>) -> Self {
debug_assert!(!map.is_empty());
Self(Arc::new(map))
}
fn condition_kind(&self, name: &str) -> Option<ConditionKind> {
self.0.get(name).copied()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (&Name, ConditionKind)> {
self.0.iter().map(|(name, &kind)| (name, kind))
}
fn merge(mut self, other: Self) -> Option<Self> {
let vars = Arc::make_mut(&mut self.0);
for (name, other_kind) in other.0.iter() {
match vars.entry(name.clone()) {
Entry::Occupied(self_kind) if self_kind.get() != other_kind => {
return None;
}
Entry::Occupied(_entry) => {}
Entry::Vacant(entry) => {
entry.insert(*other_kind);
}
}
}
Some(self)
}
}
impl Conditions {
fn from_variables(map: IndexMap<Name, ConditionKind>) -> Self {
if map.is_empty() {
Self::always()
} else {
Self::Variables(VariableConditions::new_unchecked(map))
}
}
pub(crate) const fn always() -> Self {
Self::Boolean(true)
}
pub(crate) const fn never() -> Self {
Self::Boolean(false)
}
pub(crate) fn from_directives(directives: &DirectiveList) -> Result<Self, FederationError> {
let mut variables = IndexMap::default();
if let Some(skip) = directives.get("skip") {
let Some(value) = skip.specified_argument_by_name("if") else {
bail!("missing @skip(if:) argument");
};
match value.as_ref() {
Value::Boolean(true) => return Ok(Self::never()),
Value::Boolean(_) => {}
Value::Variable(name) => {
variables.insert(name.clone(), ConditionKind::Skip);
}
_ => {
bail!("expected boolean or variable `if` argument, got {value}");
}
}
}
if let Some(include) = directives.get("include") {
let Some(value) = include.specified_argument_by_name("if") else {
bail!("missing @include(if:) argument");
};
match value.as_ref() {
Value::Boolean(false) => return Ok(Self::never()),
Value::Boolean(true) => {}
Value::Variable(name) => {
if variables.insert(name.clone(), ConditionKind::Include)
== Some(ConditionKind::Skip)
{
return Ok(Self::never());
}
}
_ => {
bail!("expected boolean or variable `if` argument, got {value}");
}
}
}
Ok(Self::from_variables(variables))
}
pub(crate) fn update_with(&self, handled_conditions: &Self) -> Self {
match (self, handled_conditions) {
(Conditions::Boolean(_), _) | (_, Conditions::Boolean(_)) => self.clone(),
(Conditions::Variables(new_conditions), Conditions::Variables(handled_conditions)) => {
let mut filtered = IndexMap::default();
for (cond_name, &cond_kind) in new_conditions.0.iter() {
match handled_conditions.condition_kind(cond_name) {
Some(handled_cond_kind) if cond_kind != handled_cond_kind => {
return Conditions::never();
}
Some(_) => {}
None => {
filtered.insert(cond_name.clone(), cond_kind);
}
}
}
Self::from_variables(filtered)
}
}
}
pub(crate) fn merge(self, other: Self) -> Self {
match (self, other) {
(Conditions::Boolean(false), _) | (_, Conditions::Boolean(false)) => {
Conditions::never()
}
(Conditions::Boolean(true), x) | (x, Conditions::Boolean(true)) => x,
(Conditions::Variables(self_vars), Conditions::Variables(other_vars)) => {
match self_vars.merge(other_vars) {
Some(vars) => Conditions::Variables(vars),
None => Conditions::never(),
}
}
}
}
}
pub(crate) fn remove_conditions_from_selection_set(
selection_set: &SelectionSet,
conditions: &Conditions,
) -> Result<SelectionSet, FederationError> {
match conditions {
Conditions::Boolean(_) => {
Ok(selection_set.clone())
}
Conditions::Variables(variable_conditions) => {
selection_set.lazy_map(|selection| {
let element = selection.element();
let updated_element =
remove_conditions_of_element(element.clone(), variable_conditions);
if let Some(selection_set) = selection.selection_set() {
let updated_selection_set =
remove_conditions_from_selection_set(selection_set, conditions)?;
if updated_element == element {
if *selection_set == updated_selection_set {
Ok(SelectionMapperReturn::Selection(selection.clone()))
} else {
Ok(SelectionMapperReturn::Selection(
selection
.with_updated_selection_set(Some(updated_selection_set))?,
))
}
} else {
Ok(SelectionMapperReturn::Selection(Selection::from_element(
updated_element,
Some(updated_selection_set),
)?))
}
} else if updated_element == element {
Ok(SelectionMapperReturn::Selection(selection.clone()))
} else {
Ok(SelectionMapperReturn::Selection(Selection::from_element(
updated_element,
None,
)?))
}
})
}
}
}
pub(crate) fn remove_unneeded_top_level_fragment_directives(
selection_set: &SelectionSet,
unneeded_directives: &DirectiveList,
) -> Result<SelectionSet, FederationError> {
let mut selection_map = SelectionMap::new();
for selection in selection_set.selections.values() {
match selection {
Selection::Field(_) => {
selection_map.insert(selection.clone());
}
Selection::InlineFragment(inline_fragment) => {
let fragment = &inline_fragment.inline_fragment;
if fragment.type_condition_position.is_none() {
selection_map.insert(selection.clone());
} else {
let needed_directives: Vec<Node<Directive>> = fragment
.directives
.iter()
.filter(|directive| !unneeded_directives.contains(directive))
.cloned()
.collect();
let updated_selections = remove_unneeded_top_level_fragment_directives(
&inline_fragment.selection_set,
unneeded_directives,
)?;
if needed_directives.len() == fragment.directives.len() {
let final_selection =
inline_fragment.with_updated_selection_set(updated_selections);
selection_map.insert(Selection::InlineFragment(Arc::new(final_selection)));
} else {
let final_selection = inline_fragment
.with_updated_directives_and_selection_set(
DirectiveList::from_iter(needed_directives),
updated_selections,
);
selection_map.insert(Selection::InlineFragment(Arc::new(final_selection)));
}
}
}
}
}
Ok(SelectionSet {
schema: selection_set.schema.clone(),
type_position: selection_set.type_position.clone(),
selections: Arc::new(selection_map),
})
}
fn remove_conditions_of_element(
element: OpPathElement,
conditions: &VariableConditions,
) -> OpPathElement {
let updated_directives: DirectiveList = element
.directives()
.iter()
.filter(|d| {
!matches_condition_for_kind(d, conditions, ConditionKind::Include)
&& !matches_condition_for_kind(d, conditions, ConditionKind::Skip)
})
.cloned()
.collect();
if updated_directives.len() == element.directives().len() {
element
} else {
element.with_updated_directives(updated_directives)
}
}
fn matches_condition_for_kind(
directive: &Directive,
conditions: &VariableConditions,
kind: ConditionKind,
) -> bool {
if directive.name != kind.as_str() {
return false;
}
match directive.specified_argument_by_name("if") {
Some(v) => match v.as_variable() {
Some(directive_var) => conditions.condition_kind(directive_var) == Some(kind),
None => true,
},
None => false,
}
}
#[cfg(test)]
mod tests {
use apollo_compiler::ExecutableDocument;
use apollo_compiler::Schema;
use super::*;
fn parse(directives: &str) -> Conditions {
let schema =
Schema::parse_and_validate("type Query { a: String }", "schema.graphql").unwrap();
let doc =
ExecutableDocument::parse(&schema, format!("{{ a {directives} }}"), "query.graphql")
.unwrap();
let operation = doc.operations.get(None).unwrap();
let directives = operation.selection_set.selections[0].directives();
Conditions::from_directives(&DirectiveList::from(directives.clone())).unwrap()
}
#[test]
fn merge_conditions() {
assert_eq!(
parse("@skip(if: $a)")
.merge(parse("@include(if: $b)"))
.to_string(),
"[@skip(if: $a) @include(if: $b)]",
"combine skip/include"
);
assert_eq!(
parse("@skip(if: $a)")
.merge(parse("@skip(if: $b)"))
.to_string(),
"[@skip(if: $a) @skip(if: $b)]",
"combine multiple skips"
);
assert_eq!(
parse("@include(if: $a)")
.merge(parse("@include(if: $b)"))
.to_string(),
"[@include(if: $a) @include(if: $b)]",
"combine multiple includes"
);
assert_eq!(
parse("@skip(if: $a)").merge(parse("@include(if: $a)")),
Conditions::never(),
"skip/include with same variable conflicts"
);
assert_eq!(
parse("@skip(if: $a)").merge(Conditions::always()),
parse("@skip(if: $a)"),
"merge with `true` returns original"
);
assert_eq!(
Conditions::always().merge(Conditions::always()),
Conditions::always(),
"merge with `true` returns original"
);
assert_eq!(
parse("@skip(if: $a)").merge(Conditions::never()),
Conditions::never(),
"merge with `false` returns `false`"
);
assert_eq!(
parse("@include(if: $a)").merge(Conditions::never()),
Conditions::never(),
"merge with `false` returns `false`"
);
assert_eq!(
Conditions::always().merge(Conditions::never()),
Conditions::never(),
"merge with `false` returns `false`"
);
assert_eq!(
parse("@skip(if: true)").merge(parse("@include(if: $a)")),
Conditions::never(),
"@skip with hardcoded if: true can never evaluate to true"
);
assert_eq!(
parse("@skip(if: false)").merge(parse("@include(if: $a)")),
parse("@include(if: $a)"),
"@skip with hardcoded if: false returns other side"
);
assert_eq!(
parse("@include(if: true)").merge(parse("@include(if: $a)")),
parse("@include(if: $a)"),
"@include with hardcoded if: true returns other side"
);
assert_eq!(
parse("@include(if: false)").merge(parse("@include(if: $a)")),
Conditions::never(),
"@include with hardcoded if: false can never evaluate to true"
);
}
#[test]
fn update_conditions() {
assert_eq!(
parse("@skip(if: $a)")
.merge(parse("@include(if: $b)"))
.update_with(&parse("@include(if: $b)")),
parse("@skip(if: $a)"),
"trim @include(if:) condition"
);
assert_eq!(
parse("@skip(if: $a)")
.merge(parse("@include(if: $b)"))
.update_with(&parse("@skip(if: $a)")),
parse("@include(if: $b)"),
"trim @skip(if:) condition"
);
let list = parse("@skip(if: $a)")
.merge(parse("@skip(if: $b)"))
.merge(parse("@skip(if: $c)"))
.merge(parse("@skip(if: $d)"))
.merge(parse("@skip(if: $e)"));
let handled = parse("@skip(if: $b)").merge(parse("@skip(if: $e)"));
assert_eq!(
list.update_with(&handled),
parse("@skip(if: $a)")
.merge(parse("@skip(if: $c)"))
.merge(parse("@skip(if: $d)")),
"trim multiple conditions"
);
let list = parse("@include(if: $a)")
.merge(parse("@include(if: $b)"))
.merge(parse("@include(if: $c)"))
.merge(parse("@include(if: $d)"))
.merge(parse("@include(if: $e)"));
let handled = parse("@include(if: $b)").merge(parse("@include(if: $e)"));
assert_eq!(
list.update_with(&handled),
parse("@include(if: $a)")
.merge(parse("@include(if: $c)"))
.merge(parse("@include(if: $d)")),
"trim multiple conditions"
);
let list = parse("@include(if: $a)")
.merge(parse("@include(if: $b)"))
.merge(parse("@include(if: $c)"))
.merge(parse("@include(if: $d)"))
.merge(parse("@include(if: $e)"));
assert_eq!(
list.update_with(&Conditions::never()),
list,
"update with constant does not affect conditions"
);
let list = parse("@include(if: $a)")
.merge(parse("@include(if: $b)"))
.merge(parse("@include(if: $c)"))
.merge(parse("@include(if: $d)"))
.merge(parse("@include(if: $e)"));
assert_eq!(
list.update_with(&Conditions::always()),
list,
"update with constant does not affect conditions"
);
}
}