Skip to main content

codetether_agent/provider/
metrics.rs

1//! Provider metrics wrapper
2//!
3//! Wraps any `Provider` to automatically record latency, throughput,
4//! and tokens-per-second via the global `PROVIDER_METRICS` registry.
5
6use super::{
7    CompletionRequest, CompletionResponse, EmbeddingRequest, EmbeddingResponse, ModelInfo,
8    Provider, StreamChunk, Usage,
9};
10use crate::telemetry::{PROVIDER_METRICS, ProviderRequestRecord};
11use anyhow::Result;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15/// A provider wrapper that instruments every call with performance metrics.
16pub struct MetricsProvider {
17    inner: Arc<dyn Provider>,
18}
19
20impl MetricsProvider {
21    /// Wrap a provider with automatic metrics collection
22    pub fn wrap(inner: Arc<dyn Provider>) -> Arc<Self> {
23        Arc::new(Self { inner })
24    }
25
26    async fn record_request(&self, model: &str, latency_ms: u64, usage: &Usage, success: bool) {
27        let record = ProviderRequestRecord {
28            provider: self.inner.name().to_string(),
29            model: model.to_string(),
30            timestamp: chrono::Utc::now(),
31            prompt_tokens: usage.prompt_tokens as u64,
32            completion_tokens: usage.completion_tokens as u64,
33            input_tokens: usage.prompt_tokens as u64,
34            output_tokens: usage.completion_tokens as u64,
35            latency_ms,
36            ttft_ms: None, // non-streaming: no TTFT distinction
37            success,
38        };
39
40        tracing::info!(
41            provider = %record.provider,
42            model = %record.model,
43            latency_ms = record.latency_ms,
44            input_tokens = record.input_tokens,
45            output_tokens = record.output_tokens,
46            tps = format!("{:.1}", record.tokens_per_second()),
47            "Provider request completed"
48        );
49
50        PROVIDER_METRICS.record(record).await;
51    }
52}
53
54#[async_trait]
55impl Provider for MetricsProvider {
56    fn name(&self) -> &str {
57        self.inner.name()
58    }
59
60    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
61        self.inner.list_models().await
62    }
63
64    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
65        let model = request.model.clone();
66        let start = std::time::Instant::now();
67
68        match self.inner.complete(request).await {
69            Ok(response) => {
70                let latency_ms = start.elapsed().as_millis() as u64;
71                self.record_request(&model, latency_ms, &response.usage, true)
72                    .await;
73                Ok(response)
74            }
75            Err(e) => {
76                let latency_ms = start.elapsed().as_millis() as u64;
77                self.record_request(&model, latency_ms, &Usage::default(), false)
78                    .await;
79                Err(e)
80            }
81        }
82    }
83
84    async fn complete_stream(
85        &self,
86        request: CompletionRequest,
87    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
88        let model = request.model.clone();
89        let provider_name = self.inner.name().to_string();
90        let start = std::time::Instant::now();
91
92        match self.inner.complete_stream(request).await {
93            Ok(stream) => {
94                let ttft_ms = start.elapsed().as_millis() as u64;
95
96                // Wrap the stream to capture final usage from Done chunk
97                let stream =
98                    StreamMetricsWrapper::new(stream, provider_name, model, start, ttft_ms);
99
100                Ok(Box::pin(stream))
101            }
102            Err(e) => {
103                let latency_ms = start.elapsed().as_millis() as u64;
104                let record = ProviderRequestRecord {
105                    provider: provider_name,
106                    model,
107                    timestamp: chrono::Utc::now(),
108                    prompt_tokens: 0,
109                    completion_tokens: 0,
110                    input_tokens: 0,
111                    output_tokens: 0,
112                    latency_ms,
113                    ttft_ms: None,
114                    success: false,
115                };
116                PROVIDER_METRICS.record(record).await;
117                Err(e)
118            }
119        }
120    }
121
122    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
123        self.inner.embed(request).await
124    }
125}
126
127/// Wraps a stream to capture metrics when the `Done` chunk arrives
128struct StreamMetricsWrapper {
129    inner: futures::stream::BoxStream<'static, StreamChunk>,
130    provider: String,
131    model: String,
132    start: std::time::Instant,
133    ttft_ms: u64,
134    recorded: bool,
135}
136
137impl StreamMetricsWrapper {
138    fn new(
139        inner: futures::stream::BoxStream<'static, StreamChunk>,
140        provider: String,
141        model: String,
142        start: std::time::Instant,
143        ttft_ms: u64,
144    ) -> Self {
145        Self {
146            inner,
147            provider,
148            model,
149            start,
150            ttft_ms,
151            recorded: false,
152        }
153    }
154}
155
156impl futures::Stream for StreamMetricsWrapper {
157    type Item = StreamChunk;
158
159    fn poll_next(
160        mut self: std::pin::Pin<&mut Self>,
161        cx: &mut std::task::Context<'_>,
162    ) -> std::task::Poll<Option<Self::Item>> {
163        use std::task::Poll;
164
165        let result = std::pin::Pin::new(&mut self.inner).poll_next(cx);
166
167        match &result {
168            Poll::Ready(Some(StreamChunk::Done { usage })) if !self.recorded => {
169                self.recorded = true;
170                let latency_ms = self.start.elapsed().as_millis() as u64;
171                let (input_tokens, output_tokens) = usage
172                    .as_ref()
173                    .map(|u| (u.prompt_tokens as u64, u.completion_tokens as u64))
174                    .unwrap_or((0, 0));
175
176                let record = ProviderRequestRecord {
177                    provider: self.provider.clone(),
178                    model: self.model.clone(),
179                    timestamp: chrono::Utc::now(),
180                    prompt_tokens: input_tokens,
181                    completion_tokens: output_tokens,
182                    input_tokens,
183                    output_tokens,
184                    latency_ms,
185                    ttft_ms: Some(self.ttft_ms),
186                    success: true,
187                };
188
189                tracing::info!(
190                    provider = %record.provider,
191                    model = %record.model,
192                    latency_ms = record.latency_ms,
193                    ttft_ms = record.ttft_ms,
194                    input_tokens = record.input_tokens,
195                    output_tokens = record.output_tokens,
196                    tps = format!("{:.1}", record.tokens_per_second()),
197                    "Provider streaming request completed"
198                );
199
200                let metrics = PROVIDER_METRICS.clone();
201                tokio::spawn(async move { metrics.record(record).await });
202            }
203            Poll::Ready(Some(StreamChunk::Error(_))) if !self.recorded => {
204                self.recorded = true;
205                let latency_ms = self.start.elapsed().as_millis() as u64;
206                let record = ProviderRequestRecord {
207                    provider: self.provider.clone(),
208                    model: self.model.clone(),
209                    timestamp: chrono::Utc::now(),
210                    prompt_tokens: 0,
211                    completion_tokens: 0,
212                    input_tokens: 0,
213                    output_tokens: 0,
214                    latency_ms,
215                    ttft_ms: Some(self.ttft_ms),
216                    success: false,
217                };
218                let metrics = PROVIDER_METRICS.clone();
219                tokio::spawn(async move { metrics.record(record).await });
220            }
221            _ => {}
222        }
223
224        result
225    }
226}