1use std::sync::{Arc, Mutex};
17use std::time::Instant;
18
19use async_trait::async_trait;
20use tracing::info;
21
22use crate::types::{
23 ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError,
24};
25
26const CHARS_PER_TOKEN_ESTIMATE: u32 = 4;
28
29#[derive(Debug, Default)]
31struct MetricsState {
32 call_count: u64,
33 total_latency_ms: u64,
34 total_prompt_tokens: u64,
35 total_completion_tokens: u64,
36 total_tokens: u64,
37 errors_count: u64,
38}
39
40#[derive(Debug, Clone)]
42pub struct MetricsReport {
43 pub provider_name: String,
45 pub call_count: u64,
47 pub total_latency_ms: u64,
49 pub avg_latency_ms: u64,
51 pub total_prompt_tokens: u64,
53 pub total_completion_tokens: u64,
55 pub total_tokens: u64,
57 pub errors_count: u64,
59}
60
61pub struct MetricsProvider {
76 inner: Box<dyn LlmProvider>,
77 state: Arc<Mutex<MetricsState>>,
78}
79
80impl MetricsProvider {
81 pub fn new(inner: Box<dyn LlmProvider>) -> Self {
83 Self {
84 inner,
85 state: Arc::new(Mutex::new(MetricsState::default())),
86 }
87 }
88
89 pub fn report(&self) -> MetricsReport {
95 let state = self.state.lock().expect("metrics lock poisoned");
96 let divisor = state.call_count.max(1);
97 MetricsReport {
98 provider_name: self.inner.name().to_owned(),
99 call_count: state.call_count,
100 total_latency_ms: state.total_latency_ms,
101 avg_latency_ms: state.total_latency_ms / divisor,
102 total_prompt_tokens: state.total_prompt_tokens,
103 total_completion_tokens: state.total_completion_tokens,
104 total_tokens: state.total_tokens,
105 errors_count: state.errors_count,
106 }
107 }
108
109 pub fn reset(&self) {
115 let mut state = self.state.lock().expect("metrics lock poisoned");
116 *state = MetricsState::default();
117 }
118}
119
120fn estimate_tokens(text: &str) -> u32 {
122 #[allow(clippy::cast_possible_truncation)]
123 let len = text.len() as u32;
124 len / CHARS_PER_TOKEN_ESTIMATE.max(1)
125}
126
127#[async_trait]
128impl LlmProvider for MetricsProvider {
129 fn name(&self) -> &'static str {
130 self.inner.name()
131 }
132
133 fn display_name(&self) -> &'static str {
134 self.inner.display_name()
135 }
136
137 fn capabilities(&self) -> LlmCapabilities {
138 self.inner.capabilities()
139 }
140
141 fn default_model(&self) -> &str {
142 self.inner.default_model()
143 }
144
145 fn available_models(&self) -> &[String] {
146 self.inner.available_models()
147 }
148
149 async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
150 let start = Instant::now();
151 let result = self.inner.complete(request).await;
152 #[allow(clippy::cast_possible_truncation)]
153 let elapsed_ms = start.elapsed().as_millis() as u64;
154
155 let mut state = self.state.lock().expect("metrics lock poisoned");
156 state.call_count += 1;
157 state.total_latency_ms += elapsed_ms;
158
159 if let Ok(response) = &result {
160 let usage = response.usage.as_ref();
161 let prompt_tokens = u64::from(
162 usage.map_or_else(|| estimate_prompt_tokens(request), |u| u.prompt_tokens),
163 );
164 let completion_tokens = u64::from(usage.map_or_else(
165 || estimate_tokens(&response.content),
166 |u| u.completion_tokens,
167 ));
168 let total = prompt_tokens + completion_tokens;
169
170 state.total_prompt_tokens += prompt_tokens;
171 state.total_completion_tokens += completion_tokens;
172 state.total_tokens += total;
173
174 info!(
175 provider = self.inner.name(),
176 elapsed_ms, prompt_tokens, completion_tokens, "metrics: complete() succeeded"
177 );
178 } else {
179 state.errors_count += 1;
180 info!(
181 provider = self.inner.name(),
182 elapsed_ms, "metrics: complete() failed"
183 );
184 }
185
186 drop(state);
187 result
188 }
189
190 async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
192 self.inner.complete_stream(request).await
193 }
194
195 async fn health_check(&self) -> Result<bool, RunnerError> {
196 self.inner.health_check().await
197 }
198}
199
200fn estimate_prompt_tokens(request: &ChatRequest) -> u32 {
202 let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
203 #[allow(clippy::cast_possible_truncation)]
204 let len = total_chars as u32;
205 len / CHARS_PER_TOKEN_ESTIMATE.max(1)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::types::{
212 ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
213 RunnerError, TokenUsage,
214 };
215 use async_trait::async_trait;
216 use std::sync::atomic::{AtomicU32, Ordering};
217
218 struct TestProvider {
219 responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
220 call_count: AtomicU32,
221 }
222
223 impl TestProvider {
224 fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
225 Self {
226 responses: Mutex::new(responses),
227 call_count: AtomicU32::new(0),
228 }
229 }
230 }
231
232 #[async_trait]
233 impl LlmProvider for TestProvider {
234 fn name(&self) -> &'static str {
235 "test"
236 }
237 fn display_name(&self) -> &'static str {
238 "Test Provider"
239 }
240 fn capabilities(&self) -> LlmCapabilities {
241 LlmCapabilities::text_only()
242 }
243 fn default_model(&self) -> &'static str {
244 "test-model"
245 }
246 fn available_models(&self) -> &[String] {
247 &[]
248 }
249
250 async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
251 self.call_count.fetch_add(1, Ordering::SeqCst);
252 let mut responses = self.responses.lock().expect("test lock");
253 if responses.is_empty() {
254 Ok(ChatResponse {
255 content: "default".to_owned(),
256 model: "test-model".to_owned(),
257 usage: None,
258 finish_reason: Some("stop".to_owned()),
259 warnings: None,
260 })
261 } else {
262 responses.remove(0)
263 }
264 }
265
266 async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
267 Err(RunnerError::internal("streaming not supported in test"))
268 }
269
270 async fn health_check(&self) -> Result<bool, RunnerError> {
271 Ok(true)
272 }
273 }
274
275 #[test]
276 fn fresh_report_is_zeroed() {
277 let provider = TestProvider::new(vec![]);
278 let metered = MetricsProvider::new(Box::new(provider));
279 let report = metered.report();
280 assert_eq!(report.call_count, 0);
281 assert_eq!(report.total_latency_ms, 0);
282 assert_eq!(report.avg_latency_ms, 0);
283 assert_eq!(report.total_prompt_tokens, 0);
284 assert_eq!(report.total_completion_tokens, 0);
285 assert_eq!(report.total_tokens, 0);
286 assert_eq!(report.errors_count, 0);
287 assert_eq!(report.provider_name, "test");
288 }
289
290 #[tokio::test]
291 async fn call_count_increments() {
292 let provider = TestProvider::new(vec![
293 Ok(ChatResponse {
294 content: "hello world".to_owned(),
295 model: "test-model".to_owned(),
296 usage: Some(TokenUsage {
297 prompt_tokens: 10,
298 completion_tokens: 5,
299 total_tokens: 15,
300 }),
301 finish_reason: Some("stop".to_owned()),
302 warnings: None,
303 }),
304 Ok(ChatResponse {
305 content: "second".to_owned(),
306 model: "test-model".to_owned(),
307 usage: Some(TokenUsage {
308 prompt_tokens: 8,
309 completion_tokens: 3,
310 total_tokens: 11,
311 }),
312 finish_reason: Some("stop".to_owned()),
313 warnings: None,
314 }),
315 ]);
316 let metered = MetricsProvider::new(Box::new(provider));
317 let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
318
319 metered.complete(&request).await.expect("first call");
320 metered.complete(&request).await.expect("second call");
321
322 let report = metered.report();
323 assert_eq!(report.call_count, 2);
324 assert_eq!(report.total_prompt_tokens, 18);
325 assert_eq!(report.total_completion_tokens, 8);
326 assert_eq!(report.total_tokens, 26);
327 assert_eq!(report.errors_count, 0);
328 }
329
330 #[tokio::test]
331 async fn errors_count_on_failure() {
332 let provider = TestProvider::new(vec![Err(RunnerError::external_service("test", "boom"))]);
333 let metered = MetricsProvider::new(Box::new(provider));
334 let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
335
336 let result = metered.complete(&request).await;
337 assert!(result.is_err());
338
339 let report = metered.report();
340 assert_eq!(report.call_count, 1);
341 assert_eq!(report.errors_count, 1);
342 }
343
344 #[tokio::test]
345 async fn token_estimation_when_no_usage() {
346 let provider = TestProvider::new(vec![Ok(ChatResponse {
347 content: "abcdefghijklmnop".to_owned(), model: "test-model".to_owned(),
349 usage: None,
350 finish_reason: Some("stop".to_owned()),
351 warnings: None,
352 })]);
353 let metered = MetricsProvider::new(Box::new(provider));
354 let request = ChatRequest::new(vec![ChatMessage::user("12345678")]); metered.complete(&request).await.expect("call");
357
358 let report = metered.report();
359 assert_eq!(report.total_prompt_tokens, 2);
360 assert_eq!(report.total_completion_tokens, 4);
361 assert_eq!(report.total_tokens, 6);
362 }
363
364 #[test]
365 fn div_by_zero_guard_on_avg_latency() {
366 let provider = TestProvider::new(vec![]);
367 let metered = MetricsProvider::new(Box::new(provider));
368 let report = metered.report();
370 assert_eq!(report.avg_latency_ms, 0);
371 }
372
373 #[tokio::test]
374 async fn reset_zeroes_counters() {
375 let provider = TestProvider::new(vec![Ok(ChatResponse {
376 content: "hello".to_owned(),
377 model: "test-model".to_owned(),
378 usage: Some(TokenUsage {
379 prompt_tokens: 5,
380 completion_tokens: 2,
381 total_tokens: 7,
382 }),
383 finish_reason: Some("stop".to_owned()),
384 warnings: None,
385 })]);
386 let metered = MetricsProvider::new(Box::new(provider));
387 let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
388
389 metered.complete(&request).await.expect("call");
390 assert_eq!(metered.report().call_count, 1);
391
392 metered.reset();
393 let report = metered.report();
394 assert_eq!(report.call_count, 0);
395 assert_eq!(report.total_tokens, 0);
396 assert_eq!(report.errors_count, 0);
397 }
398}