1use std::sync::{Arc, Mutex};
2use std::time::Instant;
3
4use crate::observer::ExecutionObserver;
5use crate::{CustomMetrics, LlmQuery, QueryId};
6
7struct TranscriptEntry {
9 query_id: String,
10 prompt: String,
11 system: Option<String>,
12 response: Option<String>,
13}
14
15impl TranscriptEntry {
16 fn to_json(&self) -> serde_json::Value {
17 serde_json::json!({
18 "query_id": self.query_id,
19 "prompt": self.prompt,
20 "system": self.system,
21 "response": self.response,
22 })
23 }
24}
25
26pub(crate) struct AutoMetrics {
28 started_at: Instant,
29 ended_at: Option<Instant>,
30 llm_calls: u64,
31 pauses: u64,
32 rounds: u64,
33 total_prompt_chars: u64,
34 total_response_chars: u64,
35 transcript: Vec<TranscriptEntry>,
36}
37
38impl AutoMetrics {
39 fn new() -> Self {
40 Self {
41 started_at: Instant::now(),
42 ended_at: None,
43 llm_calls: 0,
44 pauses: 0,
45 rounds: 0,
46 total_prompt_chars: 0,
47 total_response_chars: 0,
48 transcript: Vec::new(),
49 }
50 }
51
52 fn to_json(&self) -> serde_json::Value {
53 let elapsed_ms = self
54 .ended_at
55 .map(|end| end.duration_since(self.started_at).as_millis() as u64)
56 .unwrap_or_else(|| self.started_at.elapsed().as_millis() as u64);
57
58 serde_json::json!({
59 "elapsed_ms": elapsed_ms,
60 "llm_calls": self.llm_calls,
61 "pauses": self.pauses,
62 "rounds": self.rounds,
63 "total_prompt_chars": self.total_prompt_chars,
64 "total_response_chars": self.total_response_chars,
65 })
66 }
67}
68
69pub struct ExecutionMetrics {
71 auto: Arc<Mutex<AutoMetrics>>,
72 custom: Arc<Mutex<CustomMetrics>>,
73}
74
75impl ExecutionMetrics {
76 pub fn new() -> Self {
77 Self {
78 auto: Arc::new(Mutex::new(AutoMetrics::new())),
79 custom: Arc::new(Mutex::new(CustomMetrics::new())),
80 }
81 }
82
83 pub fn to_json(&self) -> serde_json::Value {
85 let auto_json = self
86 .auto
87 .lock()
88 .map(|m| m.to_json())
89 .unwrap_or(serde_json::Value::Null);
90
91 let custom_json = self
92 .custom
93 .lock()
94 .map(|m| m.to_json())
95 .unwrap_or(serde_json::Value::Null);
96
97 serde_json::json!({
98 "auto": auto_json,
99 "custom": custom_json,
100 })
101 }
102
103 pub fn transcript_to_json(&self) -> Vec<serde_json::Value> {
105 self.auto
106 .lock()
107 .map(|m| m.transcript.iter().map(|e| e.to_json()).collect())
108 .unwrap_or_default()
109 }
110
111 pub fn custom_handle(&self) -> Arc<Mutex<CustomMetrics>> {
113 Arc::clone(&self.custom)
114 }
115
116 pub fn create_observer(&self) -> MetricsObserver {
117 MetricsObserver::new(Arc::clone(&self.auto))
118 }
119}
120
121impl Default for ExecutionMetrics {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127pub struct MetricsObserver {
129 auto: Arc<Mutex<AutoMetrics>>,
130}
131
132impl MetricsObserver {
133 pub(crate) fn new(auto: Arc<Mutex<AutoMetrics>>) -> Self {
134 Self { auto }
135 }
136}
137
138impl ExecutionObserver for MetricsObserver {
139 fn on_paused(&self, queries: &[LlmQuery]) {
140 if let Ok(mut m) = self.auto.lock() {
141 m.pauses += 1;
142 m.llm_calls += queries.len() as u64;
143 for q in queries {
144 m.total_prompt_chars += q.prompt.len() as u64;
145 if let Some(ref sys) = q.system {
146 m.total_prompt_chars += sys.len() as u64;
147 }
148 m.transcript.push(TranscriptEntry {
149 query_id: q.id.as_str().to_string(),
150 prompt: q.prompt.clone(),
151 system: q.system.clone(),
152 response: None,
153 });
154 }
155 }
156 }
157
158 fn on_response_fed(&self, query_id: &QueryId, response: &str) {
159 if let Ok(mut m) = self.auto.lock() {
160 m.total_response_chars += response.len() as u64;
161 if let Some(entry) = m
163 .transcript
164 .iter_mut()
165 .rev()
166 .find(|e| e.query_id == query_id.as_str())
167 {
168 entry.response = Some(response.to_string());
169 }
170 }
171 }
172
173 fn on_resumed(&self) {
174 if let Ok(mut m) = self.auto.lock() {
175 m.rounds += 1;
176 }
177 }
178
179 fn on_completed(&self, _result: &serde_json::Value) {
180 if let Ok(mut m) = self.auto.lock() {
181 m.ended_at = Some(Instant::now());
182 }
183 }
184
185 fn on_failed(&self, _error: &str) {
186 if let Ok(mut m) = self.auto.lock() {
187 m.ended_at = Some(Instant::now());
188 }
189 }
190
191 fn on_cancelled(&self) {
192 if let Ok(mut m) = self.auto.lock() {
193 m.ended_at = Some(Instant::now());
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::{LlmQuery, QueryId};
202
203 #[test]
204 fn metrics_to_json_has_auto_and_custom() {
205 let metrics = ExecutionMetrics::new();
206 let json = metrics.to_json();
207 assert!(json.get("auto").is_some());
208 assert!(json.get("custom").is_some());
209 }
210
211 #[test]
212 fn custom_handle_shares_state() {
213 let metrics = ExecutionMetrics::new();
214 let handle = metrics.custom_handle();
215
216 handle
217 .lock()
218 .unwrap()
219 .record("key".into(), serde_json::json!("value"));
220
221 let json = metrics.to_json();
222 let custom = json.get("custom").unwrap();
223 assert_eq!(custom.get("key").unwrap(), "value");
224 }
225
226 #[test]
227 fn observer_updates_auto_metrics() {
228 let metrics = ExecutionMetrics::new();
229 let observer = metrics.create_observer();
230
231 let queries = vec![LlmQuery {
232 id: QueryId::batch(0),
233 prompt: "test".into(),
234 system: None,
235 max_tokens: 100,
236 grounded: false,
237 underspecified: false,
238 }];
239
240 observer.on_paused(&queries);
241 observer.on_completed(&serde_json::json!(null));
242
243 let json = metrics.to_json();
244 let auto = json.get("auto").unwrap();
245 assert_eq!(auto.get("llm_calls").unwrap(), 1);
246 assert_eq!(auto.get("pauses").unwrap(), 1);
247 assert_eq!(auto.get("rounds").unwrap(), 0);
248 assert_eq!(auto.get("total_prompt_chars").unwrap(), 4); assert_eq!(auto.get("total_response_chars").unwrap(), 0);
250 }
251
252 #[test]
253 fn observer_tracks_prompt_and_response_chars() {
254 let metrics = ExecutionMetrics::new();
255 let observer = metrics.create_observer();
256
257 let queries = vec![
258 LlmQuery {
259 id: QueryId::batch(0),
260 prompt: "hello".into(), system: Some("sys".into()), max_tokens: 100,
263 grounded: false,
264 underspecified: false,
265 },
266 LlmQuery {
267 id: QueryId::batch(1),
268 prompt: "world".into(), system: None,
270 max_tokens: 100,
271 grounded: false,
272 underspecified: false,
273 },
274 ];
275
276 observer.on_paused(&queries);
277 observer.on_response_fed(&QueryId::batch(0), &"x".repeat(42));
278 observer.on_response_fed(&QueryId::batch(1), &"y".repeat(58));
279 observer.on_resumed();
280 observer.on_completed(&serde_json::json!(null));
281
282 let json = metrics.to_json();
283 let auto = json.get("auto").unwrap();
284 assert_eq!(auto.get("total_prompt_chars").unwrap(), 13); assert_eq!(auto.get("total_response_chars").unwrap(), 100); assert_eq!(auto.get("rounds").unwrap(), 1);
287 }
288
289 #[test]
290 fn observer_tracks_multiple_rounds() {
291 let metrics = ExecutionMetrics::new();
292 let observer = metrics.create_observer();
293
294 let q = vec![LlmQuery {
295 id: QueryId::single(),
296 prompt: "p".into(),
297 system: None,
298 max_tokens: 10,
299 grounded: false,
300 underspecified: false,
301 }];
302
303 observer.on_paused(&q);
305 observer.on_response_fed(&QueryId::single(), &"x".repeat(10));
306 observer.on_resumed();
307 observer.on_paused(&q);
309 observer.on_response_fed(&QueryId::single(), &"y".repeat(20));
310 observer.on_resumed();
311 observer.on_paused(&q);
313 observer.on_response_fed(&QueryId::single(), &"z".repeat(30));
314 observer.on_resumed();
315
316 observer.on_completed(&serde_json::json!(null));
317
318 let json = metrics.to_json();
319 let auto = json.get("auto").unwrap();
320 assert_eq!(auto.get("rounds").unwrap(), 3);
321 assert_eq!(auto.get("pauses").unwrap(), 3);
322 assert_eq!(auto.get("llm_calls").unwrap(), 3);
323 assert_eq!(auto.get("total_prompt_chars").unwrap(), 3); assert_eq!(auto.get("total_response_chars").unwrap(), 60); }
326
327 #[test]
328 fn transcript_records_prompt_response_pairs() {
329 let metrics = ExecutionMetrics::new();
330 let observer = metrics.create_observer();
331
332 let queries = vec![LlmQuery {
333 id: QueryId::single(),
334 prompt: "What is 2+2?".into(),
335 system: Some("You are a calculator.".into()),
336 max_tokens: 50,
337 grounded: false,
338 underspecified: false,
339 }];
340
341 observer.on_paused(&queries);
342 observer.on_response_fed(&QueryId::single(), "4");
343 observer.on_resumed();
344 observer.on_completed(&serde_json::json!(null));
345
346 let transcript = metrics.transcript_to_json();
347 assert_eq!(transcript.len(), 1);
348 assert_eq!(transcript[0]["query_id"], "q-0");
349 assert_eq!(transcript[0]["prompt"], "What is 2+2?");
350 assert_eq!(transcript[0]["system"], "You are a calculator.");
351 assert_eq!(transcript[0]["response"], "4");
352 }
353
354 #[test]
355 fn transcript_not_in_stats() {
356 let metrics = ExecutionMetrics::new();
357 let observer = metrics.create_observer();
358 observer.on_paused(&[LlmQuery {
359 id: QueryId::single(),
360 prompt: "p".into(),
361 system: None,
362 max_tokens: 10,
363 grounded: false,
364 underspecified: false,
365 }]);
366 observer.on_response_fed(&QueryId::single(), "r");
367 observer.on_resumed();
368 observer.on_completed(&serde_json::json!(null));
369
370 let json = metrics.to_json();
371 assert!(json["auto"].get("transcript").is_none());
372 }
373
374 #[test]
375 fn transcript_multi_round() {
376 let metrics = ExecutionMetrics::new();
377 let observer = metrics.create_observer();
378
379 observer.on_paused(&[LlmQuery {
381 id: QueryId::single(),
382 prompt: "step1".into(),
383 system: None,
384 max_tokens: 100,
385 grounded: false,
386 underspecified: false,
387 }]);
388 observer.on_response_fed(&QueryId::single(), "answer1");
389 observer.on_resumed();
390
391 observer.on_paused(&[LlmQuery {
393 id: QueryId::single(),
394 prompt: "step2".into(),
395 system: Some("expert".into()),
396 max_tokens: 100,
397 grounded: false,
398 underspecified: false,
399 }]);
400 observer.on_response_fed(&QueryId::single(), "answer2");
401 observer.on_resumed();
402
403 observer.on_completed(&serde_json::json!(null));
404
405 let transcript = metrics.transcript_to_json();
406 assert_eq!(transcript.len(), 2);
407
408 assert_eq!(transcript[0]["prompt"], "step1");
409 assert!(transcript[0]["system"].is_null());
410 assert_eq!(transcript[0]["response"], "answer1");
411
412 assert_eq!(transcript[1]["prompt"], "step2");
413 assert_eq!(transcript[1]["system"], "expert");
414 assert_eq!(transcript[1]["response"], "answer2");
415 }
416
417 #[test]
418 fn transcript_batch_queries() {
419 let metrics = ExecutionMetrics::new();
420 let observer = metrics.create_observer();
421
422 let queries = vec![
423 LlmQuery {
424 id: QueryId::batch(0),
425 prompt: "q0".into(),
426 system: None,
427 max_tokens: 50,
428 grounded: false,
429 underspecified: false,
430 },
431 LlmQuery {
432 id: QueryId::batch(1),
433 prompt: "q1".into(),
434 system: None,
435 max_tokens: 50,
436 grounded: false,
437 underspecified: false,
438 },
439 ];
440
441 observer.on_paused(&queries);
442 observer.on_response_fed(&QueryId::batch(0), "r0");
443 observer.on_response_fed(&QueryId::batch(1), "r1");
444 observer.on_resumed();
445 observer.on_completed(&serde_json::json!(null));
446
447 let transcript = metrics.transcript_to_json();
448 assert_eq!(transcript.len(), 2);
449 assert_eq!(transcript[0]["query_id"], "q-0");
450 assert_eq!(transcript[0]["response"], "r0");
451 assert_eq!(transcript[1]["query_id"], "q-1");
452 assert_eq!(transcript[1]["response"], "r1");
453 }
454}