use std::collections::HashMap;
use std::sync::OnceLock;
use std::time::Duration;
use opentelemetry::KeyValue;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::Resource;
use opentelemetry_sdk::trace::{RandomIdGenerator, Sampler, SdkTracerProvider};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
static TRACER_PROVIDER: OnceLock<SdkTracerProvider> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct TracingConfig {
pub service_name: String,
pub endpoint: String,
pub sampling_ratio: f64,
pub resource_attributes: HashMap<String, String>,
pub export_timeout: Duration,
pub console_output: bool,
}
impl TracingConfig {
#[must_use]
pub fn new(service_name: impl Into<String>) -> Self {
Self {
service_name: service_name.into(),
endpoint: "http://localhost:4317".to_string(),
sampling_ratio: 1.0,
resource_attributes: HashMap::new(),
export_timeout: Duration::from_secs(30),
console_output: false,
}
}
#[must_use]
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
#[must_use]
pub const fn with_sampling_ratio(mut self, ratio: f64) -> Self {
self.sampling_ratio = ratio.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_resource_attribute(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.resource_attributes.insert(key.into(), value.into());
self
}
#[must_use]
pub const fn with_export_timeout(mut self, timeout: Duration) -> Self {
self.export_timeout = timeout;
self
}
#[must_use]
pub const fn with_console_output(mut self, enabled: bool) -> Self {
self.console_output = enabled;
self
}
}
#[derive(Debug, thiserror::Error)]
pub enum TracingError {
#[error("OpenTelemetry trace error: {0}")]
Trace(#[from] opentelemetry_sdk::trace::TraceError),
#[error("OTLP exporter build error: {0}")]
ExporterBuild(#[from] opentelemetry_otlp::ExporterBuildError),
#[error("Tracing already initialized")]
AlreadyInitialized,
#[error("Failed to initialize tracing subscriber")]
SubscriberInit,
}
pub fn init_tracing(service_name: &str, endpoint: &str) -> Result<(), TracingError> {
let config = TracingConfig::new(service_name).with_endpoint(endpoint);
init_tracing_with_config(&config)
}
pub fn init_tracing_with_config(config: &TracingConfig) -> Result<(), TracingError> {
let mut attributes = vec![KeyValue::new("service.name", config.service_name.clone())];
for (key, value) in &config.resource_attributes {
attributes.push(KeyValue::new(key.clone(), value.clone()));
}
let resource = Resource::builder().with_attributes(attributes).build();
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_endpoint(&config.endpoint)
.with_timeout(config.export_timeout)
.build()?;
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.with_sampler(Sampler::TraceIdRatioBased(config.sampling_ratio))
.with_id_generator(RandomIdGenerator::default())
.with_resource(resource)
.build();
if TRACER_PROVIDER.set(provider.clone()).is_err() {
return Err(TracingError::AlreadyInitialized);
}
let tracer = provider.tracer("rust-expect");
let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer);
if config.console_output {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(true)
.with_level(true);
tracing_subscriber::registry()
.with(otel_layer)
.with(fmt_layer)
.try_init()
.map_err(|_| TracingError::SubscriberInit)?;
} else {
tracing_subscriber::registry()
.with(otel_layer)
.try_init()
.map_err(|_| TracingError::SubscriberInit)?;
}
Ok(())
}
pub fn shutdown_tracing() {
if let Some(provider) = TRACER_PROVIDER.get() {
let _ = provider.force_flush();
}
}
pub fn session_span(session_id: &str, command: &str, pid: u32) -> tracing::span::EnteredSpan {
tracing::info_span!(
"session",
session.id = session_id,
session.command = command,
session.pid = pid,
otel.kind = "client"
)
.entered()
}
pub fn expect_span(description: &str, pattern: &str) -> tracing::span::EnteredSpan {
tracing::info_span!(
"expect",
expect.description = description,
expect.pattern = pattern,
otel.kind = "internal"
)
.entered()
}
pub fn send_span(data: &str) -> tracing::span::EnteredSpan {
let display_data = if data.len() > 50 {
format!("{}...", &data[..47])
} else {
data.to_string()
};
tracing::info_span!(
"send",
send.data = display_data.as_str(),
send.bytes = data.len(),
otel.kind = "internal"
)
.entered()
}
pub fn dialog_span(dialog_name: &str, step_count: usize) -> tracing::span::EnteredSpan {
tracing::info_span!(
"dialog",
dialog.name = dialog_name,
dialog.steps = step_count,
otel.kind = "internal"
)
.entered()
}
pub fn transcript_span(session_id: &str, format: &str) -> tracing::span::EnteredSpan {
tracing::info_span!(
"transcript",
transcript.session = session_id,
transcript.format = format,
otel.kind = "internal"
)
.entered()
}
pub fn record_error(error: &dyn std::error::Error) {
tracing::error!(
exception.type_ = std::any::type_name_of_val(error),
exception.message = %error,
);
}
pub fn record_match(pattern: &str, matched_text: &str, duration_ms: u64) {
tracing::info!(
match.pattern = pattern,
match.text = matched_text,
match.duration_ms = duration_ms,
);
}
pub fn record_bytes(sent: u64, received: u64) {
tracing::info!(bytes.sent = sent, bytes.received = received,);
}
pub mod attributes {
pub const SESSION_ID: &str = "session.id";
pub const SESSION_COMMAND: &str = "session.command";
pub const SESSION_PID: &str = "session.pid";
pub const EXPECT_PATTERN: &str = "expect.pattern";
pub const EXPECT_TIMEOUT_MS: &str = "expect.timeout_ms";
pub const SEND_DATA: &str = "send.data";
pub const SEND_BYTES: &str = "send.bytes";
pub const MATCH_TEXT: &str = "match.text";
pub const MATCH_DURATION_MS: &str = "match.duration_ms";
pub const ERROR_TYPE: &str = "error.type";
pub const ERROR_MESSAGE: &str = "error.message";
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracing_config_default() {
let config = TracingConfig::new("test-service");
assert_eq!(config.service_name, "test-service");
assert_eq!(config.endpoint, "http://localhost:4317");
assert!((config.sampling_ratio - 1.0).abs() < 0.001);
}
#[test]
fn tracing_config_builder() {
let config = TracingConfig::new("test")
.with_endpoint("http://custom:4317")
.with_sampling_ratio(0.5)
.with_resource_attribute("env", "test")
.with_console_output(true);
assert_eq!(config.endpoint, "http://custom:4317");
assert!((config.sampling_ratio - 0.5).abs() < 0.001);
assert_eq!(
config.resource_attributes.get("env"),
Some(&"test".to_string())
);
assert!(config.console_output);
}
#[test]
fn sampling_ratio_clamped() {
let config = TracingConfig::new("test").with_sampling_ratio(2.0);
assert!((config.sampling_ratio - 1.0).abs() < 0.001);
let config = TracingConfig::new("test").with_sampling_ratio(-0.5);
assert!(config.sampling_ratio.abs() < 0.001);
}
}