1use std::collections::HashMap;
4use std::collections::HashSet;
5use std::default::Default;
6use std::str::FromStr;
7use std::sync::Arc;
8
9use serde::de::Error as DeserializeError;
10use serde::ser::Error as SerializeError;
11use tower::BoxError;
12use tower::ServiceBuilder;
13use tower::ServiceExt;
14use tower_http::trace::MakeSpan;
15use tracing_futures::Instrument;
16
17use crate::AllowedFeature;
18use crate::axum_factory::span_mode;
19use crate::axum_factory::utils::PropagatingMakeSpan;
20use crate::configuration::Configuration;
21use crate::configuration::ConfigurationError;
22use crate::graphql;
23use crate::plugin::DynPlugin;
24use crate::plugin::Plugin;
25use crate::plugin::PluginInit;
26use crate::plugin::PluginPrivate;
27use crate::plugin::PluginUnstable;
28use crate::plugin::test::MockSubgraph;
29use crate::plugin::test::canned;
30use crate::plugins::telemetry::reload::otel::init_telemetry;
31use crate::router_factory::YamlRouterFactory;
32use crate::services::HasSchema;
33use crate::services::SupergraphCreator;
34use crate::services::execution;
35use crate::services::layers::persisted_queries::PersistedQueryLayer;
36use crate::services::layers::query_analysis::QueryAnalysisLayer;
37use crate::services::router;
38use crate::services::router::service::RouterCreator;
39use crate::services::subgraph;
40use crate::services::supergraph;
41use crate::spec::Schema;
42use crate::uplink::license_enforcement::LicenseLimits;
43use crate::uplink::license_enforcement::LicenseState;
44
45pub mod mocks;
47
48#[cfg(test)]
49pub(crate) mod http_client;
50
51#[cfg(any(test, feature = "snapshot"))]
52pub(crate) mod http_snapshot;
53
54pub struct TestHarness<'a> {
96 schema: Option<&'a str>,
97 configuration: Option<Arc<Configuration>>,
98 extra_plugins: Vec<(String, Box<dyn DynPlugin>)>,
99 subgraph_network_requests: bool,
100 license: Option<Arc<LicenseState>>,
101}
102
103impl<'a> TestHarness<'a> {
105 pub fn builder() -> Self {
107 Self {
108 schema: None,
109 configuration: None,
110 extra_plugins: Vec::new(),
111 subgraph_network_requests: false,
112 license: None,
113 }
114 }
115
116 pub fn log_level(self, log_level: &'a str) -> Self {
119 let log_level = format!("{log_level},salsa=error");
121 init_telemetry(&log_level).expect("failed to setup logging");
122 self
123 }
124
125 pub fn try_log_level(self, log_level: &'a str) -> Self {
128 let log_level = format!("{log_level},salsa=error");
130 let _ = init_telemetry(&log_level);
131 self
132 }
133
134 pub fn schema(mut self, schema: &'a str) -> Self {
142 assert!(self.schema.is_none(), "schema was specified twice");
143 self.schema = Some(schema);
144 self
145 }
146
147 pub fn configuration(mut self, configuration: Arc<Configuration>) -> Self {
149 assert!(
150 self.configuration.is_none(),
151 "configuration was specified twice"
152 );
153 self.configuration = Some(configuration);
154 self
155 }
156
157 pub fn configuration_json(
160 self,
161 configuration: serde_json::Value,
162 ) -> Result<Self, serde_json::Error> {
163 let yaml = serde_yaml::to_string(&configuration).map_err(SerializeError::custom)?;
166 let configuration: Configuration =
167 Configuration::from_str(&yaml).map_err(DeserializeError::custom)?;
168 Ok(self.configuration(Arc::new(configuration)))
169 }
170
171 pub fn configuration_yaml(self, configuration: &'a str) -> Result<Self, ConfigurationError> {
173 let configuration: Configuration = Configuration::from_str(configuration)?;
174 Ok(self.configuration(Arc::new(configuration)))
175 }
176
177 pub fn license_from_allowed_features(mut self, allowed_features: Vec<AllowedFeature>) -> Self {
183 assert!(self.license.is_none(), "license was specified twice");
184 self.license = Some(Arc::new(LicenseState::Licensed {
185 limits: {
186 Some(
187 LicenseLimits::builder()
188 .allowed_features(HashSet::from_iter(allowed_features))
189 .build(),
190 )
191 },
192 }));
193 self
194 }
195
196 pub fn extra_plugin<P: Plugin>(mut self, plugin: P) -> Self {
201 let type_id = std::any::TypeId::of::<P>();
202 let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
203 Some(factory) => factory.name.clone(),
204 None => format!(
205 "extra_plugins.{}.{}",
206 self.extra_plugins.len(),
207 std::any::type_name::<P>(),
208 ),
209 };
210
211 self.extra_plugins.push((name, plugin.into()));
212 self
213 }
214
215 pub fn extra_unstable_plugin<P: PluginUnstable>(mut self, plugin: P) -> Self {
220 let type_id = std::any::TypeId::of::<P>();
221 let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
222 Some(factory) => factory.name.clone(),
223 None => format!(
224 "extra_plugins.{}.{}",
225 self.extra_plugins.len(),
226 std::any::type_name::<P>(),
227 ),
228 };
229
230 self.extra_plugins.push((name, Box::new(plugin)));
231 self
232 }
233
234 #[allow(dead_code)]
239 pub(crate) fn extra_private_plugin<P: PluginPrivate>(mut self, plugin: P) -> Self {
240 let type_id = std::any::TypeId::of::<P>();
241 let name = match crate::plugin::plugins().find(|factory| factory.type_id == type_id) {
242 Some(factory) => factory.name.clone(),
243 None => format!(
244 "extra_plugins.{}.{}",
245 self.extra_plugins.len(),
246 std::any::type_name::<P>(),
247 ),
248 };
249
250 self.extra_plugins.push((name, Box::new(plugin)));
251 self
252 }
253
254 pub fn router_hook(
256 self,
257 callback: impl Fn(router::BoxService) -> router::BoxService + Send + Sync + 'static,
258 ) -> Self {
259 self.extra_plugin(RouterServicePlugin(callback))
260 }
261
262 pub fn supergraph_hook(
264 self,
265 callback: impl Fn(supergraph::BoxService) -> supergraph::BoxService + Send + Sync + 'static,
266 ) -> Self {
267 self.extra_plugin(SupergraphServicePlugin(callback))
268 }
269
270 pub fn execution_hook(
272 self,
273 callback: impl Fn(execution::BoxService) -> execution::BoxService + Send + Sync + 'static,
274 ) -> Self {
275 self.extra_plugin(ExecutionServicePlugin(callback))
276 }
277
278 pub fn subgraph_hook(
280 self,
281 callback: impl Fn(&str, subgraph::BoxService) -> subgraph::BoxService + Send + Sync + 'static,
282 ) -> Self {
283 self.extra_plugin(SubgraphServicePlugin(callback))
284 }
285
286 pub fn with_subgraph_network_requests(mut self) -> Self {
293 self.subgraph_network_requests = true;
294 self
295 }
296
297 pub(crate) async fn build_common(
298 self,
299 ) -> Result<(Arc<Configuration>, Arc<Schema>, SupergraphCreator), BoxError> {
300 let mut config = self.configuration.unwrap_or_default();
301 let has_legacy_mock_subgraphs_plugin = self.extra_plugins.iter().any(|(_, dyn_plugin)| {
302 dyn_plugin.name() == *crate::plugins::mock_subgraphs::PLUGIN_NAME
303 });
304 if self.schema.is_none() && !has_legacy_mock_subgraphs_plugin {
305 Arc::make_mut(&mut config)
306 .apollo_plugins
307 .plugins
308 .entry("experimental_mock_subgraphs")
309 .or_insert_with(canned::mock_subgraphs);
310 }
311 if !self.subgraph_network_requests {
312 Arc::make_mut(&mut config)
313 .apollo_plugins
314 .plugins
315 .entry("experimental_mock_subgraphs")
316 .or_insert(serde_json::json!({}));
317 }
318 let canned_schema = include_str!("../testing_schema.graphql");
319 let schema = self.schema.unwrap_or(canned_schema);
320 let schema = Arc::new(Schema::parse(schema, &config)?);
321 let license = self.license.unwrap_or(Arc::new(LicenseState::Licensed {
323 limits: Default::default(),
324 }));
325 let supergraph_creator = YamlRouterFactory
326 .inner_create_supergraph(
327 config.clone(),
328 schema.clone(),
329 None,
330 Some(self.extra_plugins),
331 license,
332 None,
333 )
334 .await?;
335
336 Ok((config, schema, supergraph_creator))
337 }
338
339 pub async fn build_supergraph(self) -> Result<supergraph::BoxCloneService, BoxError> {
341 let (config, schema, supergraph_creator) = self.build_common().await?;
342
343 Ok(tower::service_fn(move |request: supergraph::Request| {
344 let router = supergraph_creator.make();
345
346 let body = request.supergraph_request.body();
352 if let Some(query_str) = body.query.as_deref() {
354 let operation_name = body.operation_name.as_deref();
355 if !request.context.extensions().with_lock(|lock| {
356 lock.contains_key::<crate::services::layers::query_analysis::ParsedDocument>()
357 }) {
358 let doc = crate::spec::Query::parse_document(
359 query_str,
360 operation_name,
361 &schema,
362 &config,
363 )
364 .expect("parse error in test");
365 request.context.extensions().with_lock(|lock| {
366 lock.insert::<crate::services::layers::query_analysis::ParsedDocument>(doc)
367 });
368 }
369 }
370
371 async move { router.oneshot(request).await }
372 })
373 .boxed_clone())
374 }
375
376 pub async fn build_router(self) -> Result<router::BoxCloneService, BoxError> {
378 let (config, _schema, supergraph_creator) = self.build_common().await?;
379 let router_creator = RouterCreator::new(
380 QueryAnalysisLayer::new(supergraph_creator.schema(), Arc::clone(&config)).await,
381 Arc::new(PersistedQueryLayer::new(&config).await.unwrap()),
382 Arc::new(supergraph_creator),
383 config.clone(),
384 )
385 .await
386 .unwrap();
387
388 Ok(tower::service_fn(move |request: router::Request| {
389 let router = ServiceBuilder::new().service(router_creator.make()).boxed();
390 let span = PropagatingMakeSpan {
391 license: Default::default(),
392 span_mode: span_mode(&config),
393 }
394 .make_span(&request.router_request);
395 async move { router.oneshot(request).await }.instrument(span)
396 })
397 .boxed_clone())
398 }
399
400 pub async fn build_http_service(self) -> Result<HttpService, BoxError> {
402 use crate::axum_factory::ListenAddrAndRouter;
403 use crate::axum_factory::axum_http_server_factory::make_axum_router;
404 use crate::router_factory::RouterFactory;
405
406 let (config, _schema, supergraph_creator) = self.build_common().await?;
407 let router_creator = RouterCreator::new(
408 QueryAnalysisLayer::new(supergraph_creator.schema(), Arc::clone(&config)).await,
409 Arc::new(PersistedQueryLayer::new(&config).await.unwrap()),
410 Arc::new(supergraph_creator),
411 config.clone(),
412 )
413 .await?;
414
415 let web_endpoints = router_creator.web_endpoints();
416
417 let routers = make_axum_router(
418 router_creator,
419 &config,
420 web_endpoints,
421 Arc::new(LicenseState::Licensed {
422 limits: Default::default(),
423 }),
424 )?;
425 let ListenAddrAndRouter(_listener, router) = routers.main;
426 Ok(router.boxed())
427 }
428}
429
430pub type HttpService = tower::util::BoxService<
432 http::Request<crate::services::router::Body>,
433 http::Response<axum::body::Body>,
434 std::convert::Infallible,
435>;
436
437struct RouterServicePlugin<F>(F);
438struct SupergraphServicePlugin<F>(F);
439struct ExecutionServicePlugin<F>(F);
440struct SubgraphServicePlugin<F>(F);
441
442#[async_trait::async_trait]
443impl<F> Plugin for RouterServicePlugin<F>
444where
445 F: 'static + Send + Sync + Fn(router::BoxService) -> router::BoxService,
446{
447 type Config = ();
448
449 async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
450 unreachable!()
451 }
452
453 fn router_service(&self, service: router::BoxService) -> router::BoxService {
454 (self.0)(service)
455 }
456}
457
458#[async_trait::async_trait]
459impl<F> Plugin for SupergraphServicePlugin<F>
460where
461 F: 'static + Send + Sync + Fn(supergraph::BoxService) -> supergraph::BoxService,
462{
463 type Config = ();
464
465 async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
466 unreachable!()
467 }
468
469 fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
470 (self.0)(service)
471 }
472}
473
474#[async_trait::async_trait]
475impl<F> Plugin for ExecutionServicePlugin<F>
476where
477 F: 'static + Send + Sync + Fn(execution::BoxService) -> execution::BoxService,
478{
479 type Config = ();
480
481 async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
482 unreachable!()
483 }
484
485 fn execution_service(&self, service: execution::BoxService) -> execution::BoxService {
486 (self.0)(service)
487 }
488}
489
490#[async_trait::async_trait]
491impl<F> Plugin for SubgraphServicePlugin<F>
492where
493 F: 'static + Send + Sync + Fn(&str, subgraph::BoxService) -> subgraph::BoxService,
494{
495 type Config = ();
496
497 async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
498 unreachable!()
499 }
500
501 fn subgraph_service(
502 &self,
503 subgraph_name: &str,
504 service: subgraph::BoxService,
505 ) -> subgraph::BoxService {
506 (self.0)(subgraph_name, service)
507 }
508}
509
510#[derive(Default, Clone)]
512pub struct MockedSubgraphs(pub(crate) HashMap<&'static str, MockSubgraph>);
513
514impl MockedSubgraphs {
515 pub fn insert(&mut self, name: &'static str, subgraph: MockSubgraph) {
517 self.0.insert(name, subgraph);
518 }
519}
520
521#[async_trait::async_trait]
522impl Plugin for MockedSubgraphs {
523 type Config = ();
524
525 async fn new(_: PluginInit<Self::Config>) -> Result<Self, BoxError> {
526 unreachable!()
527 }
528
529 fn subgraph_service(
530 &self,
531 subgraph_name: &str,
532 default: subgraph::BoxService,
533 ) -> subgraph::BoxService {
534 self.0
535 .get(subgraph_name)
536 .map(|service| service.clone().boxed())
537 .unwrap_or(default)
538 }
539}
540
541pub fn make_fake_batch(
568 input: http::Request<graphql::Request>,
569 op_from_to: Option<(&str, &str)>,
570) -> http::Request<crate::services::router::Body> {
571 input.map(|req| {
572 let mut new_req = req.clone();
574
575 if let Some((from, to)) = op_from_to
580 && let Some(operation_name) = &req.operation_name
581 && operation_name == from
582 {
583 new_req.query = req.query.clone().map(|q| q.replace(from, to));
584 new_req.operation_name = Some(to.to_string());
585 }
586
587 let mut json_bytes_req = serde_json::to_vec(&req).unwrap();
588 let mut json_bytes_new_req = serde_json::to_vec(&new_req).unwrap();
589
590 let mut result = vec![b'['];
591 result.append(&mut json_bytes_req);
592 result.push(b',');
593 result.append(&mut json_bytes_new_req);
594 result.push(b']');
595 router::body::from_bytes(result)
596 })
597}
598
599#[tokio::test]
600async fn test_intercept_subgraph_network_requests() {
601 use futures::StreamExt;
602 let request = crate::services::supergraph::Request::canned_builder()
603 .build()
604 .unwrap();
605 let response = TestHarness::builder()
606 .schema(include_str!("../testing_schema.graphql"))
607 .configuration_json(serde_json::json!({
608 "include_subgraph_errors": {
609 "all": true
610 }
611 }))
612 .unwrap()
613 .build_router()
614 .await
615 .unwrap()
616 .oneshot(request.try_into().unwrap())
617 .await
618 .unwrap()
619 .into_graphql_response_stream()
620 .await
621 .next()
622 .await
623 .unwrap()
624 .unwrap();
625 insta::assert_json_snapshot!(response, @r###"
626 {
627 "data": {
628 "topProducts": null
629 },
630 "errors": [
631 {
632 "message": "subgraph mock not configured",
633 "path": [],
634 "extensions": {
635 "code": "SUBGRAPH_MOCK_NOT_CONFIGURED",
636 "service": "products"
637 }
638 }
639 ]
640 }
641 "###);
642}
643
644#[cfg(test)]
664pub(crate) mod tracing_test {
665 use tracing_core::dispatcher::DefaultGuard;
666
667 pub(crate) fn dispatcher_guard() -> DefaultGuard {
669 let mock_writer =
670 ::tracing_test::internal::MockWriter::new(::tracing_test::internal::global_buf());
671 let subscriber =
672 ::tracing_test::internal::get_subscriber(mock_writer, "apollo_router=trace");
673 tracing::dispatcher::set_default(&subscriber)
674 }
675
676 pub(crate) fn logs_with_scope_contain(scope: &str, value: &str) -> bool {
677 ::tracing_test::internal::logs_with_scope_contain(scope, value)
678 }
679
680 pub(crate) fn logs_contain(value: &str) -> bool {
681 logs_with_scope_contain("apollo_router", value)
682 }
683}