use std::fmt;
use std::fmt::Formatter;
use apollo_compiler::Name;
use apollo_compiler::Node;
use apollo_compiler::Schema;
use apollo_compiler::ast::Directive;
use apollo_compiler::collections::HashMap;
use apollo_compiler::collections::IndexSet;
use apollo_compiler::executable::FieldSet;
use apollo_compiler::executable::Selection;
use apollo_compiler::validation::Valid;
use itertools::Itertools;
use shape::Shape;
use shape::ShapeCase;
use crate::connectors::Connector;
use crate::connectors::Namespace;
use crate::connectors::validation::Code;
use crate::connectors::validation::Message;
use crate::connectors::variable::VariableReference;
use crate::link::federation_spec_definition::FEDERATION_FIELDS_ARGUMENT_NAME;
#[derive(Default)]
pub(crate) struct EntityKeyChecker<'schema> {
resolvable_keys: Vec<(FieldSet, &'schema Node<Directive>, &'schema Name)>,
entity_connectors: HashMap<Name, Vec<Valid<FieldSet>>>,
}
impl<'schema> EntityKeyChecker<'schema> {
pub(crate) fn add_key(&mut self, field_set: &FieldSet, directive: &'schema Node<Directive>) {
self.resolvable_keys
.push((field_set.clone(), directive, &directive.name));
}
pub(crate) fn add_connector(&mut self, field_set: Valid<FieldSet>, selection_shape: &Shape) {
let declared_type_name = &field_set.selection_set.ty;
self.entity_connectors
.entry(declared_type_name.clone())
.or_default()
.push(field_set.clone());
let concrete_types = extract_concrete_typenames(selection_shape);
for concrete_type_name in concrete_types {
if &concrete_type_name != declared_type_name {
self.entity_connectors
.entry(concrete_type_name)
.or_default()
.push(field_set.clone());
}
}
}
pub(crate) fn check_for_missing_entity_connectors(&self, schema: &Schema) -> Vec<Message> {
let mut messages = Vec::new();
for (key, directive, _) in &self.resolvable_keys {
let for_type = self.entity_connectors.get(&key.selection_set.ty);
let key_exists = for_type.is_some_and(|connectors| {
connectors.iter().any(|connector| {
field_set_fields_are_subset(key, connector)
})
});
if !key_exists {
messages.push(Message {
code: Code::MissingEntityConnector,
message: format!(
"Entity resolution for `@key(fields: \"{}\")` on `{}` is not implemented by a connector. See https://go.apollo.dev/connectors/entity-rules",
directive.argument_by_name(&FEDERATION_FIELDS_ARGUMENT_NAME, schema).ok().and_then(|arg| arg.as_str()).unwrap_or_default(),
key.selection_set.ty,
),
locations: directive
.line_column_range(&schema.sources)
.into_iter()
.collect(),
});
}
}
messages
}
}
impl fmt::Debug for EntityKeyChecker<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("EntityKeyChecker")
.field(
"resolvable_keys",
&self
.resolvable_keys
.iter()
.map(|(fs, _, _)| {
format!(
"... on {} {}",
fs.selection_set.ty,
fs.selection_set.serialize().no_indent()
)
})
.collect_vec(),
)
.field(
"entity_connectors",
&self
.entity_connectors
.values()
.flatten()
.map(|fs| {
format!(
"... on {} {}",
fs.selection_set.ty,
fs.selection_set.serialize().no_indent()
)
})
.collect_vec(),
)
.finish()
}
}
pub(crate) fn field_set_error(
variables: &[VariableReference<Namespace>],
connector: &Connector,
schema: &Schema,
) -> Message {
Message {
code: Code::ConnectorsCannotResolveKey,
message: format!(
"Variables used in connector (`{}`) on type `{}` cannot be used to create a valid `@key` directive.",
variables.iter().join("`, `"),
connector.id.directive.simple_name()
),
locations: connector
.name()
.line_column_range(&schema.sources)
.into_iter()
.collect(),
}
}
fn selection_is_subset(x: &Selection, y: &Selection) -> bool {
match (x, y) {
(Selection::Field(x), Selection::Field(y)) => {
x.name == y.name
&& x.alias == y.alias
&& vec_includes_as_set(
&x.selection_set.selections,
&y.selection_set.selections,
selection_is_subset,
)
}
(Selection::InlineFragment(x), Selection::InlineFragment(y)) => {
x.type_condition == y.type_condition
&& vec_includes_as_set(
&x.selection_set.selections,
&y.selection_set.selections,
selection_is_subset,
)
}
_ => false,
}
}
pub(crate) fn field_set_is_subset(inner: &FieldSet, outer: &FieldSet) -> bool {
inner.selection_set.ty == outer.selection_set.ty
&& vec_includes_as_set(
&outer.selection_set.selections,
&inner.selection_set.selections,
selection_is_subset,
)
}
fn field_set_fields_are_subset(inner: &FieldSet, outer: &FieldSet) -> bool {
vec_includes_as_set(
&outer.selection_set.selections,
&inner.selection_set.selections,
selection_is_subset,
)
}
fn vec_includes_as_set<T>(this: &[T], other: &[T], item_matches: impl Fn(&T, &T) -> bool) -> bool {
other.iter().all(|other_node| {
this.iter()
.any(|this_node| item_matches(this_node, other_node))
})
}
fn extract_concrete_typenames(shape: &Shape) -> IndexSet<Name> {
let mut result = IndexSet::default();
extract_concrete_typenames_into(shape, false, &mut result);
result
}
fn extract_concrete_typenames_into(shape: &Shape, in_typename: bool, result: &mut IndexSet<Name>) {
match shape.case() {
ShapeCase::String(Some(s)) => {
if in_typename && let Ok(name) = Name::new(s.as_str()) {
result.insert(name);
}
}
ShapeCase::String(None) => {}
ShapeCase::Object { fields, .. } => {
if let Some(typename_shape) = fields.get("__typename") {
extract_concrete_typenames_into(typename_shape, true, result);
}
}
ShapeCase::One(shapes) => {
for shape in shapes.iter() {
extract_concrete_typenames_into(shape, in_typename, result);
}
}
ShapeCase::All(shapes) => {
if in_typename {
let mut seen_literal: Option<&str> = None;
let mut is_satisfiable = true;
for shape in shapes.iter() {
if let ShapeCase::String(Some(s)) = shape.case() {
if let Some(prev) = seen_literal {
if prev != s.as_str() {
is_satisfiable = false;
break;
}
} else {
seen_literal = Some(s.as_str());
}
}
}
if is_satisfiable {
for shape in shapes.iter() {
extract_concrete_typenames_into(shape, in_typename, result);
}
}
} else {
for shape in shapes.iter() {
extract_concrete_typenames_into(shape, in_typename, result);
}
}
}
ShapeCase::Array { prefix, tail } => {
for shape in prefix.iter() {
extract_concrete_typenames_into(shape, false, result);
}
extract_concrete_typenames_into(tail, false, result);
}
ShapeCase::Error(shape::Error { partial, .. }) => {
if let Some(partial) = partial {
extract_concrete_typenames_into(partial, in_typename, result);
}
}
ShapeCase::Name(_, _) => {}
ShapeCase::None => {}
ShapeCase::Bool(_) => {}
ShapeCase::Int(_) => {}
ShapeCase::Float => {}
ShapeCase::Null => {}
ShapeCase::Unknown => {}
}
}
#[cfg(test)]
mod tests {
use apollo_compiler::Schema;
use apollo_compiler::executable::FieldSet;
use apollo_compiler::name;
use apollo_compiler::validation::Valid;
use rstest::rstest;
use super::field_set_is_subset;
fn schema() -> Valid<Schema> {
Schema::parse_and_validate(
r#"
type Query {
t: T
}
type T {
a: String
b: B
c: String
}
type B {
x: String
y: String
}
"#,
"",
)
.unwrap()
}
#[rstest]
#[case("a", "a")]
#[case("a b { x } c", "a b { x } c")]
#[case("a", "a c")]
#[case("b { x }", "b { x y }")]
fn test_field_set_is_subset(#[case] inner: &str, #[case] outer: &str) {
let schema = schema();
let inner = FieldSet::parse_and_validate(&schema, name!(T), inner, "inner").unwrap();
let outer = FieldSet::parse_and_validate(&schema, name!(T), outer, "outer").unwrap();
assert!(field_set_is_subset(&inner, &outer));
}
#[rstest]
#[case("a b { x } c", "a")]
#[case("b { x y }", "b { x }")]
fn test_field_set_is_not_subset(#[case] inner: &str, #[case] outer: &str) {
let schema = schema();
let inner = FieldSet::parse_and_validate(&schema, name!(T), inner, "inner").unwrap();
let outer = FieldSet::parse_and_validate(&schema, name!(T), outer, "outer").unwrap();
assert!(!field_set_is_subset(&inner, &outer));
}
#[test]
fn test_extract_concrete_typenames_from_match() {
use crate::connectors::ConnectSpec;
use crate::connectors::JSONSelection;
let selection = JSONSelection::parse_with_spec(
r#"
id
... $(name ?? null)->match(
[null, { __typename: "Anon" }],
[@, { __typename: "Named", name: name }],
)
"#,
ConnectSpec::V0_4,
)
.unwrap();
let shape = selection.shape();
eprintln!("Shape: {}", shape.pretty_print());
let concrete_types = super::extract_concrete_typenames(&shape);
eprintln!("Concrete types: {:?}", concrete_types);
assert!(concrete_types.contains(&name!(Anon)));
assert!(concrete_types.contains(&name!(Named)));
}
#[test]
fn test_extract_typename_union_of_strings() {
use shape::Shape;
let typename_union = Shape::one(
[
Shape::string_value("TypeA", []),
Shape::string_value("TypeB", []),
],
[],
);
let shape = Shape::record(
[
("__typename".to_string(), typename_union),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
eprintln!("Shape: {}", shape.pretty_print());
let concrete_types = super::extract_concrete_typenames(&shape);
eprintln!("Concrete types: {:?}", concrete_types);
assert!(concrete_types.contains(&name!(TypeA)));
assert!(concrete_types.contains(&name!(TypeB)));
assert_eq!(concrete_types.len(), 2);
}
#[test]
fn test_does_not_extract_nested_typename() {
use shape::Shape;
let nested_object = Shape::record(
[
("__typename".to_string(), Shape::string_value("User", [])),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
let shape = Shape::record(
[
("id".to_string(), Shape::string([])),
("author".to_string(), nested_object),
]
.into(),
[],
);
eprintln!("Shape: {}", shape.pretty_print());
let concrete_types = super::extract_concrete_typenames(&shape);
eprintln!("Concrete types: {:?}", concrete_types);
assert!(concrete_types.is_empty());
}
#[test]
fn test_does_not_extract_conflicting_typename_intersection() {
use shape::Shape;
let conflicting_typename = Shape::all(
[
Shape::string_value("Cat", []),
Shape::string_value("Dog", []),
],
[],
);
let shape = Shape::record(
[
("__typename".to_string(), conflicting_typename),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
eprintln!("Shape: {}", shape.pretty_print());
let concrete_types = super::extract_concrete_typenames(&shape);
eprintln!("Concrete types: {:?}", concrete_types);
assert!(
concrete_types.is_empty(),
"Conflicting __typename intersection should not extract any types"
);
}
#[test]
fn test_extracts_only_valid_typename_from_mixed_union() {
use shape::Shape;
let valid_cat = Shape::record(
[
("__typename".to_string(), Shape::string_value("Cat", [])),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
let conflicting = Shape::record(
[
(
"__typename".to_string(),
Shape::all(
[
Shape::string_value("Cat", []),
Shape::string_value("Dog", []),
],
[],
),
),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
let valid_bird = Shape::record(
[
("__typename".to_string(), Shape::string_value("Bird", [])),
("id".to_string(), Shape::string([])),
]
.into(),
[],
);
let shape = Shape::one([valid_cat, conflicting, valid_bird], []);
eprintln!("Shape: {}", shape.pretty_print());
let concrete_types = super::extract_concrete_typenames(&shape);
eprintln!("Concrete types: {:?}", concrete_types);
assert!(concrete_types.contains(&name!(Cat)));
assert!(concrete_types.contains(&name!(Bird)));
assert!(!concrete_types.contains(&name!(Dog)));
assert_eq!(concrete_types.len(), 2);
}
}