1use std::sync::{Arc, Mutex};
2use std::time::Instant;
3
4use crate::observer::ExecutionObserver;
5use crate::{CustomMetrics, LlmQuery, QueryId};
6
7struct TranscriptEntry {
9 query_id: String,
10 prompt: String,
11 system: Option<String>,
12 response: Option<String>,
13}
14
15impl TranscriptEntry {
16 fn to_json(&self) -> serde_json::Value {
17 serde_json::json!({
18 "query_id": self.query_id,
19 "prompt": self.prompt,
20 "system": self.system,
21 "response": self.response,
22 })
23 }
24}
25
26pub(crate) struct AutoMetrics {
28 started_at: Instant,
29 ended_at: Option<Instant>,
30 llm_calls: u64,
31 pauses: u64,
32 rounds: u64,
33 total_prompt_chars: u64,
34 total_response_chars: u64,
35 transcript: Vec<TranscriptEntry>,
36}
37
38impl AutoMetrics {
39 fn new() -> Self {
40 Self {
41 started_at: Instant::now(),
42 ended_at: None,
43 llm_calls: 0,
44 pauses: 0,
45 rounds: 0,
46 total_prompt_chars: 0,
47 total_response_chars: 0,
48 transcript: Vec::new(),
49 }
50 }
51
52 fn to_json(&self) -> serde_json::Value {
53 let elapsed_ms = self
54 .ended_at
55 .map(|end| end.duration_since(self.started_at).as_millis() as u64)
56 .unwrap_or_else(|| self.started_at.elapsed().as_millis() as u64);
57
58 serde_json::json!({
59 "elapsed_ms": elapsed_ms,
60 "llm_calls": self.llm_calls,
61 "pauses": self.pauses,
62 "rounds": self.rounds,
63 "total_prompt_chars": self.total_prompt_chars,
64 "total_response_chars": self.total_response_chars,
65 })
66 }
67}
68
69pub struct ExecutionMetrics {
71 auto: Arc<Mutex<AutoMetrics>>,
72 custom: Arc<Mutex<CustomMetrics>>,
73}
74
75impl ExecutionMetrics {
76 pub fn new() -> Self {
77 Self {
78 auto: Arc::new(Mutex::new(AutoMetrics::new())),
79 custom: Arc::new(Mutex::new(CustomMetrics::new())),
80 }
81 }
82
83 pub fn to_json(&self) -> serde_json::Value {
85 let auto_json = self
86 .auto
87 .lock()
88 .map(|m| m.to_json())
89 .unwrap_or(serde_json::Value::Null);
90
91 let custom_json = self
92 .custom
93 .lock()
94 .map(|m| m.to_json())
95 .unwrap_or(serde_json::Value::Null);
96
97 serde_json::json!({
98 "auto": auto_json,
99 "custom": custom_json,
100 })
101 }
102
103 pub fn transcript_to_json(&self) -> Vec<serde_json::Value> {
105 self.auto
106 .lock()
107 .map(|m| m.transcript.iter().map(|e| e.to_json()).collect())
108 .unwrap_or_default()
109 }
110
111 pub fn custom_handle(&self) -> Arc<Mutex<CustomMetrics>> {
113 Arc::clone(&self.custom)
114 }
115
116 pub fn create_observer(&self) -> MetricsObserver {
117 MetricsObserver::new(Arc::clone(&self.auto))
118 }
119}
120
121impl Default for ExecutionMetrics {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127pub struct MetricsObserver {
129 auto: Arc<Mutex<AutoMetrics>>,
130}
131
132impl MetricsObserver {
133 pub(crate) fn new(auto: Arc<Mutex<AutoMetrics>>) -> Self {
134 Self { auto }
135 }
136}
137
138impl ExecutionObserver for MetricsObserver {
139 fn on_paused(&self, queries: &[LlmQuery]) {
140 if let Ok(mut m) = self.auto.lock() {
141 m.pauses += 1;
142 m.llm_calls += queries.len() as u64;
143 for q in queries {
144 m.total_prompt_chars += q.prompt.len() as u64;
145 if let Some(ref sys) = q.system {
146 m.total_prompt_chars += sys.len() as u64;
147 }
148 m.transcript.push(TranscriptEntry {
149 query_id: q.id.as_str().to_string(),
150 prompt: q.prompt.clone(),
151 system: q.system.clone(),
152 response: None,
153 });
154 }
155 }
156 }
157
158 fn on_response_fed(&self, query_id: &QueryId, response: &str) {
159 if let Ok(mut m) = self.auto.lock() {
160 m.total_response_chars += response.len() as u64;
161 if let Some(entry) = m
163 .transcript
164 .iter_mut()
165 .rev()
166 .find(|e| e.query_id == query_id.as_str())
167 {
168 entry.response = Some(response.to_string());
169 }
170 }
171 }
172
173 fn on_resumed(&self) {
174 if let Ok(mut m) = self.auto.lock() {
175 m.rounds += 1;
176 }
177 }
178
179 fn on_completed(&self, _result: &serde_json::Value) {
180 if let Ok(mut m) = self.auto.lock() {
181 m.ended_at = Some(Instant::now());
182 }
183 }
184
185 fn on_failed(&self, _error: &str) {
186 if let Ok(mut m) = self.auto.lock() {
187 m.ended_at = Some(Instant::now());
188 }
189 }
190
191 fn on_cancelled(&self) {
192 if let Ok(mut m) = self.auto.lock() {
193 m.ended_at = Some(Instant::now());
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::{LlmQuery, QueryId};
202
203 #[test]
204 fn metrics_to_json_has_auto_and_custom() {
205 let metrics = ExecutionMetrics::new();
206 let json = metrics.to_json();
207 assert!(json.get("auto").is_some());
208 assert!(json.get("custom").is_some());
209 }
210
211 #[test]
212 fn custom_handle_shares_state() {
213 let metrics = ExecutionMetrics::new();
214 let handle = metrics.custom_handle();
215
216 handle
217 .lock()
218 .unwrap()
219 .record("key".into(), serde_json::json!("value"));
220
221 let json = metrics.to_json();
222 let custom = json.get("custom").unwrap();
223 assert_eq!(custom.get("key").unwrap(), "value");
224 }
225
226 #[test]
227 fn observer_updates_auto_metrics() {
228 let metrics = ExecutionMetrics::new();
229 let observer = metrics.create_observer();
230
231 let queries = vec![LlmQuery {
232 id: QueryId::batch(0),
233 prompt: "test".into(),
234 system: None,
235 max_tokens: 100,
236 }];
237
238 observer.on_paused(&queries);
239 observer.on_completed(&serde_json::json!(null));
240
241 let json = metrics.to_json();
242 let auto = json.get("auto").unwrap();
243 assert_eq!(auto.get("llm_calls").unwrap(), 1);
244 assert_eq!(auto.get("pauses").unwrap(), 1);
245 assert_eq!(auto.get("rounds").unwrap(), 0);
246 assert_eq!(auto.get("total_prompt_chars").unwrap(), 4); assert_eq!(auto.get("total_response_chars").unwrap(), 0);
248 }
249
250 #[test]
251 fn observer_tracks_prompt_and_response_chars() {
252 let metrics = ExecutionMetrics::new();
253 let observer = metrics.create_observer();
254
255 let queries = vec![
256 LlmQuery {
257 id: QueryId::batch(0),
258 prompt: "hello".into(), system: Some("sys".into()), max_tokens: 100,
261 },
262 LlmQuery {
263 id: QueryId::batch(1),
264 prompt: "world".into(), system: None,
266 max_tokens: 100,
267 },
268 ];
269
270 observer.on_paused(&queries);
271 observer.on_response_fed(&QueryId::batch(0), &"x".repeat(42));
272 observer.on_response_fed(&QueryId::batch(1), &"y".repeat(58));
273 observer.on_resumed();
274 observer.on_completed(&serde_json::json!(null));
275
276 let json = metrics.to_json();
277 let auto = json.get("auto").unwrap();
278 assert_eq!(auto.get("total_prompt_chars").unwrap(), 13); assert_eq!(auto.get("total_response_chars").unwrap(), 100); assert_eq!(auto.get("rounds").unwrap(), 1);
281 }
282
283 #[test]
284 fn observer_tracks_multiple_rounds() {
285 let metrics = ExecutionMetrics::new();
286 let observer = metrics.create_observer();
287
288 let q = vec![LlmQuery {
289 id: QueryId::single(),
290 prompt: "p".into(),
291 system: None,
292 max_tokens: 10,
293 }];
294
295 observer.on_paused(&q);
297 observer.on_response_fed(&QueryId::single(), &"x".repeat(10));
298 observer.on_resumed();
299 observer.on_paused(&q);
301 observer.on_response_fed(&QueryId::single(), &"y".repeat(20));
302 observer.on_resumed();
303 observer.on_paused(&q);
305 observer.on_response_fed(&QueryId::single(), &"z".repeat(30));
306 observer.on_resumed();
307
308 observer.on_completed(&serde_json::json!(null));
309
310 let json = metrics.to_json();
311 let auto = json.get("auto").unwrap();
312 assert_eq!(auto.get("rounds").unwrap(), 3);
313 assert_eq!(auto.get("pauses").unwrap(), 3);
314 assert_eq!(auto.get("llm_calls").unwrap(), 3);
315 assert_eq!(auto.get("total_prompt_chars").unwrap(), 3); assert_eq!(auto.get("total_response_chars").unwrap(), 60); }
318
319 #[test]
320 fn transcript_records_prompt_response_pairs() {
321 let metrics = ExecutionMetrics::new();
322 let observer = metrics.create_observer();
323
324 let queries = vec![LlmQuery {
325 id: QueryId::single(),
326 prompt: "What is 2+2?".into(),
327 system: Some("You are a calculator.".into()),
328 max_tokens: 50,
329 }];
330
331 observer.on_paused(&queries);
332 observer.on_response_fed(&QueryId::single(), "4");
333 observer.on_resumed();
334 observer.on_completed(&serde_json::json!(null));
335
336 let transcript = metrics.transcript_to_json();
337 assert_eq!(transcript.len(), 1);
338 assert_eq!(transcript[0]["query_id"], "q-0");
339 assert_eq!(transcript[0]["prompt"], "What is 2+2?");
340 assert_eq!(transcript[0]["system"], "You are a calculator.");
341 assert_eq!(transcript[0]["response"], "4");
342 }
343
344 #[test]
345 fn transcript_not_in_stats() {
346 let metrics = ExecutionMetrics::new();
347 let observer = metrics.create_observer();
348 observer.on_paused(&[LlmQuery {
349 id: QueryId::single(),
350 prompt: "p".into(),
351 system: None,
352 max_tokens: 10,
353 }]);
354 observer.on_response_fed(&QueryId::single(), "r");
355 observer.on_resumed();
356 observer.on_completed(&serde_json::json!(null));
357
358 let json = metrics.to_json();
359 assert!(json["auto"].get("transcript").is_none());
360 }
361
362 #[test]
363 fn transcript_multi_round() {
364 let metrics = ExecutionMetrics::new();
365 let observer = metrics.create_observer();
366
367 observer.on_paused(&[LlmQuery {
369 id: QueryId::single(),
370 prompt: "step1".into(),
371 system: None,
372 max_tokens: 100,
373 }]);
374 observer.on_response_fed(&QueryId::single(), "answer1");
375 observer.on_resumed();
376
377 observer.on_paused(&[LlmQuery {
379 id: QueryId::single(),
380 prompt: "step2".into(),
381 system: Some("expert".into()),
382 max_tokens: 100,
383 }]);
384 observer.on_response_fed(&QueryId::single(), "answer2");
385 observer.on_resumed();
386
387 observer.on_completed(&serde_json::json!(null));
388
389 let transcript = metrics.transcript_to_json();
390 assert_eq!(transcript.len(), 2);
391
392 assert_eq!(transcript[0]["prompt"], "step1");
393 assert!(transcript[0]["system"].is_null());
394 assert_eq!(transcript[0]["response"], "answer1");
395
396 assert_eq!(transcript[1]["prompt"], "step2");
397 assert_eq!(transcript[1]["system"], "expert");
398 assert_eq!(transcript[1]["response"], "answer2");
399 }
400
401 #[test]
402 fn transcript_batch_queries() {
403 let metrics = ExecutionMetrics::new();
404 let observer = metrics.create_observer();
405
406 let queries = vec![
407 LlmQuery {
408 id: QueryId::batch(0),
409 prompt: "q0".into(),
410 system: None,
411 max_tokens: 50,
412 },
413 LlmQuery {
414 id: QueryId::batch(1),
415 prompt: "q1".into(),
416 system: None,
417 max_tokens: 50,
418 },
419 ];
420
421 observer.on_paused(&queries);
422 observer.on_response_fed(&QueryId::batch(0), "r0");
423 observer.on_response_fed(&QueryId::batch(1), "r1");
424 observer.on_resumed();
425 observer.on_completed(&serde_json::json!(null));
426
427 let transcript = metrics.transcript_to_json();
428 assert_eq!(transcript.len(), 2);
429 assert_eq!(transcript[0]["query_id"], "q-0");
430 assert_eq!(transcript[0]["response"], "r0");
431 assert_eq!(transcript[1]["query_id"], "q-1");
432 assert_eq!(transcript[1]["response"], "r1");
433 }
434}