oxify_connect_llm/
observability.rs

1//! Observability middleware for LLM providers
2//!
3//! This module provides tracing, logging, and monitoring capabilities for LLM requests.
4
5use crate::{
6    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
7    LlmStream, Result, StreamingLlmProvider,
8};
9use async_trait::async_trait;
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Provider wrapper with tracing and logging
14pub struct ObservableProvider<P> {
15    inner: P,
16    provider_name: String,
17}
18
19impl<P> ObservableProvider<P> {
20    /// Create a new observable provider wrapper
21    pub fn new(inner: P, provider_name: String) -> Self {
22        Self {
23            inner,
24            provider_name,
25        }
26    }
27}
28
29#[async_trait]
30impl<P: LlmProvider> LlmProvider for ObservableProvider<P> {
31    #[tracing::instrument(
32        name = "llm_completion",
33        skip(self, request),
34        fields(
35            provider = %self.provider_name,
36            prompt_length = request.prompt.len(),
37            has_system_prompt = request.system_prompt.is_some(),
38            temperature = ?request.temperature,
39            max_tokens = ?request.max_tokens,
40            num_tools = request.tools.len(),
41            num_images = request.images.len(),
42        )
43    )]
44    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
45        let start = Instant::now();
46
47        tracing::debug!(
48            provider = %self.provider_name,
49            prompt = %request.prompt.chars().take(100).collect::<String>(),
50            "Starting LLM completion request"
51        );
52
53        let result = self.inner.complete(request).await;
54        let duration = start.elapsed();
55
56        match &result {
57            Ok(response) => {
58                tracing::info!(
59                    provider = %self.provider_name,
60                    duration_ms = duration.as_millis(),
61                    model = %response.model,
62                    content_length = response.content.len(),
63                    prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens),
64                    completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens),
65                    total_tokens = response.usage.as_ref().map(|u| u.total_tokens),
66                    num_tool_calls = response.tool_calls.len(),
67                    "LLM completion succeeded"
68                );
69            }
70            Err(e) => {
71                tracing::error!(
72                    provider = %self.provider_name,
73                    duration_ms = duration.as_millis(),
74                    error = %e,
75                    "LLM completion failed"
76                );
77            }
78        }
79
80        result
81    }
82}
83
84#[async_trait]
85impl<P: EmbeddingProvider> EmbeddingProvider for ObservableProvider<P> {
86    #[tracing::instrument(
87        name = "embedding_generation",
88        skip(self, request),
89        fields(
90            provider = %self.provider_name,
91            num_texts = request.texts.len(),
92            model = ?request.model,
93        )
94    )]
95    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
96        let start = Instant::now();
97
98        tracing::debug!(
99            provider = %self.provider_name,
100            num_texts = request.texts.len(),
101            "Starting embedding generation"
102        );
103
104        let result = self.inner.embed(request).await;
105        let duration = start.elapsed();
106
107        match &result {
108            Ok(response) => {
109                tracing::info!(
110                    provider = %self.provider_name,
111                    duration_ms = duration.as_millis(),
112                    model = %response.model,
113                    num_embeddings = response.embeddings.len(),
114                    embedding_dim = response.embeddings.first().map(|e| e.len()),
115                    prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens),
116                    "Embedding generation succeeded"
117                );
118            }
119            Err(e) => {
120                tracing::error!(
121                    provider = %self.provider_name,
122                    duration_ms = duration.as_millis(),
123                    error = %e,
124                    "Embedding generation failed"
125                );
126            }
127        }
128
129        result
130    }
131}
132
133#[async_trait]
134impl<P: StreamingLlmProvider> StreamingLlmProvider for ObservableProvider<P> {
135    #[tracing::instrument(
136        name = "llm_streaming",
137        skip(self, request),
138        fields(
139            provider = %self.provider_name,
140            prompt_length = request.prompt.len(),
141        )
142    )]
143    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
144        let start = Instant::now();
145
146        tracing::debug!(
147            provider = %self.provider_name,
148            "Starting streaming LLM completion"
149        );
150
151        let result = self.inner.complete_stream(request).await;
152
153        match &result {
154            Ok(_) => {
155                let duration = start.elapsed();
156                tracing::info!(
157                    provider = %self.provider_name,
158                    duration_ms = duration.as_millis(),
159                    "Streaming LLM completion started"
160                );
161            }
162            Err(e) => {
163                let duration = start.elapsed();
164                tracing::error!(
165                    provider = %self.provider_name,
166                    duration_ms = duration.as_millis(),
167                    error = %e,
168                    "Streaming LLM completion failed to start"
169                );
170            }
171        }
172
173        result
174    }
175}
176
177/// Metrics collector for LLM operations
178#[derive(Debug, Clone, Default)]
179pub struct Metrics {
180    /// Total number of requests
181    pub total_requests: u64,
182    /// Total number of successful requests
183    pub successful_requests: u64,
184    /// Total number of failed requests
185    pub failed_requests: u64,
186    /// Total tokens used (prompt + completion)
187    pub total_tokens: u64,
188    /// Total cost in USD
189    pub total_cost_usd: f64,
190    /// Total latency in milliseconds
191    pub total_latency_ms: u64,
192}
193
194impl Metrics {
195    /// Create a new metrics collector
196    pub fn new() -> Self {
197        Self::default()
198    }
199
200    /// Get average latency in milliseconds
201    pub fn avg_latency_ms(&self) -> f64 {
202        if self.total_requests == 0 {
203            0.0
204        } else {
205            self.total_latency_ms as f64 / self.total_requests as f64
206        }
207    }
208
209    /// Get success rate (0.0 to 1.0)
210    pub fn success_rate(&self) -> f64 {
211        if self.total_requests == 0 {
212            0.0
213        } else {
214            self.successful_requests as f64 / self.total_requests as f64
215        }
216    }
217
218    /// Get average cost per request in USD
219    pub fn avg_cost_per_request(&self) -> f64 {
220        if self.successful_requests == 0 {
221            0.0
222        } else {
223            self.total_cost_usd / self.successful_requests as f64
224        }
225    }
226
227    /// Export metrics in Prometheus text format
228    ///
229    /// Returns a string containing all metrics in Prometheus exposition format,
230    /// ready to be scraped by a Prometheus server.
231    ///
232    /// # Example
233    /// ```
234    /// use oxify_connect_llm::Metrics;
235    ///
236    /// let metrics = Metrics {
237    ///     total_requests: 100,
238    ///     successful_requests: 95,
239    ///     failed_requests: 5,
240    ///     total_tokens: 50000,
241    ///     total_cost_usd: 2.5,
242    ///     total_latency_ms: 15000,
243    /// };
244    ///
245    /// let prometheus_output = metrics.to_prometheus();
246    /// assert!(prometheus_output.contains("llm_requests_total"));
247    /// ```
248    pub fn to_prometheus(&self) -> String {
249        format!(
250            "# HELP llm_requests_total Total number of LLM requests\n\
251             # TYPE llm_requests_total counter\n\
252             llm_requests_total {}\n\
253             \n\
254             # HELP llm_requests_successful_total Total number of successful LLM requests\n\
255             # TYPE llm_requests_successful_total counter\n\
256             llm_requests_successful_total {}\n\
257             \n\
258             # HELP llm_requests_failed_total Total number of failed LLM requests\n\
259             # TYPE llm_requests_failed_total counter\n\
260             llm_requests_failed_total {}\n\
261             \n\
262             # HELP llm_tokens_total Total number of tokens processed\n\
263             # TYPE llm_tokens_total counter\n\
264             llm_tokens_total {}\n\
265             \n\
266             # HELP llm_cost_usd_total Total cost in USD\n\
267             # TYPE llm_cost_usd_total counter\n\
268             llm_cost_usd_total {}\n\
269             \n\
270             # HELP llm_latency_ms_total Total latency in milliseconds\n\
271             # TYPE llm_latency_ms_total counter\n\
272             llm_latency_ms_total {}\n\
273             \n\
274             # HELP llm_latency_avg_ms Average latency in milliseconds\n\
275             # TYPE llm_latency_avg_ms gauge\n\
276             llm_latency_avg_ms {}\n\
277             \n\
278             # HELP llm_success_rate Success rate (0.0 to 1.0)\n\
279             # TYPE llm_success_rate gauge\n\
280             llm_success_rate {}\n\
281             \n\
282             # HELP llm_cost_avg_per_request_usd Average cost per request in USD\n\
283             # TYPE llm_cost_avg_per_request_usd gauge\n\
284             llm_cost_avg_per_request_usd {}\n",
285            self.total_requests,
286            self.successful_requests,
287            self.failed_requests,
288            self.total_tokens,
289            self.total_cost_usd,
290            self.total_latency_ms,
291            self.avg_latency_ms(),
292            self.success_rate(),
293            self.avg_cost_per_request(),
294        )
295    }
296
297    /// Export metrics with labels in Prometheus text format
298    ///
299    /// # Arguments
300    /// * `provider_name` - Name of the LLM provider (e.g., "openai", "anthropic")
301    /// * `model` - Model name (e.g., "gpt-4", "claude-3-opus")
302    pub fn to_prometheus_with_labels(&self, provider_name: &str, model: &str) -> String {
303        format!(
304            "# HELP llm_requests_total Total number of LLM requests\n\
305             # TYPE llm_requests_total counter\n\
306             llm_requests_total{{provider=\"{}\",model=\"{}\"}} {}\n\
307             \n\
308             # HELP llm_requests_successful_total Total number of successful LLM requests\n\
309             # TYPE llm_requests_successful_total counter\n\
310             llm_requests_successful_total{{provider=\"{}\",model=\"{}\"}} {}\n\
311             \n\
312             # HELP llm_requests_failed_total Total number of failed LLM requests\n\
313             # TYPE llm_requests_failed_total counter\n\
314             llm_requests_failed_total{{provider=\"{}\",model=\"{}\"}} {}\n\
315             \n\
316             # HELP llm_tokens_total Total number of tokens processed\n\
317             # TYPE llm_tokens_total counter\n\
318             llm_tokens_total{{provider=\"{}\",model=\"{}\"}} {}\n\
319             \n\
320             # HELP llm_cost_usd_total Total cost in USD\n\
321             # TYPE llm_cost_usd_total counter\n\
322             llm_cost_usd_total{{provider=\"{}\",model=\"{}\"}} {}\n\
323             \n\
324             # HELP llm_latency_ms_total Total latency in milliseconds\n\
325             # TYPE llm_latency_ms_total counter\n\
326             llm_latency_ms_total{{provider=\"{}\",model=\"{}\"}} {}\n\
327             \n\
328             # HELP llm_latency_avg_ms Average latency in milliseconds\n\
329             # TYPE llm_latency_avg_ms gauge\n\
330             llm_latency_avg_ms{{provider=\"{}\",model=\"{}\"}} {}\n\
331             \n\
332             # HELP llm_success_rate Success rate (0.0 to 1.0)\n\
333             # TYPE llm_success_rate gauge\n\
334             llm_success_rate{{provider=\"{}\",model=\"{}\"}} {}\n\
335             \n\
336             # HELP llm_cost_avg_per_request_usd Average cost per request in USD\n\
337             # TYPE llm_cost_avg_per_request_usd gauge\n\
338             llm_cost_avg_per_request_usd{{provider=\"{}\",model=\"{}\"}} {}\n",
339            provider_name,
340            model,
341            self.total_requests,
342            provider_name,
343            model,
344            self.successful_requests,
345            provider_name,
346            model,
347            self.failed_requests,
348            provider_name,
349            model,
350            self.total_tokens,
351            provider_name,
352            model,
353            self.total_cost_usd,
354            provider_name,
355            model,
356            self.total_latency_ms,
357            provider_name,
358            model,
359            self.avg_latency_ms(),
360            provider_name,
361            model,
362            self.success_rate(),
363            provider_name,
364            model,
365            self.avg_cost_per_request(),
366        )
367    }
368}
369
370/// Provider wrapper with metrics collection
371pub struct MetricsProvider<P> {
372    inner: Arc<P>,
373    metrics: Arc<std::sync::Mutex<Metrics>>,
374}
375
376impl<P> MetricsProvider<P> {
377    /// Create a new metrics provider wrapper
378    pub fn new(inner: P) -> Self {
379        Self {
380            inner: Arc::new(inner),
381            metrics: Arc::new(std::sync::Mutex::new(Metrics::new())),
382        }
383    }
384
385    /// Get current metrics snapshot
386    pub fn get_metrics(&self) -> Metrics {
387        self.metrics.lock().unwrap().clone()
388    }
389
390    /// Reset metrics
391    pub fn reset_metrics(&self) {
392        let mut metrics = self.metrics.lock().unwrap();
393        *metrics = Metrics::new();
394    }
395}
396
397#[async_trait]
398impl<P: LlmProvider> LlmProvider for MetricsProvider<P> {
399    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
400        let start = Instant::now();
401        let result = self.inner.complete(request).await;
402        let duration = start.elapsed();
403
404        let mut metrics = self.metrics.lock().unwrap();
405        metrics.total_requests += 1;
406        metrics.total_latency_ms += duration.as_millis() as u64;
407
408        match &result {
409            Ok(response) => {
410                metrics.successful_requests += 1;
411                if let Some(usage) = &response.usage {
412                    metrics.total_tokens += usage.total_tokens as u64;
413                }
414            }
415            Err(_) => {
416                metrics.failed_requests += 1;
417            }
418        }
419
420        result
421    }
422}
423
424#[async_trait]
425impl<P: EmbeddingProvider> EmbeddingProvider for MetricsProvider<P> {
426    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
427        let start = Instant::now();
428        let result = self.inner.embed(request).await;
429        let duration = start.elapsed();
430
431        let mut metrics = self.metrics.lock().unwrap();
432        metrics.total_requests += 1;
433        metrics.total_latency_ms += duration.as_millis() as u64;
434
435        match &result {
436            Ok(response) => {
437                metrics.successful_requests += 1;
438                if let Some(usage) = &response.usage {
439                    metrics.total_tokens += usage.total_tokens as u64;
440                }
441            }
442            Err(_) => {
443                metrics.failed_requests += 1;
444            }
445        }
446
447        result
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_metrics_new() {
457        let metrics = Metrics::new();
458        assert_eq!(metrics.total_requests, 0);
459        assert_eq!(metrics.successful_requests, 0);
460        assert_eq!(metrics.failed_requests, 0);
461    }
462
463    #[test]
464    fn test_metrics_avg_latency() {
465        let mut metrics = Metrics::new();
466        metrics.total_requests = 5;
467        metrics.total_latency_ms = 1000;
468        assert_eq!(metrics.avg_latency_ms(), 200.0);
469    }
470
471    #[test]
472    fn test_metrics_success_rate() {
473        let mut metrics = Metrics::new();
474        metrics.total_requests = 10;
475        metrics.successful_requests = 8;
476        assert_eq!(metrics.success_rate(), 0.8);
477    }
478
479    #[test]
480    fn test_metrics_avg_cost() {
481        let mut metrics = Metrics::new();
482        metrics.successful_requests = 4;
483        metrics.total_cost_usd = 2.0;
484        assert_eq!(metrics.avg_cost_per_request(), 0.5);
485    }
486
487    #[test]
488    fn test_metrics_zero_division() {
489        let metrics = Metrics::new();
490        assert_eq!(metrics.avg_latency_ms(), 0.0);
491        assert_eq!(metrics.success_rate(), 0.0);
492        assert_eq!(metrics.avg_cost_per_request(), 0.0);
493    }
494
495    #[test]
496    fn test_prometheus_export() {
497        let metrics = Metrics {
498            total_requests: 100,
499            successful_requests: 95,
500            failed_requests: 5,
501            total_tokens: 50000,
502            total_cost_usd: 2.5,
503            total_latency_ms: 15000,
504        };
505
506        let prometheus = metrics.to_prometheus();
507
508        // Check that all expected metrics are present
509        assert!(prometheus.contains("llm_requests_total 100"));
510        assert!(prometheus.contains("llm_requests_successful_total 95"));
511        assert!(prometheus.contains("llm_requests_failed_total 5"));
512        assert!(prometheus.contains("llm_tokens_total 50000"));
513        assert!(prometheus.contains("llm_cost_usd_total 2.5"));
514        assert!(prometheus.contains("llm_latency_ms_total 15000"));
515
516        // Check that calculated metrics are present
517        assert!(prometheus.contains("llm_latency_avg_ms 150"));
518        assert!(prometheus.contains("llm_success_rate 0.95"));
519
520        // Check that HELP and TYPE annotations are present
521        assert!(prometheus.contains("# HELP llm_requests_total"));
522        assert!(prometheus.contains("# TYPE llm_requests_total counter"));
523    }
524
525    #[test]
526    fn test_prometheus_export_with_labels() {
527        let metrics = Metrics {
528            total_requests: 50,
529            successful_requests: 48,
530            failed_requests: 2,
531            total_tokens: 25000,
532            total_cost_usd: 1.25,
533            total_latency_ms: 7500,
534        };
535
536        let prometheus = metrics.to_prometheus_with_labels("openai", "gpt-4");
537
538        // Check that labels are present
539        assert!(prometheus.contains("llm_requests_total{provider=\"openai\",model=\"gpt-4\"} 50"));
540        assert!(prometheus
541            .contains("llm_requests_successful_total{provider=\"openai\",model=\"gpt-4\"} 48"));
542        assert!(
543            prometheus.contains("llm_requests_failed_total{provider=\"openai\",model=\"gpt-4\"} 2")
544        );
545        assert!(prometheus.contains("llm_tokens_total{provider=\"openai\",model=\"gpt-4\"} 25000"));
546
547        // Check that HELP and TYPE annotations are still present
548        assert!(prometheus.contains("# HELP llm_requests_total"));
549        assert!(prometheus.contains("# TYPE llm_requests_total counter"));
550    }
551
552    #[test]
553    fn test_prometheus_export_empty_metrics() {
554        let metrics = Metrics::new();
555        let prometheus = metrics.to_prometheus();
556
557        // Should contain all metric names with zero values
558        assert!(prometheus.contains("llm_requests_total 0"));
559        assert!(prometheus.contains("llm_latency_avg_ms 0"));
560        assert!(prometheus.contains("llm_success_rate 0"));
561    }
562}