use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fs;
use std::net::SocketAddr;
use std::net::TcpListener;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::time::Duration;
use buildstructor::buildstructor;
use fred::clients::RedisClient;
use fred::interfaces::ClientLike;
use fred::interfaces::KeysInterface;
use fred::prelude::RedisConfig;
use fred::types::ScanType;
use fred::types::Scanner;
use futures::StreamExt;
use http::header::ACCEPT;
use http::header::CONTENT_TYPE;
use mediatype::MediaType;
use mediatype::WriteParams;
use mediatype::names::BOUNDARY;
use mediatype::names::FORM_DATA;
use mediatype::names::MULTIPART;
use mime::APPLICATION_JSON;
use opentelemetry::global;
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry::sdk::Resource;
use opentelemetry::sdk::trace::BatchSpanProcessor;
use opentelemetry::sdk::trace::TracerProvider;
use opentelemetry::sdk::trace::config;
use opentelemetry::testing::trace::NoopSpanExporter;
use opentelemetry::trace::TraceContextExt;
use opentelemetry_api::Context;
use opentelemetry_api::KeyValue;
use opentelemetry_api::trace::SpanContext;
use opentelemetry_api::trace::TraceId;
use opentelemetry_api::trace::TracerProvider as OtherTracerProvider;
use opentelemetry_otlp::HttpExporterBuilder;
use opentelemetry_otlp::Protocol;
use opentelemetry_otlp::SpanExporterBuilder;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
use regex::Regex;
use reqwest::Request;
use serde_json::Value;
use serde_json::json;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio::task;
use tokio::time::Instant;
use tracing::info_span;
use tracing_core::Dispatch;
use tracing_core::LevelFilter;
use tracing_futures::Instrument;
use tracing_futures::WithSubscriber;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::Layer;
use tracing_subscriber::Registry;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use uuid::Uuid;
use wiremock::Mock;
use wiremock::Respond;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
static ALLOCATED_PORTS: OnceLock<Arc<Mutex<HashMap<u16, String>>>> = OnceLock::new();
fn get_allocated_ports() -> &'static Arc<Mutex<HashMap<u16, String>>> {
ALLOCATED_PORTS.get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
}
fn allocate_port(name: &str) -> std::io::Result<u16> {
let ports_registry = get_allocated_ports();
for _ in 0..100 {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let mut ports = ports_registry.lock().unwrap();
if let Entry::Vacant(e) = ports.entry(port) {
e.insert(name.to_string());
return Ok(port);
}
}
Err(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
"Could not find available port after 100 attempts",
))
}
pub struct Query {
traced: bool,
psr: Option<&'static str>,
headers: HashMap<String, String>,
content_type: String,
body: Value,
}
impl Default for Query {
fn default() -> Self {
Query::builder().build()
}
}
#[buildstructor::buildstructor]
impl Query {
#[builder]
pub fn new(
traced: Option<bool>,
psr: Option<&'static str>,
body: Option<Value>,
content_type: Option<String>,
headers: HashMap<String, String>,
) -> Self {
Self {
traced: traced.unwrap_or(true),
psr,
body: body.unwrap_or(
json!({"query":"query ExampleQuery {topProducts{name}}","variables":{}}),
),
content_type: content_type
.unwrap_or_else(|| APPLICATION_JSON.essence_str().to_string()),
headers,
}
}
}
impl Query {
#[allow(dead_code)]
pub fn with_bad_content_type(mut self) -> Self {
self.content_type = "garbage".to_string();
self
}
#[allow(dead_code)]
pub fn with_bad_query(mut self) -> Self {
self.body = json!({"garbage":{}});
self
}
#[allow(dead_code)]
pub fn with_anonymous(mut self) -> Self {
self.body = json!({"query":"query {topProducts{name}}","variables":{}});
self
}
#[allow(dead_code)]
pub fn with_huge_query(mut self) -> Self {
self.body = json!({"query":"query {topProducts{name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name, name}}","variables":{}});
self
}
#[allow(dead_code)]
pub fn introspection() -> Query {
Query::builder()
.body(json!({"query":"{__schema {types {name}}}","variables":{}}))
.build()
}
}
pub struct IntegrationTest {
router: Option<Child>,
test_config_location: PathBuf,
test_schema_location: PathBuf,
router_location: PathBuf,
stdio_tx: tokio::sync::mpsc::Sender<String>,
stdio_rx: tokio::sync::mpsc::Receiver<String>,
collect_stdio: Option<(tokio::sync::oneshot::Sender<String>, regex::Regex)>,
_subgraphs: wiremock::MockServer,
telemetry: Telemetry,
extra_propagator: Telemetry,
pub _tracer_provider_client: TracerProvider,
pub _tracer_provider_subgraph: TracerProvider,
subscriber_client: Dispatch,
_subgraph_overrides: HashMap<String, String>,
bind_address: Arc<Mutex<Option<SocketAddr>>>,
redis_namespace: String,
log: String,
subgraph_context: Arc<Mutex<Option<SpanContext>>>,
logs: Vec<String>,
env: HashMap<String, String>,
client: reqwest::Client,
port_replacements: HashMap<String, u16>,
}
impl IntegrationTest {
pub(crate) fn bind_address(&self) -> SocketAddr {
self.bind_address
.lock()
.expect("lock poisoned")
.expect("no bind address set, router must be started first.")
}
#[allow(dead_code)]
pub fn reserve_address(&mut self, placeholder_name: &str) -> u16 {
let port = allocate_port(placeholder_name).expect("Failed to allocate port");
self.set_address(placeholder_name, port);
port
}
#[allow(dead_code)]
pub fn set_address(&mut self, placeholder_name: &str, port: u16) {
let current_config = std::fs::read_to_string(&self.test_config_location)
.expect("Failed to read config file");
let placeholder_pattern = format!("{{{{{}}}}}", placeholder_name);
let port_pattern = format!(":{{{{{}}}}}", placeholder_name);
let addr_pattern = format!("127.0.0.1:{{{{{}}}}}", placeholder_name);
if !current_config.contains(&placeholder_pattern)
&& !current_config.contains(&port_pattern)
&& !current_config.contains(&addr_pattern)
{
panic!(
"Placeholder '{}' not found in config file. Expected one of: '{}', '{}', or '{}'",
placeholder_name, placeholder_pattern, port_pattern, addr_pattern
);
}
self.port_replacements
.insert(placeholder_name.to_string(), port);
let updated_config = merge_overrides(
¤t_config,
&self._subgraph_overrides,
None, &self.redis_namespace,
Some(&self.port_replacements),
);
std::fs::write(&self.test_config_location, updated_config)
.expect("Failed to write updated config");
}
#[allow(dead_code)]
pub fn set_address_from_uri(&mut self, placeholder_name: &str, uri: &str) {
let port = uri
.split(':')
.last()
.expect("URI should contain a port")
.parse::<u16>()
.expect("Port should be a valid u16");
self.set_address(placeholder_name, port);
}
#[allow(dead_code)]
pub fn replace_config_string(&mut self, from: &str, to: &str) {
let current_config = std::fs::read_to_string(&self.test_config_location)
.expect("Failed to read config file");
let updated_config = current_config.replace(from, to);
std::fs::write(&self.test_config_location, updated_config)
.expect("Failed to write updated config");
}
#[allow(dead_code)]
pub fn replace_schema_string(&mut self, from: &str, to: &str) {
let current_schema = std::fs::read_to_string(&self.test_schema_location)
.expect("Failed to read schema file");
let updated_schema = current_schema.replace(from, to);
std::fs::write(&self.test_schema_location, updated_schema)
.expect("Failed to write updated schema");
}
}
struct TracedResponder {
response_template: ResponseTemplate,
telemetry: Telemetry,
extra_propagator: Telemetry,
subscriber_subgraph: Dispatch,
subgraph_callback: Option<Box<dyn Fn() + Send + Sync>>,
subgraph_context: Arc<Mutex<Option<SpanContext>>>,
}
impl Respond for TracedResponder {
fn respond(&self, request: &wiremock::Request) -> ResponseTemplate {
let context = self.telemetry.extract_context(request, &Context::new());
let context = self.extra_propagator.extract_context(request, &context);
*self.subgraph_context.lock().expect("lock poisoned") =
Some(context.span().span_context().clone());
tracing_core::dispatcher::with_default(&self.subscriber_subgraph, || {
let _context_guard = context.attach();
let span = info_span!("subgraph server");
let _span_guard = span.enter();
if let Some(callback) = &self.subgraph_callback {
callback();
}
self.response_template.clone()
})
}
}
#[derive(Debug, Clone, Default)]
#[allow(dead_code)]
pub enum Telemetry {
Jaeger,
Otlp {
endpoint: Option<String>,
},
Datadog,
Zipkin,
#[default]
None,
}
impl Telemetry {
fn tracer_provider(&self, service_name: &str) -> TracerProvider {
let config = config().with_resource(Resource::new(vec![KeyValue::new(
SERVICE_NAME,
service_name.to_string(),
)]));
match self {
Telemetry::Jaeger => TracerProvider::builder()
.with_config(config)
.with_span_processor(
BatchSpanProcessor::builder(
opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name)
.build_sync_agent_exporter()
.expect("jaeger pipeline failed"),
opentelemetry::runtime::Tokio,
)
.with_scheduled_delay(Duration::from_millis(10))
.build(),
)
.build(),
Telemetry::Otlp {
endpoint: Some(endpoint),
} => TracerProvider::builder()
.with_config(config)
.with_span_processor(
BatchSpanProcessor::builder(
SpanExporterBuilder::Http(
HttpExporterBuilder::default()
.with_endpoint(endpoint)
.with_protocol(Protocol::HttpBinary),
)
.build_span_exporter()
.expect("otlp pipeline failed"),
opentelemetry::runtime::Tokio,
)
.with_scheduled_delay(Duration::from_millis(10))
.build(),
)
.build(),
Telemetry::Datadog => TracerProvider::builder()
.with_config(config)
.with_span_processor(
BatchSpanProcessor::builder(
opentelemetry_datadog::new_pipeline()
.with_service_name(service_name)
.build_exporter()
.expect("datadog pipeline failed"),
opentelemetry::runtime::Tokio,
)
.with_scheduled_delay(Duration::from_millis(10))
.build(),
)
.build(),
Telemetry::Zipkin => TracerProvider::builder()
.with_config(config)
.with_span_processor(
BatchSpanProcessor::builder(
opentelemetry_zipkin::new_pipeline()
.with_service_name(service_name)
.init_exporter()
.expect("zipkin pipeline failed"),
opentelemetry::runtime::Tokio,
)
.with_scheduled_delay(Duration::from_millis(10))
.build(),
)
.build(),
Telemetry::None | Telemetry::Otlp { endpoint: None } => TracerProvider::builder()
.with_config(config)
.with_simple_exporter(NoopSpanExporter::default())
.build(),
}
}
fn inject_context(&self, request: &mut Request) {
let ctx = tracing::span::Span::current().context();
match self {
Telemetry::Jaeger => {
let propagator = opentelemetry_jaeger::Propagator::new();
propagator.inject_context(
&ctx,
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
)
}
Telemetry::Datadog => {
let psr = request
.headers()
.get("x-datadog-sampling-priority")
.cloned();
let propagator = opentelemetry_datadog::DatadogPropagator::new();
propagator.inject_context(
&ctx,
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
);
if let Some(psr) = psr {
request
.headers_mut()
.insert("x-datadog-sampling-priority", psr);
}
}
Telemetry::Otlp { .. } => {
let propagator = opentelemetry::sdk::propagation::TraceContextPropagator::default();
propagator.inject_context(
&ctx,
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
)
}
Telemetry::Zipkin => {
let propagator = opentelemetry_zipkin::Propagator::new();
propagator.inject_context(
&ctx,
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
)
}
_ => {}
}
}
pub(crate) fn extract_context(
&self,
request: &wiremock::Request,
context: &Context,
) -> Context {
let headers: HashMap<String, String> = request
.headers
.iter()
.map(|(name, value)| (name.as_str().to_string(), value.as_str().to_string()))
.collect();
match self {
Telemetry::Jaeger => {
let propagator = opentelemetry_jaeger::Propagator::new();
propagator.extract_with_context(context, &headers)
}
Telemetry::Datadog => {
let span_ref = context.span();
let original_span_context = span_ref.span_context();
let propagator = opentelemetry_datadog::DatadogPropagator::new();
let mut context = propagator.extract_with_context(context, &headers);
if let Some(psr) = headers.get("x-datadog-sampling-priority") {
let state = context
.span()
.span_context()
.trace_state()
.insert("psr", psr.to_string())
.expect("psr");
let new_trace_id = if original_span_context.is_valid() {
original_span_context.trace_id()
} else {
context.span().span_context().trace_id()
};
context = context.with_remote_span_context(SpanContext::new(
new_trace_id,
context.span().span_context().span_id(),
context.span().span_context().trace_flags(),
true,
state,
));
}
context
}
Telemetry::Otlp { .. } => {
let propagator = opentelemetry::sdk::propagation::TraceContextPropagator::default();
propagator.extract_with_context(context, &headers)
}
Telemetry::Zipkin => {
let propagator = opentelemetry_zipkin::Propagator::new();
propagator.extract_with_context(context, &headers)
}
_ => context.clone(),
}
}
}
#[buildstructor]
impl IntegrationTest {
#[builder]
pub async fn new(
config: String,
telemetry: Option<Telemetry>,
extra_propagator: Option<Telemetry>,
responder: Option<ResponseTemplate>,
collect_stdio: Option<tokio::sync::oneshot::Sender<String>>,
supergraph: Option<PathBuf>,
mut subgraph_overrides: HashMap<String, String>,
log: Option<String>,
subgraph_callback: Option<Box<dyn Fn() + Send + Sync>>,
env: HashMap<String, String>,
) -> Self {
let redis_namespace = Uuid::new_v4().to_string();
let telemetry = telemetry.unwrap_or_default();
let extra_propagator = extra_propagator.unwrap_or_default();
let tracer_provider_client = telemetry.tracer_provider("client");
let subscriber_client = Self::dispatch(&tracer_provider_client);
let tracer_provider_subgraph = telemetry.tracer_provider("subgraph");
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap();
let address = listener.local_addr().unwrap();
let url = format!("http://{address}/");
subgraph_overrides.entry("products".into()).or_insert(url);
let config_str =
merge_overrides(&config, &subgraph_overrides, None, &redis_namespace, None);
let supergraph = supergraph.unwrap_or(PathBuf::from_iter([
"..",
"examples",
"graphql",
"local.graphql",
]));
let subgraphs = wiremock::MockServer::builder()
.listener(listener)
.start()
.await;
let subgraph_context = Arc::new(Mutex::new(None));
Mock::given(method("POST"))
.respond_with(TracedResponder{response_template:responder.unwrap_or_else(||
ResponseTemplate::new(200).set_body_json(json!({"data":{"topProducts":[{"name":"Table"},{"name":"Couch"},{"name":"Chair"}]}}))),
telemetry: telemetry.clone(),
extra_propagator: extra_propagator.clone(),
subscriber_subgraph: Self::dispatch(&tracer_provider_subgraph),
subgraph_callback,
subgraph_context: subgraph_context.clone()
})
.mount(&subgraphs)
.await;
let mut test_config_location = std::env::temp_dir();
let mut test_schema_location = test_config_location.clone();
let location = format!("apollo-router-test-{}.yaml", Uuid::new_v4());
test_config_location.push(location);
test_schema_location.push(format!("apollo-router-test-{}.graphql", Uuid::new_v4()));
fs::write(&test_config_location, &config_str).expect("could not write config");
fs::copy(&supergraph, &test_schema_location).expect("could not write schema");
let (stdio_tx, stdio_rx) = tokio::sync::mpsc::channel(2000);
let collect_stdio = collect_stdio.map(|sender| {
let version_line_re = regex::Regex::new("Apollo Router v[^ ]+ ").unwrap();
(sender, version_line_re)
});
Self {
client: reqwest::Client::new(),
router: None,
router_location: Self::router_location(),
test_config_location,
test_schema_location,
stdio_tx,
stdio_rx,
collect_stdio,
_subgraphs: subgraphs,
_subgraph_overrides: subgraph_overrides,
bind_address: Default::default(),
_tracer_provider_client: tracer_provider_client,
subscriber_client,
_tracer_provider_subgraph: tracer_provider_subgraph,
telemetry,
extra_propagator,
redis_namespace,
log: log.unwrap_or_else(|| "error,apollo_router=info".to_owned()),
subgraph_context,
logs: vec![],
env,
port_replacements: HashMap::new(),
}
}
fn dispatch(tracer_provider: &TracerProvider) -> Dispatch {
let tracer = tracer_provider.tracer("tracer");
let tracing_layer = tracing_opentelemetry::layer()
.with_tracer(tracer)
.with_filter(LevelFilter::INFO);
let subscriber = Registry::default().with(tracing_layer).with(
tracing_subscriber::fmt::Layer::default()
.compact()
.with_filter(EnvFilter::from_default_env()),
);
Dispatch::new(subscriber)
}
#[allow(dead_code)]
pub fn subgraph_context(&self) -> SpanContext {
self.subgraph_context
.lock()
.expect("lock poisoned")
.as_ref()
.unwrap()
.clone()
}
pub fn router_location() -> PathBuf {
PathBuf::from(env!("CARGO_BIN_EXE_router"))
}
#[allow(dead_code)]
pub async fn start(&mut self) {
let mut router = Command::new(&self.router_location);
if let (Ok(apollo_key), Ok(apollo_graph_ref)) = (
std::env::var("TEST_APOLLO_KEY"),
std::env::var("TEST_APOLLO_GRAPH_REF"),
) {
router
.env("APOLLO_KEY", apollo_key)
.env("APOLLO_GRAPH_REF", apollo_graph_ref);
}
router
.envs(self.env.iter())
.args(dbg!([
"--hr",
"--config",
&self.test_config_location.to_string_lossy(),
"--supergraph",
&self.test_schema_location.to_string_lossy(),
"--log",
&self.log,
]))
.stdout(Stdio::piped());
let mut router = router.spawn().expect("router should start");
let reader = BufReader::new(router.stdout.take().expect("out"));
let stdio_tx = self.stdio_tx.clone();
let collect_stdio = self.collect_stdio.take();
let bind_address = self.bind_address.clone();
let bind_address_regex =
Regex::new(r".*GraphQL endpoint exposed at http://(?<address>[^/]+).*").unwrap();
task::spawn(async move {
let mut collected = Vec::new();
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
println!("{line}");
if let Some(captures) = bind_address_regex.captures(&line) {
let address = captures.name("address").unwrap().as_str();
let mut bind_address = bind_address.lock().unwrap();
*bind_address = Some(address.parse().unwrap());
}
if let Some((_sender, version_line_re)) = &collect_stdio {
#[derive(serde::Deserialize)]
struct Log {
#[allow(unused)]
timestamp: String,
level: String,
message: String,
}
let Ok(log) = serde_json::from_str::<Log>(&line) else {
panic!(
"line: '{line}' isn't JSON, might you have some debug output in the logging?"
);
};
if !log.message.starts_with("RUST_BACKTRACE=full detected") {
collected.push(format!(
"{}: {}",
log.level,
version_line_re
.replace(&log.message, "Apollo Router [version number] ")
))
}
}
let _ = stdio_tx.send(line).await;
}
if let Some((sender, _version_line_re)) = collect_stdio {
let _ = sender.send(collected.join("\n"));
}
});
self.router = Some(router);
}
#[allow(dead_code)]
pub async fn assert_started(&mut self) {
self.wait_for_log_message("GraphQL endpoint exposed").await;
}
#[allow(dead_code)]
pub async fn assert_not_started(&mut self) {
self.wait_for_log_message("no valid configuration").await;
}
#[allow(dead_code)]
pub async fn touch_config(&self) {
let mut f = tokio::fs::OpenOptions::new()
.append(true)
.open(&self.test_config_location)
.await
.expect("must have been able to open config file");
f.write_all("\n#touched\n".as_bytes())
.await
.expect("must be able to write config file");
}
#[allow(dead_code)]
pub async fn update_config(&self, yaml: &str) {
tokio::fs::write(
&self.test_config_location,
&merge_overrides(
yaml,
&self._subgraph_overrides,
Some(self.bind_address()),
&self.redis_namespace,
Some(&self.port_replacements),
),
)
.await
.expect("must be able to write config");
}
#[allow(dead_code)]
pub fn update_subgraph_overrides(&mut self, overrides: HashMap<String, String>) {
self._subgraph_overrides = overrides;
}
#[allow(dead_code)]
pub async fn update_schema(&self, supergraph_path: &PathBuf) {
fs::copy(supergraph_path, &self.test_schema_location).expect("could not write schema");
}
#[allow(dead_code)]
pub fn execute_default_query(
&self,
) -> impl std::future::Future<Output = (TraceId, reqwest::Response)> {
self.execute_query(Query::builder().build())
}
#[allow(dead_code)]
pub fn execute_query(
&self,
query: Query,
) -> impl std::future::Future<Output = (TraceId, reqwest::Response)> {
assert!(
self.router.is_some(),
"router was not started, call `router.start().await; router.assert_started().await`"
);
let telemetry = self.telemetry.clone();
let extra_propagator = self.extra_propagator.clone();
let client = self.client.clone();
let url = format!("http://{}", self.bind_address());
let subgraph_context = self.subgraph_context.clone();
async move {
let span = info_span!("client_request");
let trace_id = span.context().span().span_context().trace_id();
async move {
let mut builder = client.post(url).header(CONTENT_TYPE, query.content_type);
for (name, value) in query.headers {
builder = builder.header(name, value);
}
if let Some(psr) = query.psr {
builder = builder.header("x-datadog-sampling-priority", psr);
}
let mut request = builder.json(&query.body).build().unwrap();
if query.traced {
telemetry.inject_context(&mut request);
extra_propagator.inject_context(&mut request);
}
match client.execute(request).await {
Ok(response) => {
if query.traced {
(trace_id, response)
} else {
(
subgraph_context
.lock()
.expect("poisoned")
.as_ref()
.expect("subgraph context")
.trace_id(),
response,
)
}
}
Err(err) => {
panic!("unable to send successful request to router, {err}")
}
}
}
.instrument(span)
.await
}
.with_subscriber(self.subscriber_client.clone())
}
#[allow(dead_code)]
pub fn execute_multipart_request(
&self,
request: reqwest::multipart::Form,
transform: Option<fn(reqwest::Request) -> reqwest::Request>,
) -> impl std::future::Future<Output = (String, reqwest::Response)> {
assert!(
self.router.is_some(),
"router was not started, call `router.start().await; router.assert_started().await`"
);
let url = format!("http://{}", self.bind_address());
async move {
let span = info_span!("client_raw_request");
let span_id = span.context().span().span_context().trace_id().to_string();
async move {
let client = reqwest::Client::new();
let mime = {
let mut m = MediaType::new(MULTIPART, FORM_DATA);
m.set_param(BOUNDARY, mediatype::Value::new(request.boundary()).unwrap());
m
};
let mut request = client
.post(url)
.header(CONTENT_TYPE, mime.to_string())
.header("apollographql-client-name", "custom_name")
.header("apollographql-client-version", "1.0")
.header("x-my-header", "test")
.multipart(request)
.build()
.unwrap();
let transformer = transform.unwrap_or(core::convert::identity);
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::span::Span::current().context(),
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
);
});
request.headers_mut().remove(ACCEPT);
match client.execute(transformer(request)).await {
Ok(response) => (span_id, response),
Err(err) => {
panic!("unable to send successful request to router, {err}")
}
}
}
.instrument(span)
.await
}
.with_subscriber(self.subscriber_client.clone())
}
#[allow(dead_code)]
pub async fn run_subscription(&self, subscription: &str) -> (String, reqwest::Response) {
assert!(
self.router.is_some(),
"router was not started, call `router.start().await; router.assert_started().await`"
);
let client = reqwest::Client::new();
let id = Uuid::new_v4().to_string();
let span = info_span!("client_request", unit_test = id.as_str());
let _span_guard = span.enter();
let mut request = client
.post(format!("http://{}", self.bind_address()))
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.header(ACCEPT, "multipart/mixed;subscriptionSpec=1.0")
.header("apollographql-client-name", "custom_name")
.header("apollographql-client-version", "1.0")
.json(&json!({"query":subscription,"variables":{}}))
.build()
.unwrap();
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&span.context(),
&mut opentelemetry_http::HeaderInjector(request.headers_mut()),
);
});
match client.execute(request).await {
Ok(response) => (id, response),
Err(err) => {
panic!("unable to send successful request to router, {err}")
}
}
}
#[allow(dead_code)]
pub async fn get_metrics_response(&self) -> reqwest::Result<reqwest::Response> {
let client = reqwest::Client::new();
let request = client
.get(format!("http://{}/metrics", self.bind_address()))
.header("apollographql-client-name", "custom_name")
.header("apollographql-client-version", "1.0")
.build()
.unwrap();
client.execute(request).await
}
#[allow(dead_code)]
#[cfg(target_family = "unix")]
pub async fn graceful_shutdown(&mut self) {
unsafe {
libc::kill(self.pid(), libc::SIGTERM);
}
self.assert_shutdown().await;
}
#[cfg(target_os = "windows")]
pub async fn graceful_shutdown(&mut self) {
self.kill().await
}
#[allow(dead_code)]
pub async fn kill(&mut self) {
let _ = self
.router
.as_mut()
.expect("router not started")
.kill()
.await;
self.assert_shutdown().await;
}
#[allow(dead_code)]
pub(crate) fn pid(&mut self) -> i32 {
self.router
.as_ref()
.expect("router must have been started")
.id()
.expect("id expected") as i32
}
#[allow(dead_code)]
pub async fn assert_reloaded(&mut self) {
self.wait_for_log_message("reload complete").await;
}
#[allow(dead_code)]
pub async fn assert_no_reload_necessary(&mut self) {
self.wait_for_log_message("no reload necessary").await;
}
#[allow(dead_code)]
pub async fn assert_not_reloaded(&mut self) {
self.wait_for_log_message("continuing with previous configuration")
.await;
}
#[allow(dead_code)]
pub async fn wait_for_log_message(&mut self, msg: &str) {
let now = Instant::now();
while now.elapsed() < Duration::from_secs(10) {
if let Ok(line) = self.stdio_rx.try_recv() {
self.logs.push(line.to_string());
if line.contains(msg) {
return;
}
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
self.dump_stack_traces();
panic!("'{msg}' not detected in logs");
}
#[allow(dead_code)]
pub fn print_logs(&mut self) {
for line in &self.logs {
println!("{}", line);
}
}
#[allow(dead_code)]
pub fn read_logs(&mut self) {
while let Ok(line) = self.stdio_rx.try_recv() {
self.logs.push(line.to_string());
}
}
#[allow(dead_code)]
pub fn assert_log_contained(&mut self, msg: &str) {
for line in &self.logs {
if line.contains(msg) {
return;
}
}
panic!("'{msg}' not detected in logs");
}
#[allow(dead_code)]
pub async fn assert_log_not_contains(&mut self, msg: &str) {
let now = Instant::now();
while now.elapsed() < Duration::from_secs(5) {
if let Ok(line) = self.stdio_rx.try_recv() {
if line.contains(msg) {
self.dump_stack_traces();
panic!("'{msg}' detected in logs");
}
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
#[allow(dead_code)]
pub fn assert_log_not_contained(&mut self, msg: &str) {
for line in &self.logs {
if line.contains(msg) {
panic!(
"'{msg}' detected in logs. Log dump below:\n\n{logs}",
logs = self.logs.join("\n")
);
}
}
}
#[allow(dead_code)]
pub fn error_logs(&mut self) -> Vec<String> {
self.read_logs();
const JSON_ERROR_INDICATORS: [&str; 3] = ["\"level\":\"ERROR\"", "panic", "PANIC"];
let mut error_logs = Vec::new();
for line in &self.logs {
if JSON_ERROR_INDICATORS.iter().any(|err| line.contains(err))
|| (line.contains("ERROR") && !line.contains("level"))
{
error_logs.push(line.clone());
}
}
error_logs
}
#[allow(dead_code)]
pub fn assert_no_error_logs(&mut self) {
let error_logs = self.error_logs();
if !error_logs.is_empty() {
panic!(
"Found {} unexpected error(s) in router logs:\n\n{}\n\nFull log dump:\n\n{}",
error_logs.len(),
error_logs.join("\n"),
self.logs.join("\n")
);
}
}
#[allow(dead_code)]
pub fn assert_no_error_logs_with_exceptions(&mut self, exceptions: &[&str]) {
let mut error_logs = self.error_logs();
error_logs.retain(|line| !exceptions.iter().any(|exception| line.contains(exception)));
if !error_logs.is_empty() {
panic!(
"Found {} unexpected error(s) in router logs (excluding {} exceptions):\n\n{}\n\nFull log dump:\n\n{}",
error_logs.len(),
exceptions.len(),
error_logs.join("\n"),
self.logs.join("\n")
);
}
}
#[allow(dead_code)]
pub async fn assert_metrics_contains(&self, text: &str, duration: Option<Duration>) {
let now = Instant::now();
let mut last_metrics = String::new();
let text = regex::escape(text).replace("<any>", ".+");
let re = Regex::new(&format!("(?m)^{}", text)).expect("Invalid regex");
while now.elapsed() < duration.unwrap_or_else(|| Duration::from_secs(15)) {
if let Ok(metrics) = self
.get_metrics_response()
.await
.expect("failed to fetch metrics")
.text()
.await
{
if re.is_match(&metrics) {
return;
}
last_metrics = metrics;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("'{text}' not detected in metrics\n{last_metrics}");
}
#[allow(dead_code)]
pub async fn assert_metrics_contains_multiple(
&self,
mut texts: Vec<&str>,
duration: Option<Duration>,
) {
let now = Instant::now();
let mut last_metrics = String::new();
while now.elapsed() < duration.unwrap_or_else(|| Duration::from_secs(15)) {
if let Ok(metrics) = self
.get_metrics_response()
.await
.expect("failed to fetch metrics")
.text()
.await
{
let mut v = vec![];
for text in &texts {
if !metrics.contains(text) {
v.push(*text);
}
}
if v.len() == texts.len() {
return;
} else {
texts = v;
}
last_metrics = metrics;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("'{texts:?}' not detected in metrics\n{last_metrics}");
}
#[allow(dead_code)]
pub async fn assert_metrics_does_not_contain(&self, text: &str) {
if let Ok(metrics) = self
.get_metrics_response()
.await
.expect("failed to fetch metrics")
.text()
.await
{
if metrics.contains(text) {
panic!("'{text}' detected in metrics\n{metrics}");
}
}
}
#[allow(dead_code)]
pub async fn assert_shutdown(&mut self) {
let router = self.router.as_mut().expect("router must have been started");
let now = Instant::now();
while now.elapsed() < Duration::from_secs(3) {
match router.try_wait() {
Ok(Some(_)) => {
self.router = None;
return;
}
Ok(None) => tokio::time::sleep(Duration::from_millis(10)).await,
_ => {}
}
}
self.dump_stack_traces();
panic!("unable to shutdown router, this probably means a hang and should be investigated");
}
#[allow(dead_code)]
#[cfg(target_family = "unix")]
pub async fn send_sighup(&mut self) {
unsafe {
libc::kill(self.pid(), libc::SIGHUP);
}
}
#[cfg(target_os = "linux")]
pub fn dump_stack_traces(&mut self) {
if let Ok(trace) = rstack::TraceOptions::new()
.symbols(true)
.thread_names(true)
.trace(self.pid() as u32)
{
println!("dumped stack traces");
for thread in trace.threads() {
println!(
"thread id: {}, name: {}",
thread.id(),
thread.name().unwrap_or("<unknown>")
);
for frame in thread.frames() {
println!(
" {}",
frame.symbol().map(|s| s.name()).unwrap_or("<unknown>")
);
}
}
} else {
println!("failed to dump stack trace");
}
}
#[cfg(not(target_os = "linux"))]
pub fn dump_stack_traces(&mut self) {}
#[allow(dead_code)]
pub(crate) fn force_flush(&self) {
let tracer_provider_client = self._tracer_provider_client.clone();
let tracer_provider_subgraph = self._tracer_provider_subgraph.clone();
for r in tracer_provider_subgraph.force_flush() {
if let Err(e) = r {
eprintln!("failed to flush subgraph tracer: {e}");
}
}
for r in tracer_provider_client.force_flush() {
if let Err(e) = r {
eprintln!("failed to flush client tracer: {e}");
}
}
}
#[allow(dead_code)]
pub async fn clear_redis_cache(&self) {
let config = RedisConfig::from_url("redis://127.0.0.1:6379").unwrap();
let client = RedisClient::new(config, None, None, None);
let connection_task = client.connect();
client
.wait_for_connect()
.await
.expect("could not connect to redis");
let namespace = &self.redis_namespace;
let mut scan = client.scan(format!("{namespace}:*"), None, Some(ScanType::String));
while let Some(result) = scan.next().await {
if let Some(page) = result.expect("could not scan redis").take_results() {
for key in page {
let key = key.as_str().expect("key should be a string");
if key.starts_with(&self.redis_namespace) {
client
.del::<usize, _>(key)
.await
.expect("could not delete key");
}
}
}
}
client.quit().await.expect("could not quit redis");
let _ = connection_task.await;
}
#[allow(dead_code)]
pub async fn assert_redis_cache_contains(&self, key: &str, ignore: Option<&str>) -> String {
let config = RedisConfig::from_url("redis://127.0.0.1:6379").unwrap();
let client = RedisClient::new(config, None, None, None);
let connection_task = client.connect();
client.wait_for_connect().await.unwrap();
let redis_namespace = &self.redis_namespace;
let namespaced_key = format!("{redis_namespace}:{key}");
let s = match client.get(&namespaced_key).await {
Ok(s) => s,
Err(e) => {
println!("non-ignored keys in the same namespace in Redis:");
let mut scan = client.scan(
format!("{redis_namespace}:*"),
Some(u32::MAX),
Some(ScanType::String),
);
while let Some(result) = scan.next().await {
let keys = result.as_ref().unwrap().results().as_ref().unwrap();
for key in keys {
let key = key.as_str().expect("key should be a string");
let unnamespaced_key = key.replace(&format!("{redis_namespace}:"), "");
if Some(unnamespaced_key.as_str()) != ignore {
println!("\t{unnamespaced_key}");
}
}
}
panic!(
"key {key} not found: {e}\n This may be caused by a number of things including federation version changes"
);
}
};
client.quit().await.unwrap();
let _ = connection_task.await;
s
}
}
impl Drop for IntegrationTest {
fn drop(&mut self) {
if let Some(child) = &mut self.router {
let _ = child.start_kill();
}
}
}
fn merge_overrides(
yaml: &str,
subgraph_overrides: &HashMap<String, String>,
bind_addr: Option<SocketAddr>,
redis_namespace: &str,
port_replacements: Option<&HashMap<String, u16>>,
) -> String {
let bind_addr = bind_addr
.map(|a| a.to_string())
.unwrap_or_else(|| "127.0.0.1:0".into());
let mut yaml_with_ports = yaml.to_string();
if let Some(port_replacements) = port_replacements {
for (placeholder, port) in port_replacements {
let placeholder_pattern = format!("{{{{{}}}}}", placeholder);
yaml_with_ports = yaml_with_ports.replace(&placeholder_pattern, &port.to_string());
let port_pattern = format!(":{{{{{}}}}}", placeholder);
yaml_with_ports = yaml_with_ports.replace(&port_pattern, &format!(":{}", port));
let addr_pattern = format!("127.0.0.1:{{{{{}}}}}", placeholder);
yaml_with_ports =
yaml_with_ports.replace(&addr_pattern, &format!("127.0.0.1:{}", port));
}
}
let mut config: Value = serde_yaml::from_str(&yaml_with_ports).unwrap();
let overrides = subgraph_overrides
.iter()
.map(|(name, url)| (name.clone(), serde_json::Value::String(url.clone())));
match config
.as_object_mut()
.and_then(|o| o.get_mut("override_subgraph_url"))
.and_then(|o| o.as_object_mut())
{
None => {
if let Some(o) = config.as_object_mut() {
o.insert("override_subgraph_url".to_string(), overrides.collect());
}
}
Some(override_url) => {
override_url.extend(overrides);
}
}
match config
.as_object_mut()
.and_then(|o| o.get_mut("supergraph"))
.and_then(|o| o.as_object_mut())
{
None => {
if let Some(o) = config.as_object_mut() {
o.insert(
"supergraph".to_string(),
serde_json::json!({
"listen": bind_addr.to_string(),
}),
);
}
}
Some(supergraph_conf) => {
supergraph_conf.insert(
"listen".to_string(),
serde_json::Value::String(bind_addr.to_string()),
);
}
}
if let Some(prom_config) = config
.as_object_mut()
.and_then(|o| o.get_mut("telemetry"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("exporters"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("metrics"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("prometheus"))
.and_then(|o| o.as_object_mut())
{
prom_config.insert(
"listen".to_string(),
serde_json::Value::String(bind_addr.to_string()),
);
}
config
.as_object_mut()
.expect("config should be an object")
.insert(
"health_check".to_string(),
json!({"listen": bind_addr.to_string()}),
);
if let Some(query_plan) = config
.as_object_mut()
.and_then(|o| o.get_mut("supergraph"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("query_planning"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("cache"))
.and_then(|o| o.as_object_mut())
.and_then(|o| o.get_mut("redis"))
.and_then(|o| o.as_object_mut())
{
query_plan.insert("namespace".to_string(), redis_namespace.into());
}
serde_yaml::to_string(&config).unwrap()
}
#[allow(dead_code)]
pub fn graph_os_enabled() -> bool {
matches!(
(
std::env::var("TEST_APOLLO_KEY"),
std::env::var("TEST_APOLLO_GRAPH_REF"),
),
(Ok(_), Ok(_))
)
}
#[ctor::ctor]
fn init_integration_test_tracing() {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| tracing_subscriber::EnvFilter::try_new("info,apollo_router=debug"))
.unwrap();
let _ = tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::Layer::default()
.with_target(false)
.with_thread_ids(false)
.with_thread_names(false)
.compact()
.with_filter(filter),
)
.try_init();
}