use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use apollo_compiler::ast;
use apollo_compiler::schema::FieldLookupError;
use tower::BoxError;
pub(crate) fn document(
visitor: &mut impl Visitor,
document: &ast::Document,
) -> Result<ast::Document, BoxError> {
let mut new = ast::Document {
sources: document.sources.clone(),
definitions: Vec::new(),
};
let mut fragment_visitor = FragmentOrderVisitor::new();
fragment_visitor.visit_document(document);
let ordered_fragments = fragment_visitor.ordered_fragments();
visitor.state().reset();
for def in ordered_fragments {
visitor.state().used_fragments.clear();
visitor.state().used_variables.clear();
if let Some(new_def) = visitor.fragment_definition(def)? {
let used_variables = visitor.state().used_variables.clone();
let local_used_fragments = visitor.state().used_fragments.clone();
visitor.state().defined_fragments.insert(
def.name.as_str().to_string(),
DefinedFragment {
fragment: new_def,
used_variables,
used_fragments: local_used_fragments,
},
);
}
}
let mut used_fragments = HashSet::new();
for definition in &document.definitions {
if let ast::Definition::OperationDefinition(def) = definition {
let root_type = visitor
.schema()
.root_operation(def.operation_type)
.ok_or("missing root operation definition")?
.clone();
visitor.state().used_fragments.clear();
visitor.state().used_variables.clear();
if let Some(mut new_def) = visitor.operation(&root_type, def)? {
let mut local_used_fragments = visitor.state().used_fragments.clone();
loop {
let mut new_local_used_fragments = local_used_fragments.clone();
for fragment_name in local_used_fragments.iter() {
if let Some(defined_fragment) = visitor
.state()
.defined_fragments
.get(fragment_name.as_str())
{
new_local_used_fragments
.extend(defined_fragment.used_fragments.clone());
}
}
if new_local_used_fragments.len() == local_used_fragments.len() {
break;
}
local_used_fragments = new_local_used_fragments;
}
for fragment_name in local_used_fragments.iter() {
if let Some(defined_fragment_used_variables) = visitor
.state()
.defined_fragments
.get(fragment_name.as_str())
.map(|defined_fragment| defined_fragment.used_variables.clone())
{
visitor
.state()
.used_variables
.extend(defined_fragment_used_variables);
}
}
used_fragments.extend(local_used_fragments);
new_def.variables.retain(|var| {
let res = visitor.state().used_variables.contains(var.name.as_str());
res
});
new.definitions
.push(ast::Definition::OperationDefinition(new_def.into()));
}
}
}
for (name, defined_fragment) in visitor.state().defined_fragments.clone().into_iter() {
if used_fragments.contains(name.as_str()) {
new.definitions.push(ast::Definition::FragmentDefinition(
defined_fragment.fragment.into(),
));
}
}
Ok(new)
}
pub(crate) struct TransformState {
used_fragments: HashSet<String>,
used_variables: HashSet<String>,
defined_fragments: BTreeMap<String, DefinedFragment>,
}
#[derive(Clone)]
pub(crate) struct DefinedFragment {
pub(crate) fragment: ast::FragmentDefinition,
pub(crate) used_variables: HashSet<String>,
pub(crate) used_fragments: HashSet<String>,
}
impl TransformState {
pub(crate) fn new() -> Self {
Self {
used_fragments: HashSet::new(),
used_variables: HashSet::new(),
defined_fragments: BTreeMap::new(),
}
}
fn reset(&mut self) {
self.used_fragments.clear();
self.used_variables.clear();
self.defined_fragments.clear();
}
pub(crate) fn fragments(&self) -> &BTreeMap<String, DefinedFragment> {
&self.defined_fragments
}
}
pub(crate) trait Visitor: Sized {
fn schema(&self) -> &apollo_compiler::Schema;
fn state(&mut self) -> &mut TransformState;
fn operation(
&mut self,
root_type: &str,
def: &ast::OperationDefinition,
) -> Result<Option<ast::OperationDefinition>, BoxError> {
operation(self, root_type, def)
}
fn fragment_definition(
&mut self,
def: &ast::FragmentDefinition,
) -> Result<Option<ast::FragmentDefinition>, BoxError> {
fragment_definition(self, def)
}
fn field(
&mut self,
_parent_type: &str,
field_def: &ast::FieldDefinition,
def: &ast::Field,
) -> Result<Option<ast::Field>, BoxError> {
field(self, field_def, def)
}
fn fragment_spread(
&mut self,
def: &ast::FragmentSpread,
) -> Result<Option<ast::FragmentSpread>, BoxError> {
let res = fragment_spread(self, def);
if let Ok(Some(ref fragment)) = res.as_ref() {
self.state()
.used_fragments
.insert(fragment.fragment_name.as_str().to_string());
}
res
}
fn inline_fragment(
&mut self,
parent_type: &str,
def: &ast::InlineFragment,
) -> Result<Option<ast::InlineFragment>, BoxError> {
inline_fragment(self, parent_type, def)
}
}
pub(crate) fn operation(
visitor: &mut impl Visitor,
root_type: &str,
def: &ast::OperationDefinition,
) -> Result<Option<ast::OperationDefinition>, BoxError> {
let Some(selection_set) = selection_set(visitor, root_type, &def.selection_set)? else {
return Ok(None);
};
for directive in def.directives.iter() {
for argument in directive.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
}
Ok(Some(ast::OperationDefinition {
name: def.name.clone(),
operation_type: def.operation_type,
variables: def.variables.clone(),
directives: def.directives.clone(),
selection_set,
}))
}
pub(crate) fn fragment_definition(
visitor: &mut impl Visitor,
def: &ast::FragmentDefinition,
) -> Result<Option<ast::FragmentDefinition>, BoxError> {
let Some(selection_set) = selection_set(visitor, &def.type_condition, &def.selection_set)?
else {
return Ok(None);
};
for directive in def.directives.iter() {
for argument in directive.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
}
Ok(Some(ast::FragmentDefinition {
name: def.name.clone(),
type_condition: def.type_condition.clone(),
directives: def.directives.clone(),
selection_set,
}))
}
pub(crate) fn field(
visitor: &mut impl Visitor,
field_def: &ast::FieldDefinition,
def: &ast::Field,
) -> Result<Option<ast::Field>, BoxError> {
let Some(selection_set) =
selection_set(visitor, field_def.ty.inner_named_type(), &def.selection_set)?
else {
return Ok(None);
};
for argument in def.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
for directive in def.directives.iter() {
for argument in directive.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
}
Ok(Some(ast::Field {
alias: def.alias.clone(),
name: def.name.clone(),
arguments: def.arguments.clone(),
directives: def.directives.clone(),
selection_set,
}))
}
pub(crate) fn fragment_spread(
visitor: &mut impl Visitor,
def: &ast::FragmentSpread,
) -> Result<Option<ast::FragmentSpread>, BoxError> {
visitor
.state()
.used_fragments
.insert(def.fragment_name.as_str().to_string());
for directive in def.directives.iter() {
for argument in directive.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
}
Ok(Some(def.clone()))
}
pub(crate) fn inline_fragment(
visitor: &mut impl Visitor,
parent_type: &str,
def: &ast::InlineFragment,
) -> Result<Option<ast::InlineFragment>, BoxError> {
let Some(selection_set) = selection_set(visitor, parent_type, &def.selection_set)? else {
return Ok(None);
};
for directive in def.directives.iter() {
for argument in directive.arguments.iter() {
used_variables_from_value(visitor, &argument.value);
}
}
Ok(Some(ast::InlineFragment {
type_condition: def.type_condition.clone(),
directives: def.directives.clone(),
selection_set,
}))
}
pub(crate) fn selection_set(
visitor: &mut impl Visitor,
parent_type: &str,
set: &[ast::Selection],
) -> Result<Option<Vec<ast::Selection>>, BoxError> {
if set.is_empty() {
return Ok(Some(Vec::new()));
}
let mut selections = Vec::new();
for sel in set {
match sel {
ast::Selection::Field(def) => {
let field_def = visitor
.schema()
.type_field(parent_type, &def.name)
.map_err(|e| match e {
FieldLookupError::NoSuchType => format!("type `{parent_type}` not defined"),
FieldLookupError::NoSuchField(_, _) => {
format!("no field `{}` in type `{parent_type}`", &def.name)
}
})?
.clone();
if let Some(sel) = visitor.field(parent_type, &field_def, def)? {
selections.push(ast::Selection::Field(sel.into()))
}
}
ast::Selection::FragmentSpread(def) => {
if let Some(sel) = visitor.fragment_spread(def)? {
selections.push(ast::Selection::FragmentSpread(sel.into()))
}
}
ast::Selection::InlineFragment(def) => {
let fragment_type = def
.type_condition
.as_ref()
.map(|s| s.as_str())
.unwrap_or(parent_type);
if let Some(sel) = visitor.inline_fragment(fragment_type, def)? {
selections.push(ast::Selection::InlineFragment(sel.into()))
}
}
}
}
Ok((!selections.is_empty()).then_some(selections))
}
fn used_variables_from_value<V: Visitor>(
visitor: &mut V,
argument_value: &apollo_compiler::ast::Value,
) {
match argument_value {
apollo_compiler::ast::Value::Variable(name) => {
visitor
.state()
.used_variables
.insert(name.as_str().to_string());
}
apollo_compiler::ast::Value::List(values) => {
for value in values {
used_variables_from_value(visitor, value);
}
}
apollo_compiler::ast::Value::Object(values) => {
for (_, value) in values {
used_variables_from_value(visitor, value);
}
}
_ => {}
}
}
struct FragmentOrderVisitor<'a> {
ordered_fragments: Vec<String>,
fragments: HashMap<String, &'a apollo_compiler::ast::FragmentDefinition>,
dependencies: HashMap<String, Vec<String>>,
current: Option<String>,
rank: HashMap<String, usize>,
}
impl<'a> FragmentOrderVisitor<'a> {
fn new() -> Self {
Self {
ordered_fragments: Vec::new(),
fragments: HashMap::new(),
dependencies: HashMap::new(),
current: None,
rank: HashMap::new(),
}
}
fn rerank(&mut self, name: &str) {
if let Some(v) = self.dependencies.remove(name) {
for dep in v {
if let Some(rank) = self.rank.get_mut(&dep) {
*rank -= 1;
if *rank == 0 {
self.ordered_fragments.push(dep.clone());
self.rerank(&dep);
}
}
}
}
}
fn ordered_fragments(self) -> Vec<&'a ast::FragmentDefinition> {
let mut ordered_fragments = Vec::new();
for name in self.ordered_fragments {
if let Some(fragment) = self.fragments.get(name.as_str()) {
ordered_fragments.push(*fragment);
}
}
ordered_fragments
}
fn visit_document(&mut self, doc: &'a ast::Document) {
for definition in &doc.definitions {
if let ast::Definition::FragmentDefinition(def) = definition {
self.visit_fragment_definition(def);
}
}
}
fn visit_fragment_definition(&mut self, def: &'a ast::FragmentDefinition) {
let name = def.name.as_str().to_string();
self.fragments.insert(name.clone(), def);
self.current = Some(name.clone());
self.rank.insert(name.clone(), 0);
self.visit_selection_set(&def.selection_set);
if self.rank.get(&name) == Some(&0) {
self.ordered_fragments.push(name.clone());
self.rerank(&name);
}
}
fn visit_selection_set(&mut self, selection_set: &[apollo_compiler::ast::Selection]) {
for selection in selection_set {
match selection {
ast::Selection::Field(def) => self.visit_selection_set(&def.selection_set),
ast::Selection::InlineFragment(def) => self.visit_selection_set(&def.selection_set),
ast::Selection::FragmentSpread(def) => {
let name = def.fragment_name.as_str().to_string();
if self.rank.get(name.as_str()) == Some(&0) {
continue;
}
if let Some(current) = self.current.as_ref() {
if let Some(rank) = self.rank.get_mut(current.as_str()) {
*rank += 1;
}
self.dependencies
.entry(name)
.or_default()
.push(current.clone());
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_directive_to_fields() {
struct AddDirective {
schema: apollo_compiler::Schema,
state: TransformState,
}
impl Visitor for AddDirective {
fn field(
&mut self,
_parent_type: &str,
field_def: &ast::FieldDefinition,
def: &ast::Field,
) -> Result<Option<ast::Field>, BoxError> {
Ok(field(self, field_def, def)?.map(|mut new| {
new.directives.push(ast::Directive {
name: apollo_compiler::name!("added"),
arguments: Vec::new(),
});
new
}))
}
fn schema(&self) -> &apollo_compiler::Schema {
&self.schema
}
fn state(&mut self) -> &mut TransformState {
&mut self.state
}
}
let graphql = "
type Query {
a(id: ID): String
b: Int
next: Query
}
directive @defer(label: String, if: Boolean! = true) on FRAGMENT_SPREAD | INLINE_FRAGMENT
query($id: ID = null) {
a(id: $id)
... @defer {
b
}
... F
}
fragment F on Query {
next {
a
}
}
";
let ast = apollo_compiler::ast::Document::parse(graphql, "").unwrap();
let (schema, _doc) = ast.to_mixed_validate().unwrap();
let schema = schema.into_inner();
let mut visitor = AddDirective {
schema,
state: TransformState::new(),
};
let expected = "query($id: ID = null) {
a(id: $id) @added
... @defer {
b @added
}
...F
}
fragment F on Query {
next @added {
a @added
}
}
";
assert_eq!(document(&mut visitor, &ast).unwrap().to_string(), expected)
}
struct RemoveDirective {
schema: apollo_compiler::Schema,
state: TransformState,
}
impl RemoveDirective {
fn new(schema: apollo_compiler::Schema) -> Self {
Self {
schema,
state: TransformState::new(),
}
}
}
impl Visitor for RemoveDirective {
fn field(
&mut self,
_parent_type: &str,
field_def: &ast::FieldDefinition,
def: &ast::Field,
) -> Result<Option<ast::Field>, BoxError> {
if def.directives.iter().any(|d| d.name == "remove") {
return Ok(None);
}
field(self, field_def, def)
}
fn fragment_spread(
&mut self,
def: &ast::FragmentSpread,
) -> Result<Option<ast::FragmentSpread>, BoxError> {
if def.directives.iter().any(|d| d.name == "remove") {
return Ok(None);
}
if !self
.state()
.fragments()
.contains_key(def.fragment_name.as_str())
{
return Ok(None);
}
fragment_spread(self, def)
}
fn inline_fragment(
&mut self,
_parent_type: &str,
def: &ast::InlineFragment,
) -> Result<Option<ast::InlineFragment>, BoxError> {
if def.directives.iter().any(|d| d.name == "remove") {
return Ok(None);
}
inline_fragment(self, _parent_type, def)
}
fn schema(&self) -> &apollo_compiler::Schema {
&self.schema
}
fn state(&mut self) -> &mut TransformState {
&mut self.state
}
}
struct TestResult<'a> {
query: &'a str,
result: ast::Document,
}
impl std::fmt::Display for TestResult<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "query:\n{}\nfiltered:\n{}", self.query, self.result,)
}
}
static TRANSFORM_REMOVE_SCHEMA: &str = r#"
schema
@link(url: "https://specs.apollo.dev/link/v1.0")
@link(url: "https://specs.apollo.dev/join/v0.3", for: EXECUTION)
@link(url: "https://specs.apollo.dev/authenticated/v0.1", for: SECURITY)
{
query: Query
}
directive @link(url: String, as: String, for: link__Purpose, import: [link__Import]) repeatable on SCHEMA
directive @remove on FIELD | INLINE_FRAGMENT | FRAGMENT_SPREAD
directive @hasArg(arg: String!) on QUERY | FRAGMENT_DEFINITION | INLINE_FRAGMENT | FRAGMENT_SPREAD
scalar link__Import
enum link__Purpose {
"""
`SECURITY` features provide metadata necessary to securely resolve fields.
"""
SECURITY
"""
`EXECUTION` features provide metadata necessary for operation execution.
"""
EXECUTION
}
type Query {
a(arg: String): String
b: Obj
c: Int
d(arg: [String]): String
e(arg: Inp): String
f(arg: [[String]]): String
g(arg: [Inp]): String
}
input Inp {
a: String
b: String
c: [String]
}
type Obj {
a: String
}
"#;
#[test]
fn remove_directive() {
let ast = apollo_compiler::ast::Document::parse(TRANSFORM_REMOVE_SCHEMA, "").unwrap();
let (schema, _doc) = ast.to_mixed_validate().unwrap();
let schema = schema.into_inner();
let mut visitor = RemoveDirective::new(schema.clone());
let query = r#"
query {
a
... F @remove
}
fragment F on Query {
b {
a
}
}"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String) {
a(arg: $a) @remove
c
}"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String) {
... F
c
}
fragment F on Query {
a(arg: $a) @remove
}"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String) {
... F @remove
c
}
fragment F on Query {
a(arg: $a)
}"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String) {
... F @remove
c
}
fragment F on Query {
... G
}
fragment G on Query {
a(arg: $a)
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String) {
... F
c
}
fragment F on Query {
... G
}
fragment G on Query {
a(arg: $a) @remove
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String, $b: String) {
c
d(arg: ["a", $a, "b"]) @remove
aliased: d(arg: [$b])
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String, $b: String) {
c
e(arg: {a: $a, b: "b"}) @remove
aliased: e(arg: {a: "a", b: $b})
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query Test($a: String, $b: String, $c: String) @hasArg(arg: $a) {
...TestFragment
...TestFragment2
c
}
fragment TestFragment on Query @hasArg(arg: $b) {
__typename @remove
}
fragment TestFragment2 on Query @hasArg(arg: $c) {
__typename
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query Test($a: String, $b: String) {
...TestFragment @hasArg(arg: $a)
...TestFragment2 @hasArg(arg: $b)
c
}
fragment TestFragment on Query {
__typename @remove
}
fragment TestFragment2 on Query {
__typename
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query Test($a: String, $b: String) {
... @hasArg(arg: $a) {
c @remove
}
... @hasArg(arg: $b) {
test: c
}
c
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String, $b: String) {
c
f(arg: [["a"], [$a], ["b"]]) @remove
aliased: f(arg: [["a"], [$b]])
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String, $b: String) {
c
g(arg: [{a: $a}, {a: "a"}]) @remove
aliased: g(arg: [{a: "a"}, {a: $b}])
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
let query = r#"
query($a: String, $b: String) {
c
e(arg: {c: [$a]}) @remove
aliased: e(arg: {c: [$b]})
}
"#;
let doc = ast::Document::parse(query, "query.graphql").unwrap();
let result = document(&mut visitor, &doc).unwrap();
insta::assert_snapshot!(TestResult { query, result });
}
}