Skip to main content

synaptic_callbacks/
metrics.rs

1//! Metrics collection callback — records latency, token counts, and errors.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use async_trait::async_trait;
8use synaptic_core::{CallbackHandler, RunEvent, SynapticError};
9use tokio::sync::RwLock;
10
11/// Snapshot of collected metrics.
12#[derive(Debug, Clone, Default)]
13pub struct MetricsSnapshot {
14    /// Total model calls.
15    pub total_model_calls: u64,
16    /// Total tool calls.
17    pub total_tool_calls: u64,
18    /// Total errors.
19    pub total_errors: u64,
20    /// Total input tokens across all requests.
21    pub total_input_tokens: u64,
22    /// Total output tokens across all requests.
23    pub total_output_tokens: u64,
24    /// Average model call latency in milliseconds.
25    pub avg_model_latency_ms: f64,
26    /// Per-tool metrics.
27    pub per_tool: HashMap<String, ToolMetrics>,
28}
29
30/// Metrics for a specific tool.
31#[derive(Debug, Clone, Default)]
32pub struct ToolMetrics {
33    pub calls: u64,
34    pub errors: u64,
35    pub total_latency_ms: u64,
36}
37
38struct MetricsState {
39    total_model_calls: u64,
40    total_tool_calls: u64,
41    total_errors: u64,
42    total_input_tokens: u64,
43    total_output_tokens: u64,
44    total_model_latency_ms: u64,
45    per_tool: HashMap<String, ToolMetrics>,
46    /// Pending model call start times (keyed by run_id).
47    model_start_times: HashMap<String, Instant>,
48    /// Pending tool call start times (keyed by run_id + tool_name).
49    tool_start_times: HashMap<String, Instant>,
50}
51
52/// Callback that collects latency, token, and error metrics.
53///
54/// Uses the standard `RunEvent` lifecycle events to measure model and tool
55/// call latency, accumulate token usage, and count errors.
56///
57/// Call [`snapshot()`](MetricsCallback::snapshot) at any time to read the
58/// current metrics, or [`reset()`](MetricsCallback::reset) to zero them out.
59pub struct MetricsCallback {
60    state: Arc<RwLock<MetricsState>>,
61}
62
63impl MetricsCallback {
64    pub fn new() -> Self {
65        Self {
66            state: Arc::new(RwLock::new(MetricsState {
67                total_model_calls: 0,
68                total_tool_calls: 0,
69                total_errors: 0,
70                total_input_tokens: 0,
71                total_output_tokens: 0,
72                total_model_latency_ms: 0,
73                per_tool: HashMap::new(),
74                model_start_times: HashMap::new(),
75                tool_start_times: HashMap::new(),
76            })),
77        }
78    }
79
80    /// Take a snapshot of the current metrics.
81    pub async fn snapshot(&self) -> MetricsSnapshot {
82        let state = self.state.read().await;
83        let avg = if state.total_model_calls > 0 {
84            state.total_model_latency_ms as f64 / state.total_model_calls as f64
85        } else {
86            0.0
87        };
88        MetricsSnapshot {
89            total_model_calls: state.total_model_calls,
90            total_tool_calls: state.total_tool_calls,
91            total_errors: state.total_errors,
92            total_input_tokens: state.total_input_tokens,
93            total_output_tokens: state.total_output_tokens,
94            avg_model_latency_ms: avg,
95            per_tool: state.per_tool.clone(),
96        }
97    }
98
99    /// Record token usage externally (e.g. from a `ChatResponse`).
100    ///
101    /// This allows callers that have access to the actual `TokenUsage` from
102    /// model responses to feed it into the metrics.
103    pub async fn record_tokens(&self, input_tokens: u64, output_tokens: u64) {
104        let mut state = self.state.write().await;
105        state.total_input_tokens += input_tokens;
106        state.total_output_tokens += output_tokens;
107    }
108
109    /// Reset all metrics.
110    pub async fn reset(&self) {
111        let mut state = self.state.write().await;
112        state.total_model_calls = 0;
113        state.total_tool_calls = 0;
114        state.total_errors = 0;
115        state.total_input_tokens = 0;
116        state.total_output_tokens = 0;
117        state.total_model_latency_ms = 0;
118        state.per_tool.clear();
119    }
120}
121
122impl Default for MetricsCallback {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128#[async_trait]
129impl CallbackHandler for MetricsCallback {
130    async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
131        let mut state = self.state.write().await;
132        match event {
133            // Model call lifecycle: BeforeMessage → AfterMessage
134            RunEvent::BeforeMessage { run_id, .. } => {
135                state.model_start_times.insert(run_id, Instant::now());
136            }
137            RunEvent::AfterMessage { run_id, .. } => {
138                let elapsed = state
139                    .model_start_times
140                    .remove(&run_id)
141                    .map(|start| start.elapsed().as_millis() as u64)
142                    .unwrap_or(0);
143
144                state.total_model_calls += 1;
145                state.total_model_latency_ms += elapsed;
146            }
147            // Tool call lifecycle: BeforeToolCall → AfterToolCall
148            RunEvent::BeforeToolCall {
149                run_id, tool_name, ..
150            } => {
151                let key = format!("{}:{}", run_id, tool_name);
152                state.tool_start_times.insert(key, Instant::now());
153            }
154            RunEvent::AfterToolCall {
155                run_id, tool_name, ..
156            } => {
157                let key = format!("{}:{}", run_id, tool_name);
158                let elapsed = state
159                    .tool_start_times
160                    .remove(&key)
161                    .map(|start| start.elapsed().as_millis() as u64)
162                    .unwrap_or(0);
163
164                state.total_tool_calls += 1;
165                let tool_metrics = state.per_tool.entry(tool_name).or_default();
166                tool_metrics.calls += 1;
167                tool_metrics.total_latency_ms += elapsed;
168            }
169            // Error tracking
170            RunEvent::RunFailed { .. } => {
171                state.total_errors += 1;
172            }
173            _ => {}
174        }
175        Ok(())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[tokio::test]
184    async fn test_metrics_snapshot_empty() {
185        let cb = MetricsCallback::new();
186        let snap = cb.snapshot().await;
187        assert_eq!(snap.total_model_calls, 0);
188        assert_eq!(snap.total_tool_calls, 0);
189        assert_eq!(snap.total_errors, 0);
190    }
191
192    #[tokio::test]
193    async fn test_metrics_model_call() {
194        let cb = MetricsCallback::new();
195        cb.on_event(RunEvent::BeforeMessage {
196            run_id: "r1".to_string(),
197            message_count: 3,
198        })
199        .await
200        .unwrap();
201
202        // Simulate some latency
203        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
204
205        cb.on_event(RunEvent::AfterMessage {
206            run_id: "r1".to_string(),
207            response_length: 42,
208        })
209        .await
210        .unwrap();
211
212        // Also record token usage externally
213        cb.record_tokens(10, 5).await;
214
215        let snap = cb.snapshot().await;
216        assert_eq!(snap.total_model_calls, 1);
217        assert_eq!(snap.total_input_tokens, 10);
218        assert_eq!(snap.total_output_tokens, 5);
219        assert!(snap.avg_model_latency_ms >= 5.0); // should be >= 10ms but allow some slack
220    }
221
222    #[tokio::test]
223    async fn test_metrics_tool_call() {
224        let cb = MetricsCallback::new();
225        cb.on_event(RunEvent::BeforeToolCall {
226            run_id: "r1".to_string(),
227            tool_name: "read_file".to_string(),
228            arguments: "{}".to_string(),
229        })
230        .await
231        .unwrap();
232
233        cb.on_event(RunEvent::AfterToolCall {
234            run_id: "r1".to_string(),
235            tool_name: "read_file".to_string(),
236            result: "ok".to_string(),
237        })
238        .await
239        .unwrap();
240
241        let snap = cb.snapshot().await;
242        assert_eq!(snap.total_tool_calls, 1);
243        assert!(snap.per_tool.contains_key("read_file"));
244        assert_eq!(snap.per_tool["read_file"].calls, 1);
245    }
246
247    #[tokio::test]
248    async fn test_metrics_error_counting() {
249        let cb = MetricsCallback::new();
250        cb.on_event(RunEvent::RunFailed {
251            run_id: "r1".to_string(),
252            error: "oops".to_string(),
253        })
254        .await
255        .unwrap();
256
257        assert_eq!(cb.snapshot().await.total_errors, 1);
258    }
259
260    #[tokio::test]
261    async fn test_metrics_reset() {
262        let cb = MetricsCallback::new();
263        cb.on_event(RunEvent::RunFailed {
264            run_id: "r1".to_string(),
265            error: "oops".to_string(),
266        })
267        .await
268        .unwrap();
269
270        assert_eq!(cb.snapshot().await.total_errors, 1);
271        cb.reset().await;
272        assert_eq!(cb.snapshot().await.total_errors, 0);
273    }
274
275    #[tokio::test]
276    async fn test_metrics_multiple_tools() {
277        let cb = MetricsCallback::new();
278
279        // Two calls to "read_file"
280        for i in 0..2 {
281            let run_id = format!("r{}", i);
282            cb.on_event(RunEvent::BeforeToolCall {
283                run_id: run_id.clone(),
284                tool_name: "read_file".to_string(),
285                arguments: "{}".to_string(),
286            })
287            .await
288            .unwrap();
289            cb.on_event(RunEvent::AfterToolCall {
290                run_id,
291                tool_name: "read_file".to_string(),
292                result: "ok".to_string(),
293            })
294            .await
295            .unwrap();
296        }
297
298        // One call to "write_file"
299        cb.on_event(RunEvent::BeforeToolCall {
300            run_id: "r2".to_string(),
301            tool_name: "write_file".to_string(),
302            arguments: "{}".to_string(),
303        })
304        .await
305        .unwrap();
306        cb.on_event(RunEvent::AfterToolCall {
307            run_id: "r2".to_string(),
308            tool_name: "write_file".to_string(),
309            result: "ok".to_string(),
310        })
311        .await
312        .unwrap();
313
314        let snap = cb.snapshot().await;
315        assert_eq!(snap.total_tool_calls, 3);
316        assert_eq!(snap.per_tool["read_file"].calls, 2);
317        assert_eq!(snap.per_tool["write_file"].calls, 1);
318    }
319}