use crate::config::environment::Environment;
#[cfg(feature = "otel")]
use crate::util::serde::default_true;
use config::{FileFormat, FileSourceString};
use itertools::Itertools;
use serde_derive::{Deserialize, Serialize};
use serde_with::serde_as;
use std::borrow::Cow;
use std::str::FromStr;
use strum_macros::{EnumString, IntoStaticStr};
use tracing_subscriber::EnvFilter;
#[cfg(feature = "otel")]
use url::Url;
use validator::{Validate, ValidationError};
pub(crate) fn default_config() -> config::File<FileSourceString, FileFormat> {
config::File::from_str(include_str!("config/default.toml"), FileFormat::Toml)
}
pub(crate) fn default_config_per_env(
environment: Environment,
) -> Option<config::File<FileSourceString, FileFormat>> {
let config = match environment {
Environment::Development => Some(include_str!("config/development.toml")),
Environment::Test => Some(include_str!("config/test.toml")),
_ => None,
};
config.map(|c| config::File::from_str(c, FileFormat::Toml))
}
#[serde_as]
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub struct Tracing {
#[validate(custom(function = "validate_level"))]
pub level: String,
pub format: Format,
#[cfg(feature = "otel")]
pub service_name: Option<String>,
#[serde(default = "default_true")]
#[cfg(feature = "otel")]
pub trace_propagation: bool,
#[serde(default)]
#[cfg(feature = "otel")]
pub trace_sampling_ratio: Option<f64>,
#[cfg(feature = "otel")]
#[serde_as(as = "Option<serde_with::DurationMilliSeconds>")]
pub metrics_export_interval: Option<std::time::Duration>,
#[serde(default)]
#[validate(custom(function = "validate_env_filter_str"))]
pub trace_filters: Vec<String>,
#[validate(nested)]
#[serde(default)]
#[cfg(feature = "otel")]
pub otlp: Option<Otlp>,
}
fn validate_level(level: &str) -> Result<(), ValidationError> {
let result = tracing::Level::from_str(level);
let err = match result {
Ok(_) => {
return Ok(());
}
Err(err) => err,
};
let mut validation_error = ValidationError::new("Invalid level string");
validation_error.add_param(Cow::from("level"), &level);
validation_error.add_param(Cow::from("error"), &err.to_string());
Err(validation_error)
}
fn validate_env_filter_str(trace_filters: &[String]) -> Result<(), ValidationError> {
let invalid_filters = trace_filters
.iter()
.filter_map(|filter| {
let parsed_filter: Result<EnvFilter, _> = filter.parse();
if let Err(err) = parsed_filter {
Some((filter, err.to_string()))
} else {
None
}
})
.collect_vec();
if !invalid_filters.is_empty() {
let mut err = ValidationError::new("Invalid env filter(s)");
let (filters, errors) = invalid_filters.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut filters, mut errors), (filter, error)| {
filters.push(filter);
errors.push(error);
(filters, errors)
},
);
err.add_param(Cow::from("filters"), &filters);
err.add_param(Cow::from("errors"), &errors);
return Err(err);
}
Ok(())
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, EnumString, IntoStaticStr)]
#[serde(rename_all = "kebab-case")]
#[strum(serialize_all = "kebab-case")]
#[non_exhaustive]
pub enum Format {
None,
Pretty,
Compact,
Json,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
#[cfg(feature = "otel")]
pub struct Otlp {
#[serde(default)]
endpoint: Option<OtlpProtocol>,
#[serde(default)]
trace_endpoint: Option<OtlpProtocol>,
#[serde(default)]
metric_endpoint: Option<OtlpProtocol>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case", tag = "protocol")]
#[non_exhaustive]
#[cfg(feature = "otel")]
pub enum OtlpProtocol {
Http(OtlpEndpoint),
#[cfg(feature = "otel-grpc")]
Grpc(OtlpEndpoint),
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
#[cfg(feature = "otel")]
pub struct OtlpEndpoint {
pub url: Url,
}
#[cfg(feature = "otel")]
impl Otlp {
pub fn trace_endpoint(&self) -> Option<&OtlpProtocol> {
self.trace_endpoint.as_ref().or(self.endpoint.as_ref())
}
pub fn metric_endpoint(&self) -> Option<&OtlpProtocol> {
self.metric_endpoint.as_ref().or(self.endpoint.as_ref())
}
}
#[cfg(all(test, feature = "otel", feature = "otel-grpc"))]
mod deserialize_tests {
use super::*;
use crate::testing::snapshot::TestCase;
use insta::assert_toml_snapshot;
use rstest::{fixture, rstest};
#[fixture]
#[cfg_attr(coverage_nightly, coverage(off))]
fn case() -> TestCase {
Default::default()
}
#[rstest]
#[case(
r#"
level = "debug"
format = "compact"
"#
)]
#[case(
r#"
level = "info"
format = "json"
service-name = "foo"
"#
)]
#[case(
r#"
level = "error"
format = "pretty"
trace-propagation = false
"#
)]
#[case(
r#"
level = "debug"
format = "none"
metrics-export-interval = 60000
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.endpoint]
protocol = "http"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.endpoint]
protocol = "grpc"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.trace-endpoint]
protocol = "http"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.trace-endpoint]
protocol = "grpc"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.metric-endpoint]
protocol = "http"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
[otlp.metric-endpoint]
protocol = "grpc"
url = "https://example.com:1234"
"#
)]
#[case(
r#"
level = "debug"
format = "none"
trace-filters = [ "foo=warn" ]
"#
)]
#[cfg_attr(coverage_nightly, coverage(off))]
fn tracing(_case: TestCase, #[case] config: &str) {
let tracing: Tracing = toml::from_str(config).unwrap();
assert_toml_snapshot!(tracing);
}
}
#[cfg(test)]
mod tests {
use crate::testing::snapshot::TestCase;
use rstest::{fixture, rstest};
use validator::Validate;
#[fixture]
#[cfg_attr(coverage_nightly, coverage(off))]
fn case() -> TestCase {
Default::default()
}
#[rstest]
#[case(
r#"
level = "debug"
format = "none"
trace-filters = [ "foo=warn" ]
"#,
false
)]
#[case(
r#"
level = "debug"
format = "none"
trace-filters = [ "foo=warn", "invalid filter" ]
"#,
true
)]
fn validation(_case: TestCase, #[case] config: &str, #[case] error: bool) {
let tracing: super::Tracing = toml::from_str(config).unwrap();
let validate_result = tracing.validate();
assert_eq!(validate_result.is_err(), error);
}
}