use anyhow::Result;
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry};
mod parse_env;
#[cfg(feature = "distributed-tracing")]
mod rialo_opentelemetry;
#[cfg(feature = "axum-headers")]
pub use opentelemetry::extract_and_set_trace_context_axum;
#[cfg(feature = "reqwest-headers")]
pub use opentelemetry::{apply_trace_headers_to_reqwest, inject_trace_headers};
#[cfg(feature = "env-context")]
pub use opentelemetry::{
extract_and_set_trace_context_env, extract_and_set_trace_context_from_env_map,
inject_trace_env, inject_trace_env_to_cmd,
};
#[cfg(feature = "distributed-tracing")]
pub use rialo_opentelemetry::{
clear_baggage, get_all_baggage, get_baggage, remove_baggage, set_baggage, OtlpConfig, Protocol,
Sampling, DEFAULT_OTLP_ENDPOINT,
};
#[cfg(feature = "prometheus")]
mod prometheus;
#[cfg(feature = "distributed-tracing")]
pub use opentelemetry::Context;
#[cfg(feature = "prometheus")]
pub use prometheus::{PrometheusConfig, DEFAULT_SPAN_LATENCY_BUCKETS};
#[cfg(feature = "distributed-tracing")]
pub use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::parse_env::parse_bool_env;
pub struct TelemetryHandle {
#[cfg(feature = "distributed-tracing")]
provider: Option<opentelemetry_sdk::trace::SdkTracerProvider>,
#[cfg(not(feature = "distributed-tracing"))]
_marker: std::marker::PhantomData<()>,
}
impl Drop for TelemetryHandle {
fn drop(&mut self) {
if let Err(e) = self.shutdown() {
eprintln!("Error shutting down telemetry: {}", e);
}
}
}
impl TelemetryHandle {
#[cfg(feature = "distributed-tracing")]
pub(crate) fn new(provider: opentelemetry_sdk::trace::SdkTracerProvider) -> Self {
Self {
provider: Some(provider),
}
}
pub(crate) fn empty() -> Self {
Self {
#[cfg(feature = "distributed-tracing")]
provider: None,
#[cfg(not(feature = "distributed-tracing"))]
_marker: std::marker::PhantomData,
}
}
#[allow(unused_mut)]
pub fn shutdown(&mut self) -> Result<()> {
#[cfg(feature = "distributed-tracing")]
{
if let Some(provider) = self.provider.take() {
tracing::debug!("Shutting down SdkTracerProvider");
provider.shutdown()?;
drop(provider);
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TelemetryConfig {
#[cfg(feature = "distributed-tracing")]
pub otlp: Option<rialo_opentelemetry::OtlpConfig>,
#[cfg(feature = "prometheus")]
pub prometheus: Option<prometheus::PrometheusConfig>,
pub log_level: Option<String>,
pub json_log_output: bool,
}
impl Default for TelemetryConfig {
fn default() -> Self {
Self {
#[cfg(feature = "distributed-tracing")]
otlp: None, #[cfg(feature = "prometheus")]
prometheus: None, log_level: Some("info".to_string()),
json_log_output: parse_bool_env("ENABLE_JSON_LOGS", false),
}
}
}
impl TelemetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_log_level(mut self, level: impl Into<String>) -> Self {
self.log_level = Some(level.into());
self
}
pub fn with_json_log_output(mut self, output: bool) -> Self {
self.json_log_output = output;
self
}
#[cfg(feature = "prometheus")]
pub fn with_prometheus_registry(mut self, registry: ::prometheus::Registry) -> Self {
self.prometheus = Some(prometheus::PrometheusConfig::new(registry));
self
}
#[cfg(feature = "prometheus")]
pub fn with_prometheus_config(
mut self,
prometheus_config: prometheus::PrometheusConfig,
) -> Self {
self.prometheus = Some(prometheus_config);
self
}
#[cfg(feature = "distributed-tracing")]
pub fn with_otlp(mut self) -> Self {
self.otlp = Some(rialo_opentelemetry::OtlpConfig::default());
self
}
#[cfg(feature = "distributed-tracing")]
pub fn with_otlp_config(mut self, otlp_config: rialo_opentelemetry::OtlpConfig) -> Self {
self.otlp = Some(otlp_config);
self
}
}
pub async fn init_telemetry(config: TelemetryConfig) -> Result<TelemetryHandle> {
#[cfg(feature = "distributed-tracing")]
let otel_result = if let Some(ref otlp_config) = config.otlp {
rialo_opentelemetry::init_otel(otlp_config).await?
} else {
rialo_opentelemetry::OtelResult {
handle: TelemetryHandle::empty(),
tracer: None,
}
};
#[cfg(not(feature = "distributed-tracing"))]
let otel_result = {
struct NoOtelResult {
handle: TelemetryHandle,
}
NoOtelResult {
handle: TelemetryHandle::empty(),
}
};
#[cfg(feature = "prometheus")]
let span_latency_layer = if let Some(ref prometheus_config) = config.prometheus {
prometheus::init_prometheus(prometheus_config)?
} else {
None
};
let log_level = config.log_level.unwrap_or("info".to_string());
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level));
let registry = Registry::default().with(env_filter);
let enable_console = {
#[cfg(feature = "distributed-tracing")]
{
config.otlp.as_ref().is_none_or(|otlp| otlp.enable_console)
}
#[cfg(not(feature = "distributed-tracing"))]
{
true
}
};
match (
#[cfg(feature = "prometheus")]
span_latency_layer.is_some(),
#[cfg(not(feature = "prometheus"))]
false,
#[cfg(feature = "distributed-tracing")]
otel_result.tracer.is_some(),
#[cfg(not(feature = "distributed-tracing"))]
false,
enable_console,
) {
(true, true, true) => {
#[cfg(all(feature = "prometheus", feature = "distributed-tracing"))]
set_global_subscriber(
registry
.with(span_latency_layer.unwrap())
.with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap()))
.with(create_fmt_layer(config.json_log_output)),
)?;
}
(true, true, false) => {
#[cfg(all(feature = "prometheus", feature = "distributed-tracing"))]
set_global_subscriber(
registry
.with(span_latency_layer.unwrap())
.with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap())),
)?;
}
(true, false, true) => {
#[cfg(feature = "prometheus")]
set_global_subscriber(
registry
.with(span_latency_layer.unwrap())
.with(create_fmt_layer(config.json_log_output)),
)?;
}
(true, false, false) => {
#[cfg(feature = "prometheus")]
set_global_subscriber(registry.with(span_latency_layer.unwrap()))?;
}
(false, true, true) => {
#[cfg(feature = "distributed-tracing")]
set_global_subscriber(
registry
.with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap()))
.with(create_fmt_layer(config.json_log_output)),
)?;
}
(false, true, false) => {
#[cfg(feature = "distributed-tracing")]
set_global_subscriber(
registry
.with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap())),
)?;
}
(false, false, true) => {
set_global_subscriber(registry.with(create_fmt_layer(config.json_log_output)))?;
}
(false, false, false) => {
set_global_subscriber(registry)?;
}
}
let handle = otel_result.handle;
Ok(handle)
}
fn create_fmt_layer<S>(
json_log_output: bool,
) -> Box<dyn tracing_subscriber::Layer<S> + Send + Sync + 'static>
where
S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
{
if json_log_output {
tracing_subscriber::fmt::layer()
.json()
.flatten_event(true)
.with_target(true)
.boxed()
} else {
tracing_subscriber::fmt::layer()
.with_target(true)
.with_thread_ids(true)
.with_line_number(true)
.boxed()
}
}
fn set_global_subscriber<S>(subscriber: S) -> Result<()>
where
S: tracing::Subscriber + Send + Sync + 'static,
{
tracing::subscriber::set_global_default(subscriber)
.map_err(|e| anyhow::anyhow!("Failed to set global subscriber: {}", e))
}
#[cfg(test)]
mod tests {
use std::env;
use serial_test::serial;
use super::*;
async fn init_telemetry_for_test(config: TelemetryConfig) -> Result<TelemetryHandle> {
match init_telemetry(config).await {
Ok(handle) => Ok(handle),
Err(e) => {
if e.to_string()
.contains("global default trace dispatcher has already been set")
{
Ok(TelemetryHandle::empty())
} else {
Err(e)
}
}
}
}
#[test]
fn test_telemetry_config_builder() {
#[cfg(feature = "distributed-tracing")]
{
let config = TelemetryConfig::new().with_otlp();
assert!(config.otlp.is_some());
}
#[cfg(feature = "prometheus")]
{
let registry = ::prometheus::Registry::new();
let config = TelemetryConfig::new().with_prometheus_registry(registry);
assert!(config.prometheus.is_some());
let prometheus_config = config.prometheus.unwrap();
assert_eq!(prometheus_config.span_latency_buckets, 15);
assert!(prometheus_config.enable_span_latency);
}
}
#[tokio::test]
#[serial]
async fn test_init_telemetry_console_only() {
env::remove_var("OTEL_EXPORTER_OTLP_ENDPOINT");
env::remove_var("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT");
let config = TelemetryConfig::new();
let result = init_telemetry_for_test(config).await;
assert!(result.is_ok());
}
#[tokio::test]
#[serial]
#[cfg(feature = "distributed-tracing")]
async fn test_init_telemetry_with_otlp() {
env::remove_var("OTEL_EXPORTER_OTLP_ENDPOINT");
env::remove_var("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT");
let otlp_config = rialo_opentelemetry::OtlpConfig::new()
.with_service_name("test-service")
.with_exporter_endpoint("http://localhost:9999") .with_console_enabled(true);
let config = TelemetryConfig::new().with_otlp_config(otlp_config);
let result = init_telemetry_for_test(config).await;
assert!(result.is_ok());
}
#[tokio::test]
#[serial]
#[cfg(all(feature = "distributed-tracing", feature = "env-context"))]
async fn test_init_telemetry_auto_extracts_env_context() {
env::remove_var("traceparent");
env::remove_var("tracestate");
env::set_var(
"traceparent",
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
);
env::set_var("tracestate", "rojo=00f067aa0ba902b7");
let otlp_config = rialo_opentelemetry::OtlpConfig::new()
.with_service_name("test-auto-extract")
.with_exporter_endpoint("".to_string())
.with_traces_enabled(true);
let config = TelemetryConfig::new().with_otlp_config(otlp_config);
let result = init_telemetry_for_test(config).await;
assert!(result.is_ok());
env::remove_var("traceparent");
env::remove_var("tracestate");
}
}