use std::collections::HashMap;
use std::ops::ControlFlow;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use bytes::Bytes;
use futures::StreamExt;
use futures::TryStreamExt;
use futures::future::ready;
use futures::stream::once;
use http::HeaderMap;
use http::HeaderName;
use http::HeaderValue;
use http::header;
use hyper::client::HttpConnector;
use hyper_rustls::ConfigBuilderExt;
use hyper_rustls::HttpsConnector;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use tower::BoxError;
use tower::Service;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower::timeout::TimeoutLayer;
use tower::util::MapFutureLayer;
use crate::configuration::shared::Client;
use crate::error::Error;
use crate::graphql;
use crate::json_ext::Value;
use crate::layers::ServiceBuilderExt;
use crate::layers::async_checkpoint::OneShotAsyncCheckpointLayer;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::plugins::telemetry::config_new::conditions::Condition;
use crate::plugins::telemetry::config_new::selectors::RouterSelector;
use crate::plugins::telemetry::config_new::selectors::SubgraphSelector;
use crate::plugins::traffic_shaping::Http2Config;
use crate::register_plugin;
use crate::services;
use crate::services::external::Control;
use crate::services::external::DEFAULT_EXTERNALIZATION_TIMEOUT;
use crate::services::external::EXTERNALIZABLE_VERSION;
use crate::services::external::Externalizable;
use crate::services::external::PipelineStep;
use crate::services::external::externalize_header_map;
use crate::services::hickory_dns_connector::AsyncHyperResolver;
use crate::services::hickory_dns_connector::new_async_http_connector;
use crate::services::router;
use crate::services::router::body::RouterBody;
use crate::services::router::body::RouterBodyConverter;
use crate::services::router::body::get_body_bytes;
use crate::services::subgraph;
#[cfg(test)]
mod test;
mod execution;
mod supergraph;
pub(crate) const EXTERNAL_SPAN_NAME: &str = "external_plugin";
const POOL_IDLE_TIMEOUT_DURATION: Option<Duration> = Some(Duration::from_secs(5));
const COPROCESSOR_ERROR_EXTENSION: &str = "ERROR";
const COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION: &str = "EXTERNAL_DESERIALIZATION_ERROR";
type HTTPClientService = RouterBodyConverter<
tower::timeout::Timeout<
hyper::Client<HttpsConnector<HttpConnector<AsyncHyperResolver>>, RouterBody>,
>,
>;
#[async_trait::async_trait]
impl Plugin for CoprocessorPlugin<HTTPClientService> {
type Config = Conf;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
let client_config = init.config.client.clone().unwrap_or_default();
let mut http_connector =
new_async_http_connector(client_config.dns_resolution_strategy.unwrap_or_default())?;
http_connector.set_nodelay(true);
http_connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
http_connector.enforce_http(false);
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth();
let builder = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1();
let experimental_http2 = client_config.experimental_http2.unwrap_or_default();
let connector = if experimental_http2 != Http2Config::Disable {
builder.enable_http2().wrap_connector(http_connector)
} else {
builder.wrap_connector(http_connector)
};
let http_client = RouterBodyConverter {
inner: ServiceBuilder::new()
.layer(TimeoutLayer::new(init.config.timeout))
.service(
hyper::Client::builder()
.http2_only(experimental_http2 == Http2Config::Http2Only)
.pool_idle_timeout(POOL_IDLE_TIMEOUT_DURATION)
.build(connector),
),
};
CoprocessorPlugin::new(http_client, init.config, init.supergraph_sdl)
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
self.router_service(service)
}
fn supergraph_service(
&self,
service: services::supergraph::BoxService,
) -> services::supergraph::BoxService {
self.supergraph_service(service)
}
fn execution_service(
&self,
service: services::execution::BoxService,
) -> services::execution::BoxService {
self.execution_service(service)
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
self.subgraph_service(name, service)
}
}
register_plugin!(
"apollo",
"coprocessor",
CoprocessorPlugin<HTTPClientService>
);
#[derive(Debug)]
struct CoprocessorPlugin<C>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
http_client: C,
configuration: Conf,
sdl: Arc<String>,
}
impl<C> CoprocessorPlugin<C>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
fn new(http_client: C, configuration: Conf, sdl: Arc<String>) -> Result<Self, BoxError> {
Ok(Self {
http_client,
configuration,
sdl,
})
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
self.configuration.router.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn supergraph_service(
&self,
service: services::supergraph::BoxService,
) -> services::supergraph::BoxService {
self.configuration.supergraph.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn execution_service(
&self,
service: services::execution::BoxService,
) -> services::execution::BoxService {
self.configuration.execution.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
self.configuration.subgraph.all.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
name.to_string(),
self.configuration.response_validation,
)
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct RouterRequestConf {
#[serde(skip_serializing)]
pub(super) condition: Option<Condition<RouterSelector>>,
pub(super) headers: bool,
pub(super) context: bool,
pub(super) body: bool,
pub(super) sdl: bool,
pub(super) path: bool,
pub(super) method: bool,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct RouterResponseConf {
#[serde(skip_serializing)]
pub(super) condition: Option<Condition<RouterSelector>>,
pub(super) headers: bool,
pub(super) context: bool,
pub(super) body: bool,
pub(super) sdl: bool,
pub(super) status_code: bool,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphRequestConf {
#[serde(skip_serializing)]
pub(super) condition: Option<Condition<SubgraphSelector>>,
pub(super) headers: bool,
pub(super) context: bool,
pub(super) body: bool,
pub(super) uri: bool,
pub(super) method: bool,
pub(super) service_name: bool,
pub(super) subgraph_request_id: bool,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphResponseConf {
#[serde(skip_serializing)]
pub(super) condition: Option<Condition<SubgraphSelector>>,
pub(super) headers: bool,
pub(super) context: bool,
pub(super) body: bool,
pub(super) service_name: bool,
pub(super) status_code: bool,
pub(super) subgraph_request_id: bool,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
struct Conf {
url: String,
client: Option<Client>,
#[serde(deserialize_with = "humantime_serde::deserialize")]
#[schemars(with = "String", default = "default_timeout")]
#[serde(default = "default_timeout")]
timeout: Duration,
#[serde(default = "default_response_validation")]
response_validation: bool,
#[serde(default)]
router: RouterStage,
#[serde(default)]
supergraph: supergraph::SupergraphStage,
#[serde(default)]
execution: execution::ExecutionStage,
#[serde(default)]
subgraph: SubgraphStages,
}
fn default_timeout() -> Duration {
DEFAULT_EXTERNALIZATION_TIMEOUT
}
fn default_response_validation() -> bool {
true
}
fn record_coprocessor_duration(stage: PipelineStep, duration: Duration) {
f64_histogram!(
"apollo.router.operations.coprocessor.duration",
"Time spent waiting for the coprocessor to answer, in seconds",
duration.as_secs_f64(),
coprocessor.stage = stage.to_string()
);
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default)]
pub(super) struct RouterStage {
pub(super) request: RouterRequestConf,
pub(super) response: RouterResponseConf,
}
impl RouterStage {
pub(crate) fn as_service<C>(
&self,
http_client: C,
service: router::BoxService,
coprocessor_url: String,
sdl: Arc<String>,
response_validation: bool,
) -> router::BoxService
where
C: Service<
http::Request<RouterBody>,
Response = http::Response<RouterBody>,
Error = BoxError,
> + Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let request_layer = (self.request != Default::default()).then_some({
let request_config = self.request.clone();
let coprocessor_url = coprocessor_url.clone();
let http_client = http_client.clone();
let sdl = sdl.clone();
OneShotAsyncCheckpointLayer::new(move |request: router::Request| {
let request_config = request_config.clone();
let coprocessor_url = coprocessor_url.clone();
let http_client = http_client.clone();
let sdl = sdl.clone();
async move {
let mut succeeded = true;
let result = process_router_request_stage(
http_client,
coprocessor_url,
sdl,
request,
request_config,
response_validation,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!(
"external extensibility: router request stage error: {error}"
);
error
});
u64_counter!(
"apollo.router.operations.coprocessor",
"Total operations with co-processors enabled",
1,
"coprocessor.stage" = PipelineStep::RouterRequest,
"coprocessor.succeeded" = succeeded
);
result
}
})
});
let response_layer = (self.response != Default::default()).then_some({
let response_config = self.response.clone();
MapFutureLayer::new(move |fut| {
let sdl = sdl.clone();
let coprocessor_url = coprocessor_url.clone();
let http_client = http_client.clone();
let response_config = response_config.clone();
async move {
let response: router::Response = fut.await?;
let mut succeeded = true;
let result = process_router_response_stage(
http_client,
coprocessor_url,
sdl,
response,
response_config,
response_validation,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!(
"external extensibility: router response stage error: {error}"
);
error
});
u64_counter!(
"apollo.router.operations.coprocessor",
"Total operations with co-processors enabled",
1,
"coprocessor.stage" = PipelineStep::RouterResponse,
"coprocessor.succeeded" = succeeded
);
result
}
})
});
fn external_service_span() -> impl Fn(&router::Request) -> tracing::Span + Clone {
move |_request: &router::Request| {
tracing::info_span!(
EXTERNAL_SPAN_NAME,
"external service" = stringify!(router::Request),
"otel.kind" = "INTERNAL"
)
}
}
ServiceBuilder::new()
.instrument(external_service_span())
.option_layer(request_layer)
.option_layer(response_layer)
.service(service)
.boxed()
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphStages {
#[serde(default)]
pub(super) all: SubgraphStage,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphStage {
#[serde(default)]
pub(super) request: SubgraphRequestConf,
#[serde(default)]
pub(super) response: SubgraphResponseConf,
}
impl SubgraphStage {
pub(crate) fn as_service<C>(
&self,
http_client: C,
service: subgraph::BoxService,
coprocessor_url: String,
service_name: String,
response_validation: bool,
) -> subgraph::BoxService
where
C: Service<
http::Request<RouterBody>,
Response = http::Response<RouterBody>,
Error = BoxError,
> + Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let request_layer = (self.request != Default::default()).then_some({
let request_config = self.request.clone();
let http_client = http_client.clone();
let coprocessor_url = coprocessor_url.clone();
let service_name = service_name.clone();
OneShotAsyncCheckpointLayer::new(move |request: subgraph::Request| {
let http_client = http_client.clone();
let coprocessor_url = coprocessor_url.clone();
let service_name = service_name.clone();
let request_config = request_config.clone();
async move {
let mut succeeded = true;
let result = process_subgraph_request_stage(
http_client,
coprocessor_url,
service_name,
request,
request_config,
response_validation,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!(
"external extensibility: subgraph request stage error: {error}"
);
error
});
u64_counter!(
"apollo.router.operations.coprocessor",
"Total operations with co-processors enabled",
1,
"coprocessor.stage" = PipelineStep::SubgraphRequest,
"coprocessor.succeeded" = succeeded
);
result
}
})
});
let response_layer = (self.response != Default::default()).then_some({
let response_config = self.response.clone();
MapFutureLayer::new(move |fut| {
let http_client = http_client.clone();
let coprocessor_url = coprocessor_url.clone();
let response_config = response_config.clone();
let service_name = service_name.clone();
async move {
let response: subgraph::Response = fut.await?;
let mut succeeded = true;
let result = process_subgraph_response_stage(
http_client,
coprocessor_url,
service_name,
response,
response_config,
response_validation,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!(
"external extensibility: subgraph response stage error: {error}"
);
error
});
u64_counter!(
"apollo.router.operations.coprocessor",
"Total operations with co-processors enabled",
1,
"coprocessor.stage" = PipelineStep::SubgraphResponse,
"coprocessor.succeeded" = succeeded
);
result
}
})
});
fn external_service_span() -> impl Fn(&subgraph::Request) -> tracing::Span + Clone {
move |_request: &subgraph::Request| {
tracing::info_span!(
EXTERNAL_SPAN_NAME,
"external service" = stringify!(subgraph::Request),
"otel.kind" = "INTERNAL"
)
}
}
ServiceBuilder::new()
.instrument(external_service_span())
.option_layer(request_layer)
.option_layer(response_layer)
.service(service)
.boxed()
}
}
async fn process_router_request_stage<C>(
http_client: C,
coprocessor_url: String,
sdl: Arc<String>,
mut request: router::Request,
mut request_config: RouterRequestConf,
response_validation: bool,
) -> Result<ControlFlow<router::Response, router::Request>, BoxError>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let should_be_executed = request_config
.condition
.as_mut()
.map(|c| c.evaluate_request(&request) == Some(true))
.unwrap_or(true);
if !should_be_executed {
return Ok(ControlFlow::Continue(request));
}
let (parts, body) = request.router_request.into_parts();
let bytes = get_body_bytes(body).await?;
let headers_to_send = request_config
.headers
.then(|| externalize_header_map(&parts.headers))
.transpose()?;
let body_to_send = request_config
.body
.then(|| String::from_utf8(bytes.to_vec()))
.transpose()
.unwrap_or_default();
let path_to_send = request_config.path.then(|| parts.uri.to_string());
let context_to_send = request_config.context.then(|| request.context.clone());
let sdl_to_send = request_config.sdl.then(|| sdl.clone().to_string());
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterRequest)
.control(Control::default())
.id(request.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_sdl(sdl_to_send)
.and_path(path_to_send)
.method(parts.method.to_string())
.build();
tracing::debug!(?payload, "externalized output");
let guard = request.context.enter_active_request();
let start = Instant::now();
let co_processor_result = payload.call(http_client, &coprocessor_url).await;
let duration = start.elapsed();
drop(guard);
record_coprocessor_duration(PipelineStep::RouterRequest, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let mut co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterRequest)?;
let control = co_processor_output.control.expect("validated above; qed");
if matches!(control, Control::Break(_)) {
let code = control.get_http_status()?;
let body_as_value = co_processor_output
.body
.as_ref()
.and_then(|b| serde_json::from_str(b).ok())
.unwrap_or(Value::Null);
let graphql_response = match body_as_value {
Value::Null => graphql::Response::builder()
.errors(vec![
Error::builder()
.message(co_processor_output.body.take().unwrap_or_default())
.extension_code(COPROCESSOR_ERROR_EXTENSION)
.build(),
])
.build(),
_ => deserialize_coprocessor_response(body_as_value, response_validation),
};
let res = router::Response::builder()
.errors(graphql_response.errors)
.extensions(graphql_response.extensions)
.status_code(code)
.context(request.context);
let mut res = match (graphql_response.label, graphql_response.data) {
(Some(label), Some(data)) => res.label(label).data(data).build()?,
(Some(label), None) => res.label(label).build()?,
(None, Some(data)) => res.data(data).build()?,
(None, None) => res.build()?,
};
if let Some(headers) = co_processor_output.headers {
*res.response.headers_mut() = internalize_header_map(headers)?;
}
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
res.context.upsert_json_value(key, move |_current| value);
}
}
return Ok(ControlFlow::Break(res));
}
let new_body = match co_processor_output.body {
Some(bytes) => RouterBody::from(bytes),
None => RouterBody::from(bytes),
};
request.router_request = http::Request::from_parts(parts, new_body.into_inner());
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
request
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*request.router_request.headers_mut() = internalize_header_map(headers)?;
}
Ok(ControlFlow::Continue(request))
}
async fn process_router_response_stage<C>(
http_client: C,
coprocessor_url: String,
sdl: Arc<String>,
mut response: router::Response,
response_config: RouterResponseConf,
_response_validation: bool, ) -> Result<router::Response, BoxError>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let should_be_executed = response_config
.condition
.as_ref()
.map(|c| c.evaluate_response(&response))
.unwrap_or(true);
if !should_be_executed {
return Ok(response);
}
let (parts, body) = response.response.into_parts();
let (first, rest): (
Option<Result<Bytes, hyper::Error>>,
crate::services::router::Body,
) = body.into_future().await;
let opt_first: Option<Bytes> = first.and_then(|f| f.ok());
let bytes = match opt_first {
Some(b) => b,
None => {
tracing::error!(
"Coprocessor cannot convert body into future due to problem with first part"
);
return Err(BoxError::from(
"Coprocessor cannot convert body into future due to problem with first part",
));
}
};
let headers_to_send = response_config
.headers
.then(|| externalize_header_map(&parts.headers))
.transpose()?;
let body_to_send = response_config
.body
.then(|| std::str::from_utf8(&bytes).map(|s| s.to_string()))
.transpose()?;
let status_to_send = response_config.status_code.then(|| parts.status.as_u16());
let context_to_send = response_config.context.then(|| response.context.clone());
let sdl_to_send = response_config.sdl.then(|| sdl.clone().to_string());
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterResponse)
.id(response.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_status_code(status_to_send)
.and_sdl(sdl_to_send.clone())
.build();
tracing::debug!(?payload, "externalized output");
let guard = response.context.enter_active_request();
let start = Instant::now();
let co_processor_result = payload.call(http_client.clone(), &coprocessor_url).await;
let duration = start.elapsed();
drop(guard);
record_coprocessor_duration(PipelineStep::RouterResponse, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterResponse)?;
let new_body = match co_processor_output.body {
Some(bytes) => RouterBody::from(bytes),
None => RouterBody::from(bytes),
};
response.response = http::Response::from_parts(parts, new_body.into_inner());
if let Some(control) = co_processor_output.control {
*response.response.status_mut() = control.get_http_status()?
}
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
response
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*response.response.headers_mut() = internalize_header_map(headers)?;
}
let (parts, body) = response.response.into_parts();
let context = response.context.clone();
let map_context = response.context.clone();
let mapped_stream = rest
.map_err(BoxError::from)
.and_then(move |deferred_response| {
let generator_client = http_client.clone();
let generator_coprocessor_url = coprocessor_url.clone();
let generator_map_context = map_context.clone();
let generator_sdl_to_send = sdl_to_send.clone();
let generator_id = map_context.id.clone();
async move {
let bytes = deferred_response.to_vec();
let body_to_send = response_config
.body
.then(|| String::from_utf8(bytes.clone()))
.transpose()?;
let context_to_send = response_config
.context
.then(|| generator_map_context.clone());
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterResponse)
.id(generator_id)
.and_body(body_to_send)
.and_context(context_to_send)
.and_sdl(generator_sdl_to_send)
.build();
tracing::debug!(?payload, "externalized output");
let guard = generator_map_context.enter_active_request();
let co_processor_result = payload
.call(generator_client, &generator_coprocessor_url)
.await;
drop(guard);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterResponse)?;
let final_bytes: Bytes = match co_processor_output.body {
Some(bytes) => bytes.into(),
None => bytes.into(),
};
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
generator_map_context.upsert_json_value(key, move |_current| value);
}
}
Ok(final_bytes)
}
});
let bytes = get_body_bytes(body).await.map_err(BoxError::from);
let final_stream = once(ready(bytes)).chain(mapped_stream).boxed();
Ok(router::Response {
context,
response: http::Response::from_parts(
parts,
RouterBody::wrap_stream(final_stream).into_inner(),
),
})
}
async fn process_subgraph_request_stage<C>(
http_client: C,
coprocessor_url: String,
service_name: String,
mut request: subgraph::Request,
mut request_config: SubgraphRequestConf,
response_validation: bool,
) -> Result<ControlFlow<subgraph::Response, subgraph::Request>, BoxError>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let should_be_executed = request_config
.condition
.as_mut()
.map(|c| c.evaluate_request(&request) == Some(true))
.unwrap_or(true);
if !should_be_executed {
return Ok(ControlFlow::Continue(request));
}
let (parts, body) = request.subgraph_request.into_parts();
let headers_to_send = request_config
.headers
.then(|| externalize_header_map(&parts.headers))
.transpose()?;
let body_to_send = request_config
.body
.then(|| serde_json_bytes::to_value(&body))
.transpose()?;
let context_to_send = request_config.context.then(|| request.context.clone());
let uri = request_config.uri.then(|| parts.uri.to_string());
let subgraph_name = service_name.clone();
let service_name = request_config.service_name.then_some(service_name);
let subgraph_request_id = request_config
.subgraph_request_id
.then_some(request.id.clone());
let payload = Externalizable::subgraph_builder()
.stage(PipelineStep::SubgraphRequest)
.control(Control::default())
.id(request.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.method(parts.method.to_string())
.and_service_name(service_name)
.and_uri(uri)
.and_subgraph_request_id(subgraph_request_id)
.build();
tracing::debug!(?payload, "externalized output");
let guard = request.context.enter_active_request();
let start = Instant::now();
let co_processor_result = payload.call(http_client, &coprocessor_url).await;
let duration = start.elapsed();
drop(guard);
record_coprocessor_duration(PipelineStep::SubgraphRequest, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::SubgraphRequest)?;
let control = co_processor_output.control.expect("validated above; qed");
if matches!(control, Control::Break(_)) {
let code = control.get_http_status()?;
let res = {
let graphql_response = match co_processor_output.body.unwrap_or(Value::Null) {
Value::String(s) => graphql::Response::builder()
.errors(vec![
Error::builder()
.message(s.as_str().to_owned())
.extension_code(COPROCESSOR_ERROR_EXTENSION)
.build(),
])
.build(),
value => deserialize_coprocessor_response(value, response_validation),
};
let mut http_response = http::Response::builder()
.status(code)
.body(graphql_response)?;
if let Some(headers) = co_processor_output.headers {
*http_response.headers_mut() = internalize_header_map(headers)?;
}
let subgraph_response = subgraph::Response {
response: http_response,
context: request.context,
subgraph_name: Some(subgraph_name),
id: request.id,
};
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
subgraph_response
.context
.upsert_json_value(key, move |_current| value);
}
}
subgraph_response
};
return Ok(ControlFlow::Break(res));
}
let new_body: graphql::Request = match co_processor_output.body {
Some(value) => serde_json_bytes::from_value(value)?,
None => body,
};
request.subgraph_request = http::Request::from_parts(parts, new_body);
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
request
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*request.subgraph_request.headers_mut() = internalize_header_map(headers)?;
}
if let Some(uri) = co_processor_output.uri {
*request.subgraph_request.uri_mut() = uri.parse()?;
}
Ok(ControlFlow::Continue(request))
}
async fn process_subgraph_response_stage<C>(
http_client: C,
coprocessor_url: String,
service_name: String,
mut response: subgraph::Response,
response_config: SubgraphResponseConf,
response_validation: bool,
) -> Result<subgraph::Response, BoxError>
where
C: Service<http::Request<RouterBody>, Response = http::Response<RouterBody>, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<http::Request<RouterBody>>>::Future: Send + 'static,
{
let should_be_executed = response_config
.condition
.as_ref()
.map(|c| c.evaluate_response(&response))
.unwrap_or(true);
if !should_be_executed {
return Ok(response);
}
let (parts, body) = response.response.into_parts();
let headers_to_send = response_config
.headers
.then(|| externalize_header_map(&parts.headers))
.transpose()?;
let status_to_send = response_config.status_code.then(|| parts.status.as_u16());
let body_to_send = response_config
.body
.then(|| serde_json_bytes::to_value(&body))
.transpose()?;
let context_to_send = response_config.context.then(|| response.context.clone());
let service_name = response_config.service_name.then_some(service_name);
let subgraph_request_id = response_config
.subgraph_request_id
.then_some(response.id.clone());
let payload = Externalizable::subgraph_builder()
.stage(PipelineStep::SubgraphResponse)
.id(response.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_status_code(status_to_send)
.and_service_name(service_name)
.and_subgraph_request_id(subgraph_request_id)
.build();
tracing::debug!(?payload, "externalized output");
let guard = response.context.enter_active_request();
let start = Instant::now();
let co_processor_result = payload.call(http_client, &coprocessor_url).await;
let duration = start.elapsed();
drop(guard);
record_coprocessor_duration(PipelineStep::SubgraphResponse, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::SubgraphResponse)?;
let incoming_payload_was_valid = was_incoming_payload_valid(&body, response_config.body);
let new_body = handle_graphql_response(
body,
co_processor_output.body,
response_validation,
incoming_payload_was_valid,
)?;
response.response = http::Response::from_parts(parts, new_body);
if let Some(control) = co_processor_output.control {
*response.response.status_mut() = control.get_http_status()?
}
if let Some(context) = co_processor_output.context {
for (key, value) in context.try_into_iter()? {
response
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*response.response.headers_mut() = internalize_header_map(headers)?;
}
Ok(response)
}
fn validate_coprocessor_output<T>(
co_processor_output: &Externalizable<T>,
expected_step: PipelineStep,
) -> Result<(), BoxError> {
if co_processor_output.version != EXTERNALIZABLE_VERSION {
return Err(BoxError::from(format!(
"Coprocessor returned the wrong version: expected `{}` found `{}`",
EXTERNALIZABLE_VERSION, co_processor_output.version,
)));
}
if co_processor_output.stage != expected_step.to_string() {
return Err(BoxError::from(format!(
"Coprocessor returned the wrong stage: expected `{}` found `{}`",
expected_step, co_processor_output.stage,
)));
}
if co_processor_output.control.is_none() && co_processor_output.stage.ends_with("Request") {
return Err(BoxError::from(format!(
"Coprocessor response is missing the `control` parameter in the `{}` stage. You must specify \"control\": \"Continue\" or \"control\": \"Break\"",
co_processor_output.stage,
)));
}
Ok(())
}
pub(super) fn internalize_header_map(
input: HashMap<String, Vec<String>>,
) -> Result<HeaderMap<HeaderValue>, BoxError> {
let mut output = HeaderMap::with_capacity(input.len());
for (k, values) in input
.into_iter()
.filter(|(k, _)| k != header::CONTENT_LENGTH.as_str())
{
for v in values {
let key = HeaderName::from_str(k.as_ref())?;
let value = HeaderValue::from_str(v.as_ref())?;
output.append(key, value);
}
}
Ok(output)
}
fn apply_response_post_processing(
mut new_body: graphql::Response,
original_response_body: &graphql::Response,
) -> graphql::Response {
new_body.subscribed = original_response_body.subscribed;
new_body.created_at = original_response_body.created_at;
if original_response_body.data == Some(Value::Null)
&& new_body.data.is_none()
&& new_body.subscribed == Some(true)
{
new_body.data = Some(Value::Null);
}
new_body
}
pub(super) fn is_graphql_response_minimally_valid(response: &graphql::Response) -> bool {
response.data.is_some() || !response.errors.is_empty()
}
pub(super) fn was_incoming_payload_valid(response: &graphql::Response, body_sent: bool) -> bool {
if body_sent {
is_graphql_response_minimally_valid(response)
} else {
true
}
}
pub(super) fn deserialize_coprocessor_response(
body_as_value: Value,
response_validation: bool,
) -> graphql::Response {
if response_validation {
graphql::Response::from_value(body_as_value).unwrap_or_else(|error| {
graphql::Response::builder()
.errors(vec![
Error::builder()
.message(format!(
"couldn't deserialize coprocessor output body: {error}"
))
.extension_code(COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION)
.build(),
])
.build()
})
} else {
serde_json_bytes::from_value(body_as_value).unwrap_or_else(|error| {
graphql::Response::builder()
.errors(vec![
Error::builder()
.message(format!(
"couldn't deserialize coprocessor output body: {error}"
))
.extension_code(COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION)
.build(),
])
.build()
})
}
}
pub(super) fn handle_graphql_response(
original_response_body: graphql::Response,
copro_response_body: Option<Value>,
response_validation: bool,
incoming_payload_was_valid: bool,
) -> Result<graphql::Response, BoxError> {
const ENABLE_CONDITIONAL_VALIDATION: bool = true;
let should_validate =
response_validation && (!ENABLE_CONDITIONAL_VALIDATION || incoming_payload_was_valid);
Ok(match copro_response_body {
Some(value) => {
if should_validate {
let new_body = graphql::Response::from_value(value)?;
apply_response_post_processing(new_body, &original_response_body)
} else {
match serde_json_bytes::from_value::<graphql::Response>(value) {
Ok(new_body) => {
apply_response_post_processing(new_body, &original_response_body)
}
Err(_) => {
original_response_body
}
}
}
}
None => original_response_body,
})
}