pub(crate) mod cors;
pub(crate) mod expansion;
mod experimental;
pub(crate) mod metrics;
mod persisted_queries;
mod schema;
pub(crate) mod subgraph;
#[cfg(test)]
mod tests;
mod upgrade;
mod yaml;
use std::fmt;
use std::io;
use std::io::BufReader;
use std::iter;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use derivative::Derivative;
use displaydoc::Display;
use itertools::Itertools;
use once_cell::sync::Lazy;
pub(crate) use persisted_queries::PersistedQueries;
#[cfg(test)]
pub(crate) use persisted_queries::PersistedQueriesSafelist;
use regex::Regex;
use rustls::Certificate;
use rustls::PrivateKey;
use rustls::ServerConfig;
use rustls_pemfile::certs;
use rustls_pemfile::read_one;
use rustls_pemfile::Item;
use schemars::gen::SchemaGenerator;
use schemars::schema::ObjectValidation;
use schemars::schema::Schema;
use schemars::schema::SchemaObject;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde_json::Map;
use serde_json::Value;
use thiserror::Error;
use self::cors::Cors;
use self::expansion::Expansion;
pub(crate) use self::experimental::Discussed;
pub(crate) use self::schema::generate_config_schema;
pub(crate) use self::schema::generate_upgrade;
use self::subgraph::SubgraphConfiguration;
use crate::cache::DEFAULT_CACHE_CAPACITY;
use crate::configuration::schema::Mode;
use crate::graphql;
use crate::notification::Notify;
#[cfg(not(test))]
use crate::notification::RouterBroadcasts;
use crate::plugin::plugins;
#[cfg(not(test))]
use crate::plugins::subscription::SubscriptionConfig;
#[cfg(not(test))]
use crate::plugins::subscription::APOLLO_SUBSCRIPTION_PLUGIN;
#[cfg(not(test))]
use crate::plugins::subscription::APOLLO_SUBSCRIPTION_PLUGIN_NAME;
use crate::uplink::UplinkConfig;
use crate::ApolloRouterError;
#[cfg(not(test))]
static HEARTBEAT_TIMEOUT_DURATION_SECONDS: u64 = 15;
static SUPERGRAPH_ENDPOINT_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"(?P<first_path>.*/)(?P<sub_path>.+)\*$")
.expect("this regex to check the path is valid")
});
#[derive(Debug, Error, Display)]
#[non_exhaustive]
pub enum ConfigurationError {
CannotExpandVariable { key: String, cause: String },
UnknownExpansionMode {
key: String,
supported_modes: String,
},
PluginUnknown(String),
PluginConfiguration { plugin: String, error: String },
InvalidConfiguration {
message: &'static str,
error: String,
},
DeserializeConfigError(serde_json::Error),
InvalidExpansionModeConfig,
MigrationFailure { error: String },
CertificateAuthorities { error: String },
}
#[derive(Clone, Derivative, Serialize, JsonSchema)]
#[derivative(Debug)]
pub struct Configuration {
#[serde(skip)]
pub(crate) validated_yaml: Option<Value>,
#[serde(default)]
pub(crate) health_check: HealthCheck,
#[serde(default)]
pub(crate) sandbox: Sandbox,
#[serde(default)]
pub(crate) homepage: Homepage,
#[serde(default)]
pub(crate) supergraph: Supergraph,
#[serde(default)]
pub(crate) cors: Cors,
#[serde(default)]
pub(crate) tls: Tls,
#[serde(default)]
pub(crate) apq: Apq,
#[serde(default)]
pub persisted_queries: PersistedQueries,
#[serde(default)]
pub(crate) limits: Limits,
#[serde(default)]
pub(crate) experimental_chaos: Chaos,
#[serde(default)]
pub(crate) experimental_graphql_validation_mode: GraphQLValidationMode,
#[serde(default)]
pub(crate) experimental_api_schema_generation_mode: ApiSchemaMode,
#[serde(default)]
pub(crate) plugins: UserPlugins,
#[serde(default)]
#[serde(flatten)]
pub(crate) apollo_plugins: ApolloPlugins,
#[serde(skip)]
pub uplink: Option<UplinkConfig>,
#[serde(default, skip_serializing, skip_deserializing)]
pub(crate) notify: Notify<String, graphql::Response>,
#[serde(default)]
pub(crate) experimental_batching: Batching,
}
impl PartialEq for Configuration {
fn eq(&self, other: &Self) -> bool {
self.validated_yaml == other.validated_yaml
}
}
#[derive(Clone, PartialEq, Eq, Default, Derivative, Serialize, Deserialize, JsonSchema)]
#[derivative(Debug)]
#[serde(rename_all = "lowercase")]
pub(crate) enum GraphQLValidationMode {
New,
Legacy,
#[default]
Both,
}
#[derive(Clone, PartialEq, Eq, Default, Derivative, Serialize, Deserialize, JsonSchema)]
#[derivative(Debug)]
#[serde(rename_all = "lowercase")]
pub(crate) enum ApiSchemaMode {
New,
#[default]
Legacy,
Both,
}
impl<'de> serde::Deserialize<'de> for Configuration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize, Default)]
#[serde(default)]
struct AdHocConfiguration {
health_check: HealthCheck,
sandbox: Sandbox,
homepage: Homepage,
supergraph: Supergraph,
cors: Cors,
plugins: UserPlugins,
#[serde(flatten)]
apollo_plugins: ApolloPlugins,
tls: Tls,
apq: Apq,
persisted_queries: PersistedQueries,
#[serde(skip)]
uplink: UplinkConfig,
limits: Limits,
experimental_chaos: Chaos,
experimental_graphql_validation_mode: GraphQLValidationMode,
experimental_batching: Batching,
}
let ad_hoc: AdHocConfiguration = serde::Deserialize::deserialize(deserializer)?;
Configuration::builder()
.health_check(ad_hoc.health_check)
.sandbox(ad_hoc.sandbox)
.homepage(ad_hoc.homepage)
.supergraph(ad_hoc.supergraph)
.cors(ad_hoc.cors)
.plugins(ad_hoc.plugins.plugins.unwrap_or_default())
.apollo_plugins(ad_hoc.apollo_plugins.plugins)
.tls(ad_hoc.tls)
.apq(ad_hoc.apq)
.persisted_query(ad_hoc.persisted_queries)
.operation_limits(ad_hoc.limits)
.chaos(ad_hoc.experimental_chaos)
.uplink(ad_hoc.uplink)
.graphql_validation_mode(ad_hoc.experimental_graphql_validation_mode)
.experimental_batching(ad_hoc.experimental_batching)
.build()
.map_err(|e| serde::de::Error::custom(e.to_string()))
}
}
pub(crate) const APOLLO_PLUGIN_PREFIX: &str = "apollo.";
fn default_graphql_listen() -> ListenAddr {
SocketAddr::from_str("127.0.0.1:4000").unwrap().into()
}
#[allow(dead_code)]
fn test_listen() -> ListenAddr {
SocketAddr::from_str("127.0.0.1:0").unwrap().into()
}
#[buildstructor::buildstructor]
impl Configuration {
#[builder]
pub(crate) fn new(
supergraph: Option<Supergraph>,
health_check: Option<HealthCheck>,
sandbox: Option<Sandbox>,
homepage: Option<Homepage>,
cors: Option<Cors>,
plugins: Map<String, Value>,
apollo_plugins: Map<String, Value>,
tls: Option<Tls>,
notify: Option<Notify<String, graphql::Response>>,
apq: Option<Apq>,
persisted_query: Option<PersistedQueries>,
operation_limits: Option<Limits>,
chaos: Option<Chaos>,
uplink: Option<UplinkConfig>,
graphql_validation_mode: Option<GraphQLValidationMode>,
experimental_api_schema_generation_mode: Option<ApiSchemaMode>,
experimental_batching: Option<Batching>,
) -> Result<Self, ConfigurationError> {
#[cfg(not(test))]
let notify_queue_cap = match apollo_plugins.get(APOLLO_SUBSCRIPTION_PLUGIN_NAME) {
Some(plugin_conf) => {
let conf = serde_json::from_value::<SubscriptionConfig>(plugin_conf.clone())
.map_err(|err| ConfigurationError::PluginConfiguration {
plugin: APOLLO_SUBSCRIPTION_PLUGIN.to_string(),
error: format!("{err:?}"),
})?;
conf.queue_capacity
}
None => None,
};
let conf = Self {
validated_yaml: Default::default(),
supergraph: supergraph.unwrap_or_default(),
health_check: health_check.unwrap_or_default(),
sandbox: sandbox.unwrap_or_default(),
homepage: homepage.unwrap_or_default(),
cors: cors.unwrap_or_default(),
apq: apq.unwrap_or_default(),
persisted_queries: persisted_query.unwrap_or_default(),
limits: operation_limits.unwrap_or_default(),
experimental_chaos: chaos.unwrap_or_default(),
experimental_graphql_validation_mode: graphql_validation_mode.unwrap_or_default(),
experimental_api_schema_generation_mode: experimental_api_schema_generation_mode.unwrap_or_default(),
plugins: UserPlugins {
plugins: Some(plugins),
},
apollo_plugins: ApolloPlugins {
plugins: apollo_plugins,
},
tls: tls.unwrap_or_default(),
uplink,
experimental_batching: experimental_batching.unwrap_or_default(),
#[cfg(test)]
notify: notify.unwrap_or_default(),
#[cfg(not(test))]
notify: notify.map(|n| n.set_queue_size(notify_queue_cap))
.unwrap_or_else(|| Notify::builder().and_queue_size(notify_queue_cap).ttl(Duration::from_secs(HEARTBEAT_TIMEOUT_DURATION_SECONDS)).router_broadcasts(Arc::new(RouterBroadcasts::new())).heartbeat_error_message(graphql::Response::builder().errors(vec![graphql::Error::builder().message("the connection has been closed because it hasn't heartbeat for a while").extension_code("SUBSCRIPTION_HEARTBEAT_ERROR").build()]).build()).build()),
};
conf.validate()
}
}
impl Default for Configuration {
fn default() -> Self {
Configuration::from_str("").expect("default configuration must be valid")
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Configuration {
#[builder]
pub(crate) fn fake_new(
supergraph: Option<Supergraph>,
health_check: Option<HealthCheck>,
sandbox: Option<Sandbox>,
homepage: Option<Homepage>,
cors: Option<Cors>,
plugins: Map<String, Value>,
apollo_plugins: Map<String, Value>,
tls: Option<Tls>,
notify: Option<Notify<String, graphql::Response>>,
apq: Option<Apq>,
persisted_query: Option<PersistedQueries>,
operation_limits: Option<Limits>,
chaos: Option<Chaos>,
uplink: Option<UplinkConfig>,
graphql_validation_mode: Option<GraphQLValidationMode>,
experimental_batching: Option<Batching>,
experimental_api_schema_generation_mode: Option<ApiSchemaMode>,
) -> Result<Self, ConfigurationError> {
let configuration = Self {
validated_yaml: Default::default(),
supergraph: supergraph.unwrap_or_else(|| Supergraph::fake_builder().build()),
health_check: health_check.unwrap_or_else(|| HealthCheck::fake_builder().build()),
sandbox: sandbox.unwrap_or_else(|| Sandbox::fake_builder().build()),
homepage: homepage.unwrap_or_else(|| Homepage::fake_builder().build()),
cors: cors.unwrap_or_default(),
limits: operation_limits.unwrap_or_default(),
experimental_chaos: chaos.unwrap_or_default(),
experimental_graphql_validation_mode: graphql_validation_mode.unwrap_or_default(),
experimental_api_schema_generation_mode: experimental_api_schema_generation_mode
.unwrap_or_default(),
plugins: UserPlugins {
plugins: Some(plugins),
},
apollo_plugins: ApolloPlugins {
plugins: apollo_plugins,
},
tls: tls.unwrap_or_default(),
notify: notify.unwrap_or_default(),
apq: apq.unwrap_or_default(),
persisted_queries: persisted_query.unwrap_or_default(),
uplink,
experimental_batching: experimental_batching.unwrap_or_default(),
};
configuration.validate()
}
}
impl Configuration {
pub(crate) fn validate(self) -> Result<Self, ConfigurationError> {
if self.sandbox.enabled && self.homepage.enabled {
return Err(ConfigurationError::InvalidConfiguration {
message: "sandbox and homepage cannot be enabled at the same time",
error: "disable the homepage if you want to enable sandbox".to_string(),
});
}
if self.sandbox.enabled && !self.supergraph.introspection {
return Err(ConfigurationError::InvalidConfiguration {
message: "sandbox requires introspection",
error: "sandbox needs introspection to be enabled".to_string(),
});
}
if !self.supergraph.path.starts_with('/') {
return Err(ConfigurationError::InvalidConfiguration {
message: "invalid 'server.graphql_path' configuration",
error: format!(
"'{}' is invalid, it must be an absolute path and start with '/', you should try with '/{}'",
self.supergraph.path,
self.supergraph.path
),
});
}
if self.supergraph.path.ends_with('*')
&& !self.supergraph.path.ends_with("/*")
&& !SUPERGRAPH_ENDPOINT_REGEX.is_match(&self.supergraph.path)
{
return Err(ConfigurationError::InvalidConfiguration {
message: "invalid 'server.graphql_path' configuration",
error: format!(
"'{}' is invalid, you can only set a wildcard after a '/'",
self.supergraph.path
),
});
}
if self.supergraph.path.contains("/*/") {
return Err(
ConfigurationError::InvalidConfiguration {
message: "invalid 'server.graphql_path' configuration",
error: format!(
"'{}' is invalid, if you need to set a path like '/*/graphql' then specify it as a path parameter with a name, for example '/:my_project_key/graphql'",
self.supergraph.path
),
},
);
}
if self.persisted_queries.enabled {
if self.persisted_queries.safelist.enabled && self.apq.enabled {
return Err(ConfigurationError::InvalidConfiguration {
message: "apqs must be disabled to enable safelisting",
error: "either set persisted_queries.safelist.enabled: false or apq.enabled: false in your router yaml configuration".into()
});
} else if !self.persisted_queries.safelist.enabled
&& self.persisted_queries.safelist.require_id
{
return Err(ConfigurationError::InvalidConfiguration {
message: "safelist must be enabled to require IDs",
error: "either set persisted_queries.safelist.enabled: true or persisted_queries.safelist.require_id: false in your router yaml configuration".into()
});
}
} else {
if self.persisted_queries.safelist.enabled {
return Err(ConfigurationError::InvalidConfiguration {
message: "persisted queries must be enabled to enable safelisting",
error: "either set persisted_queries.safelist.enabled: false or persisted_queries.enabled: true in your router yaml configuration".into()
});
} else if self.persisted_queries.log_unknown {
return Err(ConfigurationError::InvalidConfiguration {
message: "persisted queries must be enabled to enable logging unknown operations",
error: "either set persisted_queries.log_unknown: false or persisted_queries.enabled: true in your router yaml configuration".into()
});
}
}
Ok(self)
}
}
impl FromStr for Configuration {
type Err = ConfigurationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
schema::validate_yaml_configuration(s, Expansion::default()?, Mode::Upgrade)?.validate()
}
}
fn gen_schema(plugins: schemars::Map<String, Schema>) -> Schema {
let plugins_object = SchemaObject {
object: Some(Box::new(ObjectValidation {
properties: plugins,
additional_properties: Option::Some(Box::new(Schema::Bool(false))),
..Default::default()
})),
..Default::default()
};
Schema::Object(plugins_object)
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(transparent)]
pub(crate) struct ApolloPlugins {
pub(crate) plugins: Map<String, Value>,
}
impl JsonSchema for ApolloPlugins {
fn schema_name() -> String {
stringify!(Plugins).to_string()
}
fn json_schema(gen: &mut SchemaGenerator) -> Schema {
let plugins = crate::plugin::plugins()
.sorted_by_key(|factory| factory.name.clone())
.filter(|factory| factory.name.starts_with(APOLLO_PLUGIN_PREFIX))
.map(|factory| {
(
factory.name[APOLLO_PLUGIN_PREFIX.len()..].to_string(),
factory.create_schema(gen),
)
})
.collect::<schemars::Map<String, Schema>>();
gen_schema(plugins)
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(transparent)]
pub(crate) struct UserPlugins {
pub(crate) plugins: Option<Map<String, Value>>,
}
impl JsonSchema for UserPlugins {
fn schema_name() -> String {
stringify!(Plugins).to_string()
}
fn json_schema(gen: &mut SchemaGenerator) -> Schema {
let plugins = crate::plugin::plugins()
.sorted_by_key(|factory| factory.name.clone())
.filter(|factory| !factory.name.starts_with(APOLLO_PLUGIN_PREFIX))
.map(|factory| (factory.name.to_string(), factory.create_schema(gen)))
.collect::<schemars::Map<String, Schema>>();
gen_schema(plugins)
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Supergraph {
pub(crate) listen: ListenAddr,
pub(crate) path: String,
pub(crate) introspection: bool,
#[serde(rename = "experimental_reuse_query_fragments")]
pub(crate) reuse_query_fragments: Option<bool>,
pub(crate) defer_support: bool,
pub(crate) query_planning: QueryPlanning,
}
fn default_defer_support() -> bool {
true
}
#[buildstructor::buildstructor]
impl Supergraph {
#[builder]
pub(crate) fn new(
listen: Option<ListenAddr>,
path: Option<String>,
introspection: Option<bool>,
defer_support: Option<bool>,
query_planning: Option<QueryPlanning>,
reuse_query_fragments: Option<bool>,
) -> Self {
Self {
listen: listen.unwrap_or_else(default_graphql_listen),
path: path.unwrap_or_else(default_graphql_path),
introspection: introspection.unwrap_or_else(default_graphql_introspection),
defer_support: defer_support.unwrap_or_else(default_defer_support),
query_planning: query_planning.unwrap_or_default(),
reuse_query_fragments,
}
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Supergraph {
#[builder]
pub(crate) fn fake_new(
listen: Option<ListenAddr>,
path: Option<String>,
introspection: Option<bool>,
defer_support: Option<bool>,
query_planning: Option<QueryPlanning>,
reuse_query_fragments: Option<bool>,
) -> Self {
Self {
listen: listen.unwrap_or_else(test_listen),
path: path.unwrap_or_else(default_graphql_path),
introspection: introspection.unwrap_or_else(default_graphql_introspection),
defer_support: defer_support.unwrap_or_else(default_defer_support),
query_planning: query_planning.unwrap_or_default(),
reuse_query_fragments,
}
}
}
impl Default for Supergraph {
fn default() -> Self {
Self::builder().build()
}
}
impl Supergraph {
pub(crate) fn sanitized_path(&self) -> String {
let mut path = self.path.clone();
if self.path.ends_with("/*") {
path = format!("{}router_extra_path", self.path);
} else if SUPERGRAPH_ENDPOINT_REGEX.is_match(&self.path) {
let new_path = SUPERGRAPH_ENDPOINT_REGEX
.replace(&self.path, "${first_path}${sub_path}:supergraph_route");
path = new_path.to_string();
}
path
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct Limits {
pub(crate) max_depth: Option<u32>,
pub(crate) max_height: Option<u32>,
pub(crate) max_root_fields: Option<u32>,
pub(crate) max_aliases: Option<u32>,
pub(crate) warn_only: bool,
pub(crate) parser_max_recursion: usize,
pub(crate) parser_max_tokens: usize,
pub(crate) http_max_request_bytes: usize,
}
impl Default for Limits {
fn default() -> Self {
Self {
max_depth: None,
max_height: None,
max_root_fields: None,
max_aliases: None,
warn_only: false,
http_max_request_bytes: 2_000_000,
parser_max_tokens: 15_000,
parser_max_recursion: 500,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, Default)]
#[serde(deny_unknown_fields)]
pub(crate) struct Router {
#[serde(default)]
pub(crate) cache: Cache,
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct Apq {
pub(crate) enabled: bool,
pub(crate) router: Router,
pub(crate) subgraph: SubgraphConfiguration<SubgraphApq>,
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Apq {
#[builder]
pub(crate) fn fake_new(enabled: Option<bool>) -> Self {
Self {
enabled: enabled.unwrap_or_else(default_apq),
..Default::default()
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct SubgraphApq {
pub(crate) enabled: bool,
}
fn default_apq() -> bool {
true
}
impl Default for Apq {
fn default() -> Self {
Self {
enabled: default_apq(),
router: Default::default(),
subgraph: Default::default(),
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct QueryPlanning {
pub(crate) cache: Cache,
#[serde(default)]
pub(crate) warmed_up_queries: Option<usize>,
pub(crate) experimental_plans_limit: Option<u32>,
pub(crate) experimental_paths_limit: Option<u32>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct Cache {
pub(crate) in_memory: InMemoryCache,
pub(crate) redis: Option<RedisCache>,
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct InMemoryCache {
pub(crate) limit: NonZeroUsize,
}
impl Default for InMemoryCache {
fn default() -> Self {
Self {
limit: DEFAULT_CACHE_CAPACITY,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct RedisCache {
pub(crate) urls: Vec<url::Url>,
pub(crate) username: Option<String>,
pub(crate) password: Option<String>,
#[serde(deserialize_with = "humantime_serde::deserialize", default)]
#[schemars(with = "Option<String>", default)]
pub(crate) timeout: Option<Duration>,
#[serde(deserialize_with = "humantime_serde::deserialize", default)]
#[schemars(with = "Option<String>", default)]
pub(crate) ttl: Option<Duration>,
pub(crate) namespace: Option<String>,
#[serde(default)]
pub(crate) tls: Option<TlsClient>,
#[serde(default = "default_required_to_start")]
pub(crate) required_to_start: bool,
}
fn default_required_to_start() -> bool {
false
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Tls {
pub(crate) supergraph: Option<TlsSupergraph>,
pub(crate) subgraph: SubgraphConfiguration<TlsClient>,
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct TlsSupergraph {
#[serde(deserialize_with = "deserialize_certificate", skip_serializing)]
#[schemars(with = "String")]
pub(crate) certificate: Certificate,
#[serde(deserialize_with = "deserialize_key", skip_serializing)]
#[schemars(with = "String")]
pub(crate) key: PrivateKey,
#[serde(deserialize_with = "deserialize_certificate_chain", skip_serializing)]
#[schemars(with = "String")]
pub(crate) certificate_chain: Vec<Certificate>,
}
impl TlsSupergraph {
pub(crate) fn tls_config(&self) -> Result<Arc<rustls::ServerConfig>, ApolloRouterError> {
let mut certificates = vec![self.certificate.clone()];
certificates.extend(self.certificate_chain.iter().cloned());
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certificates, self.key.clone())
.map_err(ApolloRouterError::Rustls)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
}
fn deserialize_certificate<'de, D>(deserializer: D) -> Result<Certificate, D::Error>
where
D: Deserializer<'de>,
{
let data = String::deserialize(deserializer)?;
load_certs(&data)
.map_err(serde::de::Error::custom)
.and_then(|mut certs| {
if certs.len() > 1 {
Err(serde::de::Error::custom("expected exactly one certificate"))
} else {
certs
.pop()
.ok_or(serde::de::Error::custom("expected exactly one certificate"))
}
})
}
fn deserialize_certificate_chain<'de, D>(deserializer: D) -> Result<Vec<Certificate>, D::Error>
where
D: Deserializer<'de>,
{
let data = String::deserialize(deserializer)?;
load_certs(&data).map_err(serde::de::Error::custom)
}
fn deserialize_key<'de, D>(deserializer: D) -> Result<PrivateKey, D::Error>
where
D: Deserializer<'de>,
{
let data = String::deserialize(deserializer)?;
load_key(&data).map_err(serde::de::Error::custom)
}
pub(crate) fn load_certs(data: &str) -> io::Result<Vec<Certificate>> {
certs(&mut BufReader::new(data.as_bytes()))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())
}
pub(crate) fn load_key(data: &str) -> io::Result<PrivateKey> {
let mut reader = BufReader::new(data.as_bytes());
let mut key_iterator = iter::from_fn(|| read_one(&mut reader).transpose());
let private_key = match key_iterator.next() {
Some(Ok(Item::RSAKey(key))) => PrivateKey(key),
Some(Ok(Item::PKCS8Key(key))) => PrivateKey(key),
Some(Ok(Item::ECKey(key))) => PrivateKey(key),
Some(Err(e)) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("could not parse the key: {e}"),
))
}
Some(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"expected a private key",
))
}
None => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"could not find a private key",
))
}
};
if key_iterator.next().is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"expected exactly one private key",
));
}
Ok(private_key)
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct TlsClient {
pub(crate) certificate_authorities: Option<String>,
pub(crate) client_authentication: Option<TlsClientAuth>,
}
#[buildstructor::buildstructor]
impl TlsClient {
#[builder]
pub(crate) fn new(
certificate_authorities: Option<String>,
client_authentication: Option<TlsClientAuth>,
) -> Self {
Self {
certificate_authorities,
client_authentication,
}
}
}
impl Default for TlsClient {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct TlsClientAuth {
#[serde(deserialize_with = "deserialize_certificate_chain", skip_serializing)]
#[schemars(with = "String")]
pub(crate) certificate_chain: Vec<Certificate>,
#[serde(deserialize_with = "deserialize_key", skip_serializing)]
#[schemars(with = "String")]
pub(crate) key: PrivateKey,
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Sandbox {
pub(crate) enabled: bool,
}
fn default_sandbox() -> bool {
false
}
#[buildstructor::buildstructor]
impl Sandbox {
#[builder]
pub(crate) fn new(enabled: Option<bool>) -> Self {
Self {
enabled: enabled.unwrap_or_else(default_sandbox),
}
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Sandbox {
#[builder]
pub(crate) fn fake_new(enabled: Option<bool>) -> Self {
Self {
enabled: enabled.unwrap_or_else(default_sandbox),
}
}
}
impl Default for Sandbox {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Homepage {
pub(crate) enabled: bool,
pub(crate) graph_ref: Option<String>,
}
fn default_homepage() -> bool {
true
}
#[buildstructor::buildstructor]
impl Homepage {
#[builder]
pub(crate) fn new(enabled: Option<bool>) -> Self {
Self {
enabled: enabled.unwrap_or_else(default_homepage),
graph_ref: None,
}
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Homepage {
#[builder]
pub(crate) fn fake_new(enabled: Option<bool>) -> Self {
Self {
enabled: enabled.unwrap_or_else(default_homepage),
graph_ref: None,
}
}
}
impl Default for Homepage {
fn default() -> Self {
Self::builder().enabled(default_homepage()).build()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct HealthCheck {
pub(crate) listen: ListenAddr,
pub(crate) enabled: bool,
pub(crate) path: String,
}
fn default_health_check_listen() -> ListenAddr {
SocketAddr::from_str("127.0.0.1:8088").unwrap().into()
}
fn default_health_check_enabled() -> bool {
true
}
fn default_health_check_path() -> String {
"/health".to_string()
}
#[buildstructor::buildstructor]
impl HealthCheck {
#[builder]
pub(crate) fn new(
listen: Option<ListenAddr>,
enabled: Option<bool>,
path: Option<String>,
) -> Self {
let mut path = path.unwrap_or_else(default_health_check_path);
if !path.starts_with('/') {
path = format!("/{path}").to_string();
}
Self {
listen: listen.unwrap_or_else(default_health_check_listen),
enabled: enabled.unwrap_or_else(default_health_check_enabled),
path,
}
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl HealthCheck {
#[builder]
pub(crate) fn fake_new(
listen: Option<ListenAddr>,
enabled: Option<bool>,
path: Option<String>,
) -> Self {
let mut path = path.unwrap_or_else(default_health_check_path);
if !path.starts_with('/') {
path = format!("/{path}");
}
Self {
listen: listen.unwrap_or_else(test_listen),
enabled: enabled.unwrap_or_else(default_health_check_enabled),
path,
}
}
}
impl Default for HealthCheck {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Chaos {
#[serde(with = "humantime_serde")]
#[schemars(with = "Option<String>")]
pub(crate) force_reload: Option<std::time::Duration>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
#[serde(untagged)]
pub enum ListenAddr {
SocketAddr(SocketAddr),
#[cfg(unix)]
UnixSocket(std::path::PathBuf),
}
impl ListenAddr {
pub(crate) fn ip_and_port(&self) -> Option<(IpAddr, u16)> {
#[cfg_attr(not(unix), allow(irrefutable_let_patterns))]
if let Self::SocketAddr(addr) = self {
Some((addr.ip(), addr.port()))
} else {
None
}
}
}
impl From<SocketAddr> for ListenAddr {
fn from(addr: SocketAddr) -> Self {
Self::SocketAddr(addr)
}
}
#[allow(clippy::from_over_into)]
impl Into<serde_json::Value> for ListenAddr {
fn into(self) -> serde_json::Value {
match self {
Self::SocketAddr(addr) => serde_json::Value::String(addr.to_string()),
#[cfg(unix)]
Self::UnixSocket(path) => serde_json::Value::String(
path.as_os_str()
.to_str()
.expect("unsupported non-UTF-8 path")
.to_string(),
),
}
}
}
#[cfg(unix)]
impl From<tokio_util::either::Either<std::net::SocketAddr, tokio::net::unix::SocketAddr>>
for ListenAddr
{
fn from(
addr: tokio_util::either::Either<std::net::SocketAddr, tokio::net::unix::SocketAddr>,
) -> Self {
match addr {
tokio_util::either::Either::Left(addr) => Self::SocketAddr(addr),
tokio_util::either::Either::Right(addr) => Self::UnixSocket(
addr.as_pathname()
.map(ToOwned::to_owned)
.unwrap_or_default(),
),
}
}
}
impl fmt::Display for ListenAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::SocketAddr(addr) => write!(f, "http://{addr}"),
#[cfg(unix)]
Self::UnixSocket(path) => write!(f, "{}", path.display()),
}
}
}
fn default_graphql_path() -> String {
String::from("/")
}
fn default_graphql_introspection() -> bool {
false
}
#[derive(Clone, Debug, Default, Error, Display, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum BatchingMode {
#[default]
BatchHttpLink,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub(crate) struct Batching {
#[serde(default)]
pub(crate) enabled: bool,
pub(crate) mode: BatchingMode,
}