use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ObservabilityConfig {
pub enable_tracing: bool,
pub enable_metrics: bool,
pub trace_sample_rate: f64,
pub service_name: String,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
enable_tracing: true,
enable_metrics: true,
trace_sample_rate: 1.0,
service_name: "llmkit".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct RequestSpan {
pub request_id: String,
pub parent_span_id: Option<String>,
pub operation: String,
pub start_time: Instant,
pub metadata: Vec<(String, String)>,
}
impl RequestSpan {
pub fn new(operation: impl Into<String>) -> Self {
Self {
request_id: uuid::Uuid::new_v4().to_string(),
parent_span_id: None,
operation: operation.into(),
start_time: Instant::now(),
metadata: Vec::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.push((key.into(), value.into()));
self
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn elapsed_ms(&self) -> f64 {
self.elapsed().as_secs_f64() * 1000.0
}
}
#[derive(Debug)]
pub struct MetricsRecorder {
total_requests: Arc<AtomicU64>,
total_errors: Arc<AtomicU64>,
total_latency_ms: Arc<AtomicU64>,
config: ObservabilityConfig,
}
impl MetricsRecorder {
pub fn new(config: ObservabilityConfig) -> Self {
Self {
total_requests: Arc::new(AtomicU64::new(0)),
total_errors: Arc::new(AtomicU64::new(0)),
total_latency_ms: Arc::new(AtomicU64::new(0)),
config,
}
}
pub fn record_success(&self, latency_ms: f64) {
if !self.config.enable_metrics {
return;
}
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.total_latency_ms
.fetch_add(latency_ms as u64, Ordering::Relaxed);
}
pub fn record_error(&self, latency_ms: f64) {
if !self.config.enable_metrics {
return;
}
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.total_errors.fetch_add(1, Ordering::Relaxed);
self.total_latency_ms
.fetch_add(latency_ms as u64, Ordering::Relaxed);
}
pub fn snapshot(&self) -> MetricsSnapshot {
let total_requests = self.total_requests.load(Ordering::Acquire);
let total_errors = self.total_errors.load(Ordering::Acquire);
let total_latency_ms = self.total_latency_ms.load(Ordering::Acquire);
let error_rate = if total_requests > 0 {
total_errors as f64 / total_requests as f64
} else {
0.0
};
let avg_latency_ms = if total_requests > 0 {
total_latency_ms as f64 / total_requests as f64
} else {
0.0
};
MetricsSnapshot {
total_requests,
total_errors,
error_rate,
average_latency_ms: avg_latency_ms,
}
}
pub fn reset(&self) {
self.total_requests.store(0, Ordering::Release);
self.total_errors.store(0, Ordering::Release);
self.total_latency_ms.store(0, Ordering::Release);
}
}
impl Clone for MetricsRecorder {
fn clone(&self) -> Self {
Self {
total_requests: Arc::clone(&self.total_requests),
total_errors: Arc::clone(&self.total_errors),
total_latency_ms: Arc::clone(&self.total_latency_ms),
config: self.config.clone(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MetricsSnapshot {
pub total_requests: u64,
pub total_errors: u64,
pub error_rate: f64,
pub average_latency_ms: f64,
}
#[derive(Debug, Clone)]
pub struct TracingContext {
pub trace_id: String,
pub span_id: String,
pub parent_span_id: Option<String>,
pub baggage: Vec<(String, String)>,
}
impl Default for TracingContext {
fn default() -> Self {
Self {
trace_id: uuid::Uuid::new_v4().to_string(),
span_id: uuid::Uuid::new_v4().to_string(),
parent_span_id: None,
baggage: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct Observability {
config: ObservabilityConfig,
metrics: MetricsRecorder,
}
impl Observability {
pub fn new(config: ObservabilityConfig) -> Self {
Self {
metrics: MetricsRecorder::new(config.clone()),
config,
}
}
pub fn start_span(&self, operation: impl Into<String>) -> RequestSpan {
RequestSpan::new(operation)
}
pub fn record_request(&self, span: &RequestSpan, success: bool) {
let latency_ms = span.elapsed_ms();
if success {
self.metrics.record_success(latency_ms);
} else {
self.metrics.record_error(latency_ms);
}
}
pub fn metrics(&self) -> MetricsSnapshot {
self.metrics.snapshot()
}
pub fn reset_metrics(&self) {
self.metrics.reset();
}
pub fn create_context(&self) -> TracingContext {
TracingContext::default()
}
pub fn config(&self) -> &ObservabilityConfig {
&self.config
}
}
impl Default for Observability {
fn default() -> Self {
Self::new(ObservabilityConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_request_span_creation() {
let span = RequestSpan::new("test_operation");
assert_eq!(span.operation, "test_operation");
assert!(!span.request_id.is_empty());
assert!(span.elapsed_ms() >= 0.0);
}
#[test]
fn test_request_span_metadata() {
let span = RequestSpan::new("test")
.with_metadata("provider", "openai")
.with_metadata("model", "gpt-4");
assert_eq!(span.metadata.len(), 2);
assert_eq!(span.metadata[0].0, "provider");
assert_eq!(span.metadata[0].1, "openai");
}
#[test]
fn test_metrics_recorder_success() {
let config = ObservabilityConfig::default();
let recorder = MetricsRecorder::new(config);
recorder.record_success(50.0);
recorder.record_success(75.0);
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 2);
assert_eq!(snapshot.total_errors, 0);
assert_eq!(snapshot.error_rate, 0.0);
assert!((snapshot.average_latency_ms - 62.5).abs() < 0.1);
}
#[test]
fn test_metrics_recorder_errors() {
let config = ObservabilityConfig::default();
let recorder = MetricsRecorder::new(config);
recorder.record_success(50.0);
recorder.record_error(100.0);
recorder.record_error(150.0);
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 3);
assert_eq!(snapshot.total_errors, 2);
assert!((snapshot.error_rate - 2.0 / 3.0).abs() < 0.01);
}
#[test]
fn test_metrics_recorder_disabled() {
let config = ObservabilityConfig {
enable_metrics: false,
..Default::default()
};
let recorder = MetricsRecorder::new(config);
recorder.record_success(50.0);
recorder.record_success(75.0);
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 0);
assert_eq!(snapshot.total_errors, 0);
}
#[test]
fn test_metrics_recorder_clone() {
let config = ObservabilityConfig::default();
let recorder1 = MetricsRecorder::new(config);
let recorder2 = recorder1.clone();
recorder1.record_success(50.0);
let snapshot = recorder2.snapshot();
assert_eq!(snapshot.total_requests, 1);
}
#[test]
fn test_metrics_recorder_reset() {
let config = ObservabilityConfig::default();
let recorder = MetricsRecorder::new(config);
recorder.record_success(50.0);
recorder.record_success(75.0);
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 2);
recorder.reset();
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 0);
}
#[test]
fn test_observability_integration() {
let obs = Observability::default();
let span = obs.start_span("test_operation");
thread::sleep(Duration::from_millis(10));
obs.record_request(&span, true);
let metrics = obs.metrics();
assert_eq!(metrics.total_requests, 1);
assert_eq!(metrics.total_errors, 0);
assert!(metrics.average_latency_ms >= 10.0);
}
#[test]
fn test_tracing_context() {
let ctx = TracingContext::default();
assert!(!ctx.trace_id.is_empty());
assert!(!ctx.span_id.is_empty());
assert!(ctx.parent_span_id.is_none());
assert!(ctx.baggage.is_empty());
}
#[test]
fn test_observability_disabled_tracing() {
let config = ObservabilityConfig {
enable_tracing: false,
..Default::default()
};
let obs = Observability::new(config);
let span = obs.start_span("test");
assert_eq!(span.operation, "test");
}
#[tokio::test]
async fn test_concurrent_metrics() {
let config = ObservabilityConfig::default();
let recorder = MetricsRecorder::new(config);
let mut set = tokio::task::JoinSet::new();
for i in 0..10 {
let rec = recorder.clone();
set.spawn(async move {
for j in 0..10 {
let latency = ((i * 10 + j) as f64) * 1.5;
if (i + j) % 3 == 0 {
rec.record_error(latency);
} else {
rec.record_success(latency);
}
}
});
}
while (set.join_next().await).is_some() {}
let snapshot = recorder.snapshot();
assert_eq!(snapshot.total_requests, 100);
assert!(snapshot.total_errors > 0);
assert!(snapshot.average_latency_ms > 0.0);
}
}