1use anyhow::Result;
12use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry};
13mod parse_env;
14
15#[cfg(feature = "distributed-tracing")]
16mod rialo_opentelemetry;
17#[cfg(feature = "distributed-tracing")]
18pub use rialo_opentelemetry::{
19 apply_trace_headers_to_reqwest, extract_and_set_trace_context_axum,
20 extract_and_set_trace_context_env, extract_and_set_trace_context_from_env_map, get_all_baggage,
21 get_baggage, inject_trace_env, inject_trace_env_to_cmd, inject_trace_headers, OtlpConfig,
22 Protocol, Sampling, DEFAULT_OTLP_ENDPOINT,
23};
24
25#[cfg(feature = "prometheus")]
26mod prometheus;
27
28#[cfg(feature = "distributed-tracing")]
29pub use opentelemetry::Context;
30#[cfg(feature = "prometheus")]
31pub use prometheus::{PrometheusConfig, DEFAULT_SPAN_LATENCY_BUCKETS};
32#[cfg(feature = "distributed-tracing")]
33pub use tracing_opentelemetry::OpenTelemetrySpanExt;
34
35use crate::parse_env::parse_bool_env;
36
37pub struct TelemetryHandle {
39 #[cfg(feature = "distributed-tracing")]
40 provider: Option<opentelemetry_sdk::trace::SdkTracerProvider>,
41 #[cfg(not(feature = "distributed-tracing"))]
42 _marker: std::marker::PhantomData<()>,
43}
44
45impl Drop for TelemetryHandle {
46 fn drop(&mut self) {
47 if let Err(e) = self.shutdown() {
48 eprintln!("Error shutting down telemetry: {}", e);
49 }
50 }
51}
52
53impl TelemetryHandle {
54 #[cfg(feature = "distributed-tracing")]
56 pub(crate) fn new(provider: opentelemetry_sdk::trace::SdkTracerProvider) -> Self {
57 Self {
58 provider: Some(provider),
59 }
60 }
61
62 pub(crate) fn empty() -> Self {
64 Self {
65 #[cfg(feature = "distributed-tracing")]
66 provider: None,
67 #[cfg(not(feature = "distributed-tracing"))]
68 _marker: std::marker::PhantomData,
69 }
70 }
71
72 #[allow(unused_mut)]
74 pub fn shutdown(&mut self) -> Result<()> {
75 #[cfg(feature = "distributed-tracing")]
76 {
77 if let Some(provider) = self.provider.take() {
78 tracing::debug!("Shutting down SdkTracerProvider");
79 provider.shutdown()?;
80 drop(provider);
81 }
82 }
83 Ok(())
84 }
85}
86
87#[derive(Debug, Clone)]
94pub struct TelemetryConfig {
95 #[cfg(feature = "distributed-tracing")]
97 pub otlp: Option<rialo_opentelemetry::OtlpConfig>,
98 #[cfg(feature = "prometheus")]
100 pub prometheus: Option<prometheus::PrometheusConfig>,
101 pub log_level: Option<String>,
103 pub json_log_output: bool,
104}
105
106impl Default for TelemetryConfig {
107 fn default() -> Self {
108 Self {
109 #[cfg(feature = "distributed-tracing")]
110 otlp: None, #[cfg(feature = "prometheus")]
112 prometheus: None, log_level: Some("info".to_string()),
114 json_log_output: parse_bool_env("ENABLE_JSON_LOGS", false),
115 }
116 }
117}
118
119impl TelemetryConfig {
120 pub fn new() -> Self {
122 Self::default()
123 }
124
125 pub fn with_log_level(mut self, level: impl Into<String>) -> Self {
127 self.log_level = Some(level.into());
128 self
129 }
130
131 pub fn with_json_log_output(mut self, output: bool) -> Self {
132 self.json_log_output = output;
133 self
134 }
135
136 #[cfg(feature = "prometheus")]
138 pub fn with_prometheus_registry(mut self, registry: ::prometheus::Registry) -> Self {
139 self.prometheus = Some(prometheus::PrometheusConfig::new(registry));
140 self
141 }
142
143 #[cfg(feature = "prometheus")]
145 pub fn with_prometheus_config(
146 mut self,
147 prometheus_config: prometheus::PrometheusConfig,
148 ) -> Self {
149 self.prometheus = Some(prometheus_config);
150 self
151 }
152
153 #[cfg(feature = "distributed-tracing")]
155 pub fn with_otlp(mut self) -> Self {
156 self.otlp = Some(rialo_opentelemetry::OtlpConfig::default());
157 self
158 }
159
160 #[cfg(feature = "distributed-tracing")]
162 pub fn with_otlp_config(mut self, otlp_config: rialo_opentelemetry::OtlpConfig) -> Self {
163 self.otlp = Some(otlp_config);
164 self
165 }
166}
167
168pub async fn init_telemetry(config: TelemetryConfig) -> Result<TelemetryHandle> {
197 #[cfg(feature = "distributed-tracing")]
199 let otel_result = if let Some(ref otlp_config) = config.otlp {
200 rialo_opentelemetry::init_otel(otlp_config).await?
201 } else {
202 rialo_opentelemetry::OtelResult {
203 handle: TelemetryHandle::empty(),
204 tracer: None,
205 }
206 };
207
208 #[cfg(not(feature = "distributed-tracing"))]
209 let otel_result = {
210 struct NoOtelResult {
211 handle: TelemetryHandle,
212 }
213 NoOtelResult {
214 handle: TelemetryHandle::empty(),
215 }
216 };
217
218 #[cfg(feature = "prometheus")]
220 let span_latency_layer = if let Some(ref prometheus_config) = config.prometheus {
221 prometheus::init_prometheus(prometheus_config)?
222 } else {
223 None
224 };
225
226 let log_level = config.log_level.unwrap_or("info".to_string());
228
229 let env_filter =
230 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level));
231
232 let registry = Registry::default().with(env_filter);
233
234 let enable_console = {
236 #[cfg(feature = "distributed-tracing")]
237 {
238 config.otlp.as_ref().is_none_or(|otlp| otlp.enable_console)
239 }
240 #[cfg(not(feature = "distributed-tracing"))]
241 {
242 true
243 }
244 };
245
246 match (
248 #[cfg(feature = "prometheus")]
249 span_latency_layer.is_some(),
250 #[cfg(not(feature = "prometheus"))]
251 false,
252 #[cfg(feature = "distributed-tracing")]
253 otel_result.tracer.is_some(),
254 #[cfg(not(feature = "distributed-tracing"))]
255 false,
256 enable_console,
257 ) {
258 (true, true, true) => {
259 #[cfg(all(feature = "prometheus", feature = "distributed-tracing"))]
260 set_global_subscriber(
261 registry
262 .with(span_latency_layer.unwrap())
263 .with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap()))
264 .with(create_fmt_layer(config.json_log_output)),
265 )?;
266 }
267 (true, true, false) => {
268 #[cfg(all(feature = "prometheus", feature = "distributed-tracing"))]
269 set_global_subscriber(
270 registry
271 .with(span_latency_layer.unwrap())
272 .with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap())),
273 )?;
274 }
275 (true, false, true) => {
276 #[cfg(feature = "prometheus")]
277 set_global_subscriber(
278 registry
279 .with(span_latency_layer.unwrap())
280 .with(create_fmt_layer(config.json_log_output)),
281 )?;
282 }
283 (true, false, false) => {
284 #[cfg(feature = "prometheus")]
285 set_global_subscriber(registry.with(span_latency_layer.unwrap()))?;
286 }
287 (false, true, true) => {
288 #[cfg(feature = "distributed-tracing")]
289 set_global_subscriber(
290 registry
291 .with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap()))
292 .with(create_fmt_layer(config.json_log_output)),
293 )?;
294 }
295 (false, true, false) => {
296 #[cfg(feature = "distributed-tracing")]
297 set_global_subscriber(
298 registry
299 .with(tracing_opentelemetry::layer().with_tracer(otel_result.tracer.unwrap())),
300 )?;
301 }
302 (false, false, true) => {
303 set_global_subscriber(registry.with(create_fmt_layer(config.json_log_output)))?;
304 }
305 (false, false, false) => {
306 set_global_subscriber(registry)?;
307 }
308 }
309 let handle = otel_result.handle;
310
311 Ok(handle)
312}
313
314fn create_fmt_layer<S>(
316 json_log_output: bool,
317) -> Box<dyn tracing_subscriber::Layer<S> + Send + Sync + 'static>
318where
319 S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
320{
321 if json_log_output {
322 tracing_subscriber::fmt::layer()
323 .json()
324 .flatten_event(true)
325 .with_target(true)
326 .boxed()
327 } else {
328 tracing_subscriber::fmt::layer()
329 .with_target(true)
330 .with_thread_ids(true)
331 .with_line_number(true)
332 .boxed()
333 }
334}
335
336fn set_global_subscriber<S>(subscriber: S) -> Result<()>
338where
339 S: tracing::Subscriber + Send + Sync + 'static,
340{
341 tracing::subscriber::set_global_default(subscriber)
342 .map_err(|e| anyhow::anyhow!("Failed to set global subscriber: {}", e))
343}
344
345#[cfg(test)]
346mod tests {
347 use std::env;
348
349 use serial_test::serial;
350
351 use super::*;
352
353 async fn init_telemetry_for_test(config: TelemetryConfig) -> Result<TelemetryHandle> {
356 match init_telemetry(config).await {
357 Ok(handle) => Ok(handle),
358 Err(e) => {
359 if e.to_string()
362 .contains("global default trace dispatcher has already been set")
363 {
364 Ok(TelemetryHandle::empty())
366 } else {
367 Err(e)
368 }
369 }
370 }
371 }
372
373 #[test]
374 fn test_telemetry_config_builder() {
375 #[cfg(feature = "distributed-tracing")]
376 {
377 let config = TelemetryConfig::new().with_otlp();
378 assert!(config.otlp.is_some());
379 }
380
381 #[cfg(feature = "prometheus")]
382 {
383 let registry = ::prometheus::Registry::new();
384 let config = TelemetryConfig::new().with_prometheus_registry(registry);
385 assert!(config.prometheus.is_some());
386 let prometheus_config = config.prometheus.unwrap();
387 assert_eq!(prometheus_config.span_latency_buckets, 15);
388 assert!(prometheus_config.enable_span_latency);
389 }
390 }
391
392 #[tokio::test]
393 #[serial]
394 async fn test_init_telemetry_console_only() {
395 env::remove_var("OTEL_EXPORTER_OTLP_ENDPOINT");
397 env::remove_var("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT");
398
399 let config = TelemetryConfig::new();
400
401 let result = init_telemetry_for_test(config).await;
403 assert!(result.is_ok());
404 }
405
406 #[tokio::test]
407 #[serial]
408 #[cfg(feature = "distributed-tracing")]
409 async fn test_init_telemetry_with_otlp() {
410 env::remove_var("OTEL_EXPORTER_OTLP_ENDPOINT");
412 env::remove_var("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT");
413
414 let otlp_config = rialo_opentelemetry::OtlpConfig::new()
415 .with_service_name("test-service")
416 .with_exporter_endpoint("http://localhost:9999") .with_console_enabled(true);
418
419 let config = TelemetryConfig::new().with_otlp_config(otlp_config);
420
421 let result = init_telemetry_for_test(config).await;
423 assert!(result.is_ok());
424 }
425
426 #[tokio::test]
427 #[serial]
428 #[cfg(feature = "distributed-tracing")]
429 async fn test_init_telemetry_auto_extracts_env_context() {
430 env::remove_var("traceparent");
432 env::remove_var("tracestate");
433
434 env::set_var(
436 "traceparent",
437 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
438 );
439 env::set_var("tracestate", "rojo=00f067aa0ba902b7");
440
441 let otlp_config = rialo_opentelemetry::OtlpConfig::new()
442 .with_service_name("test-auto-extract")
443 .with_exporter_endpoint("".to_string())
444 .with_traces_enabled(true);
445
446 let config = TelemetryConfig::new().with_otlp_config(otlp_config);
447
448 let result = init_telemetry_for_test(config).await;
450 assert!(result.is_ok());
451
452 env::remove_var("traceparent");
454 env::remove_var("tracestate");
455 }
456}