1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use async_trait::async_trait;
8use synaptic_core::{CallbackHandler, RunEvent, SynapticError};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, Default)]
13pub struct MetricsSnapshot {
14 pub total_model_calls: u64,
16 pub total_tool_calls: u64,
18 pub total_errors: u64,
20 pub total_input_tokens: u64,
22 pub total_output_tokens: u64,
24 pub avg_model_latency_ms: f64,
26 pub per_tool: HashMap<String, ToolMetrics>,
28}
29
30#[derive(Debug, Clone, Default)]
32pub struct ToolMetrics {
33 pub calls: u64,
34 pub errors: u64,
35 pub total_latency_ms: u64,
36}
37
38struct MetricsState {
39 total_model_calls: u64,
40 total_tool_calls: u64,
41 total_errors: u64,
42 total_input_tokens: u64,
43 total_output_tokens: u64,
44 total_model_latency_ms: u64,
45 per_tool: HashMap<String, ToolMetrics>,
46 model_start_times: HashMap<String, Instant>,
48 tool_start_times: HashMap<String, Instant>,
50}
51
52pub struct MetricsCallback {
60 state: Arc<RwLock<MetricsState>>,
61}
62
63impl MetricsCallback {
64 pub fn new() -> Self {
65 Self {
66 state: Arc::new(RwLock::new(MetricsState {
67 total_model_calls: 0,
68 total_tool_calls: 0,
69 total_errors: 0,
70 total_input_tokens: 0,
71 total_output_tokens: 0,
72 total_model_latency_ms: 0,
73 per_tool: HashMap::new(),
74 model_start_times: HashMap::new(),
75 tool_start_times: HashMap::new(),
76 })),
77 }
78 }
79
80 pub async fn snapshot(&self) -> MetricsSnapshot {
82 let state = self.state.read().await;
83 let avg = if state.total_model_calls > 0 {
84 state.total_model_latency_ms as f64 / state.total_model_calls as f64
85 } else {
86 0.0
87 };
88 MetricsSnapshot {
89 total_model_calls: state.total_model_calls,
90 total_tool_calls: state.total_tool_calls,
91 total_errors: state.total_errors,
92 total_input_tokens: state.total_input_tokens,
93 total_output_tokens: state.total_output_tokens,
94 avg_model_latency_ms: avg,
95 per_tool: state.per_tool.clone(),
96 }
97 }
98
99 pub async fn record_tokens(&self, input_tokens: u64, output_tokens: u64) {
104 let mut state = self.state.write().await;
105 state.total_input_tokens += input_tokens;
106 state.total_output_tokens += output_tokens;
107 }
108
109 pub async fn reset(&self) {
111 let mut state = self.state.write().await;
112 state.total_model_calls = 0;
113 state.total_tool_calls = 0;
114 state.total_errors = 0;
115 state.total_input_tokens = 0;
116 state.total_output_tokens = 0;
117 state.total_model_latency_ms = 0;
118 state.per_tool.clear();
119 }
120}
121
122impl Default for MetricsCallback {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128#[async_trait]
129impl CallbackHandler for MetricsCallback {
130 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
131 let mut state = self.state.write().await;
132 match event {
133 RunEvent::BeforeMessage { run_id, .. } => {
135 state.model_start_times.insert(run_id, Instant::now());
136 }
137 RunEvent::AfterMessage { run_id, .. } => {
138 let elapsed = state
139 .model_start_times
140 .remove(&run_id)
141 .map(|start| start.elapsed().as_millis() as u64)
142 .unwrap_or(0);
143
144 state.total_model_calls += 1;
145 state.total_model_latency_ms += elapsed;
146 }
147 RunEvent::BeforeToolCall {
149 run_id, tool_name, ..
150 } => {
151 let key = format!("{}:{}", run_id, tool_name);
152 state.tool_start_times.insert(key, Instant::now());
153 }
154 RunEvent::AfterToolCall {
155 run_id, tool_name, ..
156 } => {
157 let key = format!("{}:{}", run_id, tool_name);
158 let elapsed = state
159 .tool_start_times
160 .remove(&key)
161 .map(|start| start.elapsed().as_millis() as u64)
162 .unwrap_or(0);
163
164 state.total_tool_calls += 1;
165 let tool_metrics = state.per_tool.entry(tool_name).or_default();
166 tool_metrics.calls += 1;
167 tool_metrics.total_latency_ms += elapsed;
168 }
169 RunEvent::RunFailed { .. } => {
171 state.total_errors += 1;
172 }
173 _ => {}
174 }
175 Ok(())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[tokio::test]
184 async fn test_metrics_snapshot_empty() {
185 let cb = MetricsCallback::new();
186 let snap = cb.snapshot().await;
187 assert_eq!(snap.total_model_calls, 0);
188 assert_eq!(snap.total_tool_calls, 0);
189 assert_eq!(snap.total_errors, 0);
190 }
191
192 #[tokio::test]
193 async fn test_metrics_model_call() {
194 let cb = MetricsCallback::new();
195 cb.on_event(RunEvent::BeforeMessage {
196 run_id: "r1".to_string(),
197 message_count: 3,
198 })
199 .await
200 .unwrap();
201
202 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
204
205 cb.on_event(RunEvent::AfterMessage {
206 run_id: "r1".to_string(),
207 response_length: 42,
208 })
209 .await
210 .unwrap();
211
212 cb.record_tokens(10, 5).await;
214
215 let snap = cb.snapshot().await;
216 assert_eq!(snap.total_model_calls, 1);
217 assert_eq!(snap.total_input_tokens, 10);
218 assert_eq!(snap.total_output_tokens, 5);
219 assert!(snap.avg_model_latency_ms >= 5.0); }
221
222 #[tokio::test]
223 async fn test_metrics_tool_call() {
224 let cb = MetricsCallback::new();
225 cb.on_event(RunEvent::BeforeToolCall {
226 run_id: "r1".to_string(),
227 tool_name: "read_file".to_string(),
228 arguments: "{}".to_string(),
229 })
230 .await
231 .unwrap();
232
233 cb.on_event(RunEvent::AfterToolCall {
234 run_id: "r1".to_string(),
235 tool_name: "read_file".to_string(),
236 result: "ok".to_string(),
237 })
238 .await
239 .unwrap();
240
241 let snap = cb.snapshot().await;
242 assert_eq!(snap.total_tool_calls, 1);
243 assert!(snap.per_tool.contains_key("read_file"));
244 assert_eq!(snap.per_tool["read_file"].calls, 1);
245 }
246
247 #[tokio::test]
248 async fn test_metrics_error_counting() {
249 let cb = MetricsCallback::new();
250 cb.on_event(RunEvent::RunFailed {
251 run_id: "r1".to_string(),
252 error: "oops".to_string(),
253 })
254 .await
255 .unwrap();
256
257 assert_eq!(cb.snapshot().await.total_errors, 1);
258 }
259
260 #[tokio::test]
261 async fn test_metrics_reset() {
262 let cb = MetricsCallback::new();
263 cb.on_event(RunEvent::RunFailed {
264 run_id: "r1".to_string(),
265 error: "oops".to_string(),
266 })
267 .await
268 .unwrap();
269
270 assert_eq!(cb.snapshot().await.total_errors, 1);
271 cb.reset().await;
272 assert_eq!(cb.snapshot().await.total_errors, 0);
273 }
274
275 #[tokio::test]
276 async fn test_metrics_multiple_tools() {
277 let cb = MetricsCallback::new();
278
279 for i in 0..2 {
281 let run_id = format!("r{}", i);
282 cb.on_event(RunEvent::BeforeToolCall {
283 run_id: run_id.clone(),
284 tool_name: "read_file".to_string(),
285 arguments: "{}".to_string(),
286 })
287 .await
288 .unwrap();
289 cb.on_event(RunEvent::AfterToolCall {
290 run_id,
291 tool_name: "read_file".to_string(),
292 result: "ok".to_string(),
293 })
294 .await
295 .unwrap();
296 }
297
298 cb.on_event(RunEvent::BeforeToolCall {
300 run_id: "r2".to_string(),
301 tool_name: "write_file".to_string(),
302 arguments: "{}".to_string(),
303 })
304 .await
305 .unwrap();
306 cb.on_event(RunEvent::AfterToolCall {
307 run_id: "r2".to_string(),
308 tool_name: "write_file".to_string(),
309 result: "ok".to_string(),
310 })
311 .await
312 .unwrap();
313
314 let snap = cb.snapshot().await;
315 assert_eq!(snap.total_tool_calls, 3);
316 assert_eq!(snap.per_tool["read_file"].calls, 2);
317 assert_eq!(snap.per_tool["write_file"].calls, 1);
318 }
319}