use std::future::ready;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use futures::Stream;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::once;
use serde_json_bytes::Value;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::error::TryRecvError;
use tokio_stream::wrappers::ReceiverStream;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower_service::Service;
use tracing::Instrument;
use tracing::Span;
use tracing::event;
use tracing_core::Level;
use crate::apollo_studio_interop::ReferencedEnums;
use crate::apollo_studio_interop::extract_enums_from_response;
use crate::graphql::Error;
use crate::graphql::IncrementalResponse;
use crate::graphql::Response;
use crate::json_ext::Object;
use crate::json_ext::Path;
use crate::json_ext::PathElement;
use crate::json_ext::ValueExt;
use crate::plugins::authentication::APOLLO_AUTHENTICATION_JWT_CLAIMS;
use crate::plugins::subscription::APOLLO_SUBSCRIPTION_PLUGIN;
use crate::plugins::subscription::Subscription;
use crate::plugins::subscription::SubscriptionConfig;
use crate::plugins::telemetry::Telemetry;
use crate::plugins::telemetry::apollo::Config as ApolloTelemetryConfig;
use crate::plugins::telemetry::config::ApolloMetricsReferenceMode;
use crate::query_planner::fetch::SubgraphSchemas;
use crate::query_planner::subscription::SubscriptionHandle;
use crate::services::ExecutionRequest;
use crate::services::ExecutionResponse;
use crate::services::Plugins;
use crate::services::SubgraphServiceFactory;
use crate::services::execution;
use crate::services::new_service::ServiceFactory;
use crate::spec::Query;
use crate::spec::Schema;
use crate::spec::query::subselections::BooleanValues;
#[derive(Clone)]
pub(crate) struct ExecutionService {
pub(crate) schema: Arc<Schema>,
pub(crate) subgraph_schemas: Arc<SubgraphSchemas>,
pub(crate) subgraph_service_factory: Arc<SubgraphServiceFactory>,
subscription_config: Option<SubscriptionConfig>,
apollo_telemetry_config: Option<ApolloTelemetryConfig>,
}
type CloseSignal = broadcast::Sender<()>;
pub(crate) struct StreamWrapper(pub(crate) ReceiverStream<Response>, Option<CloseSignal>);
impl Stream for StreamWrapper {
type Item = Response;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_next(cx)
}
}
impl Drop for StreamWrapper {
fn drop(&mut self) {
if let Some(closed_signal) = self.1.take() {
if let Err(err) = closed_signal.send(()) {
tracing::trace!("cannot close the subscription: {err:?}");
}
}
self.0.close();
}
}
impl Service<ExecutionRequest> for ExecutionService {
type Response = ExecutionResponse;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: ExecutionRequest) -> Self::Future {
let clone = self.clone();
let mut this = std::mem::replace(self, clone);
let fut = async move { Ok(this.call_inner(req).await) }.in_current_span();
Box::pin(fut)
}
}
impl ExecutionService {
async fn call_inner(&mut self, req: ExecutionRequest) -> ExecutionResponse {
let context = req.context;
let ctx = context.clone();
let variables = req.supergraph_request.body().variables.clone();
let (sender, receiver) = mpsc::channel(10);
let is_deferred = req.query_plan.is_deferred(&variables);
let is_subscription = req.query_plan.is_subscription();
let mut claims = None;
if is_deferred {
claims = context.get(APOLLO_AUTHENTICATION_JWT_CLAIMS).ok().flatten()
}
let (tx_close_signal, subscription_handle) = if is_subscription {
let (tx_close_signal, rx_close_signal) = broadcast::channel(1);
(
Some(tx_close_signal),
Some(SubscriptionHandle::new(
rx_close_signal,
req.subscription_tx,
)),
)
} else {
(None, None)
};
let has_initial_data = req.source_stream_value.is_some();
let mut first = req
.query_plan
.execute(
&context,
&self.subgraph_service_factory,
&Arc::new(req.supergraph_request),
&self.schema,
&self.subgraph_schemas,
sender,
subscription_handle.clone(),
&self.subscription_config,
req.source_stream_value,
)
.await;
let query = req.query_plan.query.clone();
let stream = if (is_deferred || is_subscription) && !has_initial_data {
let stream_mode = if is_deferred {
StreamMode::Defer
} else {
first.subscribed = Some(first.errors.is_empty());
StreamMode::Subscription
};
let stream = filter_stream(first, receiver, stream_mode);
StreamWrapper(stream, tx_close_signal).boxed()
} else if has_initial_data {
once(ready(first)).boxed()
} else {
once(ready(first))
.chain(ReceiverStream::new(receiver))
.boxed()
};
if has_initial_data {
return ExecutionResponse::new_from_response(http::Response::new(stream as _), ctx);
}
let schema = self.schema.clone();
let mut nullified_paths: Vec<Path> = vec![];
let metrics_ref_mode = match &self.apollo_telemetry_config {
Some(conf) => conf.metrics_reference_mode,
_ => ApolloMetricsReferenceMode::default(),
};
let execution_span = Span::current();
let stream = stream
.map(move |mut response: Response| {
if is_deferred {
let ts_opt = claims.as_ref().and_then(|x: &Value| {
if !x.is_object() {
tracing::error!("JWT claims should be an object");
return None;
}
let claims = x.as_object().expect("claims should be an object");
let exp = claims.get("exp")?;
if !exp.is_number() {
tracing::error!("JWT 'exp' (expiry) claim should be a number");
return None;
}
exp.as_i64()
});
if let Some(ts) = ts_opt {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("we should not run before EPOCH")
.as_secs() as i64;
if ts < now {
tracing::debug!("token has expired, shut down the subscription");
response = Response::builder()
.has_next(false)
.error(
Error::builder()
.message(
"deferred response closed because the JWT has expired",
)
.extension_code("DEFERRED_RESPONSE_JWT_EXPIRED")
.build(),
)
.build()
}
}
}
response
})
.filter_map(move |response: Response| {
ready(execution_span.in_scope(|| {
Self::process_graphql_response(
&query,
&variables,
is_deferred,
&schema,
&mut nullified_paths,
metrics_ref_mode,
&context,
response,
)
}))
})
.boxed();
ExecutionResponse::new_from_response(http::Response::new(stream as _), ctx)
}
#[allow(clippy::too_many_arguments)]
fn process_graphql_response(
query: &Arc<Query>,
variables: &Object,
is_deferred: bool,
schema: &Arc<Schema>,
nullified_paths: &mut Vec<Path>,
metrics_ref_mode: ApolloMetricsReferenceMode,
context: &crate::Context,
mut response: Response,
) -> Option<Response> {
if response
.path
.as_ref()
.map(|response_path| {
nullified_paths
.iter()
.any(|path| response_path.starts_with(path))
})
.unwrap_or(false)
{
if response.has_next == Some(false) {
return Some(Response::builder().has_next(false).build());
} else {
return None;
}
}
if response.subscribed == Some(false)
&& response.data.is_none()
&& response.errors.is_empty()
{
return response.into();
}
let has_next = response.has_next.unwrap_or(true);
let variables_set = query.defer_variables_set(variables);
tracing::debug_span!("format_response").in_scope(|| {
let mut paths = Vec::new();
if !query.unauthorized.paths.is_empty() {
if query.unauthorized.errors.log {
let unauthorized_paths = query.unauthorized.paths.iter().map(|path| path.to_string()).collect::<Vec<_>>();
event!(Level::ERROR, unauthorized_query_paths = ?unauthorized_paths, "Authorization error",);
}
match query.unauthorized.errors.response {
crate::plugins::authorization::ErrorLocation::Errors => for path in &query.unauthorized.paths {
response.errors.push(Error::builder()
.message("Unauthorized field or type")
.path(path.clone())
.extension_code("UNAUTHORIZED_FIELD_OR_TYPE").build());
},
crate::plugins::authorization::ErrorLocation::Extensions =>{
if !query.unauthorized.paths.is_empty() {
let mut v = vec![];
for path in &query.unauthorized.paths{
v.push(serde_json_bytes::to_value(Error::builder()
.message("Unauthorized field or type")
.path(path.clone())
.extension_code("UNAUTHORIZED_FIELD_OR_TYPE").build()).expect("error serialization should not fail"));
}
response.extensions.insert("authorizationErrors", Value::Array(v));
}
},
crate::plugins::authorization::ErrorLocation::Disabled => {},
}
}
if let Some(filtered_query) = query.filtered_query.as_ref() {
paths = filtered_query.format_response(
&mut response,
variables.clone(),
schema.api_schema(),
variables_set,
);
}
paths.extend(
query
.format_response(
&mut response,
variables.clone(),
schema.api_schema(),
variables_set,
)
,
);
for error in response.errors.iter_mut() {
if let Some(path) = &mut error.path {
let matching_len = query.matching_error_path_length(path);
if path.len() != matching_len {
path.0.drain(matching_len..);
if path.is_empty() {
error.path = None;
}
error.locations.clear();
}
}
}
nullified_paths.extend(paths);
let mut referenced_enums = context
.extensions()
.with_lock(|lock| lock.get::<ReferencedEnums>().cloned())
.unwrap_or_default();
if let (ApolloMetricsReferenceMode::Extended, Some(Value::Object(response_body))) = (metrics_ref_mode, &response.data) {
extract_enums_from_response(
query.clone(),
schema.api_schema(),
response_body,
&mut referenced_enums,
)
};
context
.extensions()
.with_lock(|mut lock| lock.insert::<ReferencedEnums>(referenced_enums));
});
match (response.path.as_ref(), response.data.as_ref()) {
(None, _) | (_, None) => {
if is_deferred {
response.has_next = Some(has_next);
}
response.errors.retain(|error| match &error.path {
None => true,
Some(error_path) => {
query.contains_error_path(&response.label, error_path, variables_set)
}
});
response.label = rewrite_defer_label(&response);
Some(response)
}
(Some(response_path), Some(response_data)) => {
let mut sub_responses = Vec::new();
response_data.select_values_and_paths(schema, response_path, |path, value| {
if let Value::Array(array) = value {
let mut parent = path.clone();
for (i, value) in array.iter().enumerate() {
parent.push(PathElement::Index(i));
sub_responses.push((parent.clone(), value.clone()));
parent.pop();
}
} else {
sub_responses.push((path.clone(), value.clone()));
}
});
Self::split_incremental_response(
query,
has_next,
variables_set,
response,
sub_responses,
)
}
}
}
fn split_incremental_response(
query: &Arc<Query>,
has_next: bool,
variables_set: BooleanValues,
response: Response,
sub_responses: Vec<(Path, Value)>,
) -> Option<Response> {
let query = query.clone();
let rewritten_label = rewrite_defer_label(&response);
let incremental = sub_responses
.into_iter()
.filter_map(move |(path, data)| {
let errors = response
.errors
.iter()
.filter(|error| match &error.path {
None => false,
Some(error_path) => {
query.contains_error_path(&response.label, error_path, variables_set)
&& error_path.starts_with(&path)
}
})
.cloned()
.collect::<Vec<_>>();
let extensions: Object = response
.extensions
.iter()
.map(|(key, value)| {
if key.as_str() == "valueCompletion" {
let value = match value.as_array() {
None => Value::Null,
Some(v) => Value::Array(
v.iter()
.filter(|ext| {
ext.as_object()
.and_then(|ext| ext.get("path"))
.and_then(|v| {
serde_json_bytes::from_value::<Path>(v.clone())
.ok()
})
.map(|ext_path| ext_path.starts_with(&path))
.unwrap_or(false)
})
.cloned()
.collect(),
),
};
(key.clone(), value)
} else {
(key.clone(), value.clone())
}
})
.collect();
if !data.is_null() || !errors.is_empty() || !extensions.is_empty() {
Some(
IncrementalResponse::builder()
.and_label(rewritten_label.clone())
.data(data)
.path(path)
.errors(errors)
.extensions(extensions)
.build(),
)
} else {
None
}
})
.collect();
Some(
Response::builder()
.has_next(has_next)
.incremental(incremental)
.build(),
)
}
}
fn rewrite_defer_label(response: &Response) -> Option<String> {
if let Some(label) = &response.label {
#[allow(clippy::manual_map)] if let Some(rest) = label.strip_prefix('_') {
Some(rest.to_owned())
} else {
None
}
} else {
None
}
}
#[derive(Clone, Copy)]
enum StreamMode {
Defer,
Subscription,
}
fn filter_stream(
first: Response,
mut stream: Receiver<Response>,
stream_mode: StreamMode,
) -> ReceiverStream<Response> {
let (mut sender, receiver) = mpsc::channel(10);
tokio::task::spawn(async move {
let mut seen_last_message =
consume_responses(first, &mut stream, &mut sender, stream_mode).await?;
while let Some(current_response) = stream.recv().await {
seen_last_message =
consume_responses(current_response, &mut stream, &mut sender, stream_mode).await?;
}
if !seen_last_message {
let res = match stream_mode {
StreamMode::Defer => Response::builder().has_next(false).build(),
StreamMode::Subscription => Response::builder().subscribed(false).build(),
};
sender.send(res).await?;
}
Ok::<_, SendError<Response>>(())
});
receiver.into()
}
async fn consume_responses(
mut current_response: Response,
stream: &mut Receiver<Response>,
sender: &mut Sender<Response>,
stream_mode: StreamMode,
) -> Result<bool, SendError<Response>> {
loop {
match stream.try_recv() {
Err(err) => {
match err {
TryRecvError::Empty => {
sender.send(current_response).await?;
return Ok(false);
}
TryRecvError::Disconnected => {
match stream_mode {
StreamMode::Defer => current_response.has_next = Some(false),
StreamMode::Subscription => current_response.subscribed = Some(false),
}
sender.send(current_response).await?;
return Ok(true);
}
}
}
Ok(response) => {
sender.send(current_response).await?;
current_response = response;
}
}
}
}
#[derive(Clone)]
pub(crate) struct ExecutionServiceFactory {
pub(crate) schema: Arc<Schema>,
pub(crate) subgraph_schemas: Arc<SubgraphSchemas>,
pub(crate) plugins: Arc<Plugins>,
pub(crate) subgraph_service_factory: Arc<SubgraphServiceFactory>,
}
impl ServiceFactory<ExecutionRequest> for ExecutionServiceFactory {
type Service = execution::BoxService;
fn create(&self) -> Self::Service {
let subscription_plugin_conf = self
.plugins
.iter()
.find(|i| i.0.as_str() == APOLLO_SUBSCRIPTION_PLUGIN)
.and_then(|plugin| (*plugin.1).as_any().downcast_ref::<Subscription>())
.map(|p| p.config.clone());
let apollo_telemetry_conf = self
.plugins
.iter()
.find(|i| i.0.as_str() == "apollo.telemetry")
.and_then(|plugin| (*plugin.1).as_any().downcast_ref::<Telemetry>())
.map(|t| t.config.apollo.clone());
ServiceBuilder::new()
.service(
self.plugins.iter().rev().fold(
crate::services::execution::service::ExecutionService {
schema: self.schema.clone(),
subgraph_service_factory: self.subgraph_service_factory.clone(),
subscription_config: subscription_plugin_conf,
subgraph_schemas: self.subgraph_schemas.clone(),
apollo_telemetry_config: apollo_telemetry_conf,
}
.boxed(),
|acc, (_, e)| e.execution_service(acc),
),
)
.boxed()
}
}