use apollo_configuration::ErrorCollector;
use apollo_configuration::Validate;
use apollo_configuration::configuration;
#[configuration]
#[derive(Default)]
pub(crate) enum Sampler {
#[default]
AlwaysOn,
AlwaysOff,
TraceIdRatioBased(TraceIdRatioBasedSamplerConfig),
ParentBased(ParentBasedSamplerConfig),
}
#[configuration]
pub(crate) struct TraceIdRatioBasedSamplerConfig {
#[config(required)]
pub(crate) ratio: SamplingRatio,
}
#[derive(Debug, Clone, Copy, serde::Deserialize, schemars::JsonSchema)]
#[serde(transparent)]
pub(crate) struct SamplingRatio(f64);
impl Validate for SamplingRatio {
fn validate(&self, mut errors: ErrorCollector<'_>) {
if !(0.0..=1.0).contains(&self.0) {
errors.report_simple("sampling ratio must be between 0.0 and 1.0");
}
}
}
#[configuration]
pub(crate) struct ParentBasedSamplerConfig {
pub(crate) root: Option<Box<RootSampler>>,
pub(crate) remote_parent_sampled: Option<Box<RootSampler>>,
pub(crate) remote_parent_not_sampled: Option<Box<RootSampler>>,
pub(crate) local_parent_sampled: Option<Box<RootSampler>>,
pub(crate) local_parent_not_sampled: Option<Box<RootSampler>>,
}
#[configuration]
pub(crate) enum RootSampler {
AlwaysOn,
AlwaysOff,
TraceIdRatioBased(TraceIdRatioBasedSamplerConfig),
}
use opentelemetry_sdk::trace::Sampler as SdkSampler;
impl From<&Sampler> for SdkSampler {
fn from(config: &Sampler) -> Self {
match config {
Sampler::AlwaysOn => SdkSampler::AlwaysOn,
Sampler::AlwaysOff => SdkSampler::AlwaysOff,
Sampler::TraceIdRatioBased(ratio_config) => {
SdkSampler::TraceIdRatioBased(ratio_config.ratio.0)
}
Sampler::ParentBased(parent_config) => {
let root = parent_config
.root
.as_ref()
.map(|s| SdkSampler::from(s.as_ref()))
.unwrap_or(SdkSampler::AlwaysOn);
SdkSampler::ParentBased(Box::new(root))
}
}
}
}
impl From<&RootSampler> for SdkSampler {
fn from(config: &RootSampler) -> Self {
match config {
RootSampler::AlwaysOn => SdkSampler::AlwaysOn,
RootSampler::AlwaysOff => SdkSampler::AlwaysOff,
RootSampler::TraceIdRatioBased(ratio_config) => {
SdkSampler::TraceIdRatioBased(ratio_config.ratio.0)
}
}
}
}
#[cfg(test)]
mod tests {
use apollo_configuration::parse_yaml;
use super::*;
use crate::config::OpenTelemetryConfig;
#[test]
fn always_on_conversion() {
let config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler: always_on
"},
&Default::default(),
)
.unwrap();
let sampler: SdkSampler = (&config.tracer_provider.sampler).into();
assert!(matches!(sampler, SdkSampler::AlwaysOn));
}
#[test]
fn always_off_conversion() {
let config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler: always_off
"},
&Default::default(),
)
.unwrap();
let sampler: SdkSampler = (&config.tracer_provider.sampler).into();
assert!(matches!(sampler, SdkSampler::AlwaysOff));
}
#[test]
fn trace_id_ratio_conversion() {
let config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler:
trace_id_ratio_based:
ratio: 0.5
"},
&Default::default(),
)
.unwrap();
let sampler: SdkSampler = (&config.tracer_provider.sampler).into();
let SdkSampler::TraceIdRatioBased(ratio) = sampler else {
panic!("Expected TraceIdRatioBased sampler");
};
assert!((ratio - 0.5).abs() < f64::EPSILON);
}
#[test]
fn parent_based_conversion() {
let config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler:
parent_based:
root:
trace_id_ratio_based:
ratio: 0.1
"},
&Default::default(),
)
.unwrap();
let sampler: SdkSampler = (&config.tracer_provider.sampler).into();
assert!(matches!(sampler, SdkSampler::ParentBased(_)));
}
#[test]
fn parse_sampler_always_on() {
let _config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler: always_on
"},
&Default::default(),
)
.unwrap();
}
#[test]
fn parse_sampler_ratio() {
let _config: OpenTelemetryConfig = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler:
trace_id_ratio_based:
ratio: 0.5
"},
&Default::default(),
)
.unwrap();
}
#[test]
fn parse_sampler_ratio_invalid() {
let result: Result<OpenTelemetryConfig, _> = parse_yaml(
indoc::indoc! {"
tracer_provider:
sampler:
trace_id_ratio_based:
ratio: 1.5
"},
&Default::default(),
);
assert!(result.is_err());
let err = format!("{:?}", result.unwrap_err());
assert!(
err.contains("sampling ratio must be between 0.0 and 1.0"),
"unexpected error: {err}"
);
}
}