use opentelemetry::{global, trace::TracerProvider as _, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
trace::{RandomIdGenerator, Sampler, SdkTracerProvider},
Resource,
};
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::Registry;
use crate::config::TracingConfig;
use crate::error::{ObservabilityError, Result};
pub struct TracingGuard {
provider: Option<SdkTracerProvider>,
}
impl Drop for TracingGuard {
fn drop(&mut self) {
if let Some(ref provider) = self.provider {
if let Err(e) = provider.shutdown() {
tracing::warn!("Error shutting down tracer provider: {:?}", e);
}
}
}
}
pub fn init_tracing(config: &TracingConfig) -> Result<TracingGuard> {
if !config.enabled {
tracing::info!("OpenTelemetry tracing disabled");
return Ok(TracingGuard { provider: None });
}
let endpoint = config.otlp_endpoint.as_ref().ok_or_else(|| {
ObservabilityError::TracingInit(
"OTLP endpoint required when tracing is enabled".to_string(),
)
})?;
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_endpoint(endpoint)
.build()
.map_err(|e| ObservabilityError::TracingInit(e.to_string()))?;
let sampler = if config.sampling_ratio >= 1.0 {
Sampler::AlwaysOn
} else if config.sampling_ratio <= 0.0 {
Sampler::AlwaysOff
} else {
Sampler::TraceIdRatioBased(config.sampling_ratio)
};
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.with_sampler(sampler)
.with_id_generator(RandomIdGenerator::default())
.with_resource(
Resource::builder_empty()
.with_service_name(config.service_name.clone())
.with_attribute(KeyValue::new("service.version", env!("CARGO_PKG_VERSION")))
.build(),
)
.build();
global::set_tracer_provider(provider.clone());
tracing::info!(
endpoint = %endpoint,
service_name = %config.service_name,
sampling_ratio = config.sampling_ratio,
"OpenTelemetry tracing initialized"
);
Ok(TracingGuard {
provider: Some(provider),
})
}
pub fn create_otel_layer(
config: &TracingConfig,
) -> Result<Option<OpenTelemetryLayer<Registry, opentelemetry_sdk::trace::Tracer>>> {
if !config.enabled {
return Ok(None);
}
let endpoint = config.otlp_endpoint.as_ref().ok_or_else(|| {
ObservabilityError::TracingInit(
"OTLP endpoint required when tracing is enabled".to_string(),
)
})?;
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_endpoint(endpoint)
.build()
.map_err(|e| ObservabilityError::TracingInit(e.to_string()))?;
let sampler = if config.sampling_ratio >= 1.0 {
Sampler::AlwaysOn
} else if config.sampling_ratio <= 0.0 {
Sampler::AlwaysOff
} else {
Sampler::TraceIdRatioBased(config.sampling_ratio)
};
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.with_sampler(sampler)
.with_id_generator(RandomIdGenerator::default())
.with_resource(
Resource::builder_empty()
.with_service_name(config.service_name.clone())
.build(),
)
.build();
let tracer = provider.tracer("zlayer");
global::set_tracer_provider(provider);
Ok(Some(OpenTelemetryLayer::new(tracer)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_disabled_tracing() {
let config = TracingConfig {
enabled: false,
..Default::default()
};
let guard = init_tracing(&config).unwrap();
assert!(guard.provider.is_none());
}
#[test]
fn test_enabled_without_endpoint_fails() {
let config = TracingConfig {
enabled: true,
otlp_endpoint: None,
..Default::default()
};
let result = init_tracing(&config);
assert!(result.is_err());
}
#[test]
fn test_create_layer_disabled() {
let config = TracingConfig {
enabled: false,
..Default::default()
};
let layer = create_otel_layer(&config).unwrap();
assert!(layer.is_none());
}
#[test]
fn test_create_layer_without_endpoint_fails() {
let config = TracingConfig {
enabled: true,
otlp_endpoint: None,
..Default::default()
};
let result = create_otel_layer(&config);
assert!(result.is_err());
}
}