apollo_router/
router_factory.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use apollo_compiler::validation::Valid;
7use axum::response::IntoResponse;
8use http::StatusCode;
9use indexmap::IndexMap;
10use multimap::MultiMap;
11use rustls::RootCertStore;
12use rustls::pki_types::CertificateDer;
13use serde_json::Map;
14use serde_json::Value;
15use tower::BoxError;
16use tower::ServiceExt;
17use tower::service_fn;
18use tower_service::Service;
19use tracing::Instrument;
20
21use crate::AllowedFeature;
22use crate::ListenAddr;
23use crate::configuration::APOLLO_PLUGIN_PREFIX;
24use crate::configuration::Configuration;
25use crate::configuration::ConfigurationError;
26use crate::configuration::TlsClient;
27use crate::plugin::DynPlugin;
28use crate::plugin::Handler;
29use crate::plugin::PluginFactory;
30use crate::plugin::PluginInit;
31use crate::plugins::subscription::notification::Notify;
32use crate::plugins::telemetry::reload::otel::apollo_opentelemetry_initialized;
33use crate::plugins::traffic_shaping::APOLLO_TRAFFIC_SHAPING;
34use crate::plugins::traffic_shaping::TrafficShaping;
35use crate::query_planner::QueryPlannerService;
36use crate::services::HasSchema;
37use crate::services::PluggableSupergraphServiceBuilder;
38use crate::services::Plugins;
39use crate::services::SubgraphService;
40use crate::services::SupergraphCreator;
41use crate::services::apollo_graph_reference;
42use crate::services::apollo_key;
43use crate::services::http::HttpClientServiceFactory;
44use crate::services::layers::persisted_queries::PersistedQueryLayer;
45use crate::services::layers::query_analysis::QueryAnalysisLayer;
46use crate::services::new_service::ServiceFactory;
47use crate::services::router;
48use crate::services::router::pipeline_handle::PipelineRef;
49use crate::services::router::service::RouterCreator;
50use crate::spec::Schema;
51use crate::uplink::license_enforcement::LicenseState;
52
53pub(crate) const STARTING_SPAN_NAME: &str = "starting";
54
55#[derive(Clone)]
56/// A path and a handler to be exposed as a web_endpoint for plugins
57pub struct Endpoint {
58    pub(crate) path: String,
59    // Plugins need to be Send + Sync
60    // BoxCloneService isn't enough
61    handler: EndpointHandler,
62}
63
64#[derive(Clone)]
65enum EndpointHandler {
66    /// Legacy handler wrapping a router service
67    Service(Handler),
68    /// Direct axum router (bypasses service conversion)
69    Router(axum::Router),
70}
71
72impl std::fmt::Debug for Endpoint {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("Endpoint")
75            .field("path", &self.path)
76            .finish()
77    }
78}
79
80impl Endpoint {
81    /// Creates an Endpoint given a path and a Boxed Service
82    pub fn from_router_service(path: String, handler: router::BoxService) -> Self {
83        Self {
84            path,
85            handler: EndpointHandler::Service(Handler::new(handler)),
86        }
87    }
88
89    /// Creates an Endpoint given a path and an axum Router
90    ///
91    /// This is the preferred method for plugins that use axum internally,
92    /// as it avoids unnecessary service wrapping and path manipulation.
93    ///
94    /// The router will be automatically nested at the specified path, allowing
95    /// it to handle all sub-routes. For example, a router registered at `/diagnostics`
96    /// will handle `/diagnostics/`, `/diagnostics/memory/status`, etc.
97    ///
98    /// # Example
99    ///
100    /// ```rust,ignore
101    /// use axum::{Router, routing::get};
102    ///
103    /// let router = Router::new()
104    ///     .route("/", get(handle_dashboard))
105    ///     .route("/status", get(handle_status));
106    ///
107    /// let endpoint = Endpoint::from_router("/diagnostics".to_string(), router);
108    /// // This will handle:
109    /// // - /diagnostics/
110    /// // - /diagnostics/status
111    /// ```
112    pub(crate) fn from_router(path: String, router: axum::Router) -> Self {
113        Self {
114            path,
115            handler: EndpointHandler::Router(router),
116        }
117    }
118
119    pub(crate) fn into_router(self) -> axum::Router {
120        match self.handler {
121            // If we already have a router, just nest it at the path
122            EndpointHandler::Router(router) => axum::Router::new().nest(&self.path, router),
123            // Legacy service handling with path-based routing
124            EndpointHandler::Service(handler) => {
125                let handler_clone = handler.clone();
126                let handler = move |req: http::Request<axum::body::Body>| {
127                    let endpoint = handler_clone.clone();
128                    async move {
129                        Ok(endpoint
130                            .oneshot(req.into())
131                            .await
132                            .map(|res| res.response)
133                            .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
134                            .into_response())
135                    }
136                };
137
138                axum::Router::new().route_service(self.path.as_str(), service_fn(handler))
139            }
140        }
141    }
142}
143/// Factory for creating a RouterService
144///
145/// Instances of this traits are used by the HTTP server to generate a new
146/// RouterService on each request
147pub(crate) trait RouterFactory:
148    ServiceFactory<router::Request, Service = Self::RouterService> + Clone + Send + Sync + 'static
149{
150    type RouterService: Service<
151            router::Request,
152            Response = router::Response,
153            Error = BoxError,
154            Future = Self::Future,
155        > + Send;
156    type Future: Send;
157
158    fn web_endpoints(&self) -> MultiMap<ListenAddr, Endpoint>;
159
160    fn pipeline_ref(&self) -> Arc<PipelineRef>;
161}
162
163/// Factory for creating a RouterFactory
164///
165/// Instances of this traits are used by the StateMachine to generate a new
166/// RouterFactory from configuration when it changes
167#[async_trait::async_trait]
168pub(crate) trait RouterSuperServiceFactory: Send + Sync + 'static {
169    type RouterFactory: RouterFactory;
170
171    async fn create<'a>(
172        &'a mut self,
173        is_telemetry_disabled: bool,
174        configuration: Arc<Configuration>,
175        schema: Arc<Schema>,
176        previous_router: Option<&'a Self::RouterFactory>,
177        extra_plugins: Option<Vec<(String, Box<dyn DynPlugin>)>>,
178        license: Arc<LicenseState>,
179    ) -> Result<Self::RouterFactory, BoxError>;
180}
181
182/// Main implementation of the SupergraphService factory, supporting the extensions system
183#[derive(Default)]
184pub(crate) struct YamlRouterFactory;
185
186#[async_trait::async_trait]
187impl RouterSuperServiceFactory for YamlRouterFactory {
188    type RouterFactory = RouterCreator;
189
190    async fn create<'a>(
191        &'a mut self,
192        _is_telemetry_disabled: bool,
193        configuration: Arc<Configuration>,
194        schema: Arc<Schema>,
195        previous_router: Option<&'a Self::RouterFactory>,
196        extra_plugins: Option<Vec<(String, Box<dyn DynPlugin>)>>,
197        license: Arc<LicenseState>,
198    ) -> Result<Self::RouterFactory, BoxError> {
199        // we have to create a telemetry plugin before creating everything else, to generate a trace
200        // of router and plugin creation
201        let plugin_registry = &*crate::plugin::PLUGINS;
202        let mut initial_telemetry_plugin = None;
203
204        if previous_router.is_none()
205            && apollo_opentelemetry_initialized()
206            && let Some(factory) = plugin_registry
207                .iter()
208                .find(|factory| factory.name == "apollo.telemetry")
209        {
210            let mut telemetry_config = configuration
211                .apollo_plugins
212                .plugins
213                .get("telemetry")
214                .cloned();
215            if let Some(plugin_config) = &mut telemetry_config {
216                inject_schema_id(schema.schema_id.as_str(), plugin_config);
217                // Extract previous telemetry config for hot reload comparison
218                let previous_telemetry_config = previous_router.and_then(|router| {
219                    router
220                        .configuration
221                        .apollo_plugins
222                        .plugins
223                        .get("telemetry")
224                        .cloned()
225                });
226
227                let telemetry_init = PluginInit::builder()
228                    .config(plugin_config.clone())
229                    .and_previous_config(previous_telemetry_config)
230                    .supergraph_sdl(schema.raw_sdl.clone())
231                    .supergraph_schema_id(schema.schema_id.clone().into_inner())
232                    .supergraph_schema(Arc::new(schema.supergraph_schema().clone()))
233                    .notify(configuration.notify.clone())
234                    .license(license.clone())
235                    .full_config(configuration.validated_yaml.clone())
236                    .and_original_config_yaml(configuration.raw_yaml.clone())
237                    .build();
238
239                match factory.create_instance(telemetry_init).await {
240                    Ok(plugin) => {
241                        if let Some(telemetry) = plugin
242                            .as_any()
243                            .downcast_ref::<crate::plugins::telemetry::Telemetry>()
244                        {
245                            telemetry.activate();
246                        }
247                        initial_telemetry_plugin = Some(plugin);
248                    }
249                    Err(e) => return Err(e),
250                }
251            }
252        }
253
254        let router_span = tracing::info_span!(STARTING_SPAN_NAME);
255        Self.inner_create(
256            configuration,
257            schema,
258            previous_router,
259            initial_telemetry_plugin,
260            extra_plugins,
261            license,
262        )
263        .instrument(router_span)
264        .await
265    }
266}
267
268impl YamlRouterFactory {
269    async fn inner_create<'a>(
270        &'a mut self,
271        configuration: Arc<Configuration>,
272        schema: Arc<Schema>,
273        previous_router: Option<&'a RouterCreator>,
274        initial_telemetry_plugin: Option<Box<dyn DynPlugin>>,
275        extra_plugins: Option<Vec<(String, Box<dyn DynPlugin>)>>,
276        license: Arc<LicenseState>,
277    ) -> Result<RouterCreator, BoxError> {
278        let mut supergraph_creator = self
279            .inner_create_supergraph(
280                configuration.clone(),
281                schema,
282                initial_telemetry_plugin,
283                extra_plugins,
284                license,
285                previous_router,
286            )
287            .await?;
288
289        // Instantiate the parser here so we can use it to warm up the planner below
290        let query_analysis_layer =
291            QueryAnalysisLayer::new(supergraph_creator.schema(), Arc::clone(&configuration)).await;
292
293        let persisted_query_layer = Arc::new(PersistedQueryLayer::new(&configuration).await?);
294
295        if let Some(previous_router) = previous_router {
296            let previous_cache = previous_router.previous_cache();
297
298            supergraph_creator
299                .warm_up_query_planner(
300                    &query_analysis_layer,
301                    &persisted_query_layer,
302                    Some(previous_cache),
303                    configuration.supergraph.query_planning.warmed_up_queries,
304                    configuration
305                        .supergraph
306                        .query_planning
307                        .experimental_reuse_query_plans,
308                    &configuration
309                        .persisted_queries
310                        .experimental_prewarm_query_plan_cache,
311                )
312                .await;
313        } else {
314            supergraph_creator
315                .warm_up_query_planner(
316                    &query_analysis_layer,
317                    &persisted_query_layer,
318                    None,
319                    configuration.supergraph.query_planning.warmed_up_queries,
320                    configuration
321                        .supergraph
322                        .query_planning
323                        .experimental_reuse_query_plans,
324                    &configuration
325                        .persisted_queries
326                        .experimental_prewarm_query_plan_cache,
327                )
328                .await;
329        };
330        RouterCreator::new(
331            query_analysis_layer,
332            persisted_query_layer,
333            Arc::new(supergraph_creator),
334            configuration,
335        )
336        .await
337    }
338
339    pub(crate) async fn inner_create_supergraph(
340        &mut self,
341        configuration: Arc<Configuration>,
342        schema: Arc<Schema>,
343        initial_telemetry_plugin: Option<Box<dyn DynPlugin>>,
344        extra_plugins: Option<Vec<(String, Box<dyn DynPlugin>)>>,
345        license: Arc<LicenseState>,
346        previous_router: Option<&crate::services::router::service::RouterCreator>,
347    ) -> Result<SupergraphCreator, BoxError> {
348        let query_planner_span = tracing::info_span!("query_planner_creation");
349        // QueryPlannerService takes an UnplannedRequest and outputs PlannedRequest
350        let planner = QueryPlannerService::new(schema.clone(), configuration.clone())
351            .instrument(query_planner_span)
352            .await?;
353
354        let span = tracing::info_span!("plugins");
355
356        // Process the plugins.
357        let subgraph_schemas = Arc::new(
358            planner
359                .subgraph_schemas()
360                .iter()
361                .map(|(k, v)| (k.clone(), v.schema.clone()))
362                .collect(),
363        );
364
365        let plugins: Arc<Plugins> = Arc::new(
366            create_plugins(
367                &configuration,
368                &schema,
369                subgraph_schemas,
370                initial_telemetry_plugin,
371                extra_plugins,
372                license,
373                previous_router,
374            )
375            .instrument(span)
376            .await?
377            .into_iter()
378            .collect(),
379        );
380
381        async {
382            let mut builder = PluggableSupergraphServiceBuilder::new(planner);
383            builder = builder.with_configuration(configuration.clone());
384            let http_service_factory =
385                create_http_services(&plugins, &schema, &configuration).await?;
386            let subgraph_services =
387                create_subgraph_services(&http_service_factory, &configuration).await?;
388            builder = builder.with_http_service_factory(http_service_factory);
389            for (name, subgraph_service) in subgraph_services {
390                builder = builder.with_subgraph_service(&name, subgraph_service);
391            }
392
393            // Final creation after this line we must NOT fail to go live with the new router from this point as some plugins may interact with globals.
394            let supergraph_creator = builder.with_plugins(plugins).build().await?;
395
396            Ok(supergraph_creator)
397        }
398        .instrument(tracing::info_span!("supergraph_creation"))
399        .await
400    }
401}
402
403pub(crate) async fn create_subgraph_services(
404    http_service_factory: &IndexMap<String, HttpClientServiceFactory>,
405    configuration: &Configuration,
406) -> Result<IndexMap<String, SubgraphService>, BoxError> {
407    let mut subgraph_services = IndexMap::default();
408    for (name, http_service_factory) in http_service_factory.iter() {
409        let subgraph_service = SubgraphService::from_config(
410            name.clone(),
411            configuration,
412            http_service_factory.clone(),
413        )?;
414        subgraph_services.insert(name.clone(), subgraph_service);
415    }
416
417    Ok(subgraph_services)
418}
419
420pub(crate) async fn create_http_services(
421    plugins: &Arc<Plugins>,
422    schema: &Schema,
423    configuration: &Configuration,
424) -> Result<IndexMap<String, HttpClientServiceFactory>, BoxError> {
425    // Note we are grabbing these root stores once and then reusing it for each subgraph. Why?
426    // When TLS was not configured for subgraphs, the OS provided list of certificates was parsed once per subgraph, which resulted in long loading times on OSX.
427    // This generates the native root store once, and reuses it across subgraphs
428    let subgraph_tls_root_store: RootCertStore = configuration
429        .tls
430        .subgraph
431        .all
432        .create_certificate_store()
433        .transpose()?
434        .unwrap_or_else(crate::services::http::HttpClientService::native_roots_store);
435    let connector_tls_root_store: RootCertStore = configuration
436        .tls
437        .connector
438        .all
439        .create_certificate_store()
440        .transpose()?
441        .unwrap_or_else(crate::services::http::HttpClientService::native_roots_store);
442
443    let shaping = plugins
444        .iter()
445        .find(|i| i.0.as_str() == APOLLO_TRAFFIC_SHAPING)
446        .and_then(|plugin| (*plugin.1).as_any().downcast_ref::<TrafficShaping>())
447        .expect("traffic shaping should always be part of the plugin list");
448
449    let connector_subgraphs: HashSet<String> = schema
450        .connectors
451        .as_ref()
452        .map(|c| {
453            c.by_service_name
454                .iter()
455                .map(|(_, connector)| connector.id.subgraph_name.clone())
456                .collect()
457        })
458        .unwrap_or_default();
459
460    let mut http_services = IndexMap::new();
461    for (name, _) in schema.subgraphs() {
462        if connector_subgraphs.contains(name) {
463            continue; // Avoid adding services for subgraphs that are actually connectors since we'll separately add them below per source
464        }
465        let http_service = crate::services::http::HttpClientService::from_config_for_subgraph(
466            name,
467            configuration,
468            &subgraph_tls_root_store,
469            shaping.subgraph_client_config(name),
470        )?;
471
472        let http_service_factory = HttpClientServiceFactory::new(http_service, plugins.clone());
473        http_services.insert(name.clone(), http_service_factory);
474    }
475
476    // Also create client service factories for connector sources
477    let connector_sources = schema
478        .connectors
479        .as_ref()
480        .map(|c| c.source_config_keys.clone())
481        .unwrap_or_default();
482
483    for name in connector_sources.iter() {
484        let http_service = crate::services::http::HttpClientService::from_config_for_connector(
485            name,
486            configuration,
487            &connector_tls_root_store,
488            shaping.connector_client_config(name),
489        )?;
490
491        let http_service_factory = HttpClientServiceFactory::new(http_service, plugins.clone());
492        http_services.insert(name.clone(), http_service_factory);
493    }
494
495    Ok(http_services)
496}
497
498impl TlsClient {
499    pub(crate) fn create_certificate_store(
500        &self,
501    ) -> Option<Result<RootCertStore, ConfigurationError>> {
502        self.certificate_authorities
503            .as_deref()
504            .map(create_certificate_store)
505    }
506}
507
508pub(crate) fn create_certificate_store(
509    certificate_authorities: &str,
510) -> Result<RootCertStore, ConfigurationError> {
511    let mut store = RootCertStore::empty();
512    let certificates = load_certs(certificate_authorities).map_err(|e| {
513        ConfigurationError::CertificateAuthorities {
514            error: format!("could not parse the certificate list: {e}"),
515        }
516    })?;
517    for certificate in certificates {
518        store
519            .add(certificate)
520            .map_err(|e| ConfigurationError::CertificateAuthorities {
521                error: format!("could not add certificate to root store: {e}"),
522            })?;
523    }
524    if store.is_empty() {
525        Err(ConfigurationError::CertificateAuthorities {
526            error: "the certificate list is empty".to_string(),
527        })
528    } else {
529        Ok(store)
530    }
531}
532
533fn load_certs(certificates: &str) -> io::Result<Vec<CertificateDer<'static>>> {
534    tracing::debug!("loading root certificates");
535
536    // Load and return certificate.
537    rustls_pemfile::certs(&mut certificates.as_bytes())
538        .collect::<Result<Vec<_>, _>>()
539        // XXX(@goto-bus-stop): the error type here is already io::Error. Should we wrap it,
540        // instead of replacing it with this generic error message?
541        .map_err(|_| io::Error::other("failed to load certificate"))
542}
543
544/// test only helper method to create a router factory in integration tests
545///
546/// not meant to be used directly
547pub async fn create_test_service_factory_from_yaml(schema: &str, configuration: &str) {
548    let config: Configuration = serde_yaml::from_str(configuration).unwrap();
549    let schema = Arc::new(Schema::parse(schema, &config).unwrap());
550
551    let is_telemetry_disabled = false;
552    let service = YamlRouterFactory
553        .create(
554            is_telemetry_disabled,
555            Arc::new(config),
556            schema,
557            None,
558            None,
559            Default::default(),
560        )
561        .await;
562    assert_eq!(
563        service.map(|_| ()).unwrap_err().to_string().as_str(),
564        r#"failed to initialize the query planner: An internal error has occurred, please report this bug to Apollo.
565
566Details: Object field "Product.reviews"'s inner type "Review" does not refer to an existing output type."#
567    );
568}
569
570#[allow(clippy::too_many_arguments)]
571pub(crate) async fn add_plugin(
572    name: String,
573    factory: &PluginFactory,
574    plugin_config: &Value,
575    previous_plugin_config: Option<&Value>,
576    schema: Arc<String>,
577    schema_id: Arc<String>,
578    supergraph_schema: Arc<Valid<apollo_compiler::Schema>>,
579    subgraph_schemas: Arc<HashMap<String, Arc<Valid<apollo_compiler::Schema>>>>,
580    launch_id: Option<Arc<String>>,
581    notify: &Notify<String, crate::graphql::Response>,
582    plugin_instances: &mut Plugins,
583    errors: &mut Vec<ConfigurationError>,
584    license: Arc<LicenseState>,
585    full_config: Option<Value>,
586    original_config_yaml: Option<Arc<str>>,
587) {
588    let plugin_init = PluginInit::builder()
589        .config(plugin_config.clone())
590        .and_previous_config(previous_plugin_config.cloned())
591        .supergraph_sdl(schema)
592        .supergraph_schema_id(schema_id)
593        .supergraph_schema(supergraph_schema)
594        .subgraph_schemas(subgraph_schemas)
595        .launch_id(launch_id)
596        .notify(notify.clone())
597        .license(license)
598        .and_full_config(full_config)
599        .and_original_config_yaml(original_config_yaml)
600        .build();
601
602    match factory.create_instance(plugin_init).await {
603        Ok(plugin) => {
604            let _ = plugin_instances.insert(name, plugin);
605        }
606        Err(err) => errors.push(ConfigurationError::PluginConfiguration {
607            plugin: name,
608            error: err.to_string(),
609        }),
610    }
611}
612
613pub(crate) async fn create_plugins(
614    configuration: &Configuration,
615    schema: &Schema,
616    subgraph_schemas: Arc<HashMap<String, Arc<Valid<apollo_compiler::Schema>>>>,
617    initial_telemetry_plugin: Option<Box<dyn DynPlugin>>,
618    extra_plugins: Option<Vec<(String, Box<dyn DynPlugin>)>>,
619    license: Arc<LicenseState>,
620    previous_router: Option<&crate::services::router::service::RouterCreator>,
621) -> Result<Plugins, BoxError> {
622    let supergraph_schema = Arc::new(schema.supergraph_schema().clone());
623    let supergraph_schema_id = schema.schema_id.clone().into_inner();
624    let mut apollo_plugins_config = configuration.apollo_plugins.clone().plugins;
625    let user_plugins_config = configuration.plugins.clone().plugins.unwrap_or_default();
626
627    // Extract previous plugin configurations for hot reload previous config detection
628    let (previous_apollo_plugins_config, previous_user_plugins_config) = match previous_router {
629        Some(router) => {
630            // Extract apollo plugin configs from the previous router's stored configuration
631            let prev_apollo_configs: HashMap<&str, &Value> = router
632                .configuration
633                .apollo_plugins
634                .plugins
635                .iter()
636                .map(|(k, v)| (k.as_str(), v))
637                .collect();
638
639            // Extract user plugin configs from the previous router's stored configuration
640            let prev_user_configs: HashMap<String, &Value> = router
641                .configuration
642                .plugins
643                .plugins
644                .as_ref()
645                .map(|plugins| plugins.iter().map(|(k, v)| (k.clone(), v)).collect())
646                .unwrap_or_default();
647
648            (prev_apollo_configs, prev_user_configs)
649        }
650        None => (HashMap::new(), HashMap::new()),
651    };
652    let extra = extra_plugins.unwrap_or_default();
653    let plugin_registry = &*crate::plugin::PLUGINS;
654    let apollo_telemetry_plugin_mandatory = apollo_opentelemetry_initialized();
655    let mut apollo_plugin_factories: HashMap<&str, &PluginFactory> = plugin_registry
656        .iter()
657        .filter(|factory| {
658            // the name starts with apollo
659            factory.name.starts_with(APOLLO_PLUGIN_PREFIX)
660                && (
661                    // the plugin is mandatory
662                    apollo_telemetry_plugin_mandatory ||
663                    // the name isn't apollo.telemetry
664                    factory.name != "apollo.telemetry"
665                )
666        })
667        .map(|factory| (factory.name.as_str(), &**factory))
668        .collect();
669    let mut errors = Vec::new();
670    let mut plugin_instances = Plugins::default();
671
672    // Use function-like macros to avoid borrow conflicts of captures
673    macro_rules! add_plugin {
674        ($name: expr, $factory: expr, $plugin_config: expr, $maybe_full_config: expr, $previous_plugin_config: expr) => {{
675            add_plugin(
676                $name,
677                $factory,
678                &$plugin_config,
679                $previous_plugin_config,
680                schema.as_string().clone(),
681                supergraph_schema_id.clone(),
682                supergraph_schema.clone(),
683                subgraph_schemas.clone(),
684                schema.launch_id.clone(),
685                &configuration.notify.clone(),
686                &mut plugin_instances,
687                &mut errors,
688                license.clone(),
689                $maybe_full_config,
690                configuration.raw_yaml.clone(),
691            )
692            .await;
693        }};
694    }
695
696    macro_rules! add_mandatory_apollo_plugin_inner {
697        ($name: literal, $opt_plugin_config: expr) => {{
698            let name = concat!("apollo.", $name);
699            let span = tracing::info_span!(concat!("plugin: ", "apollo.", $name));
700            async {
701                let factory = apollo_plugin_factories
702                    .remove(name)
703                    .unwrap_or_else(|| panic!("Apollo plugin not registered: {name}"));
704                if let Some(mut plugin_config) = $opt_plugin_config {
705                    let mut full_config = None;
706                    if name == "apollo.telemetry" {
707                        // The apollo.telemetry" plugin isn't happy with empty config, so we
708                        // give it some. If any of the other mandatory plugins need special
709                        // treatment, then we'll have to perform it here
710                        inject_schema_id(&supergraph_schema_id, &mut plugin_config);
711
712                        // Only the telemetry plugin should have access to the full configuration
713                        full_config = configuration.validated_yaml.clone();
714                    }
715                    let previous_config = previous_apollo_plugins_config.get($name).copied();
716                    add_plugin!(
717                        name.to_string(),
718                        factory,
719                        plugin_config,
720                        full_config,
721                        previous_config
722                    );
723                }
724            }
725            .instrument(span)
726            .await;
727        }};
728    }
729
730    macro_rules! add_optional_apollo_plugin_inner {
731        ($name: literal, $opt_plugin_config: expr, $license: expr) => {{
732            let name = concat!("apollo.", $name);
733            let span = tracing::info_span!(concat!("plugin: ", "apollo.", $name));
734            async {
735                let factory = apollo_plugin_factories
736                    .remove(name)
737                    .unwrap_or_else(|| panic!("Apollo plugin not registered: {name}"));
738                if let Some(plugin_config) = $opt_plugin_config {
739                    let allowed_features = $license.get_allowed_features();
740
741                    match AllowedFeature::from_plugin_name($name) {
742                        Some(allowed_feature) => {
743                            if allowed_features.contains(&allowed_feature) {
744                                let previous_config = previous_apollo_plugins_config.get($name).copied();
745                                add_plugin!(name.to_string(), factory, plugin_config, None, previous_config);
746                            } else {
747                                tracing::warn!(
748                                    "{name} plugin is not registered, {name} is a restricted feature that requires a license"
749                                );
750                            }
751                        }
752                        None => {
753                            // If the plugin name did not map to an allowed feature we add it
754                            let previous_config = previous_apollo_plugins_config.get($name).copied();
755                            add_plugin!(name.to_string(), factory, plugin_config, None, previous_config);
756                        }
757                    }
758                }
759            }
760            .instrument(span)
761            .await;
762        }};
763    }
764
765    macro_rules! add_oss_apollo_plugin_inner {
766        ($name: literal, $opt_plugin_config: expr) => {{
767            let name = concat!("apollo.", $name);
768            let span = tracing::info_span!(concat!("plugin: ", "apollo.", $name));
769            async {
770                let factory = apollo_plugin_factories
771                    .remove(name)
772                    .unwrap_or_else(|| panic!("Apollo plugin not registered: {name}"));
773                if let Some(plugin_config) = $opt_plugin_config {
774                    // We add oss plugins without a license check
775                    let previous_config = previous_apollo_plugins_config.get($name).copied();
776                    add_plugin!(
777                        name.to_string(),
778                        factory,
779                        plugin_config,
780                        None,
781                        previous_config
782                    );
783                    return;
784                }
785            }
786            .instrument(span)
787            .await;
788        }};
789    }
790
791    macro_rules! add_mandatory_apollo_plugin {
792        ($name: literal) => {
793            add_mandatory_apollo_plugin_inner!(
794                $name,
795                Some(
796                    apollo_plugins_config
797                        .remove($name)
798                        .unwrap_or(Value::Object(Map::new()))
799                )
800            );
801        };
802    }
803
804    macro_rules! add_optional_apollo_plugin {
805        ($name: literal) => {
806            add_optional_apollo_plugin_inner!($name, apollo_plugins_config.remove($name), &license);
807        };
808    }
809
810    macro_rules! add_oss_apollo_plugin {
811        ($name: literal) => {
812            add_oss_apollo_plugin_inner!($name, apollo_plugins_config.remove($name));
813        };
814    }
815
816    macro_rules! add_user_plugins {
817        () => {
818            for (name, plugin_config) in user_plugins_config {
819                let user_span = tracing::info_span!("user_plugin", "name" = &name);
820
821                async {
822                    if let Some(factory) =
823                        plugin_registry.iter().find(|factory| factory.name == name)
824                    {
825                        let previous_config = previous_user_plugins_config.get(&name).copied();
826                        add_plugin!(name, factory, plugin_config, None, previous_config);
827                    } else {
828                        errors.push(ConfigurationError::PluginUnknown(name))
829                    }
830                }
831                .instrument(user_span)
832                .await;
833            }
834
835            plugin_instances.extend(extra);
836        };
837    }
838
839    // Be careful with this list! Moving things around can have subtle consequences.
840    // Requests flow through this list multiple times in two directions. First, they go "down"
841    // through the list several times as requests at the different services. Then, they go
842    // "up" through the list as a response several times, once for each service.
843    //
844    // The order of this list determines the relative order of plugin hooks executing at each
845    // service. This is *not* the same as the order a request flows through the router.
846    // For example, assume these three plugins:
847    // 1. header propagation (has a hook at the subgraph service)
848    // 2. telemetry (has hooks at router, supergraph, and subgraph services)
849    // 3. rate limiting (has a hook at the router service)
850    // The order here means that header propagation happens before telemetry *at the subgraph
851    // service*. Depending on the requirements of plugins, it may have to be in this order. The
852    // *router service* hook for telemetry still happens well before header propagation. Similarly,
853    // header propagation being first does not mean that it's exempt from rate limiting, for the
854    // same reason. Rate limiting must be after telemetry, though, because telemetry and rate
855    // limiting both work at the router service, and requests rejected from the router service must
856    // flow through telemetry so we can record errors.
857    //
858    // Broadly, for telemetry to work, we must make sure that the telemetry plugin is the first
859    // plugin in this list *that adds a router service hook*. Other plugins can be before the
860    // telemetry plugin if they must do work *before* telemetry at specific services.
861    add_mandatory_apollo_plugin!("include_subgraph_errors");
862    add_mandatory_apollo_plugin!("headers");
863    if apollo_telemetry_plugin_mandatory {
864        match initial_telemetry_plugin {
865            None => {
866                add_mandatory_apollo_plugin!("telemetry");
867            }
868            Some(plugin) => {
869                let _ = plugin_instances.insert("apollo.telemetry".to_string(), plugin);
870                apollo_plugins_config.remove("apollo.telemetry");
871                apollo_plugin_factories.remove("apollo.telemetry");
872            }
873        }
874    }
875    add_mandatory_apollo_plugin!("license_enforcement");
876    add_mandatory_apollo_plugin!("health_check");
877    add_mandatory_apollo_plugin!("traffic_shaping");
878    add_mandatory_apollo_plugin!("limits");
879    add_mandatory_apollo_plugin!("csrf");
880    add_mandatory_apollo_plugin!("fleet_detector");
881    add_mandatory_apollo_plugin!("enhanced_client_awareness");
882    add_mandatory_apollo_plugin!("experimental_diagnostics");
883
884    add_oss_apollo_plugin!("forbid_mutations");
885    add_optional_apollo_plugin!("subscription");
886    add_oss_apollo_plugin!("override_subgraph_url");
887    add_optional_apollo_plugin!("authorization");
888    add_optional_apollo_plugin!("authentication");
889    add_oss_apollo_plugin!("preview_file_uploads");
890    add_optional_apollo_plugin!("preview_response_cache");
891    add_optional_apollo_plugin!("preview_entity_cache");
892    add_mandatory_apollo_plugin!("progressive_override");
893    add_optional_apollo_plugin!("demand_control");
894
895    // This relative ordering is documented in `docs/source/customizations/native.mdx`:
896    add_oss_apollo_plugin!("connectors");
897    add_oss_apollo_plugin!("rhai");
898    add_optional_apollo_plugin!("coprocessor");
899    add_user_plugins!();
900
901    // Because this plugin intercepts subgraph requests
902    // and does not forward them to the next service in the chain,
903    // it needs to intervene after user plugins for users plugins to run at all.
904    add_optional_apollo_plugin!("experimental_mock_subgraphs");
905
906    // Macros above remove from `apollo_plugin_factories`, so anything left at the end
907    // indicates a missing macro call.
908    let unused_apollo_plugin_names = apollo_plugin_factories.keys().copied().collect::<Vec<_>>();
909    if !unused_apollo_plugin_names.is_empty() {
910        panic!(
911            "Apollo plugins without their ordering specified in `fn create_plugins`: {}",
912            unused_apollo_plugin_names.join(", ")
913        )
914    }
915
916    let plugin_details = plugin_instances
917        .iter()
918        .map(|(name, plugin)| (name, plugin.name()))
919        .collect::<Vec<(&String, &str)>>();
920    tracing::debug!(
921        "plugins list: {:?}",
922        plugin_details
923            .iter()
924            .map(|(name, _)| name)
925            .collect::<Vec<&&String>>()
926    );
927
928    if !errors.is_empty() {
929        for error in &errors {
930            tracing::error!("{:#}", error);
931        }
932
933        let errors_list = errors
934            .iter()
935            .map(ToString::to_string)
936            .collect::<Vec<String>>()
937            .join("\n");
938
939        Err(BoxError::from(format!(
940            "there were {} configuration errors\n{}",
941            errors.len(),
942            errors_list
943        )))
944    } else {
945        Ok(plugin_instances)
946    }
947}
948
949fn inject_schema_id(
950    // Ideally we'd use &SchemaHash, but we'll need to update a bunch of tests to do so
951    schema_id: &str,
952    configuration: &mut Value,
953) {
954    if configuration.get("apollo").is_none() {
955        // Warning: this must be done here, otherwise studio reporting will not work
956        if apollo_key().is_some() && apollo_graph_reference().is_some() {
957            if let Some(telemetry) = configuration.as_object_mut() {
958                telemetry.insert("apollo".to_string(), Value::Object(Default::default()));
959            }
960        } else {
961            return;
962        }
963    }
964    if let Some(apollo) = configuration.get_mut("apollo")
965        && let Some(apollo) = apollo.as_object_mut()
966    {
967        apollo.insert(
968            "schema_id".to_string(),
969            Value::String(schema_id.to_string()),
970        );
971    }
972}
973
974#[cfg(test)]
975mod test {
976    use std::collections::HashSet;
977    use std::sync::Arc;
978
979    use rstest::rstest;
980    use schemars::JsonSchema;
981    use serde::Deserialize;
982    use serde_json::json;
983    use tower_http::BoxError;
984
985    use crate::AllowedFeature;
986    use crate::configuration::Configuration;
987    use crate::plugin::Plugin;
988    use crate::plugin::PluginInit;
989    use crate::register_plugin;
990    use crate::router_factory::RouterSuperServiceFactory;
991    use crate::router_factory::YamlRouterFactory;
992    use crate::router_factory::inject_schema_id;
993    use crate::services::supergraph::service::HasPlugins;
994    use crate::spec::Schema;
995    use crate::uplink::license_enforcement::LicenseLimits;
996    use crate::uplink::license_enforcement::LicenseState;
997
998    const MANDATORY_PLUGINS: &[&str] = &[
999        "apollo.include_subgraph_errors",
1000        "apollo.headers",
1001        "apollo.license_enforcement",
1002        "apollo.health_check",
1003        "apollo.traffic_shaping",
1004        "apollo.limits",
1005        "apollo.csrf",
1006        "apollo.fleet_detector",
1007        "apollo.enhanced_client_awareness",
1008        "apollo.progressive_override",
1009    ];
1010
1011    const OSS_PLUGINS: &[&str] = &[
1012        "apollo.forbid_mutations",
1013        "apollo.override_subgraph_url",
1014        "apollo.connectors",
1015    ];
1016
1017    // Always starts and stops plugin
1018
1019    #[derive(Debug)]
1020    struct AlwaysStartsAndStopsPlugin {}
1021
1022    /// Configuration for the test plugin
1023    #[derive(Debug, Default, Deserialize, JsonSchema)]
1024    struct Conf {
1025        /// The name of the test
1026        name: String,
1027    }
1028
1029    #[async_trait::async_trait]
1030    impl Plugin for AlwaysStartsAndStopsPlugin {
1031        type Config = Conf;
1032
1033        async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
1034            tracing::debug!("{}", init.config.name);
1035            Ok(AlwaysStartsAndStopsPlugin {})
1036        }
1037    }
1038
1039    register_plugin!(
1040        "test",
1041        "always_starts_and_stops",
1042        AlwaysStartsAndStopsPlugin
1043    );
1044
1045    // Always fails to start plugin
1046
1047    #[derive(Debug)]
1048    struct AlwaysFailsToStartPlugin {}
1049
1050    #[async_trait::async_trait]
1051    impl Plugin for AlwaysFailsToStartPlugin {
1052        type Config = Conf;
1053
1054        async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
1055            tracing::debug!("{}", init.config.name);
1056            Err(BoxError::from("Error"))
1057        }
1058    }
1059
1060    register_plugin!("test", "always_fails_to_start", AlwaysFailsToStartPlugin);
1061
1062    async fn create_service(config: Configuration) -> Result<(), BoxError> {
1063        let schema = include_str!("testdata/supergraph.graphql");
1064        let schema = Schema::parse(schema, &config)?;
1065
1066        let is_telemetry_disabled = false;
1067        let service = YamlRouterFactory
1068            .create(
1069                is_telemetry_disabled,
1070                Arc::new(config),
1071                Arc::new(schema),
1072                None,
1073                None,
1074                Arc::new(LicenseState::default()),
1075            )
1076            .await;
1077        service.map(|_| ())
1078    }
1079
1080    #[tokio::test]
1081    async fn test_yaml_no_extras() {
1082        let config = Configuration::builder().build().unwrap();
1083        let service = create_service(config).await;
1084        assert!(service.is_ok())
1085    }
1086
1087    #[tokio::test]
1088    async fn test_yaml_plugins_always_starts_and_stops() {
1089        let config: Configuration = serde_yaml::from_str(
1090            r#"
1091            plugins:
1092                test.always_starts_and_stops:
1093                    name: albert
1094        "#,
1095        )
1096        .unwrap();
1097        let service = create_service(config).await;
1098        assert!(service.is_ok())
1099    }
1100
1101    #[tokio::test]
1102    async fn test_yaml_plugins_always_fails_to_start() {
1103        let config: Configuration = serde_yaml::from_str(
1104            r#"
1105            plugins:
1106                test.always_fails_to_start:
1107                    name: albert
1108        "#,
1109        )
1110        .unwrap();
1111        let service = create_service(config).await;
1112        assert!(service.is_err())
1113    }
1114
1115    #[tokio::test]
1116    async fn test_yaml_plugins_combo_start_and_fail() {
1117        let config: Configuration = serde_yaml::from_str(
1118            r#"
1119            plugins:
1120                test.always_starts_and_stops:
1121                    name: albert
1122                test.always_fails_to_start:
1123                    name: albert
1124        "#,
1125        )
1126        .unwrap();
1127        let service = create_service(config).await;
1128        assert!(service.is_err())
1129    }
1130
1131    #[test]
1132    fn test_inject_schema_id() {
1133        let mut config = json!({ "apollo": {} });
1134        inject_schema_id(
1135            "8e2021d131b23684671c3b85f82dfca836908c6a541bbd5c3772c66e7f8429d8",
1136            &mut config,
1137        );
1138        let config =
1139            serde_json::from_value::<crate::plugins::telemetry::config::Conf>(config).unwrap();
1140        assert_eq!(
1141            &config.apollo.schema_id,
1142            "8e2021d131b23684671c3b85f82dfca836908c6a541bbd5c3772c66e7f8429d8"
1143        );
1144    }
1145
1146    fn get_plugin_config(plugin: &str) -> &str {
1147        match plugin {
1148            "subscription" => {
1149                r#"
1150                enabled: true
1151                "#
1152            }
1153            "authentication" => {
1154                r#"
1155                connector:
1156                  sources: {}
1157                "#
1158            }
1159            "authorization" => {
1160                r#"
1161                require_authentication: false
1162                "#
1163            }
1164            "preview_file_uploads" => {
1165                r#"
1166                enabled: true
1167                protocols:
1168                  multipart:
1169                    enabled: false
1170                "#
1171            }
1172            "preview_entity_cache" => {
1173                r#"
1174                enabled: true
1175                subgraph:
1176                  all:
1177                    enabled: true
1178                "#
1179            }
1180            "preview_response_cache" => {
1181                r#"
1182                enabled: true
1183                subgraph:
1184                  all:
1185                    enabled: true
1186                "#
1187            }
1188            "demand_control" => {
1189                r#"
1190                enabled: true
1191                mode: measure
1192                strategy:
1193                  static_estimated:
1194                    list_size: 0
1195                    max: 0.0
1196                "#
1197            }
1198            "coprocessor" => {
1199                r#"
1200                url: http://service.example.com/url
1201                "#
1202            }
1203            "connectors" => {
1204                r#"
1205                debug_extensions: false
1206                "#
1207            }
1208            "experimental_mock_subgraphs" => {
1209                r#"
1210               subgraphs: {}
1211                "#
1212            }
1213            "forbid_mutations" => {
1214                r#"
1215                false
1216                "#
1217            }
1218            "override_subgraph_url" => {
1219                r#"
1220                {}
1221                "#
1222            }
1223            _ => panic!("This function does not contain config for plugin: {plugin}"),
1224        }
1225    }
1226
1227    #[tokio::test]
1228    #[rstest]
1229    #[case::empty_allowed_features_set(HashSet::new())]
1230    #[case::nonempty_allowed_features_set(HashSet::from_iter(vec![AllowedFeature::Coprocessors]))]
1231    async fn test_mandatory_plugins_added(#[case] allowed_features: HashSet<AllowedFeature>) {
1232        /*
1233         * GIVEN
1234         *  - a valid license
1235         *  - a valid config
1236         *  - a valid schema
1237         * */
1238        let license = LicenseState::Licensed {
1239            limits: Some(LicenseLimits {
1240                tps: None,
1241                allowed_features,
1242            }),
1243        };
1244
1245        let router_config = Configuration::builder().build().unwrap();
1246        let schema = include_str!("testdata/supergraph.graphql");
1247        let schema = Schema::parse(schema, &router_config).unwrap();
1248
1249        /*
1250         * WHEN
1251         *  - the router factory runs (including the plugin inits gated by the license)
1252         * */
1253        let is_telemetry_disabled = false;
1254        let service = YamlRouterFactory
1255            .create(
1256                is_telemetry_disabled,
1257                Arc::new(router_config),
1258                Arc::new(schema),
1259                None,
1260                None,
1261                Arc::new(license),
1262            )
1263            .await
1264            .unwrap();
1265
1266        /*
1267         * THEN
1268         *  - the mandatory plugins are added
1269         * */
1270        assert!(
1271            MANDATORY_PLUGINS
1272                .iter()
1273                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1274        );
1275    }
1276
1277    #[tokio::test]
1278    #[rstest]
1279    #[case::allowed_features_empty(HashSet::new())]
1280    #[case::allowed_features_nonempty(HashSet::from_iter(vec![
1281        AllowedFeature::Coprocessors,
1282        AllowedFeature::DemandControl
1283    ]))]
1284    async fn test_oss_plugins_added(#[case] allowed_features: HashSet<AllowedFeature>) {
1285        /*
1286         * GIVEN
1287         *  - a valid license
1288         *  - a valid config that contains configuration for oss plugins
1289         *  - a valid schema
1290         * */
1291        let license = LicenseState::Licensed {
1292            limits: Some(LicenseLimits {
1293                tps: None,
1294                allowed_features,
1295            }),
1296        };
1297
1298        // Create config for oss plugins
1299        let forbid_mutations_config =
1300            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("forbid_mutations"))
1301                .unwrap();
1302        let override_subgraph_url_config =
1303            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("override_subgraph_url"))
1304                .unwrap();
1305        let connectors_config =
1306            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("connectors")).unwrap();
1307
1308        let router_config = Configuration::builder()
1309            .apollo_plugin("forbid_mutations", forbid_mutations_config)
1310            .apollo_plugin("override_subgraph_url", override_subgraph_url_config)
1311            .apollo_plugin("connectors", connectors_config)
1312            .build()
1313            .unwrap();
1314
1315        let schema = include_str!("testdata/supergraph.graphql");
1316        let schema = Schema::parse(schema, &router_config).unwrap();
1317
1318        /*
1319         * WHEN
1320         *  - the router factory runs (including the plugin inits gated by the license)
1321         * */
1322        let is_telemetry_disabled = false;
1323        let service = YamlRouterFactory
1324            .create(
1325                is_telemetry_disabled,
1326                Arc::new(router_config),
1327                Arc::new(schema),
1328                None,
1329                None,
1330                Arc::new(license),
1331            )
1332            .await
1333            .unwrap();
1334
1335        /*
1336         * THEN
1337         *  - all oss plugins should have been added
1338         * */
1339        assert!(
1340            OSS_PLUGINS
1341                .iter()
1342                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1343        );
1344    }
1345
1346    #[tokio::test]
1347    #[rstest]
1348    #[case::subscripions(
1349        "subscription",
1350        HashSet::from_iter(vec![AllowedFeature::DemandControl, AllowedFeature::Subscriptions]))
1351    ]
1352    #[case::authorization(
1353        "authorization",
1354        HashSet::from_iter(vec![AllowedFeature::Authorization, AllowedFeature::Subscriptions]))
1355    ]
1356    #[case::authentication(
1357        "authentication",
1358        HashSet::from_iter(vec![AllowedFeature::DemandControl, AllowedFeature::Authentication, AllowedFeature::Subscriptions]))
1359    ]
1360    #[case::entity_caching(
1361        "preview_entity_cache",
1362        HashSet::from_iter(vec![AllowedFeature::EntityCaching, AllowedFeature::DemandControl]))
1363    ]
1364    #[case::response_cache(
1365        "preview_response_cache",
1366        HashSet::from_iter(vec![AllowedFeature::DemandControl, AllowedFeature::ResponseCaching]))
1367    ]
1368    #[case::authorization(
1369        "demand_control",
1370        HashSet::from_iter(vec![AllowedFeature::Authorization, AllowedFeature::Subscriptions, AllowedFeature::DemandControl]))
1371    ]
1372    #[case::coprocessor(
1373        "coprocessor",
1374        HashSet::from_iter(vec![AllowedFeature::Coprocessors, AllowedFeature::DemandControl]))
1375    ]
1376    async fn test_optional_plugin_added_with_restricted_allowed_features(
1377        #[case] plugin: &str,
1378        #[case] allowed_features: HashSet<AllowedFeature>,
1379    ) {
1380        /*
1381         * GIVEN
1382         *  - a restricted license with allowed feature set containing the given `plugin`
1383         *  - a valid config including valid config for the given `plugin`
1384         *  - a valid schema
1385         * */
1386        let license = LicenseState::Licensed {
1387            limits: Some(LicenseLimits {
1388                tps: None,
1389                allowed_features,
1390            }),
1391        };
1392
1393        let plugin_config =
1394            serde_yaml::from_str::<serde_json::Value>(get_plugin_config(plugin)).unwrap();
1395        dbg!(&plugin_config);
1396        let router_config = Configuration::builder()
1397            .apollo_plugin(plugin, plugin_config)
1398            .build()
1399            .unwrap();
1400
1401        let schema = include_str!("testdata/supergraph.graphql");
1402        let schema = Schema::parse(schema, &router_config).unwrap();
1403
1404        /*
1405         * WHEN
1406         *  - the router factory runs (including the plugin inits gated by the license)
1407         * */
1408        let is_telemetry_disabled = false;
1409        let service = YamlRouterFactory
1410            .create(
1411                is_telemetry_disabled,
1412                Arc::new(router_config),
1413                Arc::new(schema),
1414                None,
1415                None,
1416                Arc::new(license),
1417            )
1418            .await
1419            .unwrap();
1420
1421        /*
1422         * THEN
1423         *  - since the plugin is part of the `allowed_features` set
1424         *    the plugin should have been added.
1425         * - mandatory plugins should have been added.
1426         * */
1427        assert!(
1428            service
1429                .supergraph_creator
1430                .plugins()
1431                .contains_key(&format!("apollo.{plugin}")),
1432            "Plugin {plugin} should have been added"
1433        );
1434        assert!(
1435            MANDATORY_PLUGINS
1436                .iter()
1437                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1438        );
1439    }
1440
1441    #[tokio::test]
1442    #[rstest]
1443    #[case::subscripions(
1444        "subscription",
1445        HashSet::from_iter(vec![]))
1446    ]
1447    #[case::authorization(
1448        "authorization",
1449        HashSet::from_iter(vec![AllowedFeature::Authentication, AllowedFeature::Subscriptions]))
1450    ]
1451    #[case::authentication(
1452        "authentication",
1453        HashSet::from_iter(vec![AllowedFeature::DemandControl,AllowedFeature::Subscriptions]))
1454    ]
1455    #[case::entity_caching(
1456        "preview_entity_cache",
1457        HashSet::from_iter(vec![AllowedFeature::DemandControl]))
1458    ]
1459    #[case::response_cache(
1460        "preview_response_cache",
1461        HashSet::from_iter(vec![AllowedFeature::EntityCaching]))
1462    ]
1463    #[case::authorization(
1464        "demand_control",
1465        HashSet::from_iter(vec![AllowedFeature::Authorization, AllowedFeature::Subscriptions, AllowedFeature::Experimental]))
1466    ]
1467    #[case::coprocessor(
1468        "coprocessor",
1469        HashSet::from_iter(vec![AllowedFeature::DemandControl]))
1470    ]
1471    async fn test_optional_plugin_not_added_with_restricted_allowed_features(
1472        #[case] plugin: &str,
1473        #[case] allowed_features: HashSet<AllowedFeature>,
1474    ) {
1475        /*
1476         * GIVEN
1477         *  - a restricted license whose allowed feature set does not contain the given `plugin`
1478         *  - a valid config including valid config for the given `plugin`
1479         *  - a valid schema
1480         * */
1481        let license = LicenseState::Licensed {
1482            limits: Some(LicenseLimits {
1483                tps: None,
1484                allowed_features,
1485            }),
1486        };
1487
1488        let plugin_config =
1489            serde_yaml::from_str::<serde_json::Value>(get_plugin_config(plugin)).unwrap();
1490        let router_config = Configuration::builder()
1491            .apollo_plugin(plugin, plugin_config)
1492            .build()
1493            .unwrap();
1494
1495        let schema = include_str!("testdata/supergraph.graphql");
1496        let schema = Schema::parse(schema, &router_config).unwrap();
1497
1498        /*
1499         * WHEN
1500         *  - the router factory runs (including the plugin inits gated by the license)
1501         * */
1502        let is_telemetry_disabled = false;
1503        let service = YamlRouterFactory
1504            .create(
1505                is_telemetry_disabled,
1506                Arc::new(router_config),
1507                Arc::new(schema),
1508                None,
1509                None,
1510                Arc::new(license),
1511            )
1512            .await
1513            .unwrap();
1514
1515        /*
1516         * THEN
1517         *  - since the plugin is not part of the `allowed_features` set
1518         *    the plugin should not have been added.
1519         * - mandatory plugins should have been added.
1520         * */
1521        assert!(
1522            !service
1523                .supergraph_creator
1524                .plugins()
1525                .contains_key(&format!("apollo.{plugin}")),
1526            "Plugin {plugin} should not have been added"
1527        );
1528        assert!(
1529            MANDATORY_PLUGINS
1530                .iter()
1531                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1532        );
1533    }
1534
1535    #[tokio::test]
1536    #[rstest]
1537    #[case::mock_subgraphs_non_empty_allowed_features(
1538        "experimental_mock_subgraphs",
1539        HashSet::from_iter(vec![AllowedFeature::DemandControl])
1540    )]
1541    #[case::mock_subgraphs_empty_allowed_features(
1542        "experimental_mock_subgraphs",
1543        HashSet::from_iter(vec![])
1544    )]
1545    async fn test_optional_plugin_that_does_not_map_to_an_allowed_feature_is_added(
1546        #[case] plugin: &str,
1547        #[case] allowed_features: HashSet<AllowedFeature>,
1548    ) {
1549        /*
1550         * GIVEN
1551         *  - a valid license
1552         *  - a valid config including valid config for the optional plugin that does
1553         *    not map to an allowed feature
1554         *  - a valid schema
1555         * */
1556        let license = LicenseState::Licensed {
1557            limits: Some(LicenseLimits {
1558                tps: None,
1559                allowed_features,
1560            }),
1561        };
1562
1563        let plugin_config =
1564            serde_yaml::from_str::<serde_json::Value>(get_plugin_config(plugin)).unwrap();
1565        let router_config = Configuration::builder()
1566            .apollo_plugin(plugin, plugin_config)
1567            .build()
1568            .unwrap();
1569
1570        let schema = include_str!("testdata/supergraph.graphql");
1571        let schema = Schema::parse(schema, &router_config).unwrap();
1572
1573        /*
1574         * WHEN
1575         *  - the router factory runs (including the plugin inits gated by the license)
1576         * */
1577        let is_telemetry_disabled = false;
1578        let service = YamlRouterFactory
1579            .create(
1580                is_telemetry_disabled,
1581                Arc::new(router_config),
1582                Arc::new(schema),
1583                None,
1584                None,
1585                Arc::new(license),
1586            )
1587            .await
1588            .unwrap();
1589
1590        /*
1591         * THEN
1592         * - the plugin should be added
1593         * - mandatory plugins should have been added.
1594         * - coprocessors and subscritions (both gated features) should not have been added.
1595         * */
1596        assert!(
1597            service
1598                .supergraph_creator
1599                .plugins()
1600                .contains_key(&format!("apollo.{plugin}")),
1601            "Plugin {plugin} should have been added"
1602        );
1603        assert!(
1604            MANDATORY_PLUGINS
1605                .iter()
1606                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1607        );
1608        // These gated features should not have been added
1609        assert!(
1610            !service
1611                .supergraph_creator
1612                .plugins()
1613                .contains_key("apollo.subscription"),
1614            "Plugin {plugin} should not have been added"
1615        );
1616        assert!(
1617            !service
1618                .supergraph_creator
1619                .plugins()
1620                .contains_key("apollo.coprocessor"),
1621            "Plugin {plugin} should not have been added"
1622        );
1623    }
1624
1625    #[tokio::test]
1626    #[rstest]
1627    // NB: this is temporary behavior and will change once the `allowed_features` claim is in all licenses
1628    #[case::forbid_mutations("forbid_mutations")]
1629    #[case::subscriptions("subscription")]
1630    #[case::override_subgraph_url("override_subgraph_url")]
1631    #[case::authorization("authorization")]
1632    #[case::authentication("authentication")]
1633    #[case::file_upload("preview_file_uploads")]
1634    #[case::entity_cache("preview_entity_cache")]
1635    #[case::response_cache("preview_response_cache")]
1636    #[case::demand_control("demand_control")]
1637    #[case::connectors("connectors")]
1638    #[case::coprocessor("coprocessor")]
1639    #[case::mock_subgraphs("experimental_mock_subgraphs")]
1640    async fn test_optional_plugin_with_unrestricted_allowed_features(#[case] plugin: &str) {
1641        /*
1642         * GIVEN
1643         *  - a license with unrestricted limits (includes allowing all features)
1644         *  - a valid config including valid config for the given `plugin`
1645         *  - a valid schema
1646         * */
1647        let license = LicenseState::Licensed {
1648            limits: Default::default(),
1649        };
1650
1651        let plugin_config =
1652            serde_yaml::from_str::<serde_json::Value>(get_plugin_config(plugin)).unwrap();
1653        let router_config = Configuration::builder()
1654            .apollo_plugin(plugin, plugin_config)
1655            .build()
1656            .unwrap();
1657
1658        let schema = include_str!("testdata/supergraph.graphql");
1659        let schema = Schema::parse(schema, &router_config).unwrap();
1660
1661        /*
1662         * WHEN
1663         *  - the router factory runs (including the plugin inits gated by the license)
1664         * */
1665        let is_telemetry_disabled = false;
1666        let service = YamlRouterFactory
1667            .create(
1668                is_telemetry_disabled,
1669                Arc::new(router_config),
1670                Arc::new(schema),
1671                None,
1672                None,
1673                Arc::new(license),
1674            )
1675            .await
1676            .unwrap();
1677
1678        /*
1679         * THEN
1680         *  - since `allowed_features` is unrestricted plugin should have been added.
1681         * */
1682        assert!(
1683            service
1684                .supergraph_creator
1685                .plugins()
1686                .contains_key(&format!("apollo.{plugin}")),
1687            "Plugin {plugin} should have been added"
1688        );
1689        assert!(
1690            MANDATORY_PLUGINS
1691                .iter()
1692                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1693        );
1694    }
1695
1696    #[tokio::test]
1697    #[rstest]
1698    // NB: this is temporary behavior and will change once the `allowed_features` claim is in all licenses
1699    #[case::forbid_mutations("forbid_mutations")]
1700    #[case::subscriptions("subscription")]
1701    #[case::override_subgraph_url("override_subgraph_url")]
1702    #[case::authorization("authorization")]
1703    #[case::authentication("authentication")]
1704    #[case::file_upload("preview_file_uploads")]
1705    #[case::response_cache("preview_response_cache")]
1706    #[case::demand_control("demand_control")]
1707    #[case::connectors("connectors")]
1708    #[case::coprocessor("coprocessor")]
1709    #[case::mock_subgraphs("experimental_mock_subgraphs")]
1710    async fn test_optional_plugin_with_default_license_limits(#[case] plugin: &str) {
1711        /*
1712         * GIVEN
1713         *  - a license with license limits None
1714         *  - a valid config including valid config for the given `plugin`
1715         *  - a valid schema
1716         * */
1717        let license = LicenseState::Licensed {
1718            limits: Default::default(),
1719        };
1720
1721        // Create config for the given `plugin`
1722        let plugin_config =
1723            serde_yaml::from_str::<serde_json::Value>(get_plugin_config(plugin)).unwrap();
1724
1725        // Create config for oss plugins
1726        // Create config for oss plugins
1727        let forbid_mutations_config =
1728            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("forbid_mutations"))
1729                .unwrap();
1730        let override_subgraph_url_config =
1731            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("override_subgraph_url"))
1732                .unwrap();
1733        let connectors_config =
1734            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("connectors")).unwrap();
1735        let response_cache_config =
1736            serde_yaml::from_str::<serde_json::Value>(get_plugin_config("preview_response_cache"))
1737                .unwrap();
1738
1739        let router_config = Configuration::builder()
1740            .apollo_plugin("forbid_mutations", forbid_mutations_config)
1741            .apollo_plugin("override_subgraph_url", override_subgraph_url_config)
1742            .apollo_plugin("connectors", connectors_config)
1743            .apollo_plugin("preview_response_cache", response_cache_config)
1744            .apollo_plugin(plugin, plugin_config)
1745            .build()
1746            .unwrap();
1747
1748        let schema = include_str!("testdata/supergraph.graphql");
1749        let schema = Schema::parse(schema, &router_config).unwrap();
1750
1751        /*
1752         * WHEN
1753         *  - the router factory runs (including the plugin inits gated by the license)
1754         * */
1755        let is_telemetry_disabled = false;
1756        let service = YamlRouterFactory
1757            .create(
1758                is_telemetry_disabled,
1759                Arc::new(router_config),
1760                Arc::new(schema),
1761                None,
1762                None,
1763                Arc::new(license),
1764            )
1765            .await
1766            .unwrap();
1767
1768        /*
1769         * THEN
1770         *  // NB: this behavior may change once all licenses have an `allowed_features` claim
1771         *  - when license limits are None we default to unrestricted allowed features
1772         *  - the given `plugin` should have been added
1773         *  - all mandatory plugins should have been added
1774         *  - all oss plugins in the config should have been added
1775         * */
1776        assert!(
1777            service
1778                .supergraph_creator
1779                .plugins()
1780                .contains_key(&format!("apollo.{plugin}")),
1781            "Plugin {plugin} should have been added"
1782        );
1783        assert!(
1784            MANDATORY_PLUGINS
1785                .iter()
1786                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1787        );
1788        assert!(
1789            OSS_PLUGINS
1790                .iter()
1791                .all(|plugin| { service.supergraph_creator.plugins().contains_key(*plugin) })
1792        );
1793    }
1794}