use std::collections::HashMap;
use std::collections::HashSet;
use std::ops::ControlFlow;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use bytes::Bytes;
use futures::StreamExt;
use futures::TryStreamExt;
use futures::future::ready;
use futures::stream::once;
use http::HeaderMap;
use http::HeaderName;
use http::HeaderValue;
use http::header;
use http_body_util::BodyExt;
use schemars::JsonSchema;
use serde::Deserialize;
use tower::BoxError;
use tower::Layer;
use tower::Service;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower::timeout::TimeoutLayer;
use tower::util::MapFutureLayer;
use crate::Context;
use crate::configuration::shared::Client;
use crate::context::context_key_from_deprecated;
use crate::context::context_key_to_deprecated;
use crate::error::Error;
use crate::graphql;
use crate::json_ext::Value;
use crate::layers::ServiceBuilderExt;
use crate::layers::async_checkpoint::AsyncCheckpointLayer;
use crate::plugin::PluginInit;
use crate::plugin::PluginPrivate;
use crate::plugins::telemetry::config_new::conditions::Condition;
use crate::plugins::telemetry::config_new::router::selectors::RouterSelector;
use crate::plugins::telemetry::config_new::subgraph::selectors::SubgraphSelector;
use crate::register_private_plugin;
use crate::services;
use crate::services::PATH_QUERY_PARAM;
use crate::services::external::Control;
use crate::services::external::DEFAULT_EXTERNALIZATION_TIMEOUT;
use crate::services::external::EXTERNALIZABLE_VERSION;
use crate::services::external::Externalizable;
use crate::services::external::PipelineStep;
use crate::services::external::externalize_header_map;
use crate::services::http::HttpRequest;
use crate::services::http::HttpResponse;
use crate::services::router;
use crate::services::router::body::RouterBody;
use crate::services::subgraph;
#[cfg(test)]
mod test;
mod connector;
mod execution;
mod supergraph;
pub(crate) const EXTERNAL_SPAN_NAME: &str = "external_plugin";
const COPROCESSOR_ERROR_EXTENSION: &str = "ERROR";
const COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION: &str = "EXTERNAL_DESERIALIZATION_ERROR";
type HTTPClientService = tower::timeout::Timeout<crate::services::http::HttpClientService>;
#[async_trait::async_trait]
impl PluginPrivate for CoprocessorPlugin<HTTPClientService> {
type Config = Conf;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
let client_config = init.config.client.clone().unwrap_or_default();
if matches!(
init.config.router.request.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.router.request.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.router.response.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.router.response.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.supergraph.request.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.supergraph.request.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.supergraph.response.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.supergraph.response.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.execution.request.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.execution.request.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.execution.response.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.execution.response.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.subgraph.all.request.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.subgraph.all.request.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.subgraph.all.response.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.subgraph.all.response.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.connector.all.request.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.connector.all.request.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
if matches!(
init.config.connector.all.response.context,
ContextConf::Deprecated(true)
) {
tracing::warn!(
"Configuration `coprocessor.connector.all.response.context: true` is deprecated. See https://go.apollo.dev/o/coprocessor-context"
);
}
validate_coprocessor_url(&init.config.url, "coprocessor.url")?;
if let Some(ref url) = init.config.router.request.url {
validate_coprocessor_url(url, "coprocessor.router.request.url")?;
}
if let Some(ref url) = init.config.router.response.url {
validate_coprocessor_url(url, "coprocessor.router.response.url")?;
}
if let Some(ref url) = init.config.supergraph.request.url {
validate_coprocessor_url(url, "coprocessor.supergraph.request.url")?;
}
if let Some(ref url) = init.config.supergraph.response.url {
validate_coprocessor_url(url, "coprocessor.supergraph.response.url")?;
}
if let Some(ref url) = init.config.execution.request.url {
validate_coprocessor_url(url, "coprocessor.execution.request.url")?;
}
if let Some(ref url) = init.config.execution.response.url {
validate_coprocessor_url(url, "coprocessor.execution.response.url")?;
}
if let Some(ref url) = init.config.subgraph.all.request.url {
validate_coprocessor_url(url, "coprocessor.subgraph.all.request.url")?;
}
if let Some(ref url) = init.config.subgraph.all.response.url {
validate_coprocessor_url(url, "coprocessor.subgraph.all.response.url")?;
}
let tls_root_store =
crate::services::http::service::HttpClientService::native_roots_store();
let http_client_service =
crate::services::http::service::HttpClientService::from_config_for_coprocessor(
&tls_root_store,
client_config,
)?;
let client = TimeoutLayer::new(init.config.timeout).layer(http_client_service);
CoprocessorPlugin::new(client, init.config, init.supergraph_sdl)
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
self.router_service(service)
}
fn supergraph_service(
&self,
service: services::supergraph::BoxService,
) -> services::supergraph::BoxService {
self.supergraph_service(service)
}
fn execution_service(
&self,
service: services::execution::BoxService,
) -> services::execution::BoxService {
self.execution_service(service)
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
self.subgraph_service(name, service)
}
fn connector_request_service(
&self,
service: crate::services::connector::request_service::BoxService,
source_name: String,
) -> crate::services::connector::request_service::BoxService {
self.connector_request_service(&source_name, service)
}
}
register_private_plugin!(
"apollo",
"coprocessor",
CoprocessorPlugin<HTTPClientService>
);
#[derive(Debug)]
struct CoprocessorPlugin<C>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
http_client: C,
configuration: Conf,
sdl: Arc<String>,
}
impl<C> CoprocessorPlugin<C>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
fn new(http_client: C, configuration: Conf, sdl: Arc<String>) -> Result<Self, BoxError> {
Ok(Self {
http_client,
configuration,
sdl,
})
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
self.configuration.router.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn supergraph_service(
&self,
service: services::supergraph::BoxService,
) -> services::supergraph::BoxService {
self.configuration.supergraph.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn execution_service(
&self,
service: services::execution::BoxService,
) -> services::execution::BoxService {
self.configuration.execution.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
self.sdl.clone(),
self.configuration.response_validation,
)
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
self.configuration.subgraph.all.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
name.to_string(),
self.configuration.response_validation,
)
}
fn connector_request_service(
&self,
source_name: &str,
service: crate::services::connector::request_service::BoxService,
) -> crate::services::connector::request_service::BoxService {
self.configuration.connector.all.as_service(
self.http_client.clone(),
service,
self.configuration.url.clone(),
source_name.to_string(),
)
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct RouterRequestConf {
pub(super) condition: Option<Condition<RouterSelector>>,
pub(super) headers: bool,
pub(super) context: ContextConf,
pub(super) body: bool,
pub(super) sdl: bool,
pub(super) path: bool,
pub(super) method: bool,
pub(super) url: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct RouterResponseConf {
pub(super) condition: Condition<RouterSelector>,
pub(super) headers: bool,
pub(super) context: ContextConf,
pub(super) body: bool,
pub(super) sdl: bool,
pub(super) status_code: bool,
pub(super) url: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphRequestConf {
pub(super) condition: Condition<SubgraphSelector>,
pub(super) headers: bool,
pub(super) context: ContextConf,
pub(super) body: bool,
pub(super) uri: bool,
pub(super) method: bool,
pub(super) service_name: bool,
pub(super) subgraph_request_id: bool,
pub(super) url: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphResponseConf {
pub(super) condition: Condition<SubgraphSelector>,
pub(super) headers: bool,
pub(super) context: ContextConf,
pub(super) body: bool,
pub(super) service_name: bool,
pub(super) status_code: bool,
pub(super) subgraph_request_id: bool,
pub(super) url: Option<String>,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[schemars(rename = "CoprocessorConfig")]
struct Conf {
url: String,
client: Option<Client>,
#[serde(deserialize_with = "humantime_serde::deserialize")]
#[schemars(with = "String", default = "default_timeout")]
#[serde(default = "default_timeout")]
timeout: Duration,
#[serde(default = "default_response_validation")]
response_validation: bool,
#[serde(default)]
router: RouterStage,
#[serde(default)]
supergraph: supergraph::SupergraphStage,
#[serde(default)]
execution: execution::ExecutionStage,
#[serde(default)]
subgraph: SubgraphStages,
#[serde(default)]
connector: connector::ConnectorStages,
}
#[derive(Clone, Debug, Deserialize, JsonSchema, PartialEq)]
#[serde(deny_unknown_fields, untagged)]
pub(super) enum ContextConf {
Deprecated(bool),
NewContextConf(NewContextConf),
}
impl ContextConf {
fn is_deprecated(&self) -> bool {
match self {
Self::Deprecated(v) => *v,
Self::NewContextConf(c) => *c == NewContextConf::Deprecated,
}
}
}
impl Default for ContextConf {
fn default() -> Self {
Self::Deprecated(false)
}
}
#[derive(Clone, Debug, Deserialize, JsonSchema, PartialEq)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(super) enum NewContextConf {
All,
Deprecated,
Selective(Arc<HashSet<String>>),
}
impl ContextConf {
pub(crate) fn get_context(&self, ctx: &Context) -> Option<Context> {
match self {
Self::NewContextConf(NewContextConf::All) => Some(ctx.clone()),
Self::NewContextConf(NewContextConf::Deprecated) | Self::Deprecated(true) => {
let mut new_ctx = Context::from_iter(ctx.iter().map(|elt| {
(
context_key_to_deprecated(elt.key().clone()),
elt.value().clone(),
)
}));
new_ctx.id = ctx.id.clone();
Some(new_ctx)
}
Self::NewContextConf(NewContextConf::Selective(context_keys)) => {
let mut new_ctx = Context::from_iter(ctx.iter().filter_map(|elt| {
if context_keys.contains(elt.key()) {
Some((elt.key().clone(), elt.value().clone()))
} else {
None
}
}));
new_ctx.id = ctx.id.clone();
Some(new_ctx)
}
Self::Deprecated(false) => None,
}
}
}
fn default_timeout() -> Duration {
DEFAULT_EXTERNALIZATION_TIMEOUT
}
fn default_response_validation() -> bool {
true
}
pub(crate) fn validate_coprocessor_url(url: &str, config_path: &str) -> Result<(), BoxError> {
if let Some(path) = url.strip_prefix("unix://") {
if path.is_empty() {
return Err(format!(
"{config_path}: Unix socket URL must include a path (e.g., 'unix:///var/run/coprocessor.sock')"
)
.into());
}
if !path.starts_with('/') {
return Err(format!(
"{config_path}: Unix socket path should be absolute (e.g., 'unix:///var/run/coprocessor.sock'), got 'unix://{path}'"
)
.into());
}
if path.contains('?') && !path.contains(PATH_QUERY_PARAM) {
tracing::warn!(
"{config_path}: Unix sockets should use valid query parameters if using `?` (e.g., 'unix:///var/run/coprocessor.sock?path=some_path'), got 'unix://{path}'"
);
}
} else {
url.parse::<http::Uri>()
.map_err(|e| format!("{config_path}: invalid URL '{url}': {e}"))?;
}
Ok(())
}
pub(crate) fn update_context_from_coprocessor(
target_context: &Context,
context_returned: Context,
context_config: &ContextConf,
) -> Result<(), BoxError> {
let mut keys_returned = HashSet::with_capacity(context_returned.len());
for (mut key, value) in context_returned.try_into_iter()? {
if context_config.is_deprecated() {
key = context_key_from_deprecated(key);
}
keys_returned.insert(key.clone());
target_context.insert_json_value(key, value);
}
match context_config {
ContextConf::NewContextConf(NewContextConf::Selective(context_keys)) => {
target_context.retain(|key, _v| {
if keys_returned.contains(key) {
return true;
} else if context_keys.contains(key) {
return false;
}
true
});
}
_ => target_context.retain(|key, _v| keys_returned.contains(key)),
}
Ok(())
}
fn record_coprocessor_duration(stage: PipelineStep, duration: Duration) {
f64_histogram!(
"apollo.router.operations.coprocessor.duration",
"Time spent waiting for the coprocessor to answer, in seconds",
duration.as_secs_f64(),
coprocessor.stage = stage.to_string()
);
}
fn record_coprocessor_operation(stage: PipelineStep, succeeded: bool) {
u64_counter!(
"apollo.router.operations.coprocessor",
"Total run operations with co-processors enabled",
1,
"coprocessor.stage" = stage.to_string(),
"coprocessor.succeeded" = succeeded
);
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default)]
pub(super) struct RouterStage {
pub(super) request: RouterRequestConf,
pub(super) response: RouterResponseConf,
}
impl RouterStage {
pub(crate) fn as_service<C>(
&self,
http_client: C,
service: router::BoxService,
default_url: String,
sdl: Arc<String>,
response_validation: bool,
) -> router::BoxService
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
let request_layer = (self.request != Default::default()).then_some({
let request_config = self.request.clone();
let coprocessor_url = request_config.url.clone().unwrap_or(default_url.clone());
let http_client = http_client.clone();
let sdl = sdl.clone();
AsyncCheckpointLayer::new(move |request: router::Request| {
let request_config = request_config.clone();
let coprocessor_url = coprocessor_url.clone();
let http_client = http_client.clone();
let sdl = sdl.clone();
async move {
let mut succeeded = true;
let mut executed = false;
let result = process_router_request_stage(
http_client,
coprocessor_url,
sdl,
request,
request_config,
response_validation,
&mut executed,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!("coprocessor: router request stage error: {error}");
error
});
if executed {
record_coprocessor_operation(PipelineStep::RouterRequest, succeeded);
}
result
}
})
});
let response_layer = (self.response != Default::default()).then_some({
let response_config = self.response.clone();
let coprocessor_url = response_config.url.clone().unwrap_or(default_url);
MapFutureLayer::new(move |fut| {
let sdl = sdl.clone();
let coprocessor_url = coprocessor_url.clone();
let http_client = http_client.clone();
let response_config = response_config.clone();
async move {
let response: router::Response = fut.await?;
let mut succeeded = true;
let mut executed = false;
let result = process_router_response_stage(
http_client,
coprocessor_url,
sdl,
response,
response_config,
response_validation,
&mut executed,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!("coprocessor: router response stage error: {error}");
error
});
if executed {
record_coprocessor_operation(PipelineStep::RouterResponse, succeeded);
};
result
}
})
});
fn external_service_span() -> impl Fn(&router::Request) -> tracing::Span + Clone {
move |_request: &router::Request| {
tracing::info_span!(
EXTERNAL_SPAN_NAME,
"external service" = stringify!(router::Request),
"otel.kind" = "INTERNAL"
)
}
}
ServiceBuilder::new()
.instrument(external_service_span())
.option_layer(request_layer)
.option_layer(response_layer)
.buffered() .service(service)
.boxed()
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphStages {
#[serde(default)]
pub(super) all: SubgraphStage,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub(super) struct SubgraphStage {
#[serde(default)]
pub(super) request: SubgraphRequestConf,
#[serde(default)]
pub(super) response: SubgraphResponseConf,
}
impl SubgraphStage {
pub(crate) fn as_service<C>(
&self,
http_client: C,
service: subgraph::BoxService,
default_url: String,
service_name: String,
response_validation: bool,
) -> subgraph::BoxService
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
let request_layer = (self.request != Default::default()).then_some({
let request_config = self.request.clone();
let http_client = http_client.clone();
let coprocessor_url = request_config.url.clone().unwrap_or(default_url.clone());
let service_name = service_name.clone();
AsyncCheckpointLayer::new(move |request: subgraph::Request| {
let http_client = http_client.clone();
let coprocessor_url = coprocessor_url.clone();
let service_name = service_name.clone();
let request_config = request_config.clone();
async move {
let mut succeeded = true;
let mut executed = false;
let result = process_subgraph_request_stage(
http_client,
coprocessor_url,
service_name,
request,
request_config,
response_validation,
&mut executed,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!("coprocessor: subgraph request stage error: {error}");
error
});
if executed {
record_coprocessor_operation(PipelineStep::SubgraphRequest, succeeded);
}
result
}
})
});
let response_layer = (self.response != Default::default()).then_some({
let response_config = self.response.clone();
let coprocessor_url = response_config.url.clone().unwrap_or(default_url);
MapFutureLayer::new(move |fut| {
let http_client = http_client.clone();
let coprocessor_url = coprocessor_url.clone();
let response_config = response_config.clone();
let service_name = service_name.clone();
async move {
let response: subgraph::Response = fut.await?;
let mut succeeded = true;
let mut executed = false;
let result = process_subgraph_response_stage(
http_client,
coprocessor_url,
service_name,
response,
response_config,
response_validation,
&mut executed,
)
.await
.map_err(|error| {
succeeded = false;
tracing::error!("coprocessor: subgraph response stage error: {error}");
error
});
if executed {
record_coprocessor_operation(PipelineStep::SubgraphResponse, succeeded);
}
result
}
})
});
fn external_service_span() -> impl Fn(&subgraph::Request) -> tracing::Span + Clone {
move |_request: &subgraph::Request| {
tracing::info_span!(
EXTERNAL_SPAN_NAME,
"external service" = stringify!(subgraph::Request),
"otel.kind" = "INTERNAL"
)
}
}
ServiceBuilder::new()
.instrument(external_service_span())
.option_layer(request_layer)
.option_layer(response_layer)
.buffered() .service(service)
.boxed()
}
}
async fn process_router_request_stage<C>(
http_client: C,
coprocessor_url: String,
sdl: Arc<String>,
mut request: router::Request,
mut request_config: RouterRequestConf,
response_validation: bool,
executed: &mut bool,
) -> Result<ControlFlow<router::Response, router::Request>, BoxError>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
let should_be_executed = request_config
.condition
.as_mut()
.map(|c| c.evaluate_request(&request) == Some(true))
.unwrap_or(true);
if !should_be_executed {
return Ok(ControlFlow::Continue(request));
}
let (parts, body) = request.router_request.into_parts();
let bytes = router::body::into_bytes(body).await?;
let headers_to_send = request_config
.headers
.then(|| externalize_header_map(&parts.headers));
let body_to_send = request_config
.body
.then(|| String::from_utf8(bytes.to_vec()))
.transpose()
.unwrap_or_default();
let path_to_send = request_config.path.then(|| parts.uri.to_string());
let context_to_send = request_config.context.get_context(&request.context);
let sdl_to_send = request_config.sdl.then(|| sdl.clone().to_string());
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterRequest)
.control(Control::default())
.id(request.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_sdl(sdl_to_send)
.and_path(path_to_send)
.method(parts.method.to_string())
.build();
tracing::debug!(?payload, "externalized output");
let start = Instant::now();
let co_processor_result = payload
.call(http_client, &coprocessor_url, Context::new())
.await;
*executed = true;
let duration = start.elapsed();
record_coprocessor_duration(PipelineStep::RouterRequest, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let mut co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterRequest)?;
let control = co_processor_output.control.expect("validated above; qed");
if matches!(control, Control::Break(_)) {
let code = control.get_http_status()?;
let body_as_value = co_processor_output
.body
.as_ref()
.and_then(|b| serde_json::from_str(b).ok())
.unwrap_or(Value::Null);
let graphql_response = match body_as_value {
Value::Null => graphql::Response::builder()
.errors(vec![
Error::builder()
.message(co_processor_output.body.take().unwrap_or_default())
.extension_code(COPROCESSOR_ERROR_EXTENSION)
.build(),
])
.build(),
_ => deserialize_coprocessor_response(body_as_value, response_validation),
};
let res = router::Response::builder()
.errors(graphql_response.errors)
.extensions(graphql_response.extensions)
.status_code(code)
.context(request.context);
let mut res = match (graphql_response.label, graphql_response.data) {
(Some(label), Some(data)) => res.label(label).data(data).build()?,
(Some(label), None) => res.label(label).build()?,
(None, Some(data)) => res.data(data).build()?,
(None, None) => res.build()?,
};
if let Some(headers) = co_processor_output.headers {
*res.response.headers_mut() = internalize_header_map(headers)?;
}
if let Some(context) = co_processor_output.context {
for (mut key, value) in context.try_into_iter()? {
if let ContextConf::NewContextConf(NewContextConf::Deprecated) =
&request_config.context
{
key = context_key_from_deprecated(key);
}
res.context.upsert_json_value(key, move |_current| value);
}
}
return Ok(ControlFlow::Break(res));
}
let new_body = match co_processor_output.body {
Some(bytes) => router::body::from_bytes(bytes),
None => router::body::from_bytes(bytes),
};
request.router_request = http::Request::from_parts(parts, new_body);
if let Some(context) = co_processor_output.context {
for (mut key, value) in context.try_into_iter()? {
if let ContextConf::NewContextConf(NewContextConf::Deprecated) = &request_config.context
{
key = context_key_from_deprecated(key);
}
request
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*request.router_request.headers_mut() = internalize_header_map(headers)?;
}
Ok(ControlFlow::Continue(request))
}
async fn process_router_response_stage<C>(
http_client: C,
coprocessor_url: String,
sdl: Arc<String>,
mut response: router::Response,
response_config: RouterResponseConf,
_response_validation: bool, executed: &mut bool,
) -> Result<router::Response, BoxError>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
if !response_config.condition.evaluate_response(&response) {
return Ok(response);
}
let (parts, body) = response.response.into_parts();
let mut stream = body.into_data_stream();
let first = stream.next().await.transpose()?;
let rest = stream;
let bytes = match first {
Some(b) => b,
None => {
tracing::error!(
"Coprocessor cannot convert body into future due to problem with first part"
);
return Err(BoxError::from(
"Coprocessor cannot convert body into future due to problem with first part",
));
}
};
let headers_to_send = response_config
.headers
.then(|| externalize_header_map(&parts.headers));
let body_to_send = response_config
.body
.then(|| std::str::from_utf8(&bytes).map(|s| s.to_string()))
.transpose()?;
let status_to_send = response_config.status_code.then(|| parts.status.as_u16());
let context_to_send = response_config.context.get_context(&response.context);
let sdl_to_send = response_config.sdl.then(|| sdl.clone().to_string());
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterResponse)
.id(response.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_status_code(status_to_send)
.and_sdl(sdl_to_send.clone())
.build();
tracing::debug!(?payload, "externalized output");
let start = Instant::now();
let co_processor_result = payload
.call(http_client.clone(), &coprocessor_url, Context::new())
.await;
*executed = true;
let duration = start.elapsed();
record_coprocessor_duration(PipelineStep::RouterResponse, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterResponse)?;
let new_body = match co_processor_output.body {
Some(bytes) => router::body::from_bytes(bytes),
None => router::body::from_bytes(bytes),
};
response.response = http::Response::from_parts(parts, new_body);
if let Some(control) = co_processor_output.control {
*response.response.status_mut() = control.get_http_status()?
}
if let Some(context) = co_processor_output.context {
update_context_from_coprocessor(&response.context, context, &response_config.context)?;
}
if let Some(headers) = co_processor_output.headers {
*response.response.headers_mut() = internalize_header_map(headers)?;
}
let (parts, body) = response.response.into_parts();
let context = response.context.clone();
let map_context = response.context.clone();
let mapped_stream = rest
.map_err(BoxError::from)
.and_then(move |deferred_response| {
let generator_client = http_client.clone();
let generator_coprocessor_url = coprocessor_url.clone();
let generator_map_context = map_context.clone();
let generator_sdl_to_send = sdl_to_send.clone();
let generator_id = map_context.id.clone();
let context_conf = response_config.context.clone();
async move {
let bytes = deferred_response.to_vec();
let body_to_send = response_config
.body
.then(|| String::from_utf8(bytes.clone()))
.transpose()?;
let generator_map_context = generator_map_context.clone();
let context_to_send = context_conf.get_context(&generator_map_context);
let payload = Externalizable::router_builder()
.stage(PipelineStep::RouterResponse)
.id(generator_id)
.and_body(body_to_send)
.and_context(context_to_send)
.and_sdl(generator_sdl_to_send)
.build();
tracing::debug!(?payload, "externalized output");
let co_processor_result = payload
.call(generator_client, &generator_coprocessor_url, Context::new())
.await;
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::RouterResponse)?;
let final_bytes: Bytes = match co_processor_output.body {
Some(bytes) => bytes.into(),
None => bytes.into(),
};
if let Some(context) = co_processor_output.context {
update_context_from_coprocessor(
&generator_map_context,
context,
&context_conf,
)?;
}
Ok(final_bytes)
}
});
let bytes = router::body::into_bytes(body).await.map_err(BoxError::from);
let final_stream = RouterBody::new(http_body_util::StreamBody::new(
once(ready(bytes))
.chain(mapped_stream)
.map(|b| b.map(http_body::Frame::data).map_err(axum::Error::new)),
));
router::Response::http_response_builder()
.context(context)
.response(http::Response::from_parts(parts, final_stream))
.build()
}
async fn process_subgraph_request_stage<C>(
http_client: C,
coprocessor_url: String,
service_name: String,
mut request: subgraph::Request,
mut request_config: SubgraphRequestConf,
response_validation: bool,
executed: &mut bool,
) -> Result<ControlFlow<subgraph::Response, subgraph::Request>, BoxError>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
if request_config.condition.evaluate_request(&request) != Some(true) {
return Ok(ControlFlow::Continue(request));
}
let (parts, body) = request.subgraph_request.into_parts();
let headers_to_send = request_config
.headers
.then(|| externalize_header_map(&parts.headers));
let body_to_send = request_config
.body
.then(|| serde_json_bytes::to_value(&body))
.transpose()?;
let context_to_send = request_config.context.get_context(&request.context);
let uri = request_config.uri.then(|| parts.uri.to_string());
let subgraph_name = service_name.clone();
let service_name = request_config.service_name.then_some(service_name);
let subgraph_request_id = request_config
.subgraph_request_id
.then_some(request.id.clone());
let payload = Externalizable::subgraph_builder()
.stage(PipelineStep::SubgraphRequest)
.control(Control::default())
.id(request.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.method(parts.method.to_string())
.and_service_name(service_name)
.and_uri(uri)
.and_subgraph_request_id(subgraph_request_id)
.build();
tracing::debug!(?payload, "externalized output");
let start = Instant::now();
let co_processor_result = payload
.call(http_client, &coprocessor_url, Context::new())
.await;
*executed = true;
let duration = start.elapsed();
record_coprocessor_duration(PipelineStep::SubgraphRequest, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::SubgraphRequest)?;
let control = co_processor_output.control.expect("validated above; qed");
if matches!(control, Control::Break(_)) {
let code = control.get_http_status()?;
let res = {
let graphql_response = match co_processor_output.body.unwrap_or(Value::Null) {
Value::String(s) => graphql::Response::builder()
.errors(vec![
Error::builder()
.message(s.as_str().to_owned())
.extension_code(COPROCESSOR_ERROR_EXTENSION)
.build(),
])
.build(),
value => deserialize_coprocessor_response(value, response_validation),
};
let mut http_response = http::Response::builder()
.status(code)
.body(graphql_response)?;
if let Some(headers) = co_processor_output.headers {
*http_response.headers_mut() = internalize_header_map(headers)?;
}
let subgraph_response = subgraph::Response {
response: http_response,
context: request.context,
subgraph_name,
id: request.id,
};
if let Some(context) = co_processor_output.context {
for (mut key, value) in context.try_into_iter()? {
if let ContextConf::NewContextConf(NewContextConf::Deprecated) =
&request_config.context
{
key = context_key_from_deprecated(key);
}
subgraph_response
.context
.upsert_json_value(key, move |_current| value);
}
}
subgraph_response
};
return Ok(ControlFlow::Break(res));
}
let new_body: graphql::Request = match co_processor_output.body {
Some(value) => serde_json_bytes::from_value(value)?,
None => body,
};
request.subgraph_request = http::Request::from_parts(parts, new_body);
if let Some(context) = co_processor_output.context {
for (mut key, value) in context.try_into_iter()? {
if let ContextConf::NewContextConf(NewContextConf::Deprecated) = &request_config.context
{
key = context_key_from_deprecated(key);
}
request
.context
.upsert_json_value(key, move |_current| value);
}
}
if let Some(headers) = co_processor_output.headers {
*request.subgraph_request.headers_mut() = internalize_header_map(headers)?;
}
if let Some(uri) = co_processor_output.uri {
*request.subgraph_request.uri_mut() = uri.parse()?;
}
Ok(ControlFlow::Continue(request))
}
async fn process_subgraph_response_stage<C>(
http_client: C,
coprocessor_url: String,
service_name: String,
mut response: subgraph::Response,
response_config: SubgraphResponseConf,
response_validation: bool,
executed: &mut bool,
) -> Result<subgraph::Response, BoxError>
where
C: Service<HttpRequest, Response = HttpResponse, Error = BoxError>
+ Clone
+ Send
+ Sync
+ 'static,
<C as tower::Service<HttpRequest>>::Future: Send + 'static,
{
if !response_config.condition.evaluate_response(&response) {
return Ok(response);
}
let (parts, body) = response.response.into_parts();
let headers_to_send = response_config
.headers
.then(|| externalize_header_map(&parts.headers));
let status_to_send = response_config.status_code.then(|| parts.status.as_u16());
let body_to_send = response_config
.body
.then(|| serde_json_bytes::to_value(&body))
.transpose()?;
let context_to_send = response_config.context.get_context(&response.context);
let service_name = response_config.service_name.then_some(service_name);
let subgraph_request_id = response_config
.subgraph_request_id
.then_some(response.id.clone());
let payload = Externalizable::subgraph_builder()
.stage(PipelineStep::SubgraphResponse)
.id(response.context.id.clone())
.and_headers(headers_to_send)
.and_body(body_to_send)
.and_context(context_to_send)
.and_status_code(status_to_send)
.and_service_name(service_name)
.and_subgraph_request_id(subgraph_request_id)
.build();
tracing::debug!(?payload, "externalized output");
let start = Instant::now();
let co_processor_result = payload
.call(http_client, &coprocessor_url, Context::new())
.await;
*executed = true;
let duration = start.elapsed();
record_coprocessor_duration(PipelineStep::SubgraphResponse, duration);
tracing::debug!(?co_processor_result, "co-processor returned");
let co_processor_output = co_processor_result?;
validate_coprocessor_output(&co_processor_output, PipelineStep::SubgraphResponse)?;
let incoming_payload_was_valid = was_incoming_payload_valid(&body, response_config.body);
let new_body = handle_graphql_response(
body,
co_processor_output.body,
response_validation,
incoming_payload_was_valid,
)?;
response.response = http::Response::from_parts(parts, new_body);
if let Some(control) = co_processor_output.control {
*response.response.status_mut() = control.get_http_status()?
}
if let Some(context) = co_processor_output.context {
update_context_from_coprocessor(&response.context, context, &response_config.context)?;
}
if let Some(headers) = co_processor_output.headers {
*response.response.headers_mut() = internalize_header_map(headers)?;
}
Ok(response)
}
fn validate_coprocessor_output<T>(
co_processor_output: &Externalizable<T>,
expected_step: PipelineStep,
) -> Result<(), BoxError> {
if co_processor_output.version != EXTERNALIZABLE_VERSION {
return Err(BoxError::from(format!(
"Coprocessor returned the wrong version: expected `{}` found `{}`",
EXTERNALIZABLE_VERSION, co_processor_output.version,
)));
}
if co_processor_output.stage != expected_step.to_string() {
return Err(BoxError::from(format!(
"Coprocessor returned the wrong stage: expected `{}` found `{}`",
expected_step, co_processor_output.stage,
)));
}
if co_processor_output.control.is_none() && co_processor_output.stage.ends_with("Request") {
return Err(BoxError::from(format!(
"Coprocessor response is missing the `control` parameter in the `{}` stage. You must specify \"control\": \"Continue\" or \"control\": \"Break\"",
co_processor_output.stage,
)));
}
Ok(())
}
pub(super) fn internalize_header_map(
input: HashMap<String, Vec<String>>,
) -> Result<HeaderMap<HeaderValue>, BoxError> {
let mut output = HeaderMap::with_capacity(input.len());
for (k, values) in input
.into_iter()
.filter(|(k, _)| k != header::CONTENT_LENGTH.as_str())
{
for v in values {
let key = HeaderName::from_str(k.as_ref())?;
let value = HeaderValue::from_str(v.as_ref())?;
output.append(key, value);
}
}
Ok(output)
}
fn apply_response_post_processing(
mut new_body: graphql::Response,
original_response_body: &graphql::Response,
) -> graphql::Response {
new_body.subscribed = original_response_body.subscribed;
new_body.created_at = original_response_body.created_at;
if original_response_body.data == Some(Value::Null)
&& new_body.data.is_none()
&& new_body.subscribed == Some(true)
{
new_body.data = Some(Value::Null);
}
new_body
}
pub(super) fn is_graphql_response_minimally_valid(response: &graphql::Response) -> bool {
response.data.is_some() || !response.errors.is_empty()
}
pub(super) fn was_incoming_payload_valid(response: &graphql::Response, body_sent: bool) -> bool {
if body_sent {
is_graphql_response_minimally_valid(response)
} else {
true
}
}
pub(super) fn deserialize_coprocessor_response(
body_as_value: Value,
response_validation: bool,
) -> graphql::Response {
if response_validation {
graphql::Response::from_value(body_as_value).unwrap_or_else(|error| {
graphql::Response::builder()
.errors(vec![
Error::builder()
.message(format!(
"couldn't deserialize coprocessor output body: {error}"
))
.extension_code(COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION)
.build(),
])
.build()
})
} else {
serde_json_bytes::from_value(body_as_value).unwrap_or_else(|error| {
graphql::Response::builder()
.errors(vec![
Error::builder()
.message(format!(
"couldn't deserialize coprocessor output body: {error}"
))
.extension_code(COPROCESSOR_DESERIALIZATION_ERROR_EXTENSION)
.build(),
])
.build()
})
}
}
pub(super) fn handle_graphql_response(
original_response_body: graphql::Response,
copro_response_body: Option<Value>,
response_validation: bool,
incoming_payload_was_valid: bool,
) -> Result<graphql::Response, BoxError> {
const ENABLE_CONDITIONAL_VALIDATION: bool = true;
let should_validate =
response_validation && (!ENABLE_CONDITIONAL_VALIDATION || incoming_payload_was_valid);
Ok(match copro_response_body {
Some(value) => {
if should_validate {
let new_body = graphql::Response::from_value(value)?;
apply_response_post_processing(new_body, &original_response_body)
} else {
match serde_json_bytes::from_value::<graphql::Response>(value) {
Ok(new_body) => {
apply_response_post_processing(new_body, &original_response_body)
}
Err(_) => {
original_response_body
}
}
}
}
None => original_response_body,
})
}