1use crate::{
6 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
7 LlmStream, Result, StreamingLlmProvider,
8};
9use async_trait::async_trait;
10use std::sync::Arc;
11use std::time::Instant;
12
13pub struct ObservableProvider<P> {
15 inner: P,
16 provider_name: String,
17}
18
19impl<P> ObservableProvider<P> {
20 pub fn new(inner: P, provider_name: String) -> Self {
22 Self {
23 inner,
24 provider_name,
25 }
26 }
27}
28
29#[async_trait]
30impl<P: LlmProvider> LlmProvider for ObservableProvider<P> {
31 #[tracing::instrument(
32 name = "llm_completion",
33 skip(self, request),
34 fields(
35 provider = %self.provider_name,
36 prompt_length = request.prompt.len(),
37 has_system_prompt = request.system_prompt.is_some(),
38 temperature = ?request.temperature,
39 max_tokens = ?request.max_tokens,
40 num_tools = request.tools.len(),
41 num_images = request.images.len(),
42 )
43 )]
44 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
45 let start = Instant::now();
46
47 tracing::debug!(
48 provider = %self.provider_name,
49 prompt = %request.prompt.chars().take(100).collect::<String>(),
50 "Starting LLM completion request"
51 );
52
53 let result = self.inner.complete(request).await;
54 let duration = start.elapsed();
55
56 match &result {
57 Ok(response) => {
58 tracing::info!(
59 provider = %self.provider_name,
60 duration_ms = duration.as_millis(),
61 model = %response.model,
62 content_length = response.content.len(),
63 prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens),
64 completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens),
65 total_tokens = response.usage.as_ref().map(|u| u.total_tokens),
66 num_tool_calls = response.tool_calls.len(),
67 "LLM completion succeeded"
68 );
69 }
70 Err(e) => {
71 tracing::error!(
72 provider = %self.provider_name,
73 duration_ms = duration.as_millis(),
74 error = %e,
75 "LLM completion failed"
76 );
77 }
78 }
79
80 result
81 }
82}
83
84#[async_trait]
85impl<P: EmbeddingProvider> EmbeddingProvider for ObservableProvider<P> {
86 #[tracing::instrument(
87 name = "embedding_generation",
88 skip(self, request),
89 fields(
90 provider = %self.provider_name,
91 num_texts = request.texts.len(),
92 model = ?request.model,
93 )
94 )]
95 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
96 let start = Instant::now();
97
98 tracing::debug!(
99 provider = %self.provider_name,
100 num_texts = request.texts.len(),
101 "Starting embedding generation"
102 );
103
104 let result = self.inner.embed(request).await;
105 let duration = start.elapsed();
106
107 match &result {
108 Ok(response) => {
109 tracing::info!(
110 provider = %self.provider_name,
111 duration_ms = duration.as_millis(),
112 model = %response.model,
113 num_embeddings = response.embeddings.len(),
114 embedding_dim = response.embeddings.first().map(|e| e.len()),
115 prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens),
116 "Embedding generation succeeded"
117 );
118 }
119 Err(e) => {
120 tracing::error!(
121 provider = %self.provider_name,
122 duration_ms = duration.as_millis(),
123 error = %e,
124 "Embedding generation failed"
125 );
126 }
127 }
128
129 result
130 }
131}
132
133#[async_trait]
134impl<P: StreamingLlmProvider> StreamingLlmProvider for ObservableProvider<P> {
135 #[tracing::instrument(
136 name = "llm_streaming",
137 skip(self, request),
138 fields(
139 provider = %self.provider_name,
140 prompt_length = request.prompt.len(),
141 )
142 )]
143 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
144 let start = Instant::now();
145
146 tracing::debug!(
147 provider = %self.provider_name,
148 "Starting streaming LLM completion"
149 );
150
151 let result = self.inner.complete_stream(request).await;
152
153 match &result {
154 Ok(_) => {
155 let duration = start.elapsed();
156 tracing::info!(
157 provider = %self.provider_name,
158 duration_ms = duration.as_millis(),
159 "Streaming LLM completion started"
160 );
161 }
162 Err(e) => {
163 let duration = start.elapsed();
164 tracing::error!(
165 provider = %self.provider_name,
166 duration_ms = duration.as_millis(),
167 error = %e,
168 "Streaming LLM completion failed to start"
169 );
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, Default)]
179pub struct Metrics {
180 pub total_requests: u64,
182 pub successful_requests: u64,
184 pub failed_requests: u64,
186 pub total_tokens: u64,
188 pub total_cost_usd: f64,
190 pub total_latency_ms: u64,
192}
193
194impl Metrics {
195 pub fn new() -> Self {
197 Self::default()
198 }
199
200 pub fn avg_latency_ms(&self) -> f64 {
202 if self.total_requests == 0 {
203 0.0
204 } else {
205 self.total_latency_ms as f64 / self.total_requests as f64
206 }
207 }
208
209 pub fn success_rate(&self) -> f64 {
211 if self.total_requests == 0 {
212 0.0
213 } else {
214 self.successful_requests as f64 / self.total_requests as f64
215 }
216 }
217
218 pub fn avg_cost_per_request(&self) -> f64 {
220 if self.successful_requests == 0 {
221 0.0
222 } else {
223 self.total_cost_usd / self.successful_requests as f64
224 }
225 }
226
227 pub fn to_prometheus(&self) -> String {
249 format!(
250 "# HELP llm_requests_total Total number of LLM requests\n\
251 # TYPE llm_requests_total counter\n\
252 llm_requests_total {}\n\
253 \n\
254 # HELP llm_requests_successful_total Total number of successful LLM requests\n\
255 # TYPE llm_requests_successful_total counter\n\
256 llm_requests_successful_total {}\n\
257 \n\
258 # HELP llm_requests_failed_total Total number of failed LLM requests\n\
259 # TYPE llm_requests_failed_total counter\n\
260 llm_requests_failed_total {}\n\
261 \n\
262 # HELP llm_tokens_total Total number of tokens processed\n\
263 # TYPE llm_tokens_total counter\n\
264 llm_tokens_total {}\n\
265 \n\
266 # HELP llm_cost_usd_total Total cost in USD\n\
267 # TYPE llm_cost_usd_total counter\n\
268 llm_cost_usd_total {}\n\
269 \n\
270 # HELP llm_latency_ms_total Total latency in milliseconds\n\
271 # TYPE llm_latency_ms_total counter\n\
272 llm_latency_ms_total {}\n\
273 \n\
274 # HELP llm_latency_avg_ms Average latency in milliseconds\n\
275 # TYPE llm_latency_avg_ms gauge\n\
276 llm_latency_avg_ms {}\n\
277 \n\
278 # HELP llm_success_rate Success rate (0.0 to 1.0)\n\
279 # TYPE llm_success_rate gauge\n\
280 llm_success_rate {}\n\
281 \n\
282 # HELP llm_cost_avg_per_request_usd Average cost per request in USD\n\
283 # TYPE llm_cost_avg_per_request_usd gauge\n\
284 llm_cost_avg_per_request_usd {}\n",
285 self.total_requests,
286 self.successful_requests,
287 self.failed_requests,
288 self.total_tokens,
289 self.total_cost_usd,
290 self.total_latency_ms,
291 self.avg_latency_ms(),
292 self.success_rate(),
293 self.avg_cost_per_request(),
294 )
295 }
296
297 pub fn to_prometheus_with_labels(&self, provider_name: &str, model: &str) -> String {
303 format!(
304 "# HELP llm_requests_total Total number of LLM requests\n\
305 # TYPE llm_requests_total counter\n\
306 llm_requests_total{{provider=\"{}\",model=\"{}\"}} {}\n\
307 \n\
308 # HELP llm_requests_successful_total Total number of successful LLM requests\n\
309 # TYPE llm_requests_successful_total counter\n\
310 llm_requests_successful_total{{provider=\"{}\",model=\"{}\"}} {}\n\
311 \n\
312 # HELP llm_requests_failed_total Total number of failed LLM requests\n\
313 # TYPE llm_requests_failed_total counter\n\
314 llm_requests_failed_total{{provider=\"{}\",model=\"{}\"}} {}\n\
315 \n\
316 # HELP llm_tokens_total Total number of tokens processed\n\
317 # TYPE llm_tokens_total counter\n\
318 llm_tokens_total{{provider=\"{}\",model=\"{}\"}} {}\n\
319 \n\
320 # HELP llm_cost_usd_total Total cost in USD\n\
321 # TYPE llm_cost_usd_total counter\n\
322 llm_cost_usd_total{{provider=\"{}\",model=\"{}\"}} {}\n\
323 \n\
324 # HELP llm_latency_ms_total Total latency in milliseconds\n\
325 # TYPE llm_latency_ms_total counter\n\
326 llm_latency_ms_total{{provider=\"{}\",model=\"{}\"}} {}\n\
327 \n\
328 # HELP llm_latency_avg_ms Average latency in milliseconds\n\
329 # TYPE llm_latency_avg_ms gauge\n\
330 llm_latency_avg_ms{{provider=\"{}\",model=\"{}\"}} {}\n\
331 \n\
332 # HELP llm_success_rate Success rate (0.0 to 1.0)\n\
333 # TYPE llm_success_rate gauge\n\
334 llm_success_rate{{provider=\"{}\",model=\"{}\"}} {}\n\
335 \n\
336 # HELP llm_cost_avg_per_request_usd Average cost per request in USD\n\
337 # TYPE llm_cost_avg_per_request_usd gauge\n\
338 llm_cost_avg_per_request_usd{{provider=\"{}\",model=\"{}\"}} {}\n",
339 provider_name,
340 model,
341 self.total_requests,
342 provider_name,
343 model,
344 self.successful_requests,
345 provider_name,
346 model,
347 self.failed_requests,
348 provider_name,
349 model,
350 self.total_tokens,
351 provider_name,
352 model,
353 self.total_cost_usd,
354 provider_name,
355 model,
356 self.total_latency_ms,
357 provider_name,
358 model,
359 self.avg_latency_ms(),
360 provider_name,
361 model,
362 self.success_rate(),
363 provider_name,
364 model,
365 self.avg_cost_per_request(),
366 )
367 }
368}
369
370pub struct MetricsProvider<P> {
372 inner: Arc<P>,
373 metrics: Arc<std::sync::Mutex<Metrics>>,
374}
375
376impl<P> MetricsProvider<P> {
377 pub fn new(inner: P) -> Self {
379 Self {
380 inner: Arc::new(inner),
381 metrics: Arc::new(std::sync::Mutex::new(Metrics::new())),
382 }
383 }
384
385 pub fn get_metrics(&self) -> Metrics {
387 self.metrics.lock().unwrap().clone()
388 }
389
390 pub fn reset_metrics(&self) {
392 let mut metrics = self.metrics.lock().unwrap();
393 *metrics = Metrics::new();
394 }
395}
396
397#[async_trait]
398impl<P: LlmProvider> LlmProvider for MetricsProvider<P> {
399 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
400 let start = Instant::now();
401 let result = self.inner.complete(request).await;
402 let duration = start.elapsed();
403
404 let mut metrics = self.metrics.lock().unwrap();
405 metrics.total_requests += 1;
406 metrics.total_latency_ms += duration.as_millis() as u64;
407
408 match &result {
409 Ok(response) => {
410 metrics.successful_requests += 1;
411 if let Some(usage) = &response.usage {
412 metrics.total_tokens += usage.total_tokens as u64;
413 }
414 }
415 Err(_) => {
416 metrics.failed_requests += 1;
417 }
418 }
419
420 result
421 }
422}
423
424#[async_trait]
425impl<P: EmbeddingProvider> EmbeddingProvider for MetricsProvider<P> {
426 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
427 let start = Instant::now();
428 let result = self.inner.embed(request).await;
429 let duration = start.elapsed();
430
431 let mut metrics = self.metrics.lock().unwrap();
432 metrics.total_requests += 1;
433 metrics.total_latency_ms += duration.as_millis() as u64;
434
435 match &result {
436 Ok(response) => {
437 metrics.successful_requests += 1;
438 if let Some(usage) = &response.usage {
439 metrics.total_tokens += usage.total_tokens as u64;
440 }
441 }
442 Err(_) => {
443 metrics.failed_requests += 1;
444 }
445 }
446
447 result
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_metrics_new() {
457 let metrics = Metrics::new();
458 assert_eq!(metrics.total_requests, 0);
459 assert_eq!(metrics.successful_requests, 0);
460 assert_eq!(metrics.failed_requests, 0);
461 }
462
463 #[test]
464 fn test_metrics_avg_latency() {
465 let mut metrics = Metrics::new();
466 metrics.total_requests = 5;
467 metrics.total_latency_ms = 1000;
468 assert_eq!(metrics.avg_latency_ms(), 200.0);
469 }
470
471 #[test]
472 fn test_metrics_success_rate() {
473 let mut metrics = Metrics::new();
474 metrics.total_requests = 10;
475 metrics.successful_requests = 8;
476 assert_eq!(metrics.success_rate(), 0.8);
477 }
478
479 #[test]
480 fn test_metrics_avg_cost() {
481 let mut metrics = Metrics::new();
482 metrics.successful_requests = 4;
483 metrics.total_cost_usd = 2.0;
484 assert_eq!(metrics.avg_cost_per_request(), 0.5);
485 }
486
487 #[test]
488 fn test_metrics_zero_division() {
489 let metrics = Metrics::new();
490 assert_eq!(metrics.avg_latency_ms(), 0.0);
491 assert_eq!(metrics.success_rate(), 0.0);
492 assert_eq!(metrics.avg_cost_per_request(), 0.0);
493 }
494
495 #[test]
496 fn test_prometheus_export() {
497 let metrics = Metrics {
498 total_requests: 100,
499 successful_requests: 95,
500 failed_requests: 5,
501 total_tokens: 50000,
502 total_cost_usd: 2.5,
503 total_latency_ms: 15000,
504 };
505
506 let prometheus = metrics.to_prometheus();
507
508 assert!(prometheus.contains("llm_requests_total 100"));
510 assert!(prometheus.contains("llm_requests_successful_total 95"));
511 assert!(prometheus.contains("llm_requests_failed_total 5"));
512 assert!(prometheus.contains("llm_tokens_total 50000"));
513 assert!(prometheus.contains("llm_cost_usd_total 2.5"));
514 assert!(prometheus.contains("llm_latency_ms_total 15000"));
515
516 assert!(prometheus.contains("llm_latency_avg_ms 150"));
518 assert!(prometheus.contains("llm_success_rate 0.95"));
519
520 assert!(prometheus.contains("# HELP llm_requests_total"));
522 assert!(prometheus.contains("# TYPE llm_requests_total counter"));
523 }
524
525 #[test]
526 fn test_prometheus_export_with_labels() {
527 let metrics = Metrics {
528 total_requests: 50,
529 successful_requests: 48,
530 failed_requests: 2,
531 total_tokens: 25000,
532 total_cost_usd: 1.25,
533 total_latency_ms: 7500,
534 };
535
536 let prometheus = metrics.to_prometheus_with_labels("openai", "gpt-4");
537
538 assert!(prometheus.contains("llm_requests_total{provider=\"openai\",model=\"gpt-4\"} 50"));
540 assert!(prometheus
541 .contains("llm_requests_successful_total{provider=\"openai\",model=\"gpt-4\"} 48"));
542 assert!(
543 prometheus.contains("llm_requests_failed_total{provider=\"openai\",model=\"gpt-4\"} 2")
544 );
545 assert!(prometheus.contains("llm_tokens_total{provider=\"openai\",model=\"gpt-4\"} 25000"));
546
547 assert!(prometheus.contains("# HELP llm_requests_total"));
549 assert!(prometheus.contains("# TYPE llm_requests_total counter"));
550 }
551
552 #[test]
553 fn test_prometheus_export_empty_metrics() {
554 let metrics = Metrics::new();
555 let prometheus = metrics.to_prometheus();
556
557 assert!(prometheus.contains("llm_requests_total 0"));
559 assert!(prometheus.contains("llm_latency_avg_ms 0"));
560 assert!(prometheus.contains("llm_success_rate 0"));
561 }
562}