1use std::sync::{Arc, Mutex};
2use std::time::Instant;
3
4use crate::budget::Budget;
5use crate::observer::ExecutionObserver;
6use crate::progress::ProgressInfo;
7use crate::tokens::{estimate_tokens, TokenCount, TokenSource};
8use crate::{BudgetHandle, CustomMetrics, CustomMetricsHandle, LlmQuery, ProgressHandle, QueryId};
9
10struct TranscriptEntry {
19 query_id: String,
20 prompt: String,
21 system: Option<String>,
22 response: Option<String>,
23 prompt_tokens: u64,
25 prompt_source: TokenSource,
26 response_tokens: u64,
29 response_source: TokenSource,
30}
31
32impl TranscriptEntry {
33 fn to_json(&self) -> serde_json::Value {
34 serde_json::json!({
35 "query_id": self.query_id,
36 "prompt": self.prompt,
37 "system": self.system,
38 "response": self.response,
39 })
40 }
41}
42
43pub(crate) struct SessionStatus {
97 started_at: Instant,
98 ended_at: Option<Instant>,
99 pub(crate) llm_calls: u64,
100 pauses: u64,
101 rounds: u64,
102 total_prompt_chars: u64,
103 total_response_chars: u64,
104 transcript: Vec<TranscriptEntry>,
105 pub(crate) budget: Option<Budget>,
106 pub(crate) progress: Option<ProgressInfo>,
107}
108
109impl SessionStatus {
110 fn new() -> Self {
111 Self {
112 started_at: Instant::now(),
113 ended_at: None,
114 llm_calls: 0,
115 pauses: 0,
116 rounds: 0,
117 total_prompt_chars: 0,
118 total_response_chars: 0,
119 transcript: Vec::new(),
120 budget: None,
121 progress: None,
122 }
123 }
124
125 fn prompt_token_count(&self) -> TokenCount {
127 let mut tc = TokenCount::new(TokenSource::Definite);
128 for e in &self.transcript {
129 tc.accumulate(e.prompt_tokens, e.prompt_source);
130 }
131 tc
132 }
133
134 fn response_token_count(&self) -> TokenCount {
136 let mut tc = TokenCount::new(TokenSource::Definite);
137 for e in &self.transcript {
138 tc.accumulate(e.response_tokens, e.response_source);
139 }
140 tc
141 }
142
143 fn total_tokens(&self) -> u64 {
145 self.transcript
146 .iter()
147 .map(|e| e.prompt_tokens + e.response_tokens)
148 .sum()
149 }
150
151 fn elapsed_ms(&self) -> u64 {
153 self.ended_at
154 .map(|end| end.duration_since(self.started_at).as_millis() as u64)
155 .unwrap_or_else(|| self.started_at.elapsed().as_millis() as u64)
156 }
157
158 fn to_json(&self) -> serde_json::Value {
159 let prompt_tc = self.prompt_token_count();
160 let response_tc = self.response_token_count();
161 let total_tc = TokenCount {
162 tokens: prompt_tc.tokens + response_tc.tokens,
163 source: prompt_tc.source.weaker(response_tc.source),
164 };
165 let mut json = serde_json::json!({
166 "elapsed_ms": self.elapsed_ms(),
167 "llm_calls": self.llm_calls,
168 "pauses": self.pauses,
169 "rounds": self.rounds,
170 "total_prompt_chars": self.total_prompt_chars,
171 "total_response_chars": self.total_response_chars,
172 "prompt_tokens": prompt_tc.to_json(),
173 "response_tokens": response_tc.to_json(),
174 "total_tokens": total_tc.to_json(),
175 });
176 if let Some(ref b) = self.budget {
177 json["budget"] = b.to_json();
178 }
179 json
180 }
181
182 pub(crate) fn check_budget(&self) -> Result<(), String> {
183 match self.budget {
184 Some(ref b) => b.check(self.llm_calls, self.elapsed_ms(), self.total_tokens()),
185 None => Ok(()),
186 }
187 }
188
189 fn snapshot(&self) -> serde_json::Value {
193 let mut json = serde_json::json!({
194 "elapsed_ms": self.elapsed_ms(),
195 "llm_calls": self.llm_calls,
196 "rounds": self.rounds,
197 });
198
199 if let Some(ref p) = self.progress {
200 json["progress"] = serde_json::json!({
201 "step": p.step,
202 "total": p.total,
203 "message": p.message,
204 });
205 }
206
207 if let Some(ref b) = self.budget {
208 json["budget_remaining"] =
209 b.remaining_json(self.llm_calls, self.elapsed_ms(), self.total_tokens());
210 }
211
212 json
213 }
214
215 pub(crate) fn budget_remaining(&self) -> serde_json::Value {
216 match self.budget {
217 None => serde_json::Value::Null,
218 Some(ref b) => b.remaining_json(self.llm_calls, self.elapsed_ms(), self.total_tokens()),
219 }
220 }
221}
222
223pub struct ExecutionMetrics {
231 auto: Arc<Mutex<SessionStatus>>,
232 custom: Arc<Mutex<CustomMetrics>>,
233}
234
235impl ExecutionMetrics {
236 pub fn new() -> Self {
237 Self {
238 auto: Arc::new(Mutex::new(SessionStatus::new())),
239 custom: Arc::new(Mutex::new(CustomMetrics::new())),
240 }
241 }
242
243 pub fn to_json(&self) -> serde_json::Value {
245 let auto_json = self
246 .auto
247 .lock()
248 .map(|m| m.to_json())
249 .unwrap_or(serde_json::Value::Null);
250
251 let custom_json = self
252 .custom
253 .lock()
254 .map(|m| m.to_json())
255 .unwrap_or(serde_json::Value::Null);
256
257 serde_json::json!({
258 "auto": auto_json,
259 "custom": custom_json,
260 })
261 }
262
263 pub fn transcript_to_json(&self) -> Vec<serde_json::Value> {
265 self.auto
266 .lock()
267 .map(|m| m.transcript.iter().map(|e| e.to_json()).collect())
268 .unwrap_or_default()
269 }
270
271 pub fn custom_metrics_handle(&self) -> CustomMetricsHandle {
273 CustomMetricsHandle::new(Arc::clone(&self.custom))
274 }
275
276 pub fn set_budget(&self, budget: Budget) {
278 if let Ok(mut m) = self.auto.lock() {
279 m.budget = Some(budget);
280 }
281 }
282
283 pub fn budget_handle(&self) -> BudgetHandle {
285 BudgetHandle::new(Arc::clone(&self.auto))
286 }
287
288 pub fn progress_handle(&self) -> ProgressHandle {
290 ProgressHandle::new(Arc::clone(&self.auto))
291 }
292
293 pub fn snapshot(&self) -> serde_json::Value {
296 self.auto
297 .lock()
298 .map(|m| m.snapshot())
299 .unwrap_or(serde_json::Value::Null)
300 }
301
302 pub fn create_observer(&self) -> MetricsObserver {
303 MetricsObserver::new(Arc::clone(&self.auto))
304 }
305}
306
307impl Default for ExecutionMetrics {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313impl serde::Serialize for ExecutionMetrics {
314 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
315 self.to_json().serialize(serializer)
316 }
317}
318
319pub struct MetricsObserver {
321 auto: Arc<Mutex<SessionStatus>>,
322}
323
324impl MetricsObserver {
325 pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
326 Self { auto }
327 }
328}
329
330impl ExecutionObserver for MetricsObserver {
331 fn on_paused(&self, queries: &[LlmQuery]) {
332 if let Ok(mut m) = self.auto.lock() {
333 m.pauses += 1;
334 m.llm_calls += queries.len() as u64;
335 for q in queries {
336 m.total_prompt_chars += q.prompt.len() as u64;
337 let mut est = estimate_tokens(&q.prompt);
338 if let Some(ref sys) = q.system {
339 m.total_prompt_chars += sys.len() as u64;
340 est += estimate_tokens(sys);
341 }
342 m.transcript.push(TranscriptEntry {
343 query_id: q.id.as_str().to_string(),
344 prompt: q.prompt.clone(),
345 system: q.system.clone(),
346 response: None,
347 prompt_tokens: est,
348 prompt_source: TokenSource::Estimated,
349 response_tokens: 0,
350 response_source: TokenSource::Estimated,
351 });
352 }
353 }
354 }
355
356 fn on_response_fed(
357 &self,
358 query_id: &QueryId,
359 response: &str,
360 usage: Option<&crate::TokenUsage>,
361 ) {
362 if let Ok(mut m) = self.auto.lock() {
363 m.total_response_chars += response.len() as u64;
364
365 if let Some(entry) = m
366 .transcript
367 .iter_mut()
368 .rev()
369 .find(|e| e.query_id == query_id.as_str())
370 {
371 entry.response = Some(response.to_string());
372
373 if let Some(pt) = usage.and_then(|u| u.prompt_tokens) {
375 entry.prompt_tokens = pt;
376 entry.prompt_source = TokenSource::Provided;
377 }
378
379 match usage.and_then(|u| u.completion_tokens) {
381 Some(ct) => {
382 entry.response_tokens = ct;
383 entry.response_source = TokenSource::Provided;
384 }
385 None => {
386 entry.response_tokens = estimate_tokens(response);
387 entry.response_source = TokenSource::Estimated;
388 }
389 }
390 }
391 }
392 }
393
394 fn on_resumed(&self) {
395 if let Ok(mut m) = self.auto.lock() {
396 m.rounds += 1;
397 }
398 }
399
400 fn on_completed(&self, _result: &serde_json::Value) {
401 if let Ok(mut m) = self.auto.lock() {
402 m.ended_at = Some(Instant::now());
403 }
404 }
405
406 fn on_failed(&self, _error: &str) {
407 if let Ok(mut m) = self.auto.lock() {
408 m.ended_at = Some(Instant::now());
409 }
410 }
411
412 fn on_cancelled(&self) {
413 if let Ok(mut m) = self.auto.lock() {
414 m.ended_at = Some(Instant::now());
415 }
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::{LlmQuery, QueryId};
423
424 #[test]
425 fn metrics_to_json_has_auto_and_custom() {
426 let metrics = ExecutionMetrics::new();
427 let json = metrics.to_json();
428 assert!(json.get("auto").is_some());
429 assert!(json.get("custom").is_some());
430 }
431
432 #[test]
433 fn custom_handle_shares_state() {
434 let metrics = ExecutionMetrics::new();
435 let handle = metrics.custom_metrics_handle();
436
437 handle.record("key".into(), serde_json::json!("value"));
438
439 let json = metrics.to_json();
440 let custom = json.get("custom").unwrap();
441 assert_eq!(custom.get("key").unwrap(), "value");
442 }
443
444 #[test]
445 fn observer_updates_auto_metrics() {
446 let metrics = ExecutionMetrics::new();
447 let observer = metrics.create_observer();
448
449 let queries = vec![LlmQuery {
450 id: QueryId::batch(0),
451 prompt: "test".into(),
452 system: None,
453 max_tokens: 100,
454 grounded: false,
455 underspecified: false,
456 }];
457
458 observer.on_paused(&queries);
459 observer.on_completed(&serde_json::json!(null));
460
461 let json = metrics.to_json();
462 let auto = json.get("auto").unwrap();
463 assert_eq!(auto.get("llm_calls").unwrap(), 1);
464 assert_eq!(auto.get("pauses").unwrap(), 1);
465 assert_eq!(auto.get("rounds").unwrap(), 0);
466 assert_eq!(auto.get("total_prompt_chars").unwrap(), 4); assert_eq!(auto.get("total_response_chars").unwrap(), 0);
468 }
469
470 #[test]
471 fn observer_tracks_prompt_and_response_chars() {
472 let metrics = ExecutionMetrics::new();
473 let observer = metrics.create_observer();
474
475 let queries = vec![
476 LlmQuery {
477 id: QueryId::batch(0),
478 prompt: "hello".into(), system: Some("sys".into()), max_tokens: 100,
481 grounded: false,
482 underspecified: false,
483 },
484 LlmQuery {
485 id: QueryId::batch(1),
486 prompt: "world".into(), system: None,
488 max_tokens: 100,
489 grounded: false,
490 underspecified: false,
491 },
492 ];
493
494 observer.on_paused(&queries);
495 observer.on_response_fed(&QueryId::batch(0), &"x".repeat(42), None);
496 observer.on_response_fed(&QueryId::batch(1), &"y".repeat(58), None);
497 observer.on_resumed();
498 observer.on_completed(&serde_json::json!(null));
499
500 let json = metrics.to_json();
501 let auto = json.get("auto").unwrap();
502 assert_eq!(auto.get("total_prompt_chars").unwrap(), 13); assert_eq!(auto.get("total_response_chars").unwrap(), 100); assert_eq!(auto.get("rounds").unwrap(), 1);
505 }
506
507 #[test]
508 fn observer_tracks_multiple_rounds() {
509 let metrics = ExecutionMetrics::new();
510 let observer = metrics.create_observer();
511
512 let q = vec![LlmQuery {
513 id: QueryId::single(),
514 prompt: "p".into(),
515 system: None,
516 max_tokens: 10,
517 grounded: false,
518 underspecified: false,
519 }];
520
521 observer.on_paused(&q);
523 observer.on_response_fed(&QueryId::single(), &"x".repeat(10), None);
524 observer.on_resumed();
525 observer.on_paused(&q);
527 observer.on_response_fed(&QueryId::single(), &"y".repeat(20), None);
528 observer.on_resumed();
529 observer.on_paused(&q);
531 observer.on_response_fed(&QueryId::single(), &"z".repeat(30), None);
532 observer.on_resumed();
533
534 observer.on_completed(&serde_json::json!(null));
535
536 let json = metrics.to_json();
537 let auto = json.get("auto").unwrap();
538 assert_eq!(auto.get("rounds").unwrap(), 3);
539 assert_eq!(auto.get("pauses").unwrap(), 3);
540 assert_eq!(auto.get("llm_calls").unwrap(), 3);
541 assert_eq!(auto.get("total_prompt_chars").unwrap(), 3); assert_eq!(auto.get("total_response_chars").unwrap(), 60); }
544
545 #[test]
546 fn transcript_records_prompt_response_pairs() {
547 let metrics = ExecutionMetrics::new();
548 let observer = metrics.create_observer();
549
550 let queries = vec![LlmQuery {
551 id: QueryId::single(),
552 prompt: "What is 2+2?".into(),
553 system: Some("You are a calculator.".into()),
554 max_tokens: 50,
555 grounded: false,
556 underspecified: false,
557 }];
558
559 observer.on_paused(&queries);
560 observer.on_response_fed(&QueryId::single(), "4", None);
561 observer.on_resumed();
562 observer.on_completed(&serde_json::json!(null));
563
564 let transcript = metrics.transcript_to_json();
565 assert_eq!(transcript.len(), 1);
566 assert_eq!(transcript[0]["query_id"], "q-0");
567 assert_eq!(transcript[0]["prompt"], "What is 2+2?");
568 assert_eq!(transcript[0]["system"], "You are a calculator.");
569 assert_eq!(transcript[0]["response"], "4");
570 }
571
572 #[test]
573 fn transcript_not_in_stats() {
574 let metrics = ExecutionMetrics::new();
575 let observer = metrics.create_observer();
576 observer.on_paused(&[LlmQuery {
577 id: QueryId::single(),
578 prompt: "p".into(),
579 system: None,
580 max_tokens: 10,
581 grounded: false,
582 underspecified: false,
583 }]);
584 observer.on_response_fed(&QueryId::single(), "r", None);
585 observer.on_resumed();
586 observer.on_completed(&serde_json::json!(null));
587
588 let json = metrics.to_json();
589 assert!(json["auto"].get("transcript").is_none());
590 }
591
592 #[test]
593 fn transcript_multi_round() {
594 let metrics = ExecutionMetrics::new();
595 let observer = metrics.create_observer();
596
597 observer.on_paused(&[LlmQuery {
599 id: QueryId::single(),
600 prompt: "step1".into(),
601 system: None,
602 max_tokens: 100,
603 grounded: false,
604 underspecified: false,
605 }]);
606 observer.on_response_fed(&QueryId::single(), "answer1", None);
607 observer.on_resumed();
608
609 observer.on_paused(&[LlmQuery {
611 id: QueryId::single(),
612 prompt: "step2".into(),
613 system: Some("expert".into()),
614 max_tokens: 100,
615 grounded: false,
616 underspecified: false,
617 }]);
618 observer.on_response_fed(&QueryId::single(), "answer2", None);
619 observer.on_resumed();
620
621 observer.on_completed(&serde_json::json!(null));
622
623 let transcript = metrics.transcript_to_json();
624 assert_eq!(transcript.len(), 2);
625
626 assert_eq!(transcript[0]["prompt"], "step1");
627 assert!(transcript[0]["system"].is_null());
628 assert_eq!(transcript[0]["response"], "answer1");
629
630 assert_eq!(transcript[1]["prompt"], "step2");
631 assert_eq!(transcript[1]["system"], "expert");
632 assert_eq!(transcript[1]["response"], "answer2");
633 }
634
635 #[test]
636 fn transcript_batch_queries() {
637 let metrics = ExecutionMetrics::new();
638 let observer = metrics.create_observer();
639
640 let queries = vec![
641 LlmQuery {
642 id: QueryId::batch(0),
643 prompt: "q0".into(),
644 system: None,
645 max_tokens: 50,
646 grounded: false,
647 underspecified: false,
648 },
649 LlmQuery {
650 id: QueryId::batch(1),
651 prompt: "q1".into(),
652 system: None,
653 max_tokens: 50,
654 grounded: false,
655 underspecified: false,
656 },
657 ];
658
659 observer.on_paused(&queries);
660 observer.on_response_fed(&QueryId::batch(0), "r0", None);
661 observer.on_response_fed(&QueryId::batch(1), "r1", None);
662 observer.on_resumed();
663 observer.on_completed(&serde_json::json!(null));
664
665 let transcript = metrics.transcript_to_json();
666 assert_eq!(transcript.len(), 2);
667 assert_eq!(transcript[0]["query_id"], "q-0");
668 assert_eq!(transcript[0]["response"], "r0");
669 assert_eq!(transcript[1]["query_id"], "q-1");
670 assert_eq!(transcript[1]["response"], "r1");
671 }
672}