1use std::sync::{Arc, Mutex};
2use std::time::Instant;
3
4use crate::budget::Budget;
5use crate::observer::ExecutionObserver;
6use crate::progress::ProgressInfo;
7use crate::tokens::{estimate_tokens, TokenCount, TokenSource};
8use crate::{BudgetHandle, CustomMetrics, CustomMetricsHandle, LlmQuery, ProgressHandle, QueryId};
9
10struct TranscriptEntry {
14 query_id: String,
15 prompt: String,
16 system: Option<String>,
17 response: Option<String>,
18}
19
20impl TranscriptEntry {
21 fn to_json(&self) -> serde_json::Value {
22 serde_json::json!({
23 "query_id": self.query_id,
24 "prompt": self.prompt,
25 "system": self.system,
26 "response": self.response,
27 })
28 }
29}
30
31pub(crate) struct SessionStatus {
85 started_at: Instant,
86 ended_at: Option<Instant>,
87 pub(crate) llm_calls: u64,
88 pauses: u64,
89 rounds: u64,
90 total_prompt_chars: u64,
91 total_response_chars: u64,
92 prompt_tokens: TokenCount,
93 response_tokens: TokenCount,
94 transcript: Vec<TranscriptEntry>,
95 pub(crate) budget: Option<Budget>,
96 pub(crate) progress: Option<ProgressInfo>,
97}
98
99impl SessionStatus {
100 fn new() -> Self {
101 Self {
102 started_at: Instant::now(),
103 ended_at: None,
104 llm_calls: 0,
105 pauses: 0,
106 rounds: 0,
107 total_prompt_chars: 0,
108 total_response_chars: 0,
109 prompt_tokens: TokenCount::new(TokenSource::Estimated),
110 response_tokens: TokenCount::new(TokenSource::Estimated),
111 transcript: Vec::new(),
112 budget: None,
113 progress: None,
114 }
115 }
116
117 fn elapsed_ms(&self) -> u64 {
119 self.ended_at
120 .map(|end| end.duration_since(self.started_at).as_millis() as u64)
121 .unwrap_or_else(|| self.started_at.elapsed().as_millis() as u64)
122 }
123
124 fn to_json(&self) -> serde_json::Value {
125 let total_tokens = TokenCount {
126 tokens: self.prompt_tokens.tokens + self.response_tokens.tokens,
127 source: self
128 .prompt_tokens
129 .source
130 .weaker(self.response_tokens.source),
131 };
132 let mut json = serde_json::json!({
133 "elapsed_ms": self.elapsed_ms(),
134 "llm_calls": self.llm_calls,
135 "pauses": self.pauses,
136 "rounds": self.rounds,
137 "total_prompt_chars": self.total_prompt_chars,
138 "total_response_chars": self.total_response_chars,
139 "prompt_tokens": self.prompt_tokens.to_json(),
140 "response_tokens": self.response_tokens.to_json(),
141 "total_tokens": total_tokens.to_json(),
142 });
143 if let Some(ref b) = self.budget {
144 json["budget"] = b.to_json();
145 }
146 json
147 }
148
149 pub(crate) fn check_budget(&self) -> Result<(), String> {
150 match self.budget {
151 Some(ref b) => b.check(self.llm_calls, self.elapsed_ms()),
152 None => Ok(()),
153 }
154 }
155
156 fn snapshot(&self) -> serde_json::Value {
160 let mut json = serde_json::json!({
161 "elapsed_ms": self.elapsed_ms(),
162 "llm_calls": self.llm_calls,
163 "rounds": self.rounds,
164 });
165
166 if let Some(ref p) = self.progress {
167 json["progress"] = serde_json::json!({
168 "step": p.step,
169 "total": p.total,
170 "message": p.message,
171 });
172 }
173
174 if let Some(ref b) = self.budget {
175 json["budget_remaining"] = b.remaining_json(self.llm_calls, self.elapsed_ms());
176 }
177
178 json
179 }
180
181 pub(crate) fn budget_remaining(&self) -> serde_json::Value {
182 match self.budget {
183 None => serde_json::Value::Null,
184 Some(ref b) => b.remaining_json(self.llm_calls, self.elapsed_ms()),
185 }
186 }
187}
188
189pub struct ExecutionMetrics {
197 auto: Arc<Mutex<SessionStatus>>,
198 custom: Arc<Mutex<CustomMetrics>>,
199}
200
201impl ExecutionMetrics {
202 pub fn new() -> Self {
203 Self {
204 auto: Arc::new(Mutex::new(SessionStatus::new())),
205 custom: Arc::new(Mutex::new(CustomMetrics::new())),
206 }
207 }
208
209 pub fn to_json(&self) -> serde_json::Value {
211 let auto_json = self
212 .auto
213 .lock()
214 .map(|m| m.to_json())
215 .unwrap_or(serde_json::Value::Null);
216
217 let custom_json = self
218 .custom
219 .lock()
220 .map(|m| m.to_json())
221 .unwrap_or(serde_json::Value::Null);
222
223 serde_json::json!({
224 "auto": auto_json,
225 "custom": custom_json,
226 })
227 }
228
229 pub fn transcript_to_json(&self) -> Vec<serde_json::Value> {
231 self.auto
232 .lock()
233 .map(|m| m.transcript.iter().map(|e| e.to_json()).collect())
234 .unwrap_or_default()
235 }
236
237 pub fn custom_metrics_handle(&self) -> CustomMetricsHandle {
239 CustomMetricsHandle::new(Arc::clone(&self.custom))
240 }
241
242 pub fn set_budget(&self, budget: Budget) {
244 if let Ok(mut m) = self.auto.lock() {
245 m.budget = Some(budget);
246 }
247 }
248
249 pub fn budget_handle(&self) -> BudgetHandle {
251 BudgetHandle::new(Arc::clone(&self.auto))
252 }
253
254 pub fn progress_handle(&self) -> ProgressHandle {
256 ProgressHandle::new(Arc::clone(&self.auto))
257 }
258
259 pub fn snapshot(&self) -> serde_json::Value {
262 self.auto
263 .lock()
264 .map(|m| m.snapshot())
265 .unwrap_or(serde_json::Value::Null)
266 }
267
268 pub fn create_observer(&self) -> MetricsObserver {
269 MetricsObserver::new(Arc::clone(&self.auto))
270 }
271}
272
273impl Default for ExecutionMetrics {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279pub struct MetricsObserver {
281 auto: Arc<Mutex<SessionStatus>>,
282}
283
284impl MetricsObserver {
285 pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
286 Self { auto }
287 }
288}
289
290impl ExecutionObserver for MetricsObserver {
291 fn on_paused(&self, queries: &[LlmQuery]) {
292 if let Ok(mut m) = self.auto.lock() {
293 m.pauses += 1;
294 m.llm_calls += queries.len() as u64;
295 for q in queries {
296 m.total_prompt_chars += q.prompt.len() as u64;
297 m.prompt_tokens
298 .accumulate(estimate_tokens(&q.prompt), TokenSource::Estimated);
299 if let Some(ref sys) = q.system {
300 m.total_prompt_chars += sys.len() as u64;
301 m.prompt_tokens
302 .accumulate(estimate_tokens(sys), TokenSource::Estimated);
303 }
304 m.transcript.push(TranscriptEntry {
305 query_id: q.id.as_str().to_string(),
306 prompt: q.prompt.clone(),
307 system: q.system.clone(),
308 response: None,
309 });
310 }
311 }
312 }
313
314 fn on_response_fed(&self, query_id: &QueryId, response: &str) {
315 if let Ok(mut m) = self.auto.lock() {
316 m.total_response_chars += response.len() as u64;
317 m.response_tokens
318 .accumulate(estimate_tokens(response), TokenSource::Estimated);
319 if let Some(entry) = m
321 .transcript
322 .iter_mut()
323 .rev()
324 .find(|e| e.query_id == query_id.as_str())
325 {
326 entry.response = Some(response.to_string());
327 }
328 }
329 }
330
331 fn on_resumed(&self) {
332 if let Ok(mut m) = self.auto.lock() {
333 m.rounds += 1;
334 }
335 }
336
337 fn on_completed(&self, _result: &serde_json::Value) {
338 if let Ok(mut m) = self.auto.lock() {
339 m.ended_at = Some(Instant::now());
340 }
341 }
342
343 fn on_failed(&self, _error: &str) {
344 if let Ok(mut m) = self.auto.lock() {
345 m.ended_at = Some(Instant::now());
346 }
347 }
348
349 fn on_cancelled(&self) {
350 if let Ok(mut m) = self.auto.lock() {
351 m.ended_at = Some(Instant::now());
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::{LlmQuery, QueryId};
360
361 #[test]
362 fn metrics_to_json_has_auto_and_custom() {
363 let metrics = ExecutionMetrics::new();
364 let json = metrics.to_json();
365 assert!(json.get("auto").is_some());
366 assert!(json.get("custom").is_some());
367 }
368
369 #[test]
370 fn custom_handle_shares_state() {
371 let metrics = ExecutionMetrics::new();
372 let handle = metrics.custom_metrics_handle();
373
374 handle.record("key".into(), serde_json::json!("value"));
375
376 let json = metrics.to_json();
377 let custom = json.get("custom").unwrap();
378 assert_eq!(custom.get("key").unwrap(), "value");
379 }
380
381 #[test]
382 fn observer_updates_auto_metrics() {
383 let metrics = ExecutionMetrics::new();
384 let observer = metrics.create_observer();
385
386 let queries = vec![LlmQuery {
387 id: QueryId::batch(0),
388 prompt: "test".into(),
389 system: None,
390 max_tokens: 100,
391 grounded: false,
392 underspecified: false,
393 }];
394
395 observer.on_paused(&queries);
396 observer.on_completed(&serde_json::json!(null));
397
398 let json = metrics.to_json();
399 let auto = json.get("auto").unwrap();
400 assert_eq!(auto.get("llm_calls").unwrap(), 1);
401 assert_eq!(auto.get("pauses").unwrap(), 1);
402 assert_eq!(auto.get("rounds").unwrap(), 0);
403 assert_eq!(auto.get("total_prompt_chars").unwrap(), 4); assert_eq!(auto.get("total_response_chars").unwrap(), 0);
405 }
406
407 #[test]
408 fn observer_tracks_prompt_and_response_chars() {
409 let metrics = ExecutionMetrics::new();
410 let observer = metrics.create_observer();
411
412 let queries = vec![
413 LlmQuery {
414 id: QueryId::batch(0),
415 prompt: "hello".into(), system: Some("sys".into()), max_tokens: 100,
418 grounded: false,
419 underspecified: false,
420 },
421 LlmQuery {
422 id: QueryId::batch(1),
423 prompt: "world".into(), system: None,
425 max_tokens: 100,
426 grounded: false,
427 underspecified: false,
428 },
429 ];
430
431 observer.on_paused(&queries);
432 observer.on_response_fed(&QueryId::batch(0), &"x".repeat(42));
433 observer.on_response_fed(&QueryId::batch(1), &"y".repeat(58));
434 observer.on_resumed();
435 observer.on_completed(&serde_json::json!(null));
436
437 let json = metrics.to_json();
438 let auto = json.get("auto").unwrap();
439 assert_eq!(auto.get("total_prompt_chars").unwrap(), 13); assert_eq!(auto.get("total_response_chars").unwrap(), 100); assert_eq!(auto.get("rounds").unwrap(), 1);
442 }
443
444 #[test]
445 fn observer_tracks_multiple_rounds() {
446 let metrics = ExecutionMetrics::new();
447 let observer = metrics.create_observer();
448
449 let q = vec![LlmQuery {
450 id: QueryId::single(),
451 prompt: "p".into(),
452 system: None,
453 max_tokens: 10,
454 grounded: false,
455 underspecified: false,
456 }];
457
458 observer.on_paused(&q);
460 observer.on_response_fed(&QueryId::single(), &"x".repeat(10));
461 observer.on_resumed();
462 observer.on_paused(&q);
464 observer.on_response_fed(&QueryId::single(), &"y".repeat(20));
465 observer.on_resumed();
466 observer.on_paused(&q);
468 observer.on_response_fed(&QueryId::single(), &"z".repeat(30));
469 observer.on_resumed();
470
471 observer.on_completed(&serde_json::json!(null));
472
473 let json = metrics.to_json();
474 let auto = json.get("auto").unwrap();
475 assert_eq!(auto.get("rounds").unwrap(), 3);
476 assert_eq!(auto.get("pauses").unwrap(), 3);
477 assert_eq!(auto.get("llm_calls").unwrap(), 3);
478 assert_eq!(auto.get("total_prompt_chars").unwrap(), 3); assert_eq!(auto.get("total_response_chars").unwrap(), 60); }
481
482 #[test]
483 fn transcript_records_prompt_response_pairs() {
484 let metrics = ExecutionMetrics::new();
485 let observer = metrics.create_observer();
486
487 let queries = vec![LlmQuery {
488 id: QueryId::single(),
489 prompt: "What is 2+2?".into(),
490 system: Some("You are a calculator.".into()),
491 max_tokens: 50,
492 grounded: false,
493 underspecified: false,
494 }];
495
496 observer.on_paused(&queries);
497 observer.on_response_fed(&QueryId::single(), "4");
498 observer.on_resumed();
499 observer.on_completed(&serde_json::json!(null));
500
501 let transcript = metrics.transcript_to_json();
502 assert_eq!(transcript.len(), 1);
503 assert_eq!(transcript[0]["query_id"], "q-0");
504 assert_eq!(transcript[0]["prompt"], "What is 2+2?");
505 assert_eq!(transcript[0]["system"], "You are a calculator.");
506 assert_eq!(transcript[0]["response"], "4");
507 }
508
509 #[test]
510 fn transcript_not_in_stats() {
511 let metrics = ExecutionMetrics::new();
512 let observer = metrics.create_observer();
513 observer.on_paused(&[LlmQuery {
514 id: QueryId::single(),
515 prompt: "p".into(),
516 system: None,
517 max_tokens: 10,
518 grounded: false,
519 underspecified: false,
520 }]);
521 observer.on_response_fed(&QueryId::single(), "r");
522 observer.on_resumed();
523 observer.on_completed(&serde_json::json!(null));
524
525 let json = metrics.to_json();
526 assert!(json["auto"].get("transcript").is_none());
527 }
528
529 #[test]
530 fn transcript_multi_round() {
531 let metrics = ExecutionMetrics::new();
532 let observer = metrics.create_observer();
533
534 observer.on_paused(&[LlmQuery {
536 id: QueryId::single(),
537 prompt: "step1".into(),
538 system: None,
539 max_tokens: 100,
540 grounded: false,
541 underspecified: false,
542 }]);
543 observer.on_response_fed(&QueryId::single(), "answer1");
544 observer.on_resumed();
545
546 observer.on_paused(&[LlmQuery {
548 id: QueryId::single(),
549 prompt: "step2".into(),
550 system: Some("expert".into()),
551 max_tokens: 100,
552 grounded: false,
553 underspecified: false,
554 }]);
555 observer.on_response_fed(&QueryId::single(), "answer2");
556 observer.on_resumed();
557
558 observer.on_completed(&serde_json::json!(null));
559
560 let transcript = metrics.transcript_to_json();
561 assert_eq!(transcript.len(), 2);
562
563 assert_eq!(transcript[0]["prompt"], "step1");
564 assert!(transcript[0]["system"].is_null());
565 assert_eq!(transcript[0]["response"], "answer1");
566
567 assert_eq!(transcript[1]["prompt"], "step2");
568 assert_eq!(transcript[1]["system"], "expert");
569 assert_eq!(transcript[1]["response"], "answer2");
570 }
571
572 #[test]
573 fn transcript_batch_queries() {
574 let metrics = ExecutionMetrics::new();
575 let observer = metrics.create_observer();
576
577 let queries = vec![
578 LlmQuery {
579 id: QueryId::batch(0),
580 prompt: "q0".into(),
581 system: None,
582 max_tokens: 50,
583 grounded: false,
584 underspecified: false,
585 },
586 LlmQuery {
587 id: QueryId::batch(1),
588 prompt: "q1".into(),
589 system: None,
590 max_tokens: 50,
591 grounded: false,
592 underspecified: false,
593 },
594 ];
595
596 observer.on_paused(&queries);
597 observer.on_response_fed(&QueryId::batch(0), "r0");
598 observer.on_response_fed(&QueryId::batch(1), "r1");
599 observer.on_resumed();
600 observer.on_completed(&serde_json::json!(null));
601
602 let transcript = metrics.transcript_to_json();
603 assert_eq!(transcript.len(), 2);
604 assert_eq!(transcript[0]["query_id"], "q-0");
605 assert_eq!(transcript[0]["response"], "r0");
606 assert_eq!(transcript[1]["query_id"], "q-1");
607 assert_eq!(transcript[1]["response"], "r1");
608 }
609}