use std::collections::HashMap;
use std::collections::HashSet;
use std::default::Default;
use std::str::FromStr;
use std::sync::Arc;
use serde::de::Error as DeserializeError;
use serde::ser::Error as SerializeError;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower_http::trace::MakeSpan;
use tracing_futures::Instrument;
use crate::AllowedFeature;
use crate::axum_factory::span_mode;
use crate::axum_factory::utils::PropagatingMakeSpan;
use crate::configuration::Configuration;
use crate::configuration::ConfigurationError;
use crate::graphql;
use crate::plugin::DynPlugin;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::plugin::PluginPrivate;
use crate::plugin::PluginUnstable;
use crate::plugin::test::MockSubgraph;
use crate::plugin::test::canned;
use crate::plugins::telemetry::reload::otel::init_telemetry;
use crate::router_factory::YamlRouterFactory;
use crate::services::HasSchema;
use crate::services::SupergraphCreator;
use crate::services::execution;
use crate::services::layers::persisted_queries::PersistedQueryLayer;
use crate::services::layers::query_analysis::QueryAnalysisLayer;
use crate::services::router;
use crate::services::router::service::RouterCreator;
use crate::services::subgraph;
use crate::services::supergraph;
use crate::spec::Schema;
use crate::uplink::license_enforcement::LicenseLimits;
use crate::uplink::license_enforcement::LicenseState;
pub mod mocks;
#[cfg(test)]
pub(crate) mod http_client;
#[cfg(any(test, feature = "snapshot"))]
pub(crate) mod http_snapshot;
pub struct TestHarness<'a> {
schema: Option<&'a str>,
configuration: Option<Arc<Configuration>>,
extra_plugins: Vec<(String, Box<dyn DynPlugin>)>,
subgraph_network_requests: bool,
license: Option<Arc<LicenseState>>,
}
impl<'a> TestHarness<'a> {
pub fn builder() -> Self {
Self {
schema: None,
configuration: None,
extra_plugins: Vec::new(),
subgraph_network_requests: false,
license: None,
}
}
pub fn log_level(self, log_level: &'a str) -> Self {
let log_level = format!("{log_level},salsa=error");
init_telemetry(&log_level).expect("failed to setup logging");
self
}
pub fn try_log_level(self, log_level: &'a str) -> Self {
let log_level = format!("{log_level},salsa=error");
let _ = init_telemetry(&log_level);
self
}
pub fn schema(mut self, schema: &'a str) -> Self {
assert!(self.schema.is_none(), "schema was specified twice");
self.schema = Some(schema);
self
}
pub fn configuration(mut self, configuration: Arc<Configuration>) -> Self {
assert!(
self.configuration.is_none(),
"configuration was specified twice"
);
self.configuration = Some(configuration);
self
}
pub fn configuration_json(
self,
configuration: serde_json::Value,
) -> Result<Self, serde_json::Error> {
let yaml = serde_yaml::to_string(&configuration).map_err(SerializeError::custom)?;
let configuration: Configuration =
Configuration::from_str(&yaml).map_err(DeserializeError::custom)?;
Ok(self.configuration(Arc::new(configuration)))
}
pub fn configuration_yaml(self, configuration: &'a str) -> Result<Self, ConfigurationError> {
let configuration: Configuration = Configuration::from_str(configuration)?;
Ok(self.configuration(Arc::new(configuration)))
}
pub fn license_from_allowed_features(mut self, allowed_features: Vec<AllowedFeature>) -> Self {
assert!(self.license.is_none(), "license was specified twice");
self.license = Some(Arc::new(LicenseState::Licensed {
limits: {
Some(
LicenseLimits::builder()
.allowed_features(HashSet::from_iter(allowed_features))
.build(),
)
},
}));
self
}
pub fn extra_plugin<P: Plugin>(mut self, plugin: P) -> Self {
let type_id = std::any::TypeId::of::<P>();
let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
Some(factory) => factory.name.clone(),
None => format!(
"extra_plugins.{}.{}",
self.extra_plugins.len(),
std::any::type_name::<P>(),
),
};
self.extra_plugins.push((name, plugin.into()));
self
}
pub fn extra_unstable_plugin<P: PluginUnstable>(mut self, plugin: P) -> Self {
let type_id = std::any::TypeId::of::<P>();
let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
Some(factory) => factory.name.clone(),
None => format!(
"extra_plugins.{}.{}",
self.extra_plugins.len(),
std::any::type_name::<P>(),
),
};
self.extra_plugins.push((name, Box::new(plugin)));
self
}
#[allow(dead_code)]
pub(crate) fn extra_private_plugin<P: PluginPrivate>(mut self, plugin: P) -> Self {
let type_id = std::any::TypeId::of::<P>();
let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
Some(factory) => factory.name.clone(),
None => format!(
"extra_plugins.{}.{}",
self.extra_plugins.len(),
std::any::type_name::<P>(),
),
};
self.extra_plugins.push((name, Box::new(plugin)));
self
}
pub fn router_hook(
self,
callback: impl Fn(router::BoxService) -> router::BoxService + Send + Sync + 'static,
) -> Self {
self.extra_plugin(RouterServicePlugin(callback))
}
pub fn supergraph_hook(
self,
callback: impl Fn(supergraph::BoxService) -> supergraph::BoxService + Send + Sync + 'static,
) -> Self {
self.extra_plugin(SupergraphServicePlugin(callback))
}
pub fn execution_hook(
self,
callback: impl Fn(execution::BoxService) -> execution::BoxService + Send + Sync + 'static,
) -> Self {
self.extra_plugin(ExecutionServicePlugin(callback))
}
pub fn subgraph_hook(
self,
callback: impl Fn(&str, subgraph::BoxService) -> subgraph::BoxService + Send + Sync + 'static,
) -> Self {
self.extra_plugin(SubgraphServicePlugin(callback))
}
pub fn with_subgraph_network_requests(mut self) -> Self {
self.subgraph_network_requests = true;
self
}
pub(crate) async fn build_common(
self,
) -> Result<(Arc<Configuration>, Arc<Schema>, SupergraphCreator), BoxError> {
let mut config = self.configuration.unwrap_or_default();
let has_legacy_mock_subgraphs_plugin = self.extra_plugins.iter().any(|(_, dyn_plugin)| {
dyn_plugin.name() == *crate::plugins::mock_subgraphs::PLUGIN_NAME
});
if self.schema.is_none() && !has_legacy_mock_subgraphs_plugin {
Arc::make_mut(&mut config)
.apollo_plugins
.plugins
.entry("experimental_mock_subgraphs")
.or_insert_with(canned::mock_subgraphs);
}
if !self.subgraph_network_requests {
Arc::make_mut(&mut config)
.apollo_plugins
.plugins
.entry("experimental_mock_subgraphs")
.or_insert(serde_json::json!({}));
}
let canned_schema = include_str!("../testing_schema.graphql");
let schema = self.schema.unwrap_or(canned_schema);
let schema = Arc::new(Schema::parse(schema, &config)?);
let license = self.license.unwrap_or(Arc::new(LicenseState::Licensed {
limits: Default::default(),
}));
let supergraph_creator = YamlRouterFactory
.inner_create_supergraph(
config.clone(),
schema.clone(),
None,
Some(self.extra_plugins),
license,
None,
)
.await?;
Ok((config, schema, supergraph_creator))
}
pub async fn build_supergraph(self) -> Result<supergraph::BoxCloneService, BoxError> {
let (config, schema, supergraph_creator) = self.build_common().await?;
Ok(tower::service_fn(move |request: supergraph::Request| {
let router = supergraph_creator.make();
let body = request.supergraph_request.body();
if let Some(query_str) = body.query.as_deref() {
let operation_name = body.operation_name.as_deref();
if !request.context.extensions().with_lock(|lock| {
lock.contains_key::<crate::services::layers::query_analysis::ParsedDocument>()
}) {
let doc = crate::spec::Query::parse_document(
query_str,
operation_name,
&schema,
&config,
)
.expect("parse error in test");
request.context.extensions().with_lock(|lock| {
lock.insert::<crate::services::layers::query_analysis::ParsedDocument>(doc)
});
}
}
async move { router.oneshot(request).await }
})
.boxed_clone())
}
pub async fn build_router(self) -> Result<router::BoxCloneService, BoxError> {
let (config, _schema, supergraph_creator) = self.build_common().await?;
let router_creator = RouterCreator::new(
QueryAnalysisLayer::new(supergraph_creator.schema(), Arc::clone(&config)).await,
Arc::new(PersistedQueryLayer::new(&config).await.unwrap()),
Arc::new(supergraph_creator),
config.clone(),
)
.await
.unwrap();
Ok(tower::service_fn(move |request: router::Request| {
let router = ServiceBuilder::new().service(router_creator.make()).boxed();
let span = PropagatingMakeSpan {
license: Default::default(),
span_mode: span_mode(&config),
}
.make_span(&request.router_request);
async move { router.oneshot(request).await }.instrument(span)
})
.boxed_clone())
}
pub async fn build_http_service(self) -> Result<HttpService, BoxError> {
use crate::axum_factory::ListenAddrAndRouter;
use crate::axum_factory::axum_http_server_factory::make_axum_router;
use crate::router_factory::RouterFactory;
let (config, _schema, supergraph_creator) = self.build_common().await?;
let router_creator = RouterCreator::new(
QueryAnalysisLayer::new(supergraph_creator.schema(), Arc::clone(&config)).await,
Arc::new(PersistedQueryLayer::new(&config).await.unwrap()),
Arc::new(supergraph_creator),
config.clone(),
)
.await?;
let web_endpoints = router_creator.web_endpoints();
let routers = make_axum_router(
router_creator,
&config,
web_endpoints,
Arc::new(LicenseState::Licensed {
limits: Default::default(),
}),
)?;
let ListenAddrAndRouter(_listener, router) = routers.main;
Ok(router.boxed())
}
}
pub type HttpService = tower::util::BoxService<
http::Request<crate::services::router::Body>,
http::Response<axum::body::Body>,
std::convert::Infallible,
>;
struct RouterServicePlugin<F>(F);
struct SupergraphServicePlugin<F>(F);
struct ExecutionServicePlugin<F>(F);
struct SubgraphServicePlugin<F>(F);
#[async_trait::async_trait]
impl<F> Plugin for RouterServicePlugin<F>
where
F: 'static + Send + Sync + Fn(router::BoxService) -> router::BoxService,
{
type Config = ();
async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
unreachable!()
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
(self.0)(service)
}
}
#[async_trait::async_trait]
impl<F> Plugin for SupergraphServicePlugin<F>
where
F: 'static + Send + Sync + Fn(supergraph::BoxService) -> supergraph::BoxService,
{
type Config = ();
async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
unreachable!()
}
fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
(self.0)(service)
}
}
#[async_trait::async_trait]
impl<F> Plugin for ExecutionServicePlugin<F>
where
F: 'static + Send + Sync + Fn(execution::BoxService) -> execution::BoxService,
{
type Config = ();
async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
unreachable!()
}
fn execution_service(&self, service: execution::BoxService) -> execution::BoxService {
(self.0)(service)
}
}
#[async_trait::async_trait]
impl<F> Plugin for SubgraphServicePlugin<F>
where
F: 'static + Send + Sync + Fn(&str, subgraph::BoxService) -> subgraph::BoxService,
{
type Config = ();
async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
unreachable!()
}
fn subgraph_service(
&self,
subgraph_name: &str,
service: subgraph::BoxService,
) -> subgraph::BoxService {
(self.0)(subgraph_name, service)
}
}
#[derive(Default, Clone)]
pub struct MockedSubgraphs(pub(crate) HashMap<&'static str, MockSubgraph>);
impl MockedSubgraphs {
pub fn insert(&mut self, name: &'static str, subgraph: MockSubgraph) {
self.0.insert(name, subgraph);
}
}
#[async_trait::async_trait]
impl Plugin for MockedSubgraphs {
type Config = ();
async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
unreachable!()
}
fn subgraph_service(
&self,
subgraph_name: &str,
default: subgraph::BoxService,
) -> subgraph::BoxService {
self.0
.get(subgraph_name)
.map(|service| service.clone().boxed())
.unwrap_or(default)
}
}
pub fn make_fake_batch(
input: http::Request<graphql::Request>,
op_from_to: Option<(&str, &str)>,
) -> http::Request<crate::services::router::Body> {
input.map(|req| {
let mut new_req = req.clone();
if let Some((from, to)) = op_from_to
&& let Some(operation_name) = &req.operation_name
&& operation_name == from
{
new_req.query = req.query.clone().map(|q| q.replace(from, to));
new_req.operation_name = Some(to.to_string());
}
let mut json_bytes_req = serde_json::to_vec(&req).unwrap();
let mut json_bytes_new_req = serde_json::to_vec(&new_req).unwrap();
let mut result = vec![b'['];
result.append(&mut json_bytes_req);
result.push(b',');
result.append(&mut json_bytes_new_req);
result.push(b']');
router::body::from_bytes(result)
})
}
#[tokio::test]
async fn test_intercept_subgraph_network_requests() {
use futures::StreamExt;
let request = crate::services::supergraph::Request::canned_builder()
.build()
.unwrap();
let response = TestHarness::builder()
.schema(include_str!("../testing_schema.graphql"))
.configuration_json(serde_json::json!({
"include_subgraph_errors": {
"all": true
}
}))
.unwrap()
.build_router()
.await
.unwrap()
.oneshot(request.try_into().unwrap())
.await
.unwrap()
.into_graphql_response_stream()
.await
.next()
.await
.unwrap()
.unwrap();
insta::assert_json_snapshot!(response, @r###"
{
"data": {
"topProducts": null
},
"errors": [
{
"message": "subgraph mock not configured",
"path": [],
"extensions": {
"code": "SUBGRAPH_MOCK_NOT_CONFIGURED",
"service": "products"
}
}
]
}
"###);
}
#[cfg(test)]
pub(crate) mod tracing_test {
use tracing_core::dispatcher::DefaultGuard;
pub(crate) fn dispatcher_guard() -> DefaultGuard {
let mock_writer =
::tracing_test::internal::MockWriter::new(::tracing_test::internal::global_buf());
let subscriber =
::tracing_test::internal::get_subscriber(mock_writer, "apollo_router=trace");
tracing::dispatcher::set_default(&subscriber)
}
pub(crate) fn logs_with_scope_contain(scope: &str, value: &str) -> bool {
::tracing_test::internal::logs_with_scope_contain(scope, value)
}
pub(crate) fn logs_contain(value: &str) -> bool {
logs_with_scope_contain("apollo_router", value)
}
pub(crate) fn logs_with_scope_assert<F>(scope: &str, f: F) -> Result<(), String>
where
F: Fn(&[&str]) -> Result<(), String>,
{
::tracing_test::internal::logs_assert(scope, f)
}
pub(crate) fn logs_assert<F>(f: F) -> Result<(), String>
where
F: Fn(&[&str]) -> Result<(), String>,
{
logs_with_scope_assert("apollo_router", f)
}
}