Skip to main content

embacle/
metrics.rs

1// ABOUTME: Decorator wrapping any LlmProvider to measure latency, token usage, and call counts
2// ABOUTME: Provides MetricsReport snapshots for cost and performance normalization
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7//! # Cost/Latency Normalization
8//!
9//! [`MetricsProvider`] is a decorator that wraps any `Box<dyn LlmProvider>`,
10//! measuring per-call latency and token usage. Callers can retrieve a
11//! [`MetricsReport`] snapshot at any time via [`MetricsProvider::report()`].
12//!
13//! Token estimation: when `TokenUsage` is not provided by the inner provider,
14//! tokens are estimated at ~4 characters per token.
15
16use std::sync::{Arc, Mutex};
17use std::time::Instant;
18
19use async_trait::async_trait;
20use tracing::info;
21
22use crate::types::{
23    ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError,
24};
25
26/// Characters-per-token estimate used when the provider does not report usage
27const CHARS_PER_TOKEN_ESTIMATE: u32 = 4;
28
29/// Accumulated metrics state protected by a mutex
30#[derive(Debug, Default)]
31struct MetricsState {
32    call_count: u64,
33    total_latency_ms: u64,
34    total_prompt_tokens: u64,
35    total_completion_tokens: u64,
36    total_tokens: u64,
37    errors_count: u64,
38}
39
40/// Snapshot of accumulated metrics for a provider
41#[derive(Debug, Clone)]
42pub struct MetricsReport {
43    /// Name of the wrapped provider
44    pub provider_name: String,
45    /// Total number of `complete()` calls
46    pub call_count: u64,
47    /// Total latency across all calls (milliseconds)
48    pub total_latency_ms: u64,
49    /// Average latency per call (milliseconds)
50    pub avg_latency_ms: u64,
51    /// Total prompt tokens consumed
52    pub total_prompt_tokens: u64,
53    /// Total completion tokens generated
54    pub total_completion_tokens: u64,
55    /// Total tokens (prompt + completion)
56    pub total_tokens: u64,
57    /// Number of calls that returned an error
58    pub errors_count: u64,
59}
60
61/// Decorator wrapping any `Box<dyn LlmProvider>` to collect latency and token metrics.
62///
63/// # Usage
64///
65/// ```rust,no_run
66/// # use embacle::metrics::MetricsProvider;
67/// # use embacle::types::LlmProvider;
68/// # fn example(provider: Box<dyn LlmProvider>) {
69/// let metered = MetricsProvider::new(provider);
70/// // ... use metered as LlmProvider ...
71/// let report = metered.report();
72/// println!("calls={} avg_latency={}ms", report.call_count, report.avg_latency_ms);
73/// # }
74/// ```
75pub struct MetricsProvider {
76    inner: Box<dyn LlmProvider>,
77    state: Arc<Mutex<MetricsState>>,
78}
79
80impl MetricsProvider {
81    /// Wrap a provider with metrics collection
82    pub fn new(inner: Box<dyn LlmProvider>) -> Self {
83        Self {
84            inner,
85            state: Arc::new(Mutex::new(MetricsState::default())),
86        }
87    }
88
89    /// Return a snapshot of the current metrics
90    ///
91    /// # Panics
92    ///
93    /// Panics if the internal mutex is poisoned.
94    pub fn report(&self) -> MetricsReport {
95        let state = self.state.lock().expect("metrics lock poisoned");
96        let divisor = state.call_count.max(1);
97        MetricsReport {
98            provider_name: self.inner.name().to_owned(),
99            call_count: state.call_count,
100            total_latency_ms: state.total_latency_ms,
101            avg_latency_ms: state.total_latency_ms / divisor,
102            total_prompt_tokens: state.total_prompt_tokens,
103            total_completion_tokens: state.total_completion_tokens,
104            total_tokens: state.total_tokens,
105            errors_count: state.errors_count,
106        }
107    }
108
109    /// Reset all counters to zero
110    ///
111    /// # Panics
112    ///
113    /// Panics if the internal mutex is poisoned.
114    pub fn reset(&self) {
115        let mut state = self.state.lock().expect("metrics lock poisoned");
116        *state = MetricsState::default();
117    }
118}
119
120/// Estimate token count from character length (~4 chars per token)
121fn estimate_tokens(text: &str) -> u32 {
122    #[allow(clippy::cast_possible_truncation)]
123    let len = text.len() as u32;
124    len / CHARS_PER_TOKEN_ESTIMATE.max(1)
125}
126
127#[async_trait]
128impl LlmProvider for MetricsProvider {
129    fn name(&self) -> &'static str {
130        self.inner.name()
131    }
132
133    fn display_name(&self) -> &'static str {
134        self.inner.display_name()
135    }
136
137    fn capabilities(&self) -> LlmCapabilities {
138        self.inner.capabilities()
139    }
140
141    fn default_model(&self) -> &str {
142        self.inner.default_model()
143    }
144
145    fn available_models(&self) -> &[String] {
146        self.inner.available_models()
147    }
148
149    async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
150        let start = Instant::now();
151        let result = self.inner.complete(request).await;
152        #[allow(clippy::cast_possible_truncation)]
153        let elapsed_ms = start.elapsed().as_millis() as u64;
154
155        let mut state = self.state.lock().expect("metrics lock poisoned");
156        state.call_count += 1;
157        state.total_latency_ms += elapsed_ms;
158
159        if let Ok(response) = &result {
160            let usage = response.usage.as_ref();
161            let prompt_tokens = u64::from(
162                usage.map_or_else(|| estimate_prompt_tokens(request), |u| u.prompt_tokens),
163            );
164            let completion_tokens = u64::from(usage.map_or_else(
165                || estimate_tokens(&response.content),
166                |u| u.completion_tokens,
167            ));
168            let total = prompt_tokens + completion_tokens;
169
170            state.total_prompt_tokens += prompt_tokens;
171            state.total_completion_tokens += completion_tokens;
172            state.total_tokens += total;
173
174            info!(
175                provider = self.inner.name(),
176                elapsed_ms, prompt_tokens, completion_tokens, "metrics: complete() succeeded"
177            );
178        } else {
179            state.errors_count += 1;
180            info!(
181                provider = self.inner.name(),
182                elapsed_ms, "metrics: complete() failed"
183            );
184        }
185
186        drop(state);
187        result
188    }
189
190    /// Delegate streaming directly; only measures stream setup time (documented limitation)
191    async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
192        self.inner.complete_stream(request).await
193    }
194
195    async fn health_check(&self) -> Result<bool, RunnerError> {
196        self.inner.health_check().await
197    }
198}
199
200/// Estimate prompt tokens from request messages
201fn estimate_prompt_tokens(request: &ChatRequest) -> u32 {
202    let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
203    #[allow(clippy::cast_possible_truncation)]
204    let len = total_chars as u32;
205    len / CHARS_PER_TOKEN_ESTIMATE.max(1)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::types::{
212        ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
213        RunnerError, TokenUsage,
214    };
215    use async_trait::async_trait;
216    use std::sync::atomic::{AtomicU32, Ordering};
217
218    struct TestProvider {
219        responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
220        call_count: AtomicU32,
221    }
222
223    impl TestProvider {
224        fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
225            Self {
226                responses: Mutex::new(responses),
227                call_count: AtomicU32::new(0),
228            }
229        }
230    }
231
232    #[async_trait]
233    impl LlmProvider for TestProvider {
234        fn name(&self) -> &'static str {
235            "test"
236        }
237        fn display_name(&self) -> &'static str {
238            "Test Provider"
239        }
240        fn capabilities(&self) -> LlmCapabilities {
241            LlmCapabilities::text_only()
242        }
243        fn default_model(&self) -> &'static str {
244            "test-model"
245        }
246        fn available_models(&self) -> &[String] {
247            &[]
248        }
249
250        async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
251            self.call_count.fetch_add(1, Ordering::SeqCst);
252            let mut responses = self.responses.lock().expect("test lock");
253            if responses.is_empty() {
254                Ok(ChatResponse {
255                    content: "default".to_owned(),
256                    model: "test-model".to_owned(),
257                    usage: None,
258                    finish_reason: Some("stop".to_owned()),
259                    warnings: None,
260                })
261            } else {
262                responses.remove(0)
263            }
264        }
265
266        async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
267            Err(RunnerError::internal("streaming not supported in test"))
268        }
269
270        async fn health_check(&self) -> Result<bool, RunnerError> {
271            Ok(true)
272        }
273    }
274
275    #[test]
276    fn fresh_report_is_zeroed() {
277        let provider = TestProvider::new(vec![]);
278        let metered = MetricsProvider::new(Box::new(provider));
279        let report = metered.report();
280        assert_eq!(report.call_count, 0);
281        assert_eq!(report.total_latency_ms, 0);
282        assert_eq!(report.avg_latency_ms, 0);
283        assert_eq!(report.total_prompt_tokens, 0);
284        assert_eq!(report.total_completion_tokens, 0);
285        assert_eq!(report.total_tokens, 0);
286        assert_eq!(report.errors_count, 0);
287        assert_eq!(report.provider_name, "test");
288    }
289
290    #[tokio::test]
291    async fn call_count_increments() {
292        let provider = TestProvider::new(vec![
293            Ok(ChatResponse {
294                content: "hello world".to_owned(),
295                model: "test-model".to_owned(),
296                usage: Some(TokenUsage {
297                    prompt_tokens: 10,
298                    completion_tokens: 5,
299                    total_tokens: 15,
300                }),
301                finish_reason: Some("stop".to_owned()),
302                warnings: None,
303            }),
304            Ok(ChatResponse {
305                content: "second".to_owned(),
306                model: "test-model".to_owned(),
307                usage: Some(TokenUsage {
308                    prompt_tokens: 8,
309                    completion_tokens: 3,
310                    total_tokens: 11,
311                }),
312                finish_reason: Some("stop".to_owned()),
313                warnings: None,
314            }),
315        ]);
316        let metered = MetricsProvider::new(Box::new(provider));
317        let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
318
319        metered.complete(&request).await.expect("first call");
320        metered.complete(&request).await.expect("second call");
321
322        let report = metered.report();
323        assert_eq!(report.call_count, 2);
324        assert_eq!(report.total_prompt_tokens, 18);
325        assert_eq!(report.total_completion_tokens, 8);
326        assert_eq!(report.total_tokens, 26);
327        assert_eq!(report.errors_count, 0);
328    }
329
330    #[tokio::test]
331    async fn errors_count_on_failure() {
332        let provider = TestProvider::new(vec![Err(RunnerError::external_service("test", "boom"))]);
333        let metered = MetricsProvider::new(Box::new(provider));
334        let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
335
336        let result = metered.complete(&request).await;
337        assert!(result.is_err());
338
339        let report = metered.report();
340        assert_eq!(report.call_count, 1);
341        assert_eq!(report.errors_count, 1);
342    }
343
344    #[tokio::test]
345    async fn token_estimation_when_no_usage() {
346        let provider = TestProvider::new(vec![Ok(ChatResponse {
347            content: "abcdefghijklmnop".to_owned(), // 16 chars => 4 tokens
348            model: "test-model".to_owned(),
349            usage: None,
350            finish_reason: Some("stop".to_owned()),
351            warnings: None,
352        })]);
353        let metered = MetricsProvider::new(Box::new(provider));
354        let request = ChatRequest::new(vec![ChatMessage::user("12345678")]); // 8 chars => 2 tokens
355
356        metered.complete(&request).await.expect("call");
357
358        let report = metered.report();
359        assert_eq!(report.total_prompt_tokens, 2);
360        assert_eq!(report.total_completion_tokens, 4);
361        assert_eq!(report.total_tokens, 6);
362    }
363
364    #[test]
365    fn div_by_zero_guard_on_avg_latency() {
366        let provider = TestProvider::new(vec![]);
367        let metered = MetricsProvider::new(Box::new(provider));
368        // No calls made — avg should be 0/max(0,1) = 0, not panic
369        let report = metered.report();
370        assert_eq!(report.avg_latency_ms, 0);
371    }
372
373    #[tokio::test]
374    async fn reset_zeroes_counters() {
375        let provider = TestProvider::new(vec![Ok(ChatResponse {
376            content: "hello".to_owned(),
377            model: "test-model".to_owned(),
378            usage: Some(TokenUsage {
379                prompt_tokens: 5,
380                completion_tokens: 2,
381                total_tokens: 7,
382            }),
383            finish_reason: Some("stop".to_owned()),
384            warnings: None,
385        })]);
386        let metered = MetricsProvider::new(Box::new(provider));
387        let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
388
389        metered.complete(&request).await.expect("call");
390        assert_eq!(metered.report().call_count, 1);
391
392        metered.reset();
393        let report = metered.report();
394        assert_eq!(report.call_count, 0);
395        assert_eq!(report.total_tokens, 0);
396        assert_eq!(report.errors_count, 0);
397    }
398}