use std::collections::HashMap;
use std::collections::HashSet;
use apollo_compiler::ExecutableDocument;
use apollo_compiler::Name;
use apollo_compiler::Node;
use apollo_compiler::ast;
use apollo_compiler::ast::VariableDefinition;
use apollo_compiler::executable;
use apollo_compiler::executable::Operation;
use apollo_compiler::executable::Selection;
use apollo_compiler::executable::SelectionSet;
use apollo_compiler::validation::Valid;
use apollo_compiler::validation::WithErrors;
use serde_json_bytes::ByteString;
use serde_json_bytes::Map;
use super::fetch::SubgraphOperation;
use super::rewrites::DataKeyRenamer;
use super::rewrites::DataRewrite;
use crate::json_ext::Path;
use crate::json_ext::PathElement;
use crate::json_ext::Value;
use crate::json_ext::ValueExt;
use crate::spec::Schema;
#[derive(Debug)]
pub(crate) struct ContextualArguments {
pub(crate) arguments: HashSet<String>, pub(crate) count: usize, }
pub(crate) struct SubgraphContext<'a> {
pub(crate) data: &'a Value,
pub(crate) schema: &'a Schema,
pub(crate) context_rewrites: &'a Vec<DataRewrite>,
pub(crate) named_args: Vec<HashMap<String, Value>>,
}
fn merge_context_path(
current_dir: &Path,
context_path: &Path,
) -> Result<Path, ContextBatchingError> {
let mut i = 0;
let mut j = current_dir.len();
while i < context_path.len() {
match &context_path.0.get(i) {
Some(PathElement::Key(e, _)) => {
let mut found = false;
if e == ".." {
while !found {
if j == 0 {
return Err(ContextBatchingError::InvalidRelativePath);
}
j -= 1;
if let Some(PathElement::Key(_, _)) = current_dir.0.get(j) {
found = true;
}
}
i += 1;
} else {
break;
}
}
_ => break,
}
}
let mut return_path: Vec<PathElement> = current_dir.iter().take(j).cloned().collect();
context_path.iter().skip(i).for_each(|e| {
return_path.push(e.clone());
});
Ok(Path(return_path.into_iter().collect()))
}
impl<'a> SubgraphContext<'a> {
pub(crate) fn new(
data: &'a Value,
schema: &'a Schema,
context_rewrites: &'a Option<Vec<DataRewrite>>,
) -> Option<SubgraphContext<'a>> {
if let Some(rewrites) = context_rewrites {
if !rewrites.is_empty() {
return Some(SubgraphContext {
data,
schema,
context_rewrites: rewrites,
named_args: Vec::new(),
});
}
}
None
}
pub(crate) fn execute_on_path(&mut self, path: &Path) {
let mut found_rewrites: HashSet<String> = HashSet::new();
let hash_map: HashMap<String, Value> = self
.context_rewrites
.iter()
.filter_map(|rewrite| {
match rewrite {
DataRewrite::KeyRenamer(item) => {
if !found_rewrites.contains(item.rename_key_to.as_str()) {
let wrapped_data_path = merge_context_path(path, &item.path);
if let Ok(data_path) = wrapped_data_path {
let val = self.data.get_path(self.schema, &data_path);
if let Ok(v) = val {
found_rewrites.insert(item.rename_key_to.clone().to_string());
let mut new_value = v.clone();
if let Some(values) = new_value.as_array_mut() {
for v in values {
let data_rewrite = DataRewrite::KeyRenamer({
DataKeyRenamer {
path: data_path.clone(),
rename_key_to: item.rename_key_to.clone(),
}
});
data_rewrite.maybe_apply(self.schema, v);
}
} else {
let data_rewrite = DataRewrite::KeyRenamer({
DataKeyRenamer {
path: data_path.clone(),
rename_key_to: item.rename_key_to.clone(),
}
});
data_rewrite.maybe_apply(self.schema, &mut new_value);
}
return Some((item.rename_key_to.to_string(), new_value));
}
}
}
None
}
DataRewrite::ValueSetter(_) => None,
}
})
.collect();
self.named_args.push(hash_map);
}
pub(crate) fn add_variables_and_get_args(
&self,
variables: &mut Map<ByteString, Value>,
) -> Option<ContextualArguments> {
let (extended_vars, contextual_args) = if let Some(first_map) = self.named_args.first() {
if self.named_args.iter().all(|map| map == first_map) {
(
first_map
.iter()
.map(|(k, v)| (k.as_str().into(), v.clone()))
.collect(),
None,
)
} else {
let mut hash_map: HashMap<String, Value> = HashMap::new();
let arg_names: HashSet<_> = first_map.keys().cloned().collect();
for (index, item) in self.named_args.iter().enumerate() {
hash_map.extend(item.iter().map(|(k, v)| {
let mut new_named_param = k.clone();
new_named_param.push_str(&format!("_{}", index));
(new_named_param, v.clone())
}));
}
(
hash_map,
Some(ContextualArguments {
arguments: arg_names,
count: self.named_args.len(),
}),
)
}
} else {
(HashMap::new(), None)
};
variables.extend(
extended_vars
.iter()
.map(|(key, value)| (key.as_str().into(), value.clone())),
);
contextual_args
}
}
pub(crate) fn build_operation_with_aliasing(
subgraph_operation: &SubgraphOperation,
contextual_arguments: &ContextualArguments,
subgraph_schema: &Valid<apollo_compiler::Schema>,
) -> Result<Valid<ExecutableDocument>, ContextBatchingError> {
let ContextualArguments { arguments, count } = contextual_arguments;
let parsed_document = subgraph_operation.as_parsed();
let mut ed = ExecutableDocument::new();
if let Ok(document) = parsed_document {
if let Some(anonymous_op) = &document.operations.anonymous {
let mut cloned = anonymous_op.clone();
transform_operation(&mut cloned, arguments, count)?;
ed.operations.insert(cloned);
}
for (_, op) in &document.operations.named {
let mut cloned = op.clone();
transform_operation(&mut cloned, arguments, count)?;
ed.operations.insert(cloned);
}
return ed
.validate(subgraph_schema)
.map_err(|e| ContextBatchingError::InvalidDocumentGenerated(Box::new(e)));
}
Err(ContextBatchingError::NoSelectionSet)
}
fn transform_operation(
operation: &mut Node<Operation>,
arguments: &HashSet<String>,
count: &usize,
) -> Result<(), ContextBatchingError> {
let mut selections: Vec<Selection> = vec![];
let mut new_variables: Vec<Node<VariableDefinition>> = vec![];
operation.variables.iter().for_each(|v| {
if arguments.contains(v.name.as_str()) {
for i in 0..*count {
new_variables.push(Node::new(VariableDefinition {
name: Name::new_unchecked(&format!("{}_{}", v.name.as_str(), i)),
ty: v.ty.clone(),
default_value: v.default_value.clone(),
directives: v.directives.clone(),
}));
}
} else {
new_variables.push(v.clone());
}
});
let mut field_selection: Option<Node<executable::Field>> = None;
for selection in &operation.selection_set.selections {
match selection {
Selection::Field(f) => {
if field_selection.is_some() {
return Err(ContextBatchingError::UnexpectedSelection);
}
field_selection = Some(f.clone());
}
_ => {
return Err(ContextBatchingError::UnexpectedSelection);
}
}
}
let field_selection = field_selection.ok_or(ContextBatchingError::UnexpectedSelection)?;
for i in 0..*count {
let mut cloned = field_selection.clone();
let cfs = cloned.make_mut();
cfs.alias = Some(Name::new_unchecked(&format!("_{}", i)));
transform_field_arguments(&mut cfs.arguments, arguments, i);
transform_selection_set(&mut cfs.selection_set, arguments, i);
selections.push(Selection::Field(cloned));
}
let operation = operation.make_mut();
operation.variables = new_variables;
operation.selection_set = SelectionSet {
ty: operation.selection_set.ty.clone(),
selections,
};
Ok(())
}
fn transform_selection_set(
selection_set: &mut SelectionSet,
arguments: &HashSet<String>,
index: usize,
) {
selection_set
.selections
.iter_mut()
.for_each(|selection| match selection {
executable::Selection::Field(node) => {
let node = node.make_mut();
transform_field_arguments(&mut node.arguments, arguments, index);
transform_selection_set(&mut node.selection_set, arguments, index);
}
executable::Selection::InlineFragment(node) => {
let node = node.make_mut();
transform_selection_set(&mut node.selection_set, arguments, index);
}
_ => (),
});
}
fn transform_field_arguments(
arguments_in_selection: &mut [Node<ast::Argument>],
arguments: &HashSet<String>,
index: usize,
) {
arguments_in_selection.iter_mut().for_each(|arg| {
let arg = arg.make_mut();
if let Some(v) = arg.value.as_variable() {
if arguments.contains(v.as_str()) {
arg.value = Node::new(ast::Value::Variable(Name::new_unchecked(&format!(
"{}_{}",
v.as_str(),
index
))));
}
}
});
}
#[derive(Debug)]
pub(crate) enum ContextBatchingError {
NoSelectionSet,
InvalidDocumentGenerated(#[allow(unused)] Box<WithErrors<ExecutableDocument>>),
InvalidRelativePath,
UnexpectedSelection,
}
#[cfg(test)]
mod subgraph_context_unit_tests {
use super::*;
#[test]
fn test_merge_context_path() {
let current_dir: Path = serde_json::from_str(r#"["t","u"]"#).unwrap();
let relative_path: Path = serde_json::from_str(r#"["..","... on T","prop"]"#).unwrap();
let expected = r#"["t","... on T","prop"]"#;
let result = merge_context_path(¤t_dir, &relative_path).unwrap();
assert_eq!(expected, serde_json::to_string(&result).unwrap(),);
}
#[test]
fn test_merge_context_path_invalid() {
let current_dir: Path = serde_json::from_str(r#"["t","u"]"#).unwrap();
let relative_path: Path =
serde_json::from_str(r#"["..","..","..","... on T","prop"]"#).unwrap();
let result = merge_context_path(¤t_dir, &relative_path);
match result {
Ok(_) => panic!("Expected an error, but got Ok"),
Err(e) => match e {
ContextBatchingError::InvalidRelativePath => (),
_ => panic!("Expected InvalidRelativePath, but got a different error"),
},
}
}
#[test]
fn test_transform_selection_set() {
let type_name = Name::new("Hello").unwrap();
let field_name = Name::new("f").unwrap();
let field_definition = ast::FieldDefinition {
description: None,
name: field_name.clone(),
arguments: vec![Node::new(ast::InputValueDefinition {
description: None,
name: Name::new("param").unwrap(),
ty: Node::new(ast::Type::Named(Name::new("ParamType").unwrap())),
default_value: None,
directives: ast::DirectiveList(vec![]),
})],
ty: ast::Type::Named(Name::new("FieldType").unwrap()),
directives: ast::DirectiveList(vec![]),
};
let mut selection_set = SelectionSet::new(type_name);
let field = executable::Field::new(Name::new("f").unwrap(), Node::new(field_definition))
.with_argument(
Name::new("param").unwrap(),
Node::new(ast::Value::Variable(Name::new("variable").unwrap())),
);
selection_set.push(Selection::Field(Node::new(field)));
assert_eq!(
"{ f(param: $variable) }",
selection_set.serialize().no_indent().to_string()
);
let mut hash_set = HashSet::new();
hash_set.insert("one".to_string());
hash_set.insert("two".to_string());
hash_set.insert("param".to_string());
let mut clone = selection_set.clone();
transform_selection_set(&mut clone, &hash_set, 7);
assert_eq!(
"{ f(param: $variable) }",
clone.serialize().no_indent().to_string()
);
hash_set.insert("variable".to_string());
let mut clone = selection_set.clone();
transform_selection_set(&mut clone, &hash_set, 7);
assert_eq!(
"{ f(param: $variable_7) }",
clone.serialize().no_indent().to_string()
);
let clone = selection_set.clone();
let mut operation = Node::new(executable::Operation {
operation_type: executable::OperationType::Query,
name: None,
variables: vec![],
directives: ast::DirectiveList(vec![]),
selection_set: clone,
});
let count = 3;
transform_operation(&mut operation, &hash_set, &count).unwrap();
assert_eq!(
"{ _0: f(param: $variable_0) _1: f(param: $variable_1) _2: f(param: $variable_2) }",
operation.serialize().no_indent().to_string()
);
}
}