use std::collections::BTreeMap;
use axum::headers::HeaderName;
use opentelemetry::sdk::resource::EnvResourceDetector;
use opentelemetry::sdk::resource::ResourceDetector;
use opentelemetry::sdk::trace::SpanLimits;
use opentelemetry::sdk::Resource;
use opentelemetry::Array;
use opentelemetry::KeyValue;
use opentelemetry::Value;
use regex::Regex;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use super::metrics::MetricsAttributesConf;
use super::*;
use crate::configuration::ConfigurationError;
use crate::plugin::serde::deserialize_option_header_name;
use crate::plugin::serde::deserialize_regex;
use crate::plugins::telemetry::metrics;
#[derive(thiserror::Error, Debug)]
pub(crate) enum Error {
#[error("field level instrumentation sampler must sample less frequently than tracing level sampler")]
InvalidFieldLevelInstrumentationSampler,
}
pub(crate) trait GenericWith<T>
where
Self: Sized,
{
fn with<B>(self, option: &Option<B>, apply: fn(Self, &B) -> Self) -> Self {
if let Some(option) = option {
return apply(self, option);
}
self
}
fn try_with<B>(
self,
option: &Option<B>,
apply: fn(Self, &B) -> Result<Self, BoxError>,
) -> Result<Self, BoxError> {
if let Some(option) = option {
return apply(self, option);
}
Ok(self)
}
}
impl<T> GenericWith<T> for T where Self: Sized {}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) struct Conf {
#[serde(rename = "experimental_logging", default)]
pub(crate) logging: Logging,
pub(crate) metrics: Option<Metrics>,
pub(crate) tracing: Option<Tracing>,
pub(crate) apollo: Option<apollo::Config>,
}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
#[allow(dead_code)]
pub(crate) struct Metrics {
pub(crate) common: Option<MetricsCommon>,
pub(crate) otlp: Option<otlp::Config>,
pub(crate) prometheus: Option<metrics::prometheus::Config>,
}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) struct MetricsCommon {
pub(crate) attributes: Option<MetricsAttributesConf>,
pub(crate) service_name: Option<String>,
pub(crate) service_namespace: Option<String>,
#[serde(default)]
pub(crate) resources: HashMap<String, String>,
}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) struct Tracing {
#[serde(default, rename = "experimental_response_trace_id")]
pub(crate) response_trace_id: ExposeTraceId,
pub(crate) propagation: Option<Propagation>,
pub(crate) trace_config: Option<Trace>,
pub(crate) otlp: Option<otlp::Config>,
pub(crate) jaeger: Option<tracing::jaeger::Config>,
pub(crate) zipkin: Option<tracing::zipkin::Config>,
pub(crate) datadog: Option<tracing::datadog::Config>,
}
#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
#[serde(deny_unknown_fields, default)]
pub(crate) struct Logging {
pub(crate) format: LoggingFormat,
pub(crate) display_target: bool,
pub(crate) display_filename: bool,
pub(crate) display_line_number: bool,
pub(crate) when_header: Vec<HeaderLoggingCondition>,
}
impl Logging {
pub(crate) fn validate(&self) -> Result<(), ConfigurationError> {
let misconfiguration = self.when_header.iter().any(|cfg| match cfg {
HeaderLoggingCondition::Matching { headers, body, .. }
| HeaderLoggingCondition::Value { headers, body, .. } => !body && !headers,
});
if misconfiguration {
Err(ConfigurationError::InvalidConfiguration {
message: "'when_header' configuration for logging is invalid",
error: String::from(
"body and headers must not be both false because it doesn't enable any logs",
),
})
} else {
Ok(())
}
}
pub(crate) fn should_log(&self, req: &SupergraphRequest) -> (bool, bool) {
self.when_header
.iter()
.fold((false, false), |(log_headers, log_body), current| {
let (current_log_headers, current_log_body) = current.should_log(req);
(
log_headers || current_log_headers,
log_body || current_log_body,
)
})
}
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum HeaderLoggingCondition {
Matching {
name: String,
#[schemars(with = "String", rename = "match")]
#[serde(deserialize_with = "deserialize_regex", rename = "match")]
matching: Regex,
#[serde(default)]
headers: bool,
#[serde(default)]
body: bool,
},
Value {
name: String,
value: String,
#[serde(default)]
headers: bool,
#[serde(default)]
body: bool,
},
}
impl HeaderLoggingCondition {
pub(crate) fn should_log(&self, req: &SupergraphRequest) -> (bool, bool) {
match self {
HeaderLoggingCondition::Matching {
name,
matching: matched,
headers,
body,
} => {
let header_match = req
.supergraph_request
.headers()
.get(name)
.and_then(|h| h.to_str().ok())
.map(|h| matched.is_match(h))
.unwrap_or_default();
if header_match {
(*headers, *body)
} else {
(false, false)
}
}
HeaderLoggingCondition::Value {
name,
value,
headers,
body,
} => {
let header_match = req
.supergraph_request
.headers()
.get(name)
.and_then(|h| h.to_str().ok())
.map(|h| value.as_str() == h)
.unwrap_or_default();
if header_match {
(*headers, *body)
} else {
(false, false)
}
}
}
}
}
#[derive(Clone, Debug, Deserialize, JsonSchema, Copy)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum LoggingFormat {
Pretty,
Json,
}
impl Default for LoggingFormat {
fn default() -> Self {
if atty::is(atty::Stream::Stdout) {
Self::Pretty
} else {
Self::Json
}
}
}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case", default)]
pub(crate) struct ExposeTraceId {
pub(crate) enabled: bool,
#[schemars(with = "Option<String>")]
#[serde(deserialize_with = "deserialize_option_header_name")]
pub(crate) header_name: Option<HeaderName>,
}
#[derive(Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case", default)]
pub(crate) struct Propagation {
pub(crate) request: RequestPropagation,
pub(crate) baggage: bool,
pub(crate) trace_context: bool,
pub(crate) jaeger: bool,
pub(crate) datadog: bool,
pub(crate) zipkin: bool,
}
#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) struct RequestPropagation {
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_option_header_name")]
pub(crate) header_name: Option<HeaderName>,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
#[non_exhaustive]
pub(crate) struct Trace {
pub(crate) service_name: String,
pub(crate) service_namespace: String,
pub(crate) sampler: SamplerOption,
pub(crate) parent_based_sampler: bool,
pub(crate) max_events_per_span: u32,
pub(crate) max_attributes_per_span: u32,
pub(crate) max_links_per_span: u32,
pub(crate) max_attributes_per_event: u32,
pub(crate) max_attributes_per_link: u32,
pub(crate) attributes: BTreeMap<String, AttributeValue>,
}
fn default_parent_based_sampler() -> bool {
true
}
fn default_sampler() -> SamplerOption {
SamplerOption::Always(Sampler::AlwaysOn)
}
impl Default for Trace {
fn default() -> Self {
Self {
service_name: "router".to_string(),
service_namespace: Default::default(),
sampler: default_sampler(),
parent_based_sampler: default_parent_based_sampler(),
max_events_per_span: default_max_events_per_span(),
max_attributes_per_span: default_max_attributes_per_span(),
max_links_per_span: default_max_links_per_span(),
max_attributes_per_event: default_max_attributes_per_event(),
max_attributes_per_link: default_max_attributes_per_link(),
attributes: Default::default(),
}
}
}
fn default_max_events_per_span() -> u32 {
SpanLimits::default().max_events_per_span
}
fn default_max_attributes_per_span() -> u32 {
SpanLimits::default().max_attributes_per_span
}
fn default_max_links_per_span() -> u32 {
SpanLimits::default().max_links_per_span
}
fn default_max_attributes_per_event() -> u32 {
SpanLimits::default().max_attributes_per_event
}
fn default_max_attributes_per_link() -> u32 {
SpanLimits::default().max_attributes_per_link
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
#[serde(untagged, deny_unknown_fields)]
pub(crate) enum AttributeValue {
Bool(bool),
I64(i64),
F64(f64),
String(String),
Array(AttributeArray),
}
impl TryFrom<serde_json::Value> for AttributeValue {
type Error = ();
fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Null => Err(()),
serde_json::Value::Bool(v) => Ok(AttributeValue::Bool(v)),
serde_json::Value::Number(v) if v.is_i64() => {
Ok(AttributeValue::I64(v.as_i64().expect("i64 checked")))
}
serde_json::Value::Number(v) if v.is_f64() => {
Ok(AttributeValue::F64(v.as_f64().expect("f64 checked")))
}
serde_json::Value::String(v) => Ok(AttributeValue::String(v)),
serde_json::Value::Array(v) => {
if v.iter().all(|v| v.is_boolean()) {
Ok(AttributeValue::Array(AttributeArray::Bool(
v.iter()
.map(|v| v.as_bool().expect("all bools checked"))
.collect(),
)))
} else if v.iter().all(|v| v.is_f64()) {
Ok(AttributeValue::Array(AttributeArray::F64(
v.iter()
.map(|v| v.as_f64().expect("all f64 checked"))
.collect(),
)))
} else if v.iter().all(|v| v.is_i64()) {
Ok(AttributeValue::Array(AttributeArray::I64(
v.iter()
.map(|v| v.as_i64().expect("all i64 checked"))
.collect(),
)))
} else if v.iter().all(|v| v.is_string()) {
Ok(AttributeValue::Array(AttributeArray::String(
v.iter()
.map(|v| v.as_str().expect("all strings checked").to_string())
.collect(),
)))
} else {
Err(())
}
}
serde_json::Value::Object(_v) => Err(()),
_ => Err(()),
}
}
}
impl From<AttributeValue> for opentelemetry::Value {
fn from(value: AttributeValue) -> Self {
match value {
AttributeValue::Bool(v) => Value::Bool(v),
AttributeValue::I64(v) => Value::I64(v),
AttributeValue::F64(v) => Value::F64(v),
AttributeValue::String(v) => Value::String(v.into()),
AttributeValue::Array(v) => Value::Array(v.into()),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
#[serde(untagged, deny_unknown_fields)]
pub(crate) enum AttributeArray {
Bool(Vec<bool>),
I64(Vec<i64>),
F64(Vec<f64>),
String(Vec<String>),
}
impl From<AttributeArray> for opentelemetry::Array {
fn from(array: AttributeArray) -> Self {
match array {
AttributeArray::Bool(v) => Array::Bool(v),
AttributeArray::I64(v) => Array::I64(v),
AttributeArray::F64(v) => Array::F64(v),
AttributeArray::String(v) => Array::String(v.into_iter().map(|v| v.into()).collect()),
}
}
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, untagged)]
pub(crate) enum SamplerOption {
TraceIdRatioBased(f64),
Always(Sampler),
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, rename_all = "snake_case")]
pub(crate) enum Sampler {
AlwaysOn,
AlwaysOff,
}
impl From<Sampler> for opentelemetry::sdk::trace::Sampler {
fn from(s: Sampler) -> Self {
match s {
Sampler::AlwaysOn => opentelemetry::sdk::trace::Sampler::AlwaysOn,
Sampler::AlwaysOff => opentelemetry::sdk::trace::Sampler::AlwaysOff,
}
}
}
impl From<SamplerOption> for opentelemetry::sdk::trace::Sampler {
fn from(s: SamplerOption) -> Self {
match s {
SamplerOption::Always(s) => s.into(),
SamplerOption::TraceIdRatioBased(ratio) => {
opentelemetry::sdk::trace::Sampler::TraceIdRatioBased(ratio)
}
}
}
}
impl From<&Trace> for opentelemetry::sdk::trace::Config {
fn from(config: &Trace) -> Self {
let mut trace_config = opentelemetry::sdk::trace::config();
let mut sampler: opentelemetry::sdk::trace::Sampler = config.sampler.clone().into();
if config.parent_based_sampler {
sampler = parent_based(sampler);
}
trace_config = trace_config.with_sampler(sampler);
trace_config = trace_config.with_max_events_per_span(config.max_events_per_span);
trace_config = trace_config.with_max_attributes_per_span(config.max_attributes_per_span);
trace_config = trace_config.with_max_links_per_span(config.max_links_per_span);
trace_config = trace_config.with_max_attributes_per_event(config.max_attributes_per_event);
trace_config = trace_config.with_max_attributes_per_link(config.max_attributes_per_link);
let mut resource_defaults = vec![];
resource_defaults.push(KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
config.service_name.clone(),
));
resource_defaults.push(KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAMESPACE,
config.service_namespace.clone(),
));
resource_defaults.push(KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_VERSION,
std::env!("CARGO_PKG_VERSION"),
));
if let Some(executable_name) = std::env::current_exe().ok().and_then(|path| {
path.file_name()
.and_then(|p| p.to_str().map(|s| s.to_string()))
}) {
resource_defaults.push(KeyValue::new(
opentelemetry_semantic_conventions::resource::PROCESS_EXECUTABLE_NAME,
executable_name,
));
}
let resource = EnvResourceDetector::default()
.detect(Duration::from_secs(0))
.merge(&Resource::new(resource_defaults))
.merge(&mut Resource::new(
config
.attributes
.iter()
.map(|(k, v)| {
KeyValue::new(
opentelemetry::Key::from(k.clone()),
opentelemetry::Value::from(v.clone()),
)
})
.collect::<Vec<KeyValue>>(),
));
trace_config = trace_config.with_resource(resource);
trace_config
}
}
fn parent_based(sampler: opentelemetry::sdk::trace::Sampler) -> opentelemetry::sdk::trace::Sampler {
opentelemetry::sdk::trace::Sampler::ParentBased(Box::new(sampler))
}
impl Conf {
pub(crate) fn calculate_field_level_instrumentation_ratio(&self) -> Result<f64, Error> {
Ok(
match (
self.tracing
.clone()
.unwrap_or_default()
.trace_config
.unwrap_or_default()
.sampler,
self.apollo
.clone()
.unwrap_or_default()
.field_level_instrumentation_sampler,
) {
(
SamplerOption::TraceIdRatioBased(global_ratio),
SamplerOption::TraceIdRatioBased(field_ratio),
) if field_ratio > global_ratio => {
Err(Error::InvalidFieldLevelInstrumentationSampler)?
}
(
SamplerOption::Always(Sampler::AlwaysOff),
SamplerOption::Always(Sampler::AlwaysOn),
) => Err(Error::InvalidFieldLevelInstrumentationSampler)?,
(
SamplerOption::Always(Sampler::AlwaysOff),
SamplerOption::TraceIdRatioBased(ratio),
) if ratio != 0.0 => Err(Error::InvalidFieldLevelInstrumentationSampler)?,
(
SamplerOption::TraceIdRatioBased(ratio),
SamplerOption::Always(Sampler::AlwaysOn),
) if ratio != 1.0 => Err(Error::InvalidFieldLevelInstrumentationSampler)?,
(_, SamplerOption::TraceIdRatioBased(ratio)) if ratio == 0.0 => 0.0,
(SamplerOption::TraceIdRatioBased(ratio), _) if ratio == 0.0 => 0.0,
(_, SamplerOption::Always(Sampler::AlwaysOn)) => 1.0,
(
SamplerOption::TraceIdRatioBased(global_ratio),
SamplerOption::TraceIdRatioBased(field_ratio),
) => field_ratio / global_ratio,
(
SamplerOption::Always(Sampler::AlwaysOn),
SamplerOption::TraceIdRatioBased(field_ratio),
) => field_ratio,
(_, _) => 0.0,
},
)
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn test_logging_conf_validation() {
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![HeaderLoggingCondition::Value {
name: "test".to_string(),
value: String::new(),
headers: true,
body: false,
}],
};
logging_conf.validate().unwrap();
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![HeaderLoggingCondition::Value {
name: "test".to_string(),
value: String::new(),
headers: false,
body: false,
}],
};
let validate_res = logging_conf.validate();
assert!(validate_res.is_err());
assert_eq!(validate_res.unwrap_err().to_string(), "'when_header' configuration for logging is invalid: body and headers must not be both false because it doesn't enable any logs");
}
#[test]
fn test_logging_conf_should_log() {
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![HeaderLoggingCondition::Matching {
name: "test".to_string(),
matching: Regex::new("^foo*").unwrap(),
headers: true,
body: false,
}],
};
let req = SupergraphRequest::fake_builder()
.header("test", "foobar")
.build()
.unwrap();
assert_eq!(logging_conf.should_log(&req), (true, false));
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![HeaderLoggingCondition::Value {
name: "test".to_string(),
value: String::from("foobar"),
headers: true,
body: false,
}],
};
assert_eq!(logging_conf.should_log(&req), (true, false));
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![
HeaderLoggingCondition::Matching {
name: "test".to_string(),
matching: Regex::new("^foo*").unwrap(),
headers: true,
body: false,
},
HeaderLoggingCondition::Matching {
name: "test".to_string(),
matching: Regex::new("^*bar$").unwrap(),
headers: false,
body: true,
},
],
};
assert_eq!(logging_conf.should_log(&req), (true, true));
let logging_conf = Logging {
format: LoggingFormat::default(),
display_target: false,
display_filename: false,
display_line_number: false,
when_header: vec![HeaderLoggingCondition::Matching {
name: "testtest".to_string(),
matching: Regex::new("^foo*").unwrap(),
headers: true,
body: false,
}],
};
assert_eq!(logging_conf.should_log(&req), (false, false));
}
#[test]
fn test_attribute_value_from_json() {
assert_eq!(
AttributeValue::try_from(json!("foo")),
Ok(AttributeValue::String("foo".to_string()))
);
assert_eq!(
AttributeValue::try_from(json!(1)),
Ok(AttributeValue::I64(1))
);
assert_eq!(
AttributeValue::try_from(json!(1.1)),
Ok(AttributeValue::F64(1.1))
);
assert_eq!(
AttributeValue::try_from(json!(true)),
Ok(AttributeValue::Bool(true))
);
assert_eq!(
AttributeValue::try_from(json!(["foo", "bar"])),
Ok(AttributeValue::Array(AttributeArray::String(vec![
"foo".to_string(),
"bar".to_string()
])))
);
assert_eq!(
AttributeValue::try_from(json!([1, 2])),
Ok(AttributeValue::Array(AttributeArray::I64(vec![1, 2])))
);
assert_eq!(
AttributeValue::try_from(json!([1.1, 1.5])),
Ok(AttributeValue::Array(AttributeArray::F64(vec![1.1, 1.5])))
);
assert_eq!(
AttributeValue::try_from(json!([true, false])),
Ok(AttributeValue::Array(AttributeArray::Bool(vec![
true, false
])))
);
AttributeValue::try_from(json!(["foo", true])).expect_err("mixed conversion must fail");
AttributeValue::try_from(json!([1, true])).expect_err("mixed conversion must fail");
AttributeValue::try_from(json!([1.1, true])).expect_err("mixed conversion must fail");
AttributeValue::try_from(json!([true, "bar"])).expect_err("mixed conversion must fail");
}
}