Skip to main content

codetether_agent/provider/
metrics.rs

1//! Provider metrics wrapper
2//!
3//! Wraps any `Provider` to automatically record latency, throughput,
4//! and tokens-per-second via the global `PROVIDER_METRICS` registry.
5
6use super::{CompletionRequest, CompletionResponse, ModelInfo, Provider, StreamChunk, Usage};
7use crate::telemetry::{PROVIDER_METRICS, ProviderRequestRecord};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::Arc;
11
12/// A provider wrapper that instruments every call with performance metrics.
13pub struct MetricsProvider {
14    inner: Arc<dyn Provider>,
15}
16
17impl MetricsProvider {
18    /// Wrap a provider with automatic metrics collection
19    pub fn wrap(inner: Arc<dyn Provider>) -> Arc<Self> {
20        Arc::new(Self { inner })
21    }
22
23    fn record_request(&self, model: &str, latency_ms: u64, usage: &Usage, success: bool) {
24        let record = ProviderRequestRecord {
25            provider: self.inner.name().to_string(),
26            model: model.to_string(),
27            timestamp: chrono::Utc::now(),
28            prompt_tokens: usage.prompt_tokens as u64,
29            completion_tokens: usage.completion_tokens as u64,
30            input_tokens: usage.prompt_tokens as u64,
31            output_tokens: usage.completion_tokens as u64,
32            latency_ms,
33            ttft_ms: None, // non-streaming: no TTFT distinction
34            success,
35        };
36
37        tracing::info!(
38            provider = %record.provider,
39            model = %record.model,
40            latency_ms = record.latency_ms,
41            input_tokens = record.input_tokens,
42            output_tokens = record.output_tokens,
43            tps = format!("{:.1}", record.tokens_per_second()),
44            "Provider request completed"
45        );
46
47        PROVIDER_METRICS.record(record);
48    }
49}
50
51#[async_trait]
52impl Provider for MetricsProvider {
53    fn name(&self) -> &str {
54        self.inner.name()
55    }
56
57    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
58        self.inner.list_models().await
59    }
60
61    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
62        let model = request.model.clone();
63        let start = std::time::Instant::now();
64
65        match self.inner.complete(request).await {
66            Ok(response) => {
67                let latency_ms = start.elapsed().as_millis() as u64;
68                self.record_request(&model, latency_ms, &response.usage, true);
69                Ok(response)
70            }
71            Err(e) => {
72                let latency_ms = start.elapsed().as_millis() as u64;
73                self.record_request(&model, latency_ms, &Usage::default(), false);
74                Err(e)
75            }
76        }
77    }
78
79    async fn complete_stream(
80        &self,
81        request: CompletionRequest,
82    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
83        let model = request.model.clone();
84        let provider_name = self.inner.name().to_string();
85        let start = std::time::Instant::now();
86
87        match self.inner.complete_stream(request).await {
88            Ok(stream) => {
89                let ttft_ms = start.elapsed().as_millis() as u64;
90
91                // Wrap the stream to capture final usage from Done chunk
92                let stream =
93                    StreamMetricsWrapper::new(stream, provider_name, model, start, ttft_ms);
94
95                Ok(Box::pin(stream))
96            }
97            Err(e) => {
98                let latency_ms = start.elapsed().as_millis() as u64;
99                let record = ProviderRequestRecord {
100                    provider: provider_name,
101                    model,
102                    timestamp: chrono::Utc::now(),
103                    prompt_tokens: 0,
104                    completion_tokens: 0,
105                    input_tokens: 0,
106                    output_tokens: 0,
107                    latency_ms,
108                    ttft_ms: None,
109                    success: false,
110                };
111                PROVIDER_METRICS.record(record);
112                Err(e)
113            }
114        }
115    }
116}
117
118/// Wraps a stream to capture metrics when the `Done` chunk arrives
119struct StreamMetricsWrapper {
120    inner: futures::stream::BoxStream<'static, StreamChunk>,
121    provider: String,
122    model: String,
123    start: std::time::Instant,
124    ttft_ms: u64,
125    recorded: bool,
126}
127
128impl StreamMetricsWrapper {
129    fn new(
130        inner: futures::stream::BoxStream<'static, StreamChunk>,
131        provider: String,
132        model: String,
133        start: std::time::Instant,
134        ttft_ms: u64,
135    ) -> Self {
136        Self {
137            inner,
138            provider,
139            model,
140            start,
141            ttft_ms,
142            recorded: false,
143        }
144    }
145}
146
147impl futures::Stream for StreamMetricsWrapper {
148    type Item = StreamChunk;
149
150    fn poll_next(
151        mut self: std::pin::Pin<&mut Self>,
152        cx: &mut std::task::Context<'_>,
153    ) -> std::task::Poll<Option<Self::Item>> {
154        use std::task::Poll;
155
156        let result = std::pin::Pin::new(&mut self.inner).poll_next(cx);
157
158        match &result {
159            Poll::Ready(Some(StreamChunk::Done { usage })) if !self.recorded => {
160                self.recorded = true;
161                let latency_ms = self.start.elapsed().as_millis() as u64;
162                let (input_tokens, output_tokens) = usage
163                    .as_ref()
164                    .map(|u| (u.prompt_tokens as u64, u.completion_tokens as u64))
165                    .unwrap_or((0, 0));
166
167                let record = ProviderRequestRecord {
168                    provider: self.provider.clone(),
169                    model: self.model.clone(),
170                    timestamp: chrono::Utc::now(),
171                    prompt_tokens: input_tokens,
172                    completion_tokens: output_tokens,
173                    input_tokens,
174                    output_tokens,
175                    latency_ms,
176                    ttft_ms: Some(self.ttft_ms),
177                    success: true,
178                };
179
180                tracing::info!(
181                    provider = %record.provider,
182                    model = %record.model,
183                    latency_ms = record.latency_ms,
184                    ttft_ms = record.ttft_ms,
185                    input_tokens = record.input_tokens,
186                    output_tokens = record.output_tokens,
187                    tps = format!("{:.1}", record.tokens_per_second()),
188                    "Provider streaming request completed"
189                );
190
191                PROVIDER_METRICS.record(record);
192            }
193            Poll::Ready(Some(StreamChunk::Error(_))) if !self.recorded => {
194                self.recorded = true;
195                let latency_ms = self.start.elapsed().as_millis() as u64;
196                let record = ProviderRequestRecord {
197                    provider: self.provider.clone(),
198                    model: self.model.clone(),
199                    timestamp: chrono::Utc::now(),
200                    prompt_tokens: 0,
201                    completion_tokens: 0,
202                    input_tokens: 0,
203                    output_tokens: 0,
204                    latency_ms,
205                    ttft_ms: Some(self.ttft_ms),
206                    success: false,
207                };
208                PROVIDER_METRICS.record(record);
209            }
210            _ => {}
211        }
212
213        result
214    }
215}