use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use apollo_compiler::Schema;
use apollo_compiler::ast::NamedType;
use apollo_compiler::collections::IndexMap;
use apollo_compiler::parser::Parser;
use apollo_compiler::resolvers;
use apollo_compiler::schema::ObjectType;
use apollo_compiler::validation::Valid;
use apollo_federation::connectors::StringTemplate;
use http::HeaderValue;
use http::header::CACHE_CONTROL;
use itertools::Itertools;
use lru::LruCache;
use multimap::MultiMap;
use opentelemetry::Array;
use opentelemetry::Key;
use opentelemetry::StringValue;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use serde_json_bytes::ByteString;
use serde_json_bytes::Value;
use tokio::sync::RwLock;
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::IntervalStream;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower_service::Service;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;
use super::cache_control::CacheControl;
use super::invalidation::Invalidation;
use super::invalidation_endpoint::InvalidationEndpointConfig;
use super::invalidation_endpoint::InvalidationService;
use super::invalidation_endpoint::SubgraphInvalidationConfig;
use super::metrics::CacheMetricContextKey;
use super::metrics::record_fetch_error;
use crate::Context;
use crate::Endpoint;
use crate::ListenAddr;
use crate::configuration::subgraph::SubgraphConfiguration;
use crate::context::CONTAINS_GRAPHQL_ERROR;
use crate::error::FetchError;
use crate::graphql;
use crate::graphql::Error;
use crate::json_ext::Object;
use crate::json_ext::Path;
use crate::json_ext::PathElement;
use crate::layers::ServiceBuilderExt;
use crate::plugin::PluginInit;
use crate::plugin::PluginPrivate;
use crate::plugins::authorization::CacheKeyMetadata;
use crate::plugins::response_cache::cache_key::PrimaryCacheKeyEntity;
use crate::plugins::response_cache::cache_key::PrimaryCacheKeyRoot;
use crate::plugins::response_cache::cache_key::hash_additional_data;
use crate::plugins::response_cache::cache_key::hash_query;
use crate::plugins::response_cache::debugger::CacheEntryKind;
use crate::plugins::response_cache::debugger::CacheKeyContext;
use crate::plugins::response_cache::debugger::CacheKeySource;
use crate::plugins::response_cache::debugger::add_cache_key_to_context;
use crate::plugins::response_cache::debugger::add_cache_keys_to_context;
use crate::plugins::response_cache::storage;
use crate::plugins::response_cache::storage::CacheEntry;
use crate::plugins::response_cache::storage::CacheStorage;
use crate::plugins::response_cache::storage::Document;
use crate::plugins::response_cache::storage::redis::Storage;
use crate::plugins::telemetry::LruSizeInstrument;
use crate::plugins::telemetry::dynamic_attribute::SpanDynAttribute;
use crate::plugins::telemetry::span_ext::SpanMarkError;
use crate::query_planner::OperationKind;
use crate::services::subgraph;
use crate::services::subgraph::SubgraphRequestId;
use crate::services::supergraph;
use crate::spec::QueryHash;
use crate::spec::TYPENAME;
pub(crate) const RESPONSE_CACHE_VERSION: &str = "1.2";
pub(crate) const CACHE_TAG_DIRECTIVE_NAME: &str = "federation__cacheTag";
pub(crate) const ENTITIES: &str = "_entities";
pub(crate) const REPRESENTATIONS: &str = "representations";
pub(crate) const CONTEXT_CACHE_KEY: &str = "apollo::response_cache::key";
pub(crate) const CONTEXT_DEBUG_CACHE_KEYS: &str = "apollo::response_cache::debug_cached_keys";
pub(crate) const CACHE_DEBUG_HEADER_NAME: &str = "apollo-cache-debugging";
pub(crate) const CACHE_DEBUG_EXTENSIONS_KEY: &str = "apolloCacheDebugging";
pub(crate) const CACHE_DEBUGGER_VERSION: &str = "1.0";
pub(crate) const GRAPHQL_RESPONSE_EXTENSION_ROOT_FIELDS_CACHE_TAGS: &str = "apolloCacheTags";
pub(crate) const GRAPHQL_RESPONSE_EXTENSION_ENTITY_CACHE_TAGS: &str = "apolloEntityCacheTags";
pub(crate) const INTERNAL_CACHE_TAG_PREFIX: &str = "__apollo_internal::";
const DEFAULT_LRU_PRIVATE_QUERIES_SIZE: NonZeroUsize = NonZeroUsize::new(2048).unwrap();
const LRU_PRIVATE_QUERIES_INSTRUMENT_NAME: &str =
"apollo.router.response_cache.private_queries.lru.size";
register_private_plugin!("apollo", "response_cache", ResponseCache);
#[derive(Clone)]
pub(crate) struct ResponseCache {
pub(super) storage: Arc<StorageInterface>,
endpoint_config: Option<Arc<InvalidationEndpointConfig>>,
subgraphs: Arc<SubgraphConfiguration<Subgraph>>,
entity_type: Option<String>,
enabled: bool,
debug: bool,
private_queries: Arc<RwLock<LruCache<PrivateQueryKey, ()>>>,
pub(crate) invalidation: Invalidation,
supergraph_schema: Arc<Valid<Schema>>,
subgraph_enums: Arc<HashMap<String, String>>,
lru_size_instrument: LruSizeInstrument,
drop_tx: broadcast::Sender<()>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct PrivateQueryKey {
query_hash: String,
has_private_id: bool,
}
#[derive(Clone, Default)]
pub(crate) struct StorageInterface {
all: Option<Arc<OnceLock<Storage>>>,
subgraphs: HashMap<String, Arc<OnceLock<Storage>>>,
}
impl StorageInterface {
pub(crate) fn get(&self, subgraph: &str) -> Option<&Storage> {
let storage = self.subgraphs.get(subgraph).or(self.all.as_ref())?;
storage.get()
}
pub(crate) fn activate(&self) {
if let Some(all) = &self.all
&& let Some(storage) = all.get()
{
storage.activate();
}
for storage in self.subgraphs.values() {
if let Some(storage) = storage.get() {
storage.activate();
}
}
}
}
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
impl StorageInterface {
pub(crate) fn replace_storage(&self, storage: Storage) -> Option<()> {
self.all.as_ref()?.set(storage).ok()
}
}
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
impl From<Storage> for StorageInterface {
fn from(storage: Storage) -> Self {
Self {
all: Some(Arc::new(storage.into())),
subgraphs: HashMap::new(),
}
}
}
#[derive(Clone, Debug, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub(crate) struct Config {
#[serde(default)]
pub(crate) enabled: bool,
#[serde(default)]
debug: bool,
pub(crate) subgraph: SubgraphConfiguration<Subgraph>,
invalidation: Option<InvalidationEndpointConfig>,
#[serde(default = "default_lru_private_queries_size")]
private_queries_buffer_size: NonZeroUsize,
}
const fn default_lru_private_queries_size() -> NonZeroUsize {
DEFAULT_LRU_PRIVATE_QUERIES_SIZE
}
#[derive(Clone, Debug, JsonSchema, Deserialize, Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields, default)]
pub(crate) struct Subgraph {
pub(crate) redis: Option<storage::redis::Config>,
pub(crate) ttl: Option<Ttl>,
pub(crate) enabled: Option<bool>,
pub(crate) private_id: Option<String>,
pub(crate) invalidation: Option<SubgraphInvalidationConfig>,
}
impl Default for Subgraph {
fn default() -> Self {
Self {
redis: None,
enabled: Some(true),
ttl: Default::default(),
private_id: Default::default(),
invalidation: Default::default(),
}
}
}
#[derive(Clone, Debug, JsonSchema, Deserialize, Serialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub(crate) struct Ttl(
#[serde(deserialize_with = "humantime_serde::deserialize")]
#[schemars(with = "String")]
pub(crate) Duration,
);
#[derive(Default, Serialize, Deserialize, Debug)]
#[serde(default)]
pub(crate) struct CacheSubgraph(pub(crate) HashMap<String, CacheHitMiss>);
#[derive(Default, Serialize, Deserialize, Debug)]
#[serde(default)]
pub(crate) struct CacheHitMiss {
pub(crate) hit: usize,
pub(crate) miss: usize,
}
#[async_trait::async_trait]
impl PluginPrivate for ResponseCache {
const HIDDEN_FROM_CONFIG_JSON_SCHEMA: bool = true;
type Config = Config;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError>
where
Self: Sized,
{
let entity_type = init
.supergraph_schema
.schema_definition
.query
.as_ref()
.map(|q| q.name.to_string());
if init.config.subgraph.all.ttl.is_none()
&& init
.config
.subgraph
.subgraphs
.values()
.any(|s| s.ttl.is_none())
{
return Err("a TTL must be configured for all subgraphs or globally"
.to_string()
.into());
}
if init
.config
.subgraph
.all
.invalidation
.as_ref()
.map(|i| i.shared_key.is_empty())
.unwrap_or_default()
{
return Err(
"you must set a default shared_key invalidation for all subgraphs"
.to_string()
.into(),
);
}
let mut storage_interface = StorageInterface::default();
let (drop_tx, drop_rx) = tokio::sync::broadcast::channel(2);
if init.config.enabled
&& init.config.subgraph.all.enabled.unwrap_or_default()
&& let Some(config) = init.config.subgraph.all.redis.clone()
{
let storage = Arc::new(OnceLock::new());
storage_interface.all = Some(storage.clone());
connect_or_spawn_reconnection_task(config, storage, drop_rx).await?;
}
for (subgraph, subgraph_config) in &init.config.subgraph.subgraphs {
if Self::static_subgraph_enabled(init.config.enabled, &init.config.subgraph, subgraph) {
match subgraph_config.redis.clone() {
Some(config) => {
if Some(&config) != init.config.subgraph.all.redis.as_ref()
|| storage_interface.all.is_none()
{
let storage = Arc::new(OnceLock::new());
storage_interface
.subgraphs
.insert(subgraph.clone(), storage.clone());
connect_or_spawn_reconnection_task(
config,
storage,
drop_tx.subscribe(),
)
.await?;
}
}
None => {
if storage_interface.all.is_none() {
return Err(
format!("you must have a redis configured either for all subgraphs or for subgraph {subgraph:?}")
.into(),
);
}
}
}
}
}
let storage_interface = Arc::new(storage_interface);
let invalidation = Invalidation::new(storage_interface.clone()).await?;
Ok(Self {
storage: storage_interface,
entity_type,
enabled: init.config.enabled,
debug: init.config.debug,
endpoint_config: init.config.invalidation.clone().map(Arc::new),
subgraphs: Arc::new(init.config.subgraph),
private_queries: Arc::new(RwLock::new(LruCache::new(
init.config.private_queries_buffer_size,
))),
invalidation,
subgraph_enums: Arc::new(get_subgraph_enums(&init.supergraph_schema)),
supergraph_schema: init.supergraph_schema,
lru_size_instrument: LruSizeInstrument::new(LRU_PRIVATE_QUERIES_INSTRUMENT_NAME),
drop_tx,
})
}
fn activate(&self) {
self.storage.activate();
}
fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
let debug = self.debug;
ServiceBuilder::new()
.map_response(move |mut response: supergraph::Response| {
if let Some(mut cache_control) = response
.context
.extensions()
.with_lock(|lock| lock.get::<CacheControl>().cloned())
{
let has_errors = response
.context
.get_json_value(CONTAINS_GRAPHQL_ERROR)
.and_then(|v| v.as_bool())
.unwrap_or(false);
if has_errors {
cache_control = CacheControl::no_store();
}
let _ = cache_control.to_headers(response.response.headers_mut());
}
if debug
&& let Some(debug_data) =
response.context.get_json_value(CONTEXT_DEBUG_CACHE_KEYS)
{
return response.map_stream(move |mut body| {
body.extensions.insert(
CACHE_DEBUG_EXTENSIONS_KEY,
serde_json_bytes::json!({
"version": CACHE_DEBUGGER_VERSION,
"data": debug_data.clone()
}),
);
body
});
}
response
})
.service(service)
.boxed()
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
let subgraph_ttl = self
.subgraph_ttl(name)
.unwrap_or_else(|| Duration::from_secs(60 * 60 * 24)); let subgraph_enabled = self.subgraph_enabled(name);
let private_id = self.subgraphs.get(name).private_id.clone();
let name = name.to_string();
if subgraph_enabled {
let private_queries = self.private_queries.clone();
let inner = ServiceBuilder::new()
.map_response(move |response: subgraph::Response| {
update_cache_control(
&response.context,
&CacheControl::new(response.response.headers(), subgraph_ttl.into())
.ok()
.unwrap_or_else(CacheControl::no_store),
);
response
})
.service(CacheService {
service: ServiceBuilder::new()
.buffered()
.service(service)
.boxed_clone(),
entity_type: self.entity_type.clone(),
name: name.to_string(),
storage: self.storage.clone(),
subgraph_ttl,
private_queries,
private_id_key_name: private_id,
debug: self.debug,
supergraph_schema: self.supergraph_schema.clone(),
subgraph_enums: self.subgraph_enums.clone(),
lru_size_instrument: self.lru_size_instrument.clone(),
});
tower::util::BoxService::new(inner)
} else {
ServiceBuilder::new()
.map_response(move |response: subgraph::Response| {
update_cache_control(
&response.context,
&CacheControl::new(response.response.headers(), subgraph_ttl.into())
.ok()
.unwrap_or_else(CacheControl::no_store),
);
response
})
.service(service)
.boxed()
}
}
fn web_endpoints(&self) -> MultiMap<ListenAddr, Endpoint> {
let mut map = MultiMap::new();
let any_caching_enabled = self
.subgraphs
.subgraphs
.iter()
.any(|(subgraph_name, _cfg)| self.subgraph_enabled(subgraph_name))
|| self.subgraphs.all.enabled.unwrap_or_default();
let global_invalidation_enabled = self
.subgraphs
.all
.invalidation
.as_ref()
.map(|i| i.enabled)
.unwrap_or_default();
let any_subgraph_invalidation_enabled =
self.subgraphs.subgraphs.iter().any(|(subgraph_name, cfg)| {
self.subgraph_enabled(subgraph_name)
&& cfg
.invalidation
.as_ref()
.map(|i| i.enabled)
.unwrap_or_default()
});
if self.enabled
&& any_caching_enabled
&& (global_invalidation_enabled || any_subgraph_invalidation_enabled)
{
match &self.endpoint_config {
Some(endpoint_config) => {
let endpoint = Endpoint::from_router_service(
endpoint_config.path.clone(),
InvalidationService::new(self.subgraphs.clone(), self.invalidation.clone())
.boxed(),
);
tracing::info!(
"Response cache invalidation endpoint listening on: {}{}",
endpoint_config.listen,
endpoint_config.path
);
map.insert(endpoint_config.listen.clone(), endpoint);
}
None => {
tracing::warn!(
"Cannot start response cache invalidation endpoint because the listen address and endpoint is not configured"
);
}
}
}
map
}
}
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
pub(super) const INVALIDATION_SHARED_KEY: &str = "supersecret";
impl ResponseCache {
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
pub(crate) async fn for_test(
storage: Storage,
subgraphs: SubgraphConfiguration<Subgraph>,
supergraph_schema: Arc<Valid<Schema>>,
truncate_namespace: bool,
drop_tx: broadcast::Sender<()>,
) -> Result<Self, BoxError>
where
Self: Sized,
{
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
if truncate_namespace {
storage.truncate_namespace().await?;
}
let storage = Arc::new(StorageInterface {
all: Some(Arc::new(storage.into())),
subgraphs: HashMap::new(),
});
let invalidation = Invalidation::new(storage.clone()).await?;
Ok(Self {
storage,
entity_type: None,
enabled: true,
debug: true,
subgraphs: Arc::new(subgraphs),
private_queries: Arc::new(RwLock::new(LruCache::new(DEFAULT_LRU_PRIVATE_QUERIES_SIZE))),
endpoint_config: Some(Arc::new(InvalidationEndpointConfig {
path: String::from("/invalidation"),
listen: ListenAddr::SocketAddr(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
4000,
)),
})),
invalidation,
subgraph_enums: Arc::new(get_subgraph_enums(&supergraph_schema)),
supergraph_schema,
lru_size_instrument: LruSizeInstrument::new(LRU_PRIVATE_QUERIES_INSTRUMENT_NAME),
drop_tx,
})
}
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
pub(crate) async fn without_storage_for_failure_mode(
subgraphs: HashMap<String, Subgraph>,
supergraph_schema: Arc<Valid<Schema>>,
) -> Result<Self, BoxError>
where
Self: Sized,
{
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
let storage = Arc::new(StorageInterface {
all: Some(Default::default()),
subgraphs: HashMap::new(),
});
let invalidation = Invalidation::new(storage.clone()).await?;
let (drop_tx, _drop_rx) = broadcast::channel(2);
Ok(Self {
storage,
entity_type: None,
enabled: true,
debug: true,
subgraphs: Arc::new(SubgraphConfiguration {
all: Subgraph {
invalidation: Some(SubgraphInvalidationConfig {
enabled: true,
shared_key: INVALIDATION_SHARED_KEY.to_string(),
}),
..Default::default()
},
subgraphs,
}),
private_queries: Arc::new(RwLock::new(LruCache::new(DEFAULT_LRU_PRIVATE_QUERIES_SIZE))),
endpoint_config: Some(Arc::new(InvalidationEndpointConfig {
path: String::from("/invalidation"),
listen: ListenAddr::SocketAddr(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
4000,
)),
})),
invalidation,
subgraph_enums: Arc::new(get_subgraph_enums(&supergraph_schema)),
supergraph_schema,
lru_size_instrument: LruSizeInstrument::new(LRU_PRIVATE_QUERIES_INSTRUMENT_NAME),
drop_tx,
})
}
fn subgraph_enabled(&self, subgraph_name: &str) -> bool {
Self::static_subgraph_enabled(self.enabled, &self.subgraphs, subgraph_name)
}
fn static_subgraph_enabled(
plugin_enabled: bool,
subgraph_config: &SubgraphConfiguration<Subgraph>,
subgraph_name: &str,
) -> bool {
if !plugin_enabled {
return false;
}
match (
subgraph_config.all.enabled,
subgraph_config.get(subgraph_name).enabled,
) {
(_, Some(x)) => x, (Some(true) | None, None) => true, (Some(false), None) => false,
}
}
fn subgraph_ttl(&self, subgraph_name: &str) -> Option<Duration> {
self.subgraphs
.get(subgraph_name)
.ttl
.clone()
.map(|t| t.0)
.or_else(|| self.subgraphs.all.ttl.clone().map(|ttl| ttl.0))
}
}
impl Drop for ResponseCache {
fn drop(&mut self) {
let _ = self.drop_tx.send(());
}
}
fn get_subgraph_enums(supergraph_schema: &Valid<Schema>) -> HashMap<String, String> {
let mut subgraph_enums = HashMap::new();
if let Some(graph_enum) = supergraph_schema.get_enum("join__Graph") {
subgraph_enums.extend(graph_enum.values.iter().filter_map(
|(enum_name, enum_value_def)| {
let subgraph_name = enum_value_def
.directives
.get("join__graph")?
.specified_argument_by_name("name")?
.as_str()?
.to_string();
Some((enum_name.to_string(), subgraph_name))
},
));
}
subgraph_enums
}
#[derive(Clone)]
struct CacheService {
service: subgraph::BoxCloneService,
name: String,
entity_type: Option<String>,
storage: Arc<StorageInterface>,
subgraph_ttl: Duration,
private_queries: Arc<RwLock<LruCache<PrivateQueryKey, ()>>>,
private_id_key_name: Option<String>,
debug: bool,
supergraph_schema: Arc<Valid<Schema>>,
subgraph_enums: Arc<HashMap<String, String>>,
lru_size_instrument: LruSizeInstrument,
}
impl Service<subgraph::Request> for CacheService {
type Response = subgraph::Response;
type Error = BoxError;
type Future = <subgraph::BoxService as Service<subgraph::Request>>::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: subgraph::Request) -> Self::Future {
let clone = self.clone();
let inner = std::mem::replace(self, clone);
Box::pin(inner.call_inner(request))
}
}
impl CacheService {
async fn call_inner(
mut self,
request: subgraph::Request,
) -> Result<subgraph::Response, BoxError> {
let storage = match self
.storage
.get(&self.name)
.ok_or(storage::Error::NoStorage)
{
Ok(storage) => storage.clone(),
Err(err) => {
record_fetch_error(&err, &self.name);
return self
.service
.map_response(move |response: subgraph::Response| {
update_cache_control(
&response.context,
&CacheControl::new(response.response.headers(), None)
.ok()
.unwrap_or_else(CacheControl::no_store),
);
response
})
.call(request)
.await;
}
};
self.debug = self.debug
&& (request
.supergraph_request
.headers()
.get(CACHE_DEBUG_HEADER_NAME)
== Some(&HeaderValue::from_static("true")));
if request.is_part_of_batch() {
return self.service.call(request).await;
}
let cache_control = if request
.subgraph_request
.headers()
.contains_key(&CACHE_CONTROL)
{
let cache_control = match CacheControl::new(request.subgraph_request.headers(), None) {
Ok(cache_control) => cache_control,
Err(err) => {
return Ok(subgraph::Response::builder()
.subgraph_name(request.subgraph_name)
.id(request.id)
.context(request.context)
.error(
graphql::Error::builder()
.message(format!("cannot get cache-control header: {err}"))
.extension_code("INVALID_CACHE_CONTROL_HEADER")
.build(),
)
.extensions(Object::default())
.build());
}
};
if cache_control.is_no_cache() && cache_control.is_no_store() {
let mut resp = self.service.call(request).await?;
cache_control.to_headers(resp.response.headers_mut())?;
return Ok(resp);
}
Some(cache_control)
} else {
None
};
let private_id = self.get_private_id(&request.context);
let private_query_key = PrivateQueryKey {
query_hash: hash_query(&request.query_hash),
has_private_id: private_id.is_some(),
};
let is_known_private = {
self.private_queries
.read()
.await
.contains(&private_query_key)
};
let is_entity = request
.subgraph_request
.body()
.variables
.contains_key(REPRESENTATIONS);
if is_known_private && private_id.is_none() {
self.call_service_for_private_query_without_id(request, is_entity)
.await
} else if is_entity {
self.call_service_for_entities_query(
request,
storage,
is_known_private,
private_id,
private_query_key,
cache_control,
)
.await
} else {
self.call_service_for_root_fields_operation(
request,
storage,
is_known_private,
private_id,
private_query_key,
cache_control,
)
.await
}
}
async fn call_service_for_private_query_without_id(
mut self,
request: subgraph::Request,
is_entity: bool,
) -> Result<subgraph::Response, BoxError> {
let mut debug_subgraph_request = None;
let mut root_operation_fields = Vec::new();
if self.debug {
root_operation_fields = request.root_operation_fields();
debug_subgraph_request = Some(request.subgraph_request.body().clone());
}
let resp = self.service.call(request).await?;
if self.debug {
let cache_control =
CacheControl::new(resp.response.headers(), self.subgraph_ttl.into())?;
let kind = if is_entity {
CacheEntryKind::Entity {
typename: "".to_string(),
entity_key: Default::default(),
}
} else {
CacheEntryKind::RootFields {
root_fields: root_operation_fields,
}
};
let cache_key_context = CacheKeyContext {
key: "-".to_string(),
invalidation_keys: vec![],
kind,
hashed_private_id: None,
subgraph_name: self.name.clone(),
subgraph_request: debug_subgraph_request.unwrap_or_default(),
source: CacheKeySource::Subgraph,
cache_control,
data: serde_json_bytes::to_value(resp.response.body().clone()).unwrap_or_default(),
warnings: Vec::new(),
should_store: false,
}
.update_metadata();
add_cache_key_to_context(&resp.context, cache_key_context)?;
}
Ok(resp)
}
async fn call_service_for_root_fields_operation(
mut self,
request: subgraph::Request,
storage: Storage,
is_known_private: bool,
private_id: Option<String>,
private_query_key: PrivateQueryKey,
request_cache_control: Option<CacheControl>,
) -> Result<subgraph::Response, BoxError> {
if request.operation_kind != OperationKind::Query {
return self.service.call(request).await;
}
let mut cache_hit: HashMap<String, CacheHitMiss> = HashMap::new();
match cache_lookup_root(
self.name.clone(),
self.entity_type.as_deref(),
storage.clone(),
is_known_private,
private_id.as_deref(),
self.debug,
request,
self.supergraph_schema.clone(),
&self.subgraph_enums,
request_cache_control.as_ref(),
)
.instrument(tracing::info_span!(
"response_cache.lookup",
kind = "root",
subgraph.name = self.name.clone(),
"graphql.type" = self.entity_type.as_deref().unwrap_or_default(),
debug = self.debug,
private = is_known_private,
contains_private_id = private_id.is_some(),
"cache.key" = ::tracing::field::Empty,
))
.await?
{
ControlFlow::Break(response) => {
cache_hit.insert("Query".to_string(), CacheHitMiss { hit: 1, miss: 0 });
let _ = response.context.insert(
CacheMetricContextKey::new(response.subgraph_name.clone()),
CacheSubgraph(cache_hit),
);
Ok(response)
}
ControlFlow::Continue((request, mut root_cache_key, mut invalidation_keys)) => {
cache_hit.insert("Query".to_string(), CacheHitMiss { hit: 0, miss: 1 });
let _ = request.context.insert(
CacheMetricContextKey::new(request.subgraph_name.clone()),
CacheSubgraph(cache_hit),
);
let mut root_operation_fields: Vec<String> = Vec::new();
let mut debug_subgraph_request = None;
if self.debug {
root_operation_fields = request.root_operation_fields();
debug_subgraph_request = Some(request.subgraph_request.body().clone());
}
let response = self.service.call(request).await?;
let mut cache_control =
response.subgraph_cache_control(self.subgraph_ttl.into())?;
if let Some(Value::Array(cache_tags)) =
response.get_from_extensions(GRAPHQL_RESPONSE_EXTENSION_ROOT_FIELDS_CACHE_TAGS)
{
invalidation_keys.extend(
cache_tags
.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_owned()),
);
}
save_original_cache_control(
response.id.clone(),
&response.context,
cache_control.clone(),
);
if cache_control.private() {
if !is_known_private {
let size = {
let mut private_queries = self.private_queries.write().await;
private_queries.put(private_query_key.clone(), ());
private_queries.len()
};
self.lru_size_instrument.update(size as u64);
if let Some(s) = private_id.as_ref() {
root_cache_key = format!("{root_cache_key}:{s}");
}
}
}
if let Some(request_cache_control) = request_cache_control {
cache_control.no_store |= request_cache_control.no_store;
}
if self.debug {
let cache_key_context = CacheKeyContext {
key: root_cache_key.clone(),
hashed_private_id: private_id.clone(),
invalidation_keys: external_invalidation_keys(invalidation_keys.clone()),
kind: CacheEntryKind::RootFields {
root_fields: root_operation_fields,
},
subgraph_name: self.name.clone(),
subgraph_request: debug_subgraph_request.unwrap_or_default(),
source: CacheKeySource::Subgraph,
cache_control: cache_control.clone(),
data: serde_json_bytes::to_value(response.response.body().clone())
.unwrap_or_default(),
warnings: Vec::new(),
should_store: true,
}
.update_metadata();
add_cache_key_to_context(&response.context, cache_key_context)?;
}
let unstorable_private_response = cache_control.private() && private_id.is_none();
if !unstorable_private_response && cache_control.should_store() {
cache_store_root_from_response(
storage,
self.subgraph_ttl,
&response,
cache_control,
root_cache_key,
invalidation_keys,
self.debug,
)
.await?;
}
Ok(response)
}
}
}
async fn call_service_for_entities_query(
mut self,
request: subgraph::Request,
storage: Storage,
is_known_private: bool,
private_id: Option<String>,
private_query_key: PrivateQueryKey,
request_cache_control: Option<CacheControl>,
) -> Result<subgraph::Response, BoxError> {
match cache_lookup_entities(
self.name.clone(),
self.supergraph_schema.clone(),
&self.subgraph_enums,
storage.clone(),
is_known_private,
private_id.as_deref(),
request,
self.debug,
request_cache_control.as_ref(),
)
.instrument(tracing::info_span!(
"response_cache.lookup",
kind = "entity",
subgraph.name = self.name.clone(),
debug = self.debug,
private = is_known_private,
contains_private_id = private_id.is_some()
))
.await?
{
ControlFlow::Break(response) => Ok(response),
ControlFlow::Continue((request, mut cache_result)) => {
let context = request.context.clone();
let mut debug_subgraph_request = None;
if self.debug {
debug_subgraph_request = Some(request.subgraph_request.body().clone());
let debug_cache_keys_ctx = cache_result.0.iter().filter_map(|ir| {
ir.cache_entry.as_ref().map(|cache_entry| CacheKeyContext {
hashed_private_id: private_id.clone(),
key: cache_entry.key.clone(),
invalidation_keys: external_invalidation_keys(ir.invalidation_keys.clone()),
kind: CacheEntryKind::Entity {
typename: ir.typename.clone(),
entity_key: ir.entity_key.clone().unwrap_or_default(),
},
subgraph_name: self.name.clone(),
subgraph_request: request.subgraph_request.body().clone(),
source: CacheKeySource::Cache,
cache_control: cache_entry.control.clone(),
data: serde_json_bytes::json!({
"data": serde_json_bytes::to_value(cache_entry.data.clone()).unwrap_or_default()
}),
warnings: Vec::new(),
should_store: false,
}.update_metadata())
});
add_cache_keys_to_context(&request.context, debug_cache_keys_ctx)?;
}
let req_id = request.id.clone();
let mut response = match self.service.call(request).await {
Ok(response) => response,
Err(e) => {
let e = match e.downcast::<FetchError>() {
Ok(inner) => match *inner {
FetchError::SubrequestHttpError { .. } => *inner,
_ => FetchError::SubrequestHttpError {
status_code: None,
service: self.name.to_string(),
reason: inner.to_string(),
},
},
Err(e) => FetchError::SubrequestHttpError {
status_code: None,
service: self.name.to_string(),
reason: e.to_string(),
},
};
let graphql_error = e.to_graphql_error(None);
let (new_entities, new_errors) =
assemble_response_from_errors(&[graphql_error], &mut cache_result.0);
let mut data = Object::default();
data.insert(ENTITIES, new_entities.into());
let mut response = subgraph::Response::builder()
.context(context)
.data(Value::Object(data))
.id(req_id)
.errors(new_errors)
.subgraph_name(self.name)
.extensions(Object::new())
.build();
CacheControl::no_store().to_headers(response.response.headers_mut())?;
return Ok(response);
}
};
let mut cache_control =
response.subgraph_cache_control(self.subgraph_ttl.into())?;
save_original_cache_control(
response.id.clone(),
&response.context,
cache_control.clone(),
);
if let Some(control_from_cached) = cache_result.1 {
cache_control = cache_control.merge(&control_from_cached);
}
if let Some(request_cache_control) = request_cache_control {
cache_control.no_store |= request_cache_control.no_store;
}
if !is_known_private && cache_control.private() {
self.private_queries
.write()
.await
.put(private_query_key, ());
}
cache_store_entities_from_response(
storage,
self.subgraph_ttl,
&mut response,
cache_control.clone(),
cache_result.0,
is_known_private,
private_id,
debug_subgraph_request,
)
.await?;
cache_control.to_headers(response.response.headers_mut())?;
Ok(response)
}
}
}
fn get_private_id(&self, context: &Context) -> Option<String> {
let private_id_value = context.get_json_value(self.private_id_key_name.as_ref()?)?;
let private_id = private_id_value.as_str()?;
let mut digest = blake3::Hasher::new();
digest.update(private_id.as_bytes());
Some(digest.finalize().to_hex().to_string())
}
}
#[allow(clippy::too_many_arguments)]
async fn cache_lookup_root(
name: String,
entity_type_opt: Option<&str>,
cache: Storage,
is_known_private: bool,
private_id: Option<&str>,
debug: bool,
mut request: subgraph::Request,
supergraph_schema: Arc<Valid<Schema>>,
subgraph_enums: &HashMap<String, String>,
cache_control: Option<&CacheControl>,
) -> Result<ControlFlow<subgraph::Response, (subgraph::Request, String, Vec<String>)>, BoxError> {
let invalidation_cache_keys =
get_invalidation_root_keys_from_schema(&request, subgraph_enums, supergraph_schema)?;
let body = request.subgraph_request.body_mut();
body.variables.sort_keys();
let (key, mut invalidation_keys) = extract_cache_key_root(
&name,
entity_type_opt,
&request.query_hash,
body,
&request.context,
&request.authorization,
is_known_private,
private_id,
);
invalidation_keys.extend(invalidation_cache_keys);
Span::current().record("cache.key", key.clone());
if cache_control.is_some_and(|c| c.is_no_cache()) {
return Ok(ControlFlow::Continue((request, key, invalidation_keys)));
}
match cache.fetch(&key, &request.subgraph_name).await {
Ok(value) => {
if value.control.can_use() {
let control = value.control.clone();
save_original_cache_control(request.id.clone(), &request.context, control.clone());
update_cache_control(&request.context, &control);
if debug {
let root_operation_fields: Vec<String> = request
.executable_document
.as_ref()
.and_then(|executable_document| {
Some(
executable_document
.operations
.iter()
.next()?
.root_fields(executable_document)
.map(|f| f.name.to_string())
.collect(),
)
})
.unwrap_or_default();
let cache_key_context = CacheKeyContext {
key: value.key.clone(),
hashed_private_id: private_id.map(ToString::to_string),
invalidation_keys: value
.cache_tags
.map(external_invalidation_keys)
.unwrap_or_default(),
kind: CacheEntryKind::RootFields {
root_fields: root_operation_fields,
},
subgraph_name: request.subgraph_name.clone(),
subgraph_request: request.subgraph_request.body().clone(),
source: CacheKeySource::Cache,
cache_control: value.control.clone(),
data: serde_json_bytes::json!({"data": value.data.clone()}),
warnings: Vec::new(),
should_store: false,
}
.update_metadata();
add_cache_key_to_context(&request.context, cache_key_context)?;
}
Span::current().set_span_dyn_attribute(
opentelemetry::Key::new("cache.status"),
opentelemetry::Value::String("hit".into()),
);
let mut response = subgraph::Response::builder()
.data(value.data)
.extensions(Object::new())
.id(request.id)
.context(request.context)
.subgraph_name(request.subgraph_name.clone())
.build();
value.control.to_headers(response.response.headers_mut())?;
Ok(ControlFlow::Break(response))
} else {
Span::current().set_span_dyn_attribute(
opentelemetry::Key::new("cache.status"),
opentelemetry::Value::String("miss".into()),
);
Ok(ControlFlow::Continue((request, key, invalidation_keys)))
}
}
Err(err) => {
let span = Span::current();
if !err.is_row_not_found() {
span.mark_as_error(format!("cannot get cache entry: {err}"));
}
span.set_span_dyn_attribute(
opentelemetry::Key::new("cache.status"),
opentelemetry::Value::String("miss".into()),
);
Ok(ControlFlow::Continue((request, key, invalidation_keys)))
}
}
}
fn get_invalidation_root_keys_from_schema(
request: &subgraph::Request,
subgraph_enums: &HashMap<String, String>,
supergraph_schema: Arc<Valid<Schema>>,
) -> Result<HashSet<String>, anyhow::Error> {
struct Root<'a> {
subgraph_name: &'a str,
subgraph_enums: &'a HashMap<String, String>,
query_object_type: &'a ObjectType,
result: RefCell<Result<HashSet<String>, anyhow::Error>>,
}
impl resolvers::ObjectValue for Root<'_> {
fn type_name(&self) -> &str {
"Query"
}
fn resolve_field<'a>(
&'a self,
info: &'a resolvers::ResolveInfo<'a>,
) -> Result<resolvers::ResolvedValue<'a>, resolvers::FieldError> {
let mut result = self.result.borrow_mut();
let Ok(keys) = &mut *result else {
return Ok(resolvers::ResolvedValue::SkipForPartialExecution);
};
let Some(field_def) = self.query_object_type.fields.get(info.field_name()) else {
*result = Err(FetchError::MalformedRequest {
reason: "cannot get the field definition from supergraph schema".to_string(),
}
.into());
return Ok(resolvers::ResolvedValue::SkipForPartialExecution);
};
let templates = field_def
.directives
.get_all("join__directive")
.filter_map(|dir| {
let name = dir.argument_by_name("name", info.schema()).ok()?;
if name.as_str()? != CACHE_TAG_DIRECTIVE_NAME {
return None;
}
let is_current_subgraph =
dir.argument_by_name("graphs", info.schema())
.ok()
.and_then(|f| {
Some(f.as_list()?.iter().filter_map(|graph| graph.as_enum()).any(
|g| {
self.subgraph_enums.get(g.as_str()).map(|s| s.as_str())
== Some(self.subgraph_name)
},
))
})
.unwrap_or_default();
if !is_current_subgraph {
return None;
}
let mut format = None;
for (field_name, value) in dir
.argument_by_name("args", info.schema())
.ok()?
.as_object()?
{
if field_name.as_str() == "format" {
format = value
.as_str()
.and_then(|v| v.parse::<StringTemplate>().ok())
}
}
format
});
let mut vars = IndexMap::default();
vars.insert("$args".to_string(), Value::Object(info.arguments().clone()));
for template in templates {
match template.interpolate(&vars) {
Ok((key, _)) => {
keys.insert(key);
}
Err(e) => {
*result = Err(e.into());
break;
}
}
}
Ok(resolvers::ResolvedValue::SkipForPartialExecution)
}
}
let executable_document =
request
.executable_document
.as_ref()
.ok_or_else(|| FetchError::MalformedRequest {
reason: "cannot get the executable document for subgraph request".to_string(),
})?;
let root_query_type = supergraph_schema
.root_operation(apollo_compiler::ast::OperationType::Query)
.ok_or_else(|| FetchError::MalformedRequest {
reason: "cannot get the root operation from supergraph schema".to_string(),
})?;
let query_object_type = supergraph_schema
.get_object(root_query_type.as_str())
.ok_or_else(|| FetchError::MalformedRequest {
reason: "cannot get the root query type from supergraph schema".to_string(),
})?;
let root = Root {
subgraph_name: &request.subgraph_name,
subgraph_enums,
query_object_type,
result: RefCell::new(Ok(HashSet::new())),
};
let subgraph_request = request.subgraph_request.body();
resolvers::Execution::new(&supergraph_schema, executable_document)
.operation_name(subgraph_request.operation_name.as_deref())
.unwrap()
.raw_variable_values(&subgraph_request.variables)
.execute_sync(&root)
.map_err(|e| anyhow::Error::msg(e.message().to_string()))?;
root.result.into_inner()
}
#[derive(Default)]
struct ResponseCacheResults(Vec<IntermediateResult>, Option<CacheControl>);
#[allow(clippy::too_many_arguments)]
async fn cache_lookup_entities(
name: String,
supergraph_schema: Arc<Valid<Schema>>,
subgraph_enums: &HashMap<String, String>,
cache: Storage,
is_known_private: bool,
private_id: Option<&str>,
mut request: subgraph::Request,
debug: bool,
cache_control: Option<&CacheControl>,
) -> Result<ControlFlow<subgraph::Response, (subgraph::Request, ResponseCacheResults)>, BoxError> {
let is_no_cache = cache_control.is_some_and(|c| c.is_no_cache());
let cache_metadata = extract_cache_keys(
&name,
supergraph_schema,
subgraph_enums,
&mut request,
is_known_private,
private_id,
debug,
)?;
let keys_len = cache_metadata.len();
let cache_keys = cache_metadata
.iter()
.map(|k| k.cache_key.as_str())
.collect::<Vec<&str>>();
Span::current().set_span_dyn_attribute(
"cache.keys".into(),
opentelemetry::Value::Array(Array::String(
cache_keys
.iter()
.map(|ck| StringValue::from(ck.to_string()))
.collect(),
)),
);
let cache_result: Vec<Option<CacheEntry>> = if is_no_cache {
vec![None; keys_len]
} else {
match cache.fetch_multiple(&cache_keys, &name).await {
Ok(res) => res
.into_iter()
.map(|v| match v {
Some(v) if v.control.can_use() => Some(v),
_ => None,
})
.collect(),
Err(err) => {
if !err.is_row_not_found() {
let span = Span::current();
span.mark_as_error(format!("cannot get cache entry: {err}"));
}
vec![None; keys_len]
}
}
};
let body = request.subgraph_request.body_mut();
let representations = body
.variables
.get_mut(REPRESENTATIONS)
.and_then(|value| value.as_array_mut())
.expect("we already checked that representations exist");
let (new_representations, cache_result, cache_control) = filter_representations(
&name,
&request.id,
representations,
cache_metadata,
cache_result,
&request.context,
!is_no_cache,
)?;
if !new_representations.is_empty() {
body.variables
.insert(REPRESENTATIONS, new_representations.into());
let cache_status = if cache_result.is_empty() {
opentelemetry::Value::String("miss".into())
} else {
opentelemetry::Value::String("partial_hit".into())
};
Span::current()
.set_span_dyn_attribute(opentelemetry::Key::new("cache.status"), cache_status);
Ok(ControlFlow::Continue((
request,
ResponseCacheResults(cache_result, cache_control),
)))
} else {
if debug {
let debug_cache_keys_ctx = cache_result.iter().filter_map(|ir| {
ir.cache_entry.as_ref().map(|cache_entry| {
CacheKeyContext {
key: ir.key.clone(),
hashed_private_id: private_id.map(ToString::to_string),
invalidation_keys: cache_entry
.cache_tags
.clone()
.map(external_invalidation_keys)
.unwrap_or_default(),
kind: CacheEntryKind::Entity {
typename: ir.typename.clone(),
entity_key: ir.entity_key.clone().unwrap_or_default(),
},
subgraph_name: name.clone(),
subgraph_request: request.subgraph_request.body().clone(),
source: CacheKeySource::Cache,
cache_control: cache_entry.control.clone(),
data: serde_json_bytes::json!({"data": cache_entry.data.clone()}),
warnings: Vec::new(),
should_store: false,
}
.update_metadata()
})
});
add_cache_keys_to_context(&request.context, debug_cache_keys_ctx)?;
}
Span::current().set_span_dyn_attribute(
opentelemetry::Key::new("cache.status"),
opentelemetry::Value::String("hit".into()),
);
let entities = cache_result
.into_iter()
.filter_map(|res| res.cache_entry)
.map(|entry| entry.data)
.collect::<Vec<_>>();
let mut data = Object::default();
data.insert(ENTITIES, entities.into());
let mut response = subgraph::Response::builder()
.data(data)
.id(request.id.clone())
.extensions(Object::new())
.subgraph_name(request.subgraph_name)
.context(request.context)
.build();
cache_control
.unwrap_or_default()
.to_headers(response.response.headers_mut())?;
Ok(ControlFlow::Break(response))
}
}
fn update_cache_control(context: &Context, cache_control: &CacheControl) {
context.extensions().with_lock(|lock| {
if let Some(c) = lock.get_mut::<CacheControl>() {
*c = c.merge(cache_control);
} else {
let new_cache_control = cache_control.merge(cache_control);
lock.insert(new_cache_control);
}
})
}
fn save_original_cache_control(
req_id: SubgraphRequestId,
context: &Context,
cache_control: CacheControl,
) {
context.extensions().with_lock(|l| {
l.get_or_default_mut::<CacheControls>()
.insert(req_id, cache_control)
});
}
async fn cache_store_root_from_response(
cache: Storage,
default_subgraph_ttl: Duration,
response: &subgraph::Response,
cache_control: CacheControl,
cache_key: String,
invalidation_keys: Vec<String>,
debug: bool,
) -> Result<(), BoxError> {
if let Some(data) = response.response.body().data.as_ref() {
let ttl = cache_control
.ttl()
.map(Duration::from_secs)
.unwrap_or(default_subgraph_ttl);
if response.response.body().errors.is_empty() && cache_control.should_store() {
let document = Document {
key: cache_key,
data: data.clone(),
control: cache_control,
invalidation_keys,
expire: ttl,
debug,
};
let subgraph_name = response.subgraph_name.clone();
let span = tracing::info_span!("response_cache.store", "kind" = "root", "subgraph.name" = subgraph_name.clone(), "ttl" = ?ttl);
tokio::spawn(async move {
let _ = cache
.insert(document, &subgraph_name)
.instrument(span)
.await;
});
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn cache_store_entities_from_response(
cache: Storage,
default_subgraph_ttl: Duration,
response: &mut subgraph::Response,
cache_control: CacheControl,
mut result_from_cache: Vec<IntermediateResult>,
is_known_private: bool,
private_id: Option<String>,
subgraph_request: Option<graphql::Request>,
) -> Result<(), BoxError> {
let mut data = response.response.body_mut().data.take();
if let Some(mut entities) = data
.as_mut()
.and_then(|v| v.as_object_mut())
.and_then(|o| o.remove(ENTITIES))
{
let should_cache_private = !cache_control.private() || private_id.is_some();
let update_key_private = if !is_known_private && cache_control.private() {
private_id.clone()
} else {
None
};
let per_entity_surrogate_keys = response
.response
.body()
.extensions
.get(GRAPHQL_RESPONSE_EXTENSION_ENTITY_CACHE_TAGS)
.and_then(|value| value.as_array())
.map(|vec| vec.as_slice())
.unwrap_or_default();
let (new_entities, new_errors) = insert_entities_in_result(
entities
.as_array_mut()
.ok_or_else(|| FetchError::MalformedResponse {
reason: "expected an array of entities".to_string(),
})?,
&response.response.body().errors,
cache,
default_subgraph_ttl,
cache_control,
&mut result_from_cache,
private_id,
update_key_private,
should_cache_private,
&response.subgraph_name,
per_entity_surrogate_keys,
response.context.clone(),
subgraph_request,
)
.await?;
data.as_mut()
.and_then(|v| v.as_object_mut())
.map(|o| o.insert(ENTITIES, new_entities.into()));
response.response.body_mut().data = data;
response.response.body_mut().errors = new_errors;
} else {
let (new_entities, new_errors) =
assemble_response_from_errors(&response.response.body().errors, &mut result_from_cache);
let mut data = Object::default();
data.insert(ENTITIES, new_entities.into());
response.response.body_mut().data = Some(Value::Object(data));
response.response.body_mut().errors = new_errors;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn extract_cache_key_root(
subgraph_name: &str,
entity_type_opt: Option<&str>,
query_hash: &QueryHash,
body: &graphql::Request,
context: &Context,
cache_key: &CacheKeyMetadata,
is_known_private: bool,
private_id: Option<&str>,
) -> (String, Vec<String>) {
let entity_type = entity_type_opt.unwrap_or("Query");
let key = PrimaryCacheKeyRoot {
subgraph_name,
graphql_type: entity_type,
subgraph_query_hash: query_hash,
body,
context,
auth_cache_key_metadata: cache_key,
private_id: if is_known_private { private_id } else { None },
}
.hash();
let invalidation_keys = vec![format!(
"{INTERNAL_CACHE_TAG_PREFIX}version:{RESPONSE_CACHE_VERSION}:subgraph:{subgraph_name}:type:{entity_type}"
)];
(key, invalidation_keys)
}
struct CacheMetadata {
cache_key: String,
invalidation_keys: Vec<String>,
entity_key: Option<serde_json_bytes::Map<ByteString, Value>>,
}
#[allow(clippy::too_many_arguments)]
fn extract_cache_keys(
subgraph_name: &str,
supergraph_schema: Arc<Valid<Schema>>,
subgraph_enums: &HashMap<String, String>,
request: &mut subgraph::Request,
is_known_private: bool,
private_id: Option<&str>,
debug: bool,
) -> Result<Vec<CacheMetadata>, BoxError> {
let context = &request.context;
let authorization = &request.authorization;
let query_hash = hash_query(&request.query_hash);
let additional_data_hash = hash_additional_data(
subgraph_name,
request.subgraph_request.body_mut(),
context,
authorization,
);
let representations = request
.subgraph_request
.body_mut()
.variables
.get_mut(REPRESENTATIONS)
.and_then(|value| value.as_array_mut())
.expect("we already checked that representations exist");
let mut res = Vec::with_capacity(representations.len());
let entities = representations.len() as u64;
let mut typenames = HashSet::new();
for representation in representations {
let representation =
representation
.as_object_mut()
.ok_or_else(|| FetchError::MalformedRequest {
reason: "representation variable should be an array of object".to_string(),
})?;
let typename_value =
representation
.remove(TYPENAME)
.ok_or_else(|| FetchError::MalformedRequest {
reason: "missing __typename in representation".to_string(),
})?;
let typename = typename_value
.as_str()
.ok_or_else(|| FetchError::MalformedRequest {
reason: "__typename in representation is not a string".to_string(),
})?;
typenames.insert(typename.to_string());
let representation_entity_key = if debug {
let selection_set = find_matching_key_field_set(
representation,
typename,
subgraph_name,
&supergraph_schema,
subgraph_enums,
)?;
get_entity_key_from_selection_set(representation, &selection_set).into()
} else {
None
};
let key = PrimaryCacheKeyEntity {
subgraph_name,
entity_type: typename,
representation,
subgraph_query_hash: &query_hash,
additional_data_hash: &additional_data_hash,
private_id: if is_known_private { private_id } else { None },
}
.hash();
let mut invalidation_keys = vec![format!(
"{INTERNAL_CACHE_TAG_PREFIX}version:{RESPONSE_CACHE_VERSION}:subgraph:{subgraph_name}:type:{typename}"
)];
let invalidation_cache_keys = get_invalidation_entity_keys_from_schema(
&supergraph_schema,
subgraph_name,
subgraph_enums,
typename,
representation,
)?;
representation.insert(TYPENAME, typename_value);
invalidation_keys.extend(invalidation_cache_keys);
let cache_key_metadata = CacheMetadata {
cache_key: key,
invalidation_keys,
entity_key: representation_entity_key,
};
res.push(cache_key_metadata);
}
Span::current().set_span_dyn_attribute(
Key::from_static_str("graphql.types"),
opentelemetry::Value::Array(
typenames
.into_iter()
.map(StringValue::from)
.collect::<Vec<StringValue>>()
.into(),
),
);
u64_histogram_with_unit!(
"apollo.router.operations.response_cache.fetch.entity",
"Number of entities per subgraph fetch node",
"{entity}",
entities,
"subgraph.name" = subgraph_name.to_string()
);
Ok(res)
}
fn get_invalidation_entity_keys_from_schema(
supergraph_schema: &Arc<Valid<Schema>>,
subgraph_name: &str,
subgraph_enums: &HashMap<String, String>,
typename: &str,
representations: &serde_json_bytes::Map<ByteString, Value>,
) -> Result<HashSet<String>, anyhow::Error> {
let filter_dir = |dir: &apollo_compiler::ast::Directive| {
let Ok(name) = dir.argument_by_name("name", supergraph_schema) else {
return false;
};
let Some(name) = name.as_str() else {
return false;
};
if *name != *CACHE_TAG_DIRECTIVE_NAME {
return false;
}
dir.argument_by_name("graphs", supergraph_schema)
.ok()
.and_then(|f| {
Some(
f.as_list()?
.iter()
.filter_map(|graph| graph.as_enum())
.any(|g| {
subgraph_enums.get(g.as_str()).map(|s| s.as_str())
== Some(subgraph_name)
}),
)
})
.unwrap_or_default()
};
let all_directives: Vec<_> = match supergraph_schema.get_interface(typename) {
Some(iface_type) => {
iface_type
.directives
.get_all("join__directive")
.filter(|dir| filter_dir(dir))
.cloned()
.collect()
}
None => {
let obj_type = supergraph_schema.get_object(typename).ok_or_else(|| {
FetchError::MalformedRequest {
reason: format!("can't find corresponding type for __typename {typename:?}"),
}
})?;
let obj_directives: Vec<_> = obj_type
.directives
.get_all("join__directive")
.filter(|dir| filter_dir(dir))
.cloned()
.collect();
let iface_directives: Vec<_> = obj_type
.implements_interfaces
.iter()
.flat_map(|iface_name| {
supergraph_schema
.get_interface(iface_name)
.iter()
.flat_map(|iface| iface.directives.get_all("join__directive").cloned())
.collect::<Vec<_>>()
})
.filter(|dir| filter_dir(dir))
.collect();
obj_directives.into_iter().chain(iface_directives).collect()
}
};
let cache_keys = all_directives.into_iter().filter_map(|dir| {
dir.argument_by_name("args", supergraph_schema)
.ok()?
.as_object()?
.iter()
.find_map(|(field_name, value)| {
if field_name.as_str() == "format" {
value.as_str()?.parse::<StringTemplate>().ok()
} else {
None
}
})
});
let mut vars = IndexMap::default();
vars.insert("$key".to_string(), Value::Object(representations.clone()));
let invalidation_cache_keys = cache_keys
.map(|ck| ck.interpolate(&vars).map(|(res, _)| res))
.collect::<Result<_, _>>()?;
Ok(invalidation_cache_keys)
}
pub(in crate::plugins) fn find_matching_key_field_set(
representation: &serde_json_bytes::Map<ByteString, Value>,
typename: &str,
subgraph_name: &str,
supergraph_schema: &Valid<Schema>,
subgraph_enums: &HashMap<String, String>,
) -> Result<apollo_compiler::executable::SelectionSet, FetchError> {
collect_key_field_sets(typename, subgraph_name, supergraph_schema, subgraph_enums)?
.find(|field_set| {
matches_selection_set(representation, &field_set.selection_set)
})
.map(|field_set| field_set.selection_set)
.ok_or_else(|| {
tracing::trace!("representation does not match any key field set for typename {typename} in subgraph {subgraph_name}");
FetchError::MalformedRequest {
reason: format!("unexpected critical internal error for typename {typename} in subgraph {subgraph_name}"),
}
})
}
fn collect_key_field_sets(
typename: &str,
subgraph_name: &str,
supergraph_schema: &Valid<Schema>,
subgraph_enums: &HashMap<String, String>,
) -> Result<impl Iterator<Item = apollo_compiler::executable::FieldSet>, FetchError> {
Ok(supergraph_schema
.types
.get(typename)
.ok_or_else(|| FetchError::MalformedRequest {
reason: format!("unknown typename {typename:?} in representations"),
})?
.directives()
.get_all("join__type")
.filter_map(move |directive| {
let schema_subgraph_name = directive
.specified_argument_by_name("graph")
.and_then(|arg| arg.as_enum())
.and_then(|arg| subgraph_enums.get(arg.as_str()))?;
if schema_subgraph_name == subgraph_name {
let mut parser = Parser::new();
directive
.specified_argument_by_name("key")
.and_then(|arg| arg.as_str())
.and_then(|arg| {
parser
.parse_field_set(
supergraph_schema,
NamedType::new(typename).ok()?,
arg,
"entity_caching.graphql",
)
.ok()
})
} else {
None
}
}))
}
pub(in crate::plugins) fn matches_selection_set(
representation: &serde_json_bytes::Map<ByteString, Value>,
selection_set: &apollo_compiler::executable::SelectionSet,
) -> bool {
for field in selection_set.root_fields(&Default::default()) {
let Some(value) = representation.get(field.name.as_str()) else {
if field.definition.ty.is_non_null() {
return false;
} else {
continue;
}
};
if field.selection_set.is_empty() {
if !is_scalar_or_array_of_scalar(value) {
return false;
}
continue;
}
let result = match value {
Value::Object(obj) => {
matches_selection_set(obj, &field.selection_set)
}
Value::Array(arr) => {
let list_item_is_nullable = !field.definition.ty.item_type().is_non_null();
let exclude_value = |value: &&Value| list_item_is_nullable && value.is_null();
let arr = arr.iter().filter(|value| !exclude_value(value));
matches_array_of_objects(arr, &field.selection_set)
}
Value::Null => {
return true;
}
_other => {
false
}
};
if !result {
return false;
}
}
true
}
fn is_scalar_or_array_of_scalar(value: &Value) -> bool {
match value {
Value::Object(_) => false,
Value::Array(arr) => arr.iter().all(is_scalar_or_array_of_scalar),
_ => true,
}
}
fn matches_array_of_objects<'a, I: Iterator<Item = &'a Value>>(
arr: I,
selection_set: &apollo_compiler::executable::SelectionSet,
) -> bool {
for item in arr {
let result = match item {
Value::Object(obj) => matches_selection_set(obj, selection_set),
Value::Array(arr) => matches_array_of_objects(arr.iter(), selection_set),
_other => false,
};
if !result {
return false;
}
}
true
}
fn get_entity_key_from_selection_set(
representation: &serde_json_bytes::Map<ByteString, Value>,
selection_set: &apollo_compiler::executable::SelectionSet,
) -> serde_json_bytes::Map<ByteString, Value> {
fn traverse_object(
state: &mut serde_json_bytes::Map<ByteString, Value>,
fields: &serde_json_bytes::Map<ByteString, Value>,
selection_set: &apollo_compiler::executable::SelectionSet,
) {
let default_document = Default::default();
let sorted_selections = selection_set
.root_fields(&default_document)
.sorted_by(|a, b| a.name.cmp(&b.name));
for field in sorted_selections {
let key = field.name.as_str();
let Some(val) = fields.get(key) else {
continue;
};
match val {
serde_json_bytes::Value::Object(obj) => {
let mut obj_state = serde_json_bytes::Map::new();
traverse_object(&mut obj_state, obj, &field.selection_set);
state.insert(ByteString::from(key), Value::Object(obj_state));
}
Value::Array(arr) => {
let mut arr_state = Vec::new();
traverse_array(&mut arr_state, arr, &field.selection_set);
state.insert(ByteString::from(key), Value::Array(arr_state));
}
val => {
state.insert(ByteString::from(key), val.clone());
}
}
}
}
fn traverse_array(
state: &mut Vec<Value>,
items: &[Value],
selection_set: &apollo_compiler::executable::SelectionSet,
) {
items.iter().for_each(|v| {
match v {
serde_json_bytes::Value::Object(obj) => {
let mut obj_state = serde_json_bytes::Map::new();
traverse_object(&mut obj_state, obj, selection_set);
state.push(Value::Object(obj_state));
}
Value::Array(arr) => {
let mut arr_state = Vec::new();
traverse_array(&mut arr_state, arr, selection_set);
state.push(Value::Array(arr_state));
}
_ => {
state.push(v.clone());
}
}
});
}
let mut state = serde_json_bytes::Map::new();
traverse_object(&mut state, representation, selection_set);
state
}
struct IntermediateResult {
key: String,
invalidation_keys: Vec<String>,
typename: String,
entity_key: Option<serde_json_bytes::Map<ByteString, Value>>,
cache_entry: Option<CacheEntry>,
}
#[allow(clippy::type_complexity)]
fn filter_representations(
subgraph_name: &str,
subgraph_req_id: &SubgraphRequestId,
representations: &mut Vec<Value>,
keys: Vec<CacheMetadata>,
mut cache_result: Vec<Option<CacheEntry>>,
context: &Context,
record_metrics: bool,
) -> Result<(Vec<Value>, Vec<IntermediateResult>, Option<CacheControl>), BoxError> {
let mut new_representations: Vec<Value> = Vec::new();
let mut result = Vec::new();
let mut cache_hit: HashMap<String, CacheHitMiss> = HashMap::new();
let mut cache_control = None;
let mut non_updated_cache_control = None;
for (
(
mut representation,
CacheMetadata {
cache_key: key,
invalidation_keys,
entity_key,
..
},
),
mut cache_entry,
) in representations
.drain(..)
.zip(keys)
.zip(cache_result.drain(..))
{
let opt_type = representation
.as_object_mut()
.and_then(|o| o.remove(TYPENAME))
.ok_or_else(|| FetchError::MalformedRequest {
reason: "missing __typename in representation".to_string(),
})?;
let typename = opt_type.as_str().unwrap_or("-").to_string();
if let Some(false) = cache_entry.as_ref().map(|c| c.control.can_use()) {
cache_entry = None;
}
match cache_entry.as_ref() {
None => {
cache_hit.entry(typename.clone()).or_default().miss += 1;
representation
.as_object_mut()
.map(|o| o.insert(TYPENAME, opt_type));
new_representations.push(representation);
}
Some(entry) => {
cache_hit.entry(typename.clone()).or_default().hit += 1;
match cache_control.as_mut() {
None => cache_control = Some(entry.control.clone()),
Some(c) => *c = c.merge(&entry.control),
}
match non_updated_cache_control.as_mut() {
None => non_updated_cache_control = Some(entry.control.clone()),
Some(c) => *c = c.merge_without_update(&entry.control),
}
}
}
result.push(IntermediateResult {
key,
invalidation_keys,
typename,
cache_entry,
entity_key,
});
}
if let Some(non_updated_cache_control) = non_updated_cache_control {
save_original_cache_control(subgraph_req_id.clone(), context, non_updated_cache_control);
}
if record_metrics {
let _ = context.insert(
CacheMetricContextKey::new(subgraph_name.to_string()),
CacheSubgraph(cache_hit),
);
}
Ok((new_representations, result, cache_control))
}
#[allow(clippy::too_many_arguments)]
async fn insert_entities_in_result(
entities: &mut Vec<Value>,
errors: &[Error],
cache: Storage,
default_subgraph_ttl: Duration,
cache_control: CacheControl,
result: &mut Vec<IntermediateResult>,
private_id_for_dbg: Option<String>,
update_key_private: Option<String>,
should_cache_private: bool,
subgraph_name: &str,
per_entity_surrogate_keys: &[Value],
context: Context,
subgraph_request: Option<graphql::Request>,
) -> Result<(Vec<Value>, Vec<Error>), BoxError> {
let debug = subgraph_request.is_some();
let ttl = cache_control
.ttl()
.map(Duration::from_secs)
.unwrap_or(default_subgraph_ttl);
let mut new_entities = Vec::new();
let mut new_errors = Vec::new();
let mut inserted_types: HashMap<String, usize> = HashMap::new();
let mut to_insert: Vec<_> = Vec::new();
let mut debug_ctx_entries = Vec::new();
let mut entities_it = entities.drain(..).enumerate();
let mut per_entity_surrogate_keys_it = per_entity_surrogate_keys.iter();
for (
new_entity_idx,
IntermediateResult {
mut key,
mut invalidation_keys,
typename,
cache_entry,
entity_key,
},
) in result.drain(..).enumerate()
{
match cache_entry {
Some(v) => {
new_entities.push(v.data);
}
None => {
let (entity_idx, value) =
entities_it
.next()
.ok_or_else(|| FetchError::MalformedResponse {
reason: "invalid number of entities".to_string(),
})?;
let specific_surrogate_keys = per_entity_surrogate_keys_it.next();
*inserted_types.entry(typename.clone()).or_default() += 1;
if let Some(ref id) = update_key_private {
key = format!("{key}:{id}");
}
let mut has_errors = false;
for error in errors.iter().filter(|e| {
e.path
.as_ref()
.map(|path| {
path.starts_with(&Path(vec![
PathElement::Key(ENTITIES.to_string(), None),
PathElement::Index(entity_idx),
]))
})
.unwrap_or(false)
}) {
let mut e = error.clone();
if let Some(path) = e.path.as_mut() {
path.0[1] = PathElement::Index(new_entity_idx);
}
new_errors.push(e);
has_errors = true;
}
if let Some(Value::Array(keys)) = specific_surrogate_keys {
invalidation_keys
.extend(keys.iter().filter_map(|v| v.as_str()).map(|s| s.to_owned()));
}
if let Some(subgraph_request) = &subgraph_request {
debug_ctx_entries.push(
CacheKeyContext {
key: key.clone(),
hashed_private_id: private_id_for_dbg.clone(),
invalidation_keys: external_invalidation_keys(
invalidation_keys.clone(),
),
kind: CacheEntryKind::Entity {
typename: typename.clone(),
entity_key: entity_key.clone().unwrap_or_default(),
},
subgraph_name: subgraph_name.to_string(),
subgraph_request: subgraph_request.clone(),
source: CacheKeySource::Subgraph,
cache_control: cache_control.clone(),
data: serde_json_bytes::json!({"data": value.clone()}),
warnings: Vec::new(),
should_store: false,
}
.update_metadata(),
);
}
if !has_errors && cache_control.should_store() && should_cache_private {
to_insert.push(Document {
control: cache_control.clone(),
data: value.clone(),
key,
invalidation_keys,
expire: ttl,
debug,
});
}
new_entities.push(value);
}
}
}
if !debug_ctx_entries.is_empty() {
add_cache_keys_to_context(&context, debug_ctx_entries.into_iter())?;
}
if !to_insert.is_empty() {
let batch_size = to_insert.len();
let span = tracing::info_span!("response_cache.store", "kind" = "entity", "subgraph.name" = subgraph_name, "ttl" = ?ttl, "batch.size" = %batch_size);
let subgraph_name = subgraph_name.to_string();
tokio::spawn(async move {
let _ = cache
.insert_in_batch(to_insert, &subgraph_name)
.instrument(span)
.await;
});
}
for (ty, nb) in inserted_types {
tracing::event!(Level::TRACE, entity_type = ty.as_str(), cache_insert = nb,);
}
Ok((new_entities, new_errors))
}
fn external_invalidation_keys<I: IntoIterator<Item = String>>(invalidation_keys: I) -> Vec<String> {
invalidation_keys
.into_iter()
.filter(|k| !k.starts_with(INTERNAL_CACHE_TAG_PREFIX))
.collect()
}
fn assemble_response_from_errors(
graphql_errors: &[Error],
result: &mut Vec<IntermediateResult>,
) -> (Vec<Value>, Vec<Error>) {
let mut new_entities = Vec::new();
let mut new_errors = Vec::new();
for (new_entity_idx, IntermediateResult { cache_entry, .. }) in result.drain(..).enumerate() {
match cache_entry {
Some(v) => {
new_entities.push(v.data);
}
None => {
new_entities.push(Value::Null);
for mut error in graphql_errors.iter().cloned() {
error.path = Some(Path(vec![
PathElement::Key(ENTITIES.to_string(), None),
PathElement::Index(new_entity_idx),
]));
new_errors.push(error);
}
}
}
}
(new_entities, new_errors)
}
async fn connect_or_spawn_reconnection_task(
config: storage::redis::Config,
storage: Arc<OnceLock<Storage>>,
abort_signal: broadcast::Receiver<()>,
) -> Result<(), BoxError> {
match attempt_connection(&config, storage.clone(), abort_signal.resubscribe()).await {
Ok(()) => Ok(()),
Err(err) if config.required_to_start => Err(err),
Err(_) => {
tokio::spawn(reattempt_connection(config.clone(), storage, abort_signal));
Ok(())
}
}
}
async fn attempt_connection(
config: &storage::redis::Config,
cache_storage: Arc<OnceLock<Storage>>,
abort_signal: broadcast::Receiver<()>,
) -> Result<(), BoxError> {
let storage = Storage::new(config, abort_signal)
.await
.inspect_err(|err| {
tracing::error!(
cache = "response",
error = %err,
"could not open connection to Redis for response caching",
)
})?;
let _ = cache_storage.set(storage);
Ok(())
}
async fn reattempt_connection(
config: storage::redis::Config,
cache_storage: Arc<OnceLock<Storage>>,
mut abort_signal: broadcast::Receiver<()>,
) {
let mut interval = IntervalStream::new(tokio::time::interval(Duration::from_secs(30)));
loop {
tokio::select! {
biased;
_ = abort_signal.recv() => {
break;
}
_ = interval.next() => {
if attempt_connection(&config, cache_storage.clone(), abort_signal.resubscribe()).await.is_ok() {
break;
}
}
}
}
}
pub(crate) type CacheControls = HashMap<SubgraphRequestId, CacheControl>;
#[cfg(all(
test,
any(not(feature = "ci"), all(target_arch = "x86_64", target_os = "linux"))
))]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use apollo_compiler::Schema;
use apollo_compiler::parser::Parser;
use serde_json_bytes::json;
use tokio::sync::broadcast;
use uuid::Uuid;
use super::Subgraph;
use super::Ttl;
use crate::configuration::subgraph::SubgraphConfiguration;
use crate::plugin::PluginInit;
use crate::plugin::PluginPrivate;
use crate::plugins::response_cache::plugin::ResponseCache;
use crate::plugins::response_cache::plugin::get_entity_key_from_selection_set;
use crate::plugins::response_cache::plugin::get_invalidation_entity_keys_from_schema;
use crate::plugins::response_cache::plugin::get_invalidation_root_keys_from_schema;
use crate::plugins::response_cache::plugin::matches_selection_set;
use crate::plugins::response_cache::storage::redis::Config;
use crate::plugins::response_cache::storage::redis::Storage;
use crate::plugins::response_cache::tests::create_subgraph_conf;
use crate::services::OperationKind;
use crate::services::subgraph;
const SCHEMA: &str = include_str!("../../testdata/orga_supergraph_cache_key.graphql");
#[tokio::test]
async fn test_subgraph_enabled() {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let (drop_tx, drop_rx) = broadcast::channel(2);
let storage = Storage::new(&Config::test(false, "test_subgraph_enabled"), drop_rx)
.await
.unwrap();
let map = serde_json_bytes::from_value(serde_json_bytes::json!({
"user": {
"private_id": "sub"
},
"orga": {
"private_id": "sub",
"enabled": true
},
"archive": {
"private_id": "sub",
"enabled": false
}
}))
.unwrap();
let subgraphs_conf = create_subgraph_conf(map);
let mut response_cache = ResponseCache::for_test(
storage.clone(),
subgraphs_conf,
valid_schema.clone(),
true,
drop_tx,
)
.await
.unwrap();
assert!(response_cache.subgraph_enabled("user"));
assert!(!response_cache.subgraph_enabled("archive"));
let subgraph_config = serde_json_bytes::json!({
"all": {
"enabled": false
},
"subgraphs": response_cache.subgraphs.subgraphs.clone()
});
response_cache.subgraphs = Arc::new(serde_json_bytes::from_value(subgraph_config).unwrap());
assert!(!response_cache.subgraph_enabled("archive"));
assert!(response_cache.subgraph_enabled("user"));
assert!(response_cache.subgraph_enabled("orga"));
}
async fn get_response_cache_plugin(
all_enabled: bool,
all_invalidation_enabled: bool,
subgraph_enabled: bool,
subgraph_invalidation_enabled: bool,
) -> ResponseCache {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let (drop_tx, drop_rx) = broadcast::channel(2);
let storage = Storage::new(&Config::test(false, &Uuid::new_v4().to_string()), drop_rx)
.await
.unwrap();
ResponseCache::for_test(
storage.clone(),
serde_json_bytes::from_value(serde_json_bytes::json!({
"all": {
"enabled": all_enabled,
"ttl": "10s",
"invalidation": {
"enabled": all_invalidation_enabled,
"shared_key": "test"
}
},
"subgraphs": {
"user": {
"enabled": subgraph_enabled,
"invalidation": {
"enabled": subgraph_invalidation_enabled,
"shared_key": "test"
}
}
}
}))
.unwrap(),
valid_schema.clone(),
true,
drop_tx,
)
.await
.unwrap()
}
#[tokio::test]
async fn test_redis_connection_disabled() {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": true,
"subgraph": {
"all": {
"enabled": false,
"ttl": "10s",
"redis": {
"urls": ["redis://127.0.0.1:6379"],
"namespace": Uuid::new_v4().to_string(),
"pool_size": 1,
"required_to_start": true,
}
},
"subgraphs": {
"user": {
"enabled": false
}
}
}
}))
.unwrap();
let response_cache = ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.unwrap();
assert!(
response_cache.storage.all.is_none()
|| response_cache.storage.all.as_ref().unwrap().get().is_none(),
"Redis storage is set globally"
);
assert!(
response_cache.storage.subgraphs.is_empty(),
"Redis storage is set for a subgraph"
);
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": false,
"subgraph": {
"all": {
"enabled": true,
"ttl": "10s",
"redis": {
"urls": ["redis://127.0.0.1:6379"],
"namespace": Uuid::new_v4().to_string(),
"pool_size": 1,
"required_to_start": true,
}
},
"subgraphs": {
"user": {
"enabled": true
}
}
}
}))
.unwrap();
let response_cache = ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.unwrap();
assert!(
response_cache.storage.all.is_none()
|| response_cache.storage.all.as_ref().unwrap().get().is_none(),
"Redis storage is set globally"
);
assert!(
response_cache.storage.subgraphs.is_empty(),
"Redis storage is set for a subgraph"
);
}
#[tokio::test]
async fn test_no_redis_conf_provided_should_fail() {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": true,
"subgraph": {
"all": {
"enabled": true,
"ttl": "10s",
},
"subgraphs": {
"user": {
"enabled": true
},
"inventory": {
"enabled": true
}
}
}
}))
.unwrap();
assert!(
ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.is_err(),
"The plugin should not start properly if caching is enabled but no redis provided"
);
}
#[tokio::test]
#[rstest::rstest]
#[case(false, true, true)]
#[case(true, false, false)]
async fn test_no_redis_conf_provided_but_disabled_should_succeed(
#[case] global_enabled: bool,
#[case] all_enabled: bool,
#[case] subgraph_enabled: bool,
) {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": global_enabled,
"subgraph": {
"all": {
"enabled": all_enabled,
"ttl": "10s",
},
"subgraphs": {
"user": {
"enabled": subgraph_enabled
},
"inventory": {
"enabled": subgraph_enabled
}
}
}
}))
.unwrap();
let response_cache = ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.unwrap();
if !global_enabled {
assert!(!response_cache.enabled);
}
assert!(
response_cache.storage.all.is_none()
|| response_cache.storage.all.as_ref().unwrap().get().is_none(),
"Redis storage is set globally"
);
assert!(
response_cache.storage.subgraphs.is_empty(),
"Redis storage is set for a subgraph"
);
}
#[tokio::test]
async fn test_redis_connection_enabled_multiple_subgraphs() {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": true,
"subgraph": {
"all": {
"enabled": false,
"ttl": "10s",
"redis": {
"urls": ["redis://127.0.0.1:6379"],
"namespace": Uuid::new_v4().to_string(),
"pool_size": 1,
"required_to_start": true,
}
},
"subgraphs": {
"user": {
"enabled": false
},
"inventory": {
"enabled": true
}
}
}
}))
.unwrap();
let response_cache = ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.unwrap();
assert!(
response_cache.storage.all.is_none()
|| response_cache.storage.all.as_ref().unwrap().get().is_none(),
"Redis storage is set globally"
);
assert_eq!(
response_cache.storage.subgraphs.len(),
1,
"Redis storage is not set for a subgraph"
);
}
#[tokio::test]
#[rstest::rstest]
#[case(true, true)]
#[case(false, true)]
#[case(true, false)]
async fn test_redis_connection_enabled(
#[case] all_enabled: bool,
#[case] subgraph_enabled: bool,
) {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let config: super::Config = serde_json_bytes::from_value(serde_json_bytes::json!({
"enabled": true,
"subgraph": {
"all": {
"enabled": all_enabled,
"ttl": "10s",
"redis": {
"urls": ["redis://127.0.0.1:6379"],
"namespace": Uuid::new_v4().to_string(),
"pool_size": 1,
"required_to_start": true,
}
},
"subgraphs": {
"user": {
"enabled": subgraph_enabled
}
}
}
}))
.unwrap();
let response_cache = ResponseCache::new(PluginInit::fake_new(
config,
Arc::new(valid_schema.to_string()),
))
.await
.unwrap();
if all_enabled {
assert!(
response_cache.storage.all.is_some()
&& response_cache.storage.all.as_ref().unwrap().get().is_some(),
"Redis storage is not set globally"
);
} else {
assert!(
response_cache.storage.all.is_none()
|| response_cache.storage.all.as_ref().unwrap().get().is_none(),
"Redis storage is set globally"
);
}
if subgraph_enabled && !all_enabled {
assert_eq!(
response_cache.storage.subgraphs.len(),
1,
"Redis storage is set for a subgraph"
);
} else {
assert!(
response_cache.storage.subgraphs.is_empty(),
"Redis storage is not set for a subgraph"
);
}
}
#[tokio::test]
#[rstest::rstest]
#[case(true, true, true, true)]
#[case(true, true, true, false)]
#[case(true, false, true, true)]
#[case(false, false, true, true)]
async fn test_invalidation_endpoint_enabled(
#[case] all_enabled: bool,
#[case] all_invalidation_enabled: bool,
#[case] subgraph_enabled: bool,
#[case] subgraph_invalidation_enabled: bool,
) {
let response_cache = get_response_cache_plugin(
all_enabled,
all_invalidation_enabled,
subgraph_enabled,
subgraph_invalidation_enabled,
)
.await;
assert!(!response_cache.web_endpoints().is_empty());
}
#[tokio::test]
#[rstest::rstest]
#[case(false, false, false, false)]
#[case(false, true, false, false)]
#[case(false, true, false, true)]
#[case(true, false, true, false)]
#[case(true, false, false, false)]
#[case(true, false, false, true)]
async fn test_invalidation_endpoint_disabled(
#[case] all_enabled: bool,
#[case] all_invalidation_enabled: bool,
#[case] subgraph_enabled: bool,
#[case] subgraph_invalidation_enabled: bool,
) {
let response_cache = get_response_cache_plugin(
all_enabled,
all_invalidation_enabled,
subgraph_enabled,
subgraph_invalidation_enabled,
)
.await;
assert!(response_cache.web_endpoints().is_empty());
}
#[tokio::test]
async fn test_invalidation_endpoint_enabled_multiple_subgraphs() {
let mut response_cache = get_response_cache_plugin(false, false, true, false).await;
response_cache.subgraphs = Arc::new(
serde_json_bytes::from_value(serde_json_bytes::json!({
"all": {
"enabled": false,
"ttl": "10s",
"invalidation": {
"enabled": false,
"shared_key": "test"
}
},
"subgraphs": {
"user": {
"enabled": true,
"invalidation": {
"enabled": false,
"shared_key": "test"
}
},
"posts": {
"enabled": true,
"invalidation": {
"enabled": true,
"shared_key": "test"
}
}
}
}))
.unwrap(),
);
assert!(
!response_cache.web_endpoints().is_empty(),
"Disable invalidation globally with one specific subgraph configuration with invalidation disabled and another one enabled should enable invalidation endpoint"
);
}
#[tokio::test]
async fn test_subgraph_ttl() {
let valid_schema = Arc::new(Schema::parse_and_validate(SCHEMA, "test.graphql").unwrap());
let (drop_tx, drop_rx) = broadcast::channel(2);
let storage = Storage::new(&Config::test(false, "test_subgraph_ttl"), drop_rx)
.await
.unwrap();
let map = serde_json_bytes::from_value(serde_json_bytes::json!({
"user": {
"private_id": "sub",
"ttl": "2s"
},
"orga": {
"private_id": "sub",
"enabled": true
},
"archive": {
"private_id": "sub",
"enabled": false,
"ttl": "5000ms"
}
}))
.unwrap();
let mut response_cache = ResponseCache::for_test(
storage.clone(),
create_subgraph_conf(map),
valid_schema.clone(),
true,
drop_tx,
)
.await
.unwrap();
assert_eq!(
response_cache.subgraph_ttl("user"),
Some(Duration::from_secs(2))
);
assert!(response_cache.subgraph_ttl("orga").is_none());
assert_eq!(
response_cache.subgraph_ttl("archive"),
Some(Duration::from_millis(5000))
);
response_cache.subgraphs = Arc::new(SubgraphConfiguration {
all: Subgraph {
ttl: Some(Ttl(Duration::from_secs(25))),
..Default::default()
},
subgraphs: response_cache.subgraphs.subgraphs.clone(),
});
assert_eq!(
response_cache.subgraph_ttl("user"),
Some(Duration::from_secs(2))
);
assert_eq!(
response_cache.subgraph_ttl("orga"),
Some(Duration::from_secs(25))
);
assert_eq!(
response_cache.subgraph_ttl("archive"),
Some(Duration::from_millis(5000))
);
response_cache.subgraphs = Arc::new(SubgraphConfiguration {
all: Subgraph {
ttl: Some(Ttl(Duration::from_secs(42))),
..Default::default()
},
subgraphs: response_cache.subgraphs.subgraphs.clone(),
});
assert_eq!(
response_cache.subgraph_ttl("user"),
Some(Duration::from_secs(2))
);
assert_eq!(
response_cache.subgraph_ttl("orga"),
Some(Duration::from_secs(42))
);
assert_eq!(
response_cache.subgraph_ttl("archive"),
Some(Duration::from_millis(5000))
);
}
#[test]
fn test_matches_selection_set_handles_arrays() {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
locale: String!
lists: [List!]!
list: [List!]!
}
type List {
id: ID!
date: Int!
quantity: Int!
location: String!
}
"#;
let schema = Schema::parse_and_validate(schema_text, "test.graphql").unwrap();
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id locale lists { id date quantity location } list { id date quantity location }",
"test.graphql",
)
.unwrap();
let representation = json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50,
"location": "WAREHOUSE_A"
}
],
"list": [
{
"id": "LIST2",
"date": 20240101,
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"date": 20240102,
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
.clone();
assert!(
matches_selection_set(&representation, &field_set.selection_set),
"complex nested arrays should match"
);
}
fn repr_matches_selection_set_for_schema(
schema: &str,
named_type: &str,
selection_text: &str,
representation: serde_json_bytes::Value,
) -> bool {
let schema = Schema::parse_and_validate(schema, "test.graphql")
.expect("should be able to parse schema");
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new(named_type).unwrap(),
selection_text,
"test.graphql",
)
.expect("should be able to parse field set");
matches_selection_set(
representation.as_object().expect("must provide an object"),
&field_set.selection_set,
)
}
#[rstest::rstest]
#[case::null_list(json!(null))]
#[case::null_element(json!([null]))]
#[case::null_element(json!([{"id": "TEST1"}, null]))]
#[case::null_value_for_nullable_field(json!([{"id": "TEST1"}]))]
#[case::null_value_for_nullable_field(json!([{"id": "TEST1", "quantity": 5}]))]
#[case::multiple_differently_null_objects(json!([{"id": "TEST1"}, null, {"id": "TEST3", "quantity": null}]))]
fn test_matches_selection_set_handles_arrays_with_nullable_elements(
#[case] list_repr: serde_json_bytes::Value,
) {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
list: [NullableListElement]
}
type NullableListElement {
id: ID!
quantity: Int
inStock: Boolean
}
"#;
let named_type = "Test";
let selection_text = "id list { id quantity inStock }";
let representation = json!({
"id": "TEST123",
"list": list_repr
});
let matches_selection_set = repr_matches_selection_set_for_schema(
schema_text,
named_type,
selection_text,
representation,
);
assert!(matches_selection_set);
}
#[rstest::rstest]
#[case::null_element(json!([null]))]
#[case::null_element(json!([{"id": "TEST1"}, null]))]
#[case::null_value_for_nonnullable_field(json!([{}]))]
#[case::null_value_for_nonnullable_field(json!([{"quantity": 5}]))]
#[case::null_value_for_nonnullable_field(json!([{"id": "TEST1"}, {}]))]
#[case::null_value_for_nonnullable_field(json!([{"id": "TEST1"}, {"quantity": 5}]))]
fn test_matches_selection_set_handles_arrays_with_non_nullable_elements(
#[case] list_repr: serde_json_bytes::Value,
) {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
list: [NonNullableListElement!]
}
type NonNullableListElement {
id: ID!
quantity: Int
inStock: Boolean
}
"#;
let named_type = "Test";
let selection_text = "id list { id quantity inStock }";
let representation = json!({
"id": "TEST123",
"list": list_repr
});
let matches_selection_set = repr_matches_selection_set_for_schema(
schema_text,
named_type,
selection_text,
representation,
);
assert!(!matches_selection_set);
}
#[test]
fn test_matches_selection_subset_handles_arrays() {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
locale: String!
lists: [List!]!
list: [List!]!
}
type List {
id: ID!
date: Int!
quantity: Int!
location: String!
}
"#;
let schema = Schema::parse_and_validate(schema_text, "test.graphql").unwrap();
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id locale lists { id date quantity location } list { id date quantity location }",
"test.graphql",
)
.unwrap();
let representation = json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50
}
],
"list": [
{
"id": "LIST2",
"date": 20240101,
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"date": 20240102,
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
.clone();
assert!(!matches_selection_set(
&representation,
&field_set.selection_set
),);
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id locale lists { id date quantity } list { id date quantity location }",
"test.graphql",
)
.unwrap();
assert!(
matches_selection_set(&representation, &field_set.selection_set),
"complex nested arrays should match"
);
}
#[test]
fn test_matches_selection_set_handles_null() {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
nullable: Nullable
}
type Nullable {
id: ID!
}
"#;
let schema = Schema::parse_and_validate(schema_text, "test.graphql").unwrap();
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id nullable { id }",
"test.graphql",
)
.unwrap();
let representation = json!({
"id": "TEST123",
"nullable": null,
})
.as_object()
.unwrap()
.clone();
assert!(
matches_selection_set(&representation, &field_set.selection_set),
"complex nested arrays should match"
);
}
#[test]
fn test_take_selection_set_handles_arrays() {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
locale: String!
lists: [List!]!
list: [List!]!
}
type List {
id: ID!
date: Int!
quantity: Int!
location: String!
}
"#;
let schema = Schema::parse_and_validate(schema_text, "test.graphql").unwrap();
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id locale lists { id date quantity location } list { id date quantity location }",
"test.graphql",
)
.unwrap();
let representation = json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50,
"location": "WAREHOUSE_A"
}
],
"list": [
{
"id": "LIST2",
"date": 20240101,
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"date": 20240102,
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
.clone();
assert!(matches_selection_set(
&representation,
&field_set.selection_set
));
let entity_key =
get_entity_key_from_selection_set(&representation, &field_set.selection_set);
assert_eq!(
&entity_key,
json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50,
"location": "WAREHOUSE_A"
}
],
"list": [
{
"id": "LIST2",
"date": 20240101,
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"date": 20240102,
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
);
}
#[test]
fn test_take_selection_subset_handles_arrays() {
let schema_text = r#"
type Query {
test: Test
}
type Test {
id: ID!
locale: String!
lists: [List!]!
list: [List!]!
}
type List {
id: ID!
date: Int!
quantity: Int!
location: String!
}
"#;
let schema = Schema::parse_and_validate(schema_text, "test.graphql").unwrap();
let mut parser = Parser::new();
let field_set = parser
.parse_field_set(
&schema,
apollo_compiler::ast::NamedType::new("Test").unwrap(),
"id locale lists { id date quantity } list { id quantity location }",
"test.graphql",
)
.unwrap();
let representation = json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50,
"location": "WAREHOUSE_A"
}
],
"list": [
{
"id": "LIST2",
"date": 20240101,
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"date": 20240102,
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
.clone();
assert!(matches_selection_set(
&representation,
&field_set.selection_set
));
let entity_key =
get_entity_key_from_selection_set(&representation, &field_set.selection_set);
assert_eq!(
&entity_key,
json!({
"id": "TEST123",
"locale": "en_US",
"lists": [
{
"id": "LIST1",
"date": 20240101,
"quantity": 50
}
],
"list": [
{
"id": "LIST2",
"quantity": 100,
"location": "WAREHOUSE_A"
},
{
"id": "LIST3",
"quantity": 75,
"location": "WAREHOUSE_B"
}
]
})
.as_object()
.unwrap()
);
}
#[test]
fn test_get_invalidation_root_keys_from_schema() {
let schema_text = r#"
directive @join__directive(graphs: [join__Graph!], name: String!, args: join__DirectiveArguments) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION
directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE
directive @join__field(graph: join__Graph, requires: join__FieldSet, provides: join__FieldSet, type: String, external: Boolean, override: String, usedOverridden: Boolean, overrideLabel: String, contextArguments: [join__ContextArgument!]) repeatable on FIELD_DEFINITION | INPUT_FIELD_DEFINITION
directive @join__graph(name: String!, url: String!) on ENUM_VALUE
directive @join__implements(graph: join__Graph!, interface: String!) repeatable on OBJECT | INTERFACE
directive @join__type(graph: join__Graph!, key: join__FieldSet, extension: Boolean! = false, resolvable: Boolean! = true, isInterfaceObject: Boolean! = false) repeatable on OBJECT | INTERFACE | UNION | ENUM | INPUT_OBJECT | SCALAR
directive @join__unionMember(graph: join__Graph!, member: String!) repeatable on UNION
directive @link(url: String, as: String, for: link__Purpose, import: [link__Import]) repeatable on SCHEMA
input join__ContextArgument {
name: String!
type: String!
context: String!
selection: join__FieldValue!
}
scalar join__DirectiveArguments
scalar join__FieldSet
scalar join__FieldValue
enum join__Graph {
USER @join__graph(name: "USER", url: "none")
TEST @join__graph(name: "TEST", url: "none")
}
scalar link__Import
enum link__Purpose {
"""
`SECURITY` features provide metadata necessary to securely resolve fields.
"""
SECURITY
"""
`EXECUTION` features provide metadata necessary for operation execution.
"""
EXECUTION
}
type Query {
test: Test
testByCountry(id: ID!, country: Country!): Test @join__directive(
graphs: [USER]
name: "federation__cacheTag"
args: { format: "test-{$args.id}-{$args.country}" }
)
@join__directive(
graphs: [USER]
name: "federation__cacheTag"
args: { format: "test-{$args.country}" }
)
@join__directive(
graphs: [USER]
name: "federation__cacheTag"
args: { format: "test" }
)
}
enum Country {
BE
FR
}
type Test {
id: ID!
locale: String!
lists: [List!]!
list: [List!]!
}
type List {
id: ID!
date: Int!
quantity: Int!
location: String!
}
"#;
let schema = Arc::new(Schema::parse_and_validate(schema_text, "test.graphql").unwrap());
let query = r#"query Test {
testByCountry(id: "2", country: BE) {
locale
}
}"#;
let mut sub_request = subgraph::Request::fake_builder()
.subgraph_request(
http::Request::builder()
.body(
crate::graphql::Request::builder()
.query(query)
.operation_name("Test")
.build(),
)
.unwrap(),
)
.operation_kind(OperationKind::Query)
.subgraph_name("USER")
.build();
sub_request.executable_document = Some(Arc::new(
apollo_compiler::ExecutableDocument::parse_and_validate(&schema, query, "test.graphql")
.unwrap(),
));
let subgraph_enums: HashMap<String, String> = [("USER".to_string(), "USER".to_string())]
.into_iter()
.collect();
let cache_tags =
get_invalidation_root_keys_from_schema(&sub_request, &subgraph_enums, schema.clone())
.unwrap();
assert_eq!(
cache_tags,
[
"test".to_string(),
"test-BE".to_string(),
"test-2-BE".to_string()
]
.into_iter()
.collect()
);
}
#[test]
fn test_interface_object_typename_lookup_inbound() {
let schema_text = r#"
directive @join__type(graph: join__Graph!, key: join__FieldSet, isInterfaceObject: Boolean! = false) repeatable on
OBJECT | INTERFACE
directive @join__graph(name: String!, url: String!) on ENUM_VALUE
directive @join__implements(graph: join__Graph!, interface: String!) repeatable on OBJECT | INTERFACE
directive @join__directive(graphs: [join__Graph!], name: String!, args: join__DirectiveArguments) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION
scalar join__FieldSet
scalar join__DirectiveArguments
enum join__Graph {
SEARCH @join__graph(name: "search", url: "http://search")
INVENTORY @join__graph(name: "inventory", url: "http://inventory")
}
type Query { dummy: String }
interface Item
@join__type(graph: SEARCH, key: "id")
@join__type(graph: INVENTORY, key: "id", isInterfaceObject: true)
@join__directive(graphs: [INVENTORY], name: "federation__cacheTag", args: {format: "Item-{$key.id}"})
{
id: ID!
}
type Book implements Item
@join__implements(graph: SEARCH, interface: "Item")
@join__type(graph: SEARCH, key: "id")
{
id: ID!
isbn: String!
}
"#;
let schema = Arc::new(Schema::parse_and_validate(schema_text, "schema.graphql").unwrap());
let subgraph_enums = HashMap::from([
("SEARCH".into(), "search".into()),
("INVENTORY".into(), "inventory".into()),
]);
let representation = serde_json_bytes::json!({"__typename": "Book", "id": "123"})
.as_object()
.unwrap()
.clone();
let result = get_invalidation_entity_keys_from_schema(
&schema,
"inventory",
&subgraph_enums,
"Book",
&representation,
)
.expect("should handle interface object typename");
assert_eq!(result.into_iter().collect::<Vec<_>>(), [r#"Item-123"#]);
}
#[test]
fn test_interface_object_typename_lookup_outbound() {
let schema_text = r#"
directive @join__type(graph: join__Graph!, key: join__FieldSet, isInterfaceObject: Boolean! = false) repeatable on
OBJECT | INTERFACE
directive @join__graph(name: String!, url: String!) on ENUM_VALUE
directive @join__implements(graph: join__Graph!, interface: String!) repeatable on OBJECT | INTERFACE
directive @join__directive(graphs: [join__Graph!], name: String!, args: join__DirectiveArguments) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION
scalar join__FieldSet
scalar join__DirectiveArguments
enum join__Graph {
SEARCH @join__graph(name: "search", url: "http://search")
INVENTORY @join__graph(name: "inventory", url: "http://inventory")
}
type Query { dummy: String }
interface Item
@join__type(graph: SEARCH, key: "id")
@join__type(graph: INVENTORY, key: "id", isInterfaceObject: true)
{
id: ID!
}
type Book implements Item
@join__implements(graph: SEARCH, interface: "Item")
@join__type(graph: SEARCH, key: "id")
@join__directive(graphs: [SEARCH], name: "federation__cacheTag", args: {format: "Book-{$key.id}"})
{
id: ID!
isbn: String!
}
"#;
let schema = Arc::new(Schema::parse_and_validate(schema_text, "schema.graphql").unwrap());
let subgraph_enums = HashMap::from([
("SEARCH".into(), "search".into()),
("INVENTORY".into(), "inventory".into()),
]);
let representation = serde_json_bytes::json!({"__typename": "Item", "id": "123"})
.as_object()
.unwrap()
.clone();
let result = get_invalidation_entity_keys_from_schema(
&schema,
"inventory",
&subgraph_enums,
"Item",
&representation,
)
.expect("should handle interface object typename");
assert_eq!(result.len(), 0);
}
#[test]
fn test_interface_object_typename_into_interface_object() {
let schema_text = r#"
directive @join__type(graph: join__Graph!, key: join__FieldSet, isInterfaceObject: Boolean! = false) repeatable on
OBJECT | INTERFACE
directive @join__graph(name: String!, url: String!) on ENUM_VALUE
directive @join__implements(graph: join__Graph!, interface: String!) repeatable on OBJECT | INTERFACE
directive @join__directive(graphs: [join__Graph!], name: String!, args: join__DirectiveArguments) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION
scalar join__FieldSet
scalar join__DirectiveArguments
enum join__Graph {
SEARCH @join__graph(name: "search", url: "http://search")
INVENTORY @join__graph(name: "inventory", url: "http://inventory")
IRRELEVANT @join__graph(name: "irrelevant", url: "http://irrelevant")
}
type Query { dummy: String }
interface Item
@join__type(graph: SEARCH, key: "id", isInterfaceObject: true)
@join__type(graph: INVENTORY, key: "id", isInterfaceObject: true)
@join__type(graph: IRRELEVANT, key: "id")
@join__directive(graphs: [INVENTORY], name: "federation__cacheTag", args: {format: "Item-{$key.id}"})
{
id: ID!
}
type Book implements Item
@join__implements(graph: IRRELEVANT, interface: "Item")
@join__type(graph: IRRELEVANT, key: "id")
{
id: ID!
isbn: String!
}
"#;
let schema = Arc::new(Schema::parse_and_validate(schema_text, "schema.graphql").unwrap());
let subgraph_enums = HashMap::from([
("INVENTORY".into(), "inventory".into()),
("SEARCH".into(), "search".into()),
("IRRELEVANT".into(), "irrelevant".into()),
]);
let representation = serde_json_bytes::json!({"__typename": "Item", "id": "123"})
.as_object()
.unwrap()
.clone();
let result = get_invalidation_entity_keys_from_schema(
&schema,
"inventory",
&subgraph_enums,
"Item",
&representation,
)
.expect("should handle interface object typename");
assert_eq!(result.into_iter().collect::<Vec<_>>(), [r#"Item-123"#]);
}
#[test]
fn test_concrete_type_when_interface_object_is_false() {
let schema_text = r#"
directive @join__type(graph: join__Graph!, key: join__FieldSet, isInterfaceObject: Boolean! = false) repeatable on OBJECT | INTERFACE
directive @join__graph(name: String!, url: String!) on ENUM_VALUE
scalar join__FieldSet
enum join__Graph {
PRODUCTS @join__graph(name: "products", url: "http://products")
}
type Query { dummy: String }
# Regular interface (not an interface object)
interface Item {
id: ID!
}
# Concrete type that implements the interface
type Product implements Item @join__type(graph: PRODUCTS, key: "id") {
id: ID!
name: String
}
"#;
let schema = Arc::new(Schema::parse_and_validate(schema_text, "schema.graphql").unwrap());
let subgraph_enums = HashMap::from([("PRODUCTS".into(), "products".into())]);
let representation = serde_json_bytes::json!({
"__typename": "Product", "id": "123"
})
.as_object()
.unwrap()
.clone();
let result = get_invalidation_entity_keys_from_schema(
&schema,
"products",
&subgraph_enums,
"Product", &representation,
);
assert!(
result.is_ok(),
"should handle concrete type (isInterfaceObject: false)"
);
}
}