use std::str::FromStr;
use std::sync::OnceLock;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_otlp::WithTonicConfig;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::SdkTracerProvider;
use tracing::subscriber::Subscriber;
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::layer::{Filter, SubscriberExt};
use tracing_subscriber::{Layer, Registry};
use crate::tracing::OpenTelemetrySpanExt;
static TRACER_PROVIDER: OnceLock<SdkTracerProvider> = OnceLock::new();
#[derive(Clone, Copy)]
pub enum OpenTelemetry {
Enabled,
Disabled,
}
impl OpenTelemetry {
fn is_enabled(self) -> bool {
matches!(self, OpenTelemetry::Enabled)
}
}
pub struct OtelGuard {
tracer_provider: SdkTracerProvider,
}
impl Drop for OtelGuard {
fn drop(&mut self) {
if let Err(err) = self.tracer_provider.shutdown() {
eprintln!("{err:?}");
}
}
}
pub fn setup_tracing(otel: OpenTelemetry) -> anyhow::Result<Option<OtelGuard>> {
if otel.is_enabled() {
opentelemetry::global::set_text_map_propagator(TraceContextPropagator::new());
}
let tracer_provider = if otel.is_enabled() {
let provider = init_tracer_provider()?;
TRACER_PROVIDER
.set(provider.clone())
.expect("setup_tracing should only be called once");
Some(provider)
} else {
None
};
let otel_layer = tracer_provider.as_ref().map(|provider| {
OpenTelemetryLayer::new(provider.tracer("tracing-otel-subscriber")).boxed()
});
let subscriber = Registry::default()
.with(stdout_layer().with_filter(env_or_default_filter()))
.with(otel_layer.with_filter(env_or_default_filter()));
tracing::subscriber::set_global_default(subscriber).map_err(Into::<anyhow::Error>::into)?;
let default_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
tracing::error!(panic = true, info = %info, "panic");
let info_str = info.to_string();
let wrapped = anyhow::Error::msg(info_str);
tracing::Span::current().set_error(wrapped.as_ref());
if let Some(provider) = TRACER_PROVIDER.get() {
if let Err(err) = provider.force_flush() {
eprintln!("Failed to flush traces on panic: {err:?}");
}
}
default_hook(info);
}));
Ok(tracer_provider.map(|tracer_provider| OtelGuard { tracer_provider }))
}
fn init_tracer_provider() -> anyhow::Result<SdkTracerProvider> {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_tls_config(tonic::transport::ClientTlsConfig::new().with_native_roots())
.build()?;
Ok(opentelemetry_sdk::trace::SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.build())
}
#[cfg(feature = "testing")]
pub fn setup_test_tracing() -> anyhow::Result<(
tokio::sync::mpsc::UnboundedReceiver<opentelemetry_sdk::trace::SpanData>,
tokio::sync::mpsc::UnboundedReceiver<()>,
)> {
let (exporter, rx_export, rx_shutdown) =
opentelemetry_sdk::testing::trace::new_tokio_test_exporter();
let tracer_provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.build();
let otel_layer =
OpenTelemetryLayer::new(tracer_provider.tracer("tracing-otel-subscriber")).boxed();
let subscriber = Registry::default()
.with(stdout_layer().with_filter(env_or_default_filter()))
.with(otel_layer.with_filter(env_or_default_filter()));
tracing::subscriber::set_global_default(subscriber)?;
Ok((rx_export, rx_shutdown))
}
#[cfg(not(feature = "tracing-forest"))]
fn stdout_layer<S>() -> Box<dyn tracing_subscriber::Layer<S> + Send + Sync + 'static>
where
S: Subscriber,
for<'a> S: tracing_subscriber::registry::LookupSpan<'a>,
{
use tracing_subscriber::fmt::format::FmtSpan;
tracing_subscriber::fmt::layer()
.pretty()
.compact()
.with_level(true)
.with_file(true)
.with_line_number(true)
.with_target(true)
.with_span_events(FmtSpan::CLOSE)
.boxed()
}
#[cfg(feature = "tracing-forest")]
fn stdout_layer<S>() -> Box<dyn tracing_subscriber::Layer<S> + Send + Sync + 'static>
where
S: Subscriber,
for<'a> S: tracing_subscriber::registry::LookupSpan<'a>,
{
tracing_forest::ForestLayer::default().boxed()
}
fn env_or_default_filter<S>() -> Box<dyn Filter<S> + Send + Sync + 'static> {
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::{FilterExt, Targets};
match std::env::var(EnvFilter::DEFAULT_ENV) {
Ok(rust_log) => FilterExt::boxed(
EnvFilter::from_str(&rust_log)
.expect("RUST_LOG should contain a valid filter configuration"),
),
Err(std::env::VarError::NotUnicode(_)) => panic!("RUST_LOG contained non-unicode"),
Err(std::env::VarError::NotPresent) => {
FilterExt::boxed(
Targets::new()
.with_default(LevelFilter::INFO)
.with_target("axum::rejection", LevelFilter::TRACE),
)
},
}
}