use std::future;
use std::ops::ControlFlow;
use std::sync::Arc;
use ahash::HashMap;
use ahash::HashMapExt;
use apollo_compiler::ExecutableDocument;
use apollo_compiler::schema::FieldLookupError;
use apollo_compiler::validation::Valid;
use apollo_compiler::validation::WithErrors;
use apollo_federation::error::FederationError;
use displaydoc::Display;
use futures::StreamExt;
use futures::future::Either;
use futures::stream;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use crate::Context;
use crate::error::Error;
use crate::graphql;
use crate::graphql::IntoGraphQLErrors;
use crate::json_ext::Object;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::plugins::demand_control::cost_calculator::schema::DemandControlledSchema;
use crate::plugins::demand_control::strategy::Strategy;
use crate::plugins::demand_control::strategy::StrategyFactory;
use crate::register_plugin;
use crate::services::execution;
use crate::services::execution::BoxService;
use crate::services::subgraph;
pub(crate) mod cost_calculator;
pub(crate) mod strategy;
pub(crate) static COST_ESTIMATED_KEY: &str = "cost.estimated";
pub(crate) static COST_ACTUAL_KEY: &str = "cost.actual";
pub(crate) static COST_DELTA_KEY: &str = "cost.delta";
pub(crate) static COST_RESULT_KEY: &str = "cost.result";
pub(crate) static COST_STRATEGY_KEY: &str = "cost.strategy";
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum StrategyConfig {
StaticEstimated {
list_size: u32,
max: f64,
},
#[cfg(test)]
Test {
stage: test::TestStage,
error: test::TestError,
},
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, Eq, PartialEq)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum Mode {
Measure,
Enforce,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct DemandControlConfig {
enabled: bool,
mode: Mode,
strategy: StrategyConfig,
}
#[derive(Debug, Display, Error)]
pub(crate) enum DemandControlError {
EstimatedCostTooExpensive {
estimated_cost: f64,
max_cost: f64,
},
#[allow(dead_code)]
ActualCostTooExpensive {
actual_cost: f64,
max_cost: f64,
},
QueryParseFailure(String),
SubgraphOperationNotInitialized(crate::query_planner::fetch::SubgraphOperationNotInitialized),
ContextSerializationError(String),
FederationError(FederationError),
}
impl IntoGraphQLErrors for DemandControlError {
fn into_graphql_errors(self) -> Result<Vec<Error>, Self> {
match self {
DemandControlError::EstimatedCostTooExpensive {
estimated_cost,
max_cost,
} => {
let mut extensions = Object::new();
extensions.insert("cost.estimated", estimated_cost.into());
extensions.insert("cost.max", max_cost.into());
Ok(vec![
graphql::Error::builder()
.extension_code(self.code())
.extensions(extensions)
.message(self.to_string())
.build(),
])
}
DemandControlError::ActualCostTooExpensive {
actual_cost,
max_cost,
} => {
let mut extensions = Object::new();
extensions.insert("cost.actual", actual_cost.into());
extensions.insert("cost.max", max_cost.into());
Ok(vec![
graphql::Error::builder()
.extension_code(self.code())
.extensions(extensions)
.message(self.to_string())
.build(),
])
}
DemandControlError::QueryParseFailure(_) => Ok(vec![
graphql::Error::builder()
.extension_code(self.code())
.message(self.to_string())
.build(),
]),
DemandControlError::SubgraphOperationNotInitialized(e) => Ok(e.into_graphql_errors()),
DemandControlError::ContextSerializationError(_) => Ok(vec![
graphql::Error::builder()
.extension_code(self.code())
.message(self.to_string())
.build(),
]),
DemandControlError::FederationError(_) => Ok(vec![
graphql::Error::builder()
.extension_code(self.code())
.message(self.to_string())
.build(),
]),
}
}
}
impl DemandControlError {
fn code(&self) -> &'static str {
match self {
DemandControlError::EstimatedCostTooExpensive { .. } => "COST_ESTIMATED_TOO_EXPENSIVE",
DemandControlError::ActualCostTooExpensive { .. } => "COST_ACTUAL_TOO_EXPENSIVE",
DemandControlError::QueryParseFailure(_) => "COST_QUERY_PARSE_FAILURE",
DemandControlError::SubgraphOperationNotInitialized(e) => e.code(),
DemandControlError::ContextSerializationError(_) => "COST_CONTEXT_SERIALIZATION_ERROR",
DemandControlError::FederationError(_) => "FEDERATION_ERROR",
}
}
}
impl<T> From<WithErrors<T>> for DemandControlError {
fn from(value: WithErrors<T>) -> Self {
DemandControlError::QueryParseFailure(format!("{}", value))
}
}
impl From<FieldLookupError<'_>> for DemandControlError {
fn from(value: FieldLookupError) -> Self {
match value {
FieldLookupError::NoSuchType => DemandControlError::QueryParseFailure(
"Attempted to look up a type which does not exist in the schema".to_string(),
),
FieldLookupError::NoSuchField(type_name, _) => {
DemandControlError::QueryParseFailure(format!(
"Attempted to look up a field on type {}, but the field does not exist",
type_name
))
}
}
}
}
impl From<FederationError> for DemandControlError {
fn from(value: FederationError) -> Self {
DemandControlError::FederationError(value)
}
}
#[derive(Clone)]
pub(crate) struct DemandControlContext {
pub(crate) strategy: Strategy,
pub(crate) variables: Object,
}
impl Context {
pub(crate) fn insert_estimated_cost(&self, cost: f64) -> Result<(), DemandControlError> {
self.insert(COST_ESTIMATED_KEY, cost)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?;
Ok(())
}
pub(crate) fn get_estimated_cost(&self) -> Result<Option<f64>, DemandControlError> {
self.get::<&str, f64>(COST_ESTIMATED_KEY)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))
}
pub(crate) fn insert_actual_cost(&self, cost: f64) -> Result<(), DemandControlError> {
self.insert(COST_ACTUAL_KEY, cost)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?;
Ok(())
}
pub(crate) fn get_actual_cost(&self) -> Result<Option<f64>, DemandControlError> {
self.get::<&str, f64>(COST_ACTUAL_KEY)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))
}
pub(crate) fn get_cost_delta(&self) -> Result<Option<f64>, DemandControlError> {
let estimated = self.get_estimated_cost()?;
let actual = self.get_actual_cost()?;
Ok(estimated.zip(actual).map(|(est, act)| est - act))
}
pub(crate) fn insert_cost_result(&self, result: String) -> Result<(), DemandControlError> {
self.insert(COST_RESULT_KEY, result)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?;
Ok(())
}
pub(crate) fn get_cost_result(&self) -> Result<Option<String>, DemandControlError> {
self.get::<&str, String>(COST_RESULT_KEY)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))
}
pub(crate) fn insert_cost_strategy(&self, strategy: String) -> Result<(), DemandControlError> {
self.insert(COST_STRATEGY_KEY, strategy)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?;
Ok(())
}
pub(crate) fn get_cost_strategy(&self) -> Result<Option<String>, DemandControlError> {
self.get::<&str, String>(COST_STRATEGY_KEY)
.map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))
}
pub(crate) fn insert_demand_control_context(&self, ctx: DemandControlContext) {
self.extensions().with_lock(|mut lock| lock.insert(ctx));
}
pub(crate) fn get_demand_control_context(&self) -> Option<DemandControlContext> {
self.extensions().with_lock(|lock| lock.get().cloned())
}
}
pub(crate) struct DemandControl {
config: DemandControlConfig,
strategy_factory: StrategyFactory,
}
impl DemandControl {
fn report_operation_metric(context: Context) {
let result = context
.get(COST_RESULT_KEY)
.ok()
.flatten()
.unwrap_or("NO_CONTEXT".to_string());
u64_counter!(
"apollo.router.operations.demand_control",
"Total operations with demand control enabled",
1,
"demand_control.result" = result
);
}
}
#[async_trait::async_trait]
impl Plugin for DemandControl {
type Config = DemandControlConfig;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
if !init.config.enabled {
return Ok(DemandControl {
strategy_factory: StrategyFactory::new(
init.config.clone(),
Arc::new(DemandControlledSchema::empty(
init.supergraph_schema.clone(),
)?),
Arc::new(HashMap::new()),
),
config: init.config,
});
}
let demand_controlled_supergraph_schema =
DemandControlledSchema::new(init.supergraph_schema.clone())?;
let mut demand_controlled_subgraph_schemas = HashMap::new();
for (subgraph_name, subgraph_schema) in init.subgraph_schemas.iter() {
let demand_controlled_subgraph_schema =
DemandControlledSchema::new(subgraph_schema.clone())?;
demand_controlled_subgraph_schemas
.insert(subgraph_name.clone(), demand_controlled_subgraph_schema);
}
Ok(DemandControl {
strategy_factory: StrategyFactory::new(
init.config.clone(),
Arc::new(demand_controlled_supergraph_schema),
Arc::new(demand_controlled_subgraph_schemas),
),
config: init.config,
})
}
fn execution_service(&self, service: BoxService) -> BoxService {
if !self.config.enabled {
service
} else {
let strategy = self.strategy_factory.create();
ServiceBuilder::new()
.checkpoint(move |req: execution::Request| {
req.context
.insert_demand_control_context(DemandControlContext {
strategy: strategy.clone(),
variables: req.supergraph_request.body().variables.clone(),
});
Ok(match strategy.on_execution_request(&req) {
Ok(_) => ControlFlow::Continue(req),
Err(err) => ControlFlow::Break(
execution::Response::builder()
.errors(
err.into_graphql_errors()
.expect("must be able to convert to graphql error"),
)
.context(req.context.clone())
.build()
.expect("Must be able to build response"),
),
})
})
.map_response(|mut resp: execution::Response| {
let req = resp
.context
.unsupported_executable_document()
.expect("must have document");
let strategy = resp
.context
.get_demand_control_context()
.map(|ctx| ctx.strategy)
.expect("must have strategy");
let context = resp.context.clone();
let report_operation_metric =
futures::stream::unfold(resp.context.clone(), |ctx| async move {
Self::report_operation_metric(ctx);
None
});
resp.response = resp.response.map(move |resp| {
resp.flat_map(move |resp| {
match strategy.on_execution_response(&context, req.as_ref(), &resp) {
Ok(_) => Either::Left(stream::once(future::ready(Ok(resp)))),
Err(err) => {
Either::Right(stream::iter(vec![
Ok(graphql::Response::builder()
.errors(
err.into_graphql_errors().expect(
"must be able to convert to graphql error",
),
)
.extensions(crate::json_ext::Object::new())
.build()),
Err(()),
]))
}
}
})
.take_while(|resp| future::ready(resp.is_ok()))
.map(|i| i.expect("error used to terminate stream"))
.chain(report_operation_metric)
.boxed()
});
resp
})
.service(service)
.boxed()
}
}
fn subgraph_service(
&self,
subgraph_name: &str,
service: subgraph::BoxService,
) -> subgraph::BoxService {
if !self.config.enabled {
service
} else {
let subgraph_name = subgraph_name.to_owned();
let subgraph_name_map_fut = subgraph_name.to_owned();
ServiceBuilder::new()
.checkpoint(move |req: subgraph::Request| {
let strategy = req.context.get_demand_control_context().map(|c| c.strategy).expect("must have strategy");
Ok(match strategy.on_subgraph_request(&req) {
Ok(_) => ControlFlow::Continue(req),
Err(err) => ControlFlow::Break(
subgraph::Response::builder()
.errors(
err.into_graphql_errors()
.expect("must be able to convert to graphql error"),
)
.context(req.context.clone())
.extensions(crate::json_ext::Object::new())
.subgraph_name(subgraph_name.clone())
.build(),
),
})
})
.map_future_with_request_data(
move |req: &subgraph::Request| {
(
subgraph_name_map_fut.clone(),
req.executable_document.clone().unwrap_or_else(|| {
Arc::new(Valid::assume_valid(ExecutableDocument::new()))
}),
)
},
|(subgraph_name, req): (String, Arc<Valid<ExecutableDocument>>), fut| async move {
let resp: subgraph::Response = fut.await?;
let strategy = resp.context.get_demand_control_context().map(|c| c.strategy).expect("must have strategy");
Ok(match strategy.on_subgraph_response(req.as_ref(), &resp) {
Ok(_) => resp,
Err(err) => subgraph::Response::builder()
.errors(
err.into_graphql_errors()
.expect("must be able to convert to graphql error"),
)
.subgraph_name(subgraph_name)
.context(resp.context.clone())
.extensions(Object::new())
.build(),
})
},
)
.service(service)
.boxed()
}
}
}
register_plugin!("apollo", "demand_control", DemandControl);
#[cfg(test)]
mod test {
use std::sync::Arc;
use apollo_compiler::ExecutableDocument;
use apollo_compiler::Schema;
use apollo_compiler::ast;
use apollo_compiler::validation::Valid;
use futures::StreamExt;
use schemars::JsonSchema;
use serde::Deserialize;
use crate::Context;
use crate::graphql;
use crate::graphql::Response;
use crate::metrics::FutureMetricsExt;
use crate::plugins::demand_control::DemandControl;
use crate::plugins::demand_control::DemandControlContext;
use crate::plugins::demand_control::DemandControlError;
use crate::plugins::test::PluginTestHarness;
use crate::services::execution;
use crate::services::layers::query_analysis::ParsedDocument;
use crate::services::layers::query_analysis::ParsedDocumentInner;
use crate::services::subgraph;
#[tokio::test]
async fn test_measure_on_execution_request() {
let body = test_on_execution(include_str!(
"fixtures/measure_on_execution_request.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_enforce_on_execution_request() {
let body = test_on_execution(include_str!(
"fixtures/enforce_on_execution_request.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_measure_on_execution_response() {
let body = test_on_execution(include_str!(
"fixtures/measure_on_execution_response.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_enforce_on_execution_response() {
let body = test_on_execution(include_str!(
"fixtures/enforce_on_execution_response.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_measure_on_subgraph_request() {
let body = test_on_subgraph(include_str!(
"fixtures/measure_on_subgraph_request.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_enforce_on_subgraph_request() {
let body = test_on_subgraph(include_str!(
"fixtures/enforce_on_subgraph_request.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_measure_on_subgraph_response() {
let body = test_on_subgraph(include_str!(
"fixtures/measure_on_subgraph_response.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_enforce_on_subgraph_response() {
let body = test_on_subgraph(include_str!(
"fixtures/enforce_on_subgraph_response.router.yaml"
))
.await;
insta::assert_yaml_snapshot!(body);
}
#[tokio::test]
async fn test_operation_metrics() {
async {
test_on_execution(include_str!(
"fixtures/measure_on_execution_request.router.yaml"
))
.await;
assert_counter!(
"apollo.router.operations.demand_control",
1,
"demand_control.result" = "COST_ESTIMATED_TOO_EXPENSIVE"
);
test_on_execution(include_str!(
"fixtures/enforce_on_execution_response.router.yaml"
))
.await;
assert_counter!(
"apollo.router.operations.demand_control",
2,
"demand_control.result" = "COST_ESTIMATED_TOO_EXPENSIVE"
);
test_on_subgraph(include_str!(
"fixtures/enforce_on_subgraph_request.router.yaml"
))
.await;
test_on_subgraph(include_str!(
"fixtures/enforce_on_subgraph_response.router.yaml"
))
.await;
assert_counter!(
"apollo.router.operations.demand_control",
2,
"demand_control.result" = "COST_ESTIMATED_TOO_EXPENSIVE"
);
}
.with_metrics()
.await
}
async fn test_on_execution(config: &'static str) -> Vec<Response> {
let plugin = PluginTestHarness::<DemandControl>::builder()
.config(config)
.build()
.await;
let ctx = context();
let resp = plugin
.call_execution(
execution::Request::fake_builder().context(ctx).build(),
|req| async {
execution::Response::fake_builder()
.context(req.context)
.build()
},
)
.await
.unwrap();
resp.response
.into_body()
.collect::<Vec<graphql::Response>>()
.await
}
async fn test_on_subgraph(config: &'static str) -> Response {
let plugin = PluginTestHarness::<DemandControl>::builder()
.config(config)
.build()
.await;
let strategy = plugin.strategy_factory.create();
let ctx = context();
ctx.insert_demand_control_context(DemandControlContext {
strategy,
variables: Default::default(),
});
let mut req = subgraph::Request::fake_builder()
.subgraph_name("test")
.context(ctx)
.build();
req.executable_document = Some(Arc::new(Valid::assume_valid(ExecutableDocument::new())));
let resp = plugin
.call_subgraph(req, |req| async {
Ok(subgraph::Response::fake_builder()
.context(req.context)
.build())
})
.await
.unwrap();
resp.response.into_body()
}
fn context() -> Context {
let schema = Schema::parse_and_validate("type Query { f: Int }", "").unwrap();
let ast = ast::Document::parse("{__typename}", "").unwrap();
let doc = ast.to_executable_validate(&schema).unwrap();
let parsed_document =
ParsedDocumentInner::new(ast, doc.into(), None, Default::default()).unwrap();
let ctx = Context::new();
ctx.extensions()
.with_lock(|mut lock| lock.insert::<ParsedDocument>(parsed_document));
ctx
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum TestStage {
ExecutionRequest,
ExecutionResponse,
SubgraphRequest,
SubgraphResponse,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum TestError {
EstimatedCostTooExpensive,
ActualCostTooExpensive,
}
impl From<&TestError> for DemandControlError {
fn from(value: &TestError) -> Self {
match value {
TestError::EstimatedCostTooExpensive => {
DemandControlError::EstimatedCostTooExpensive {
max_cost: 1.0,
estimated_cost: 2.0,
}
}
TestError::ActualCostTooExpensive => DemandControlError::ActualCostTooExpensive {
actual_cost: 1.0,
max_cost: 2.0,
},
}
}
}
}