1use std::sync::{Arc, Mutex};
2use std::time::Instant;
3
4use crate::budget::Budget;
5use crate::observer::ExecutionObserver;
6use crate::progress::ProgressInfo;
7use crate::tokens::{estimate_tokens, TokenCount, TokenSource};
8use crate::{BudgetHandle, CustomMetrics, CustomMetricsHandle, LlmQuery, ProgressHandle, QueryId};
9
10struct TranscriptEntry {
14 query_id: String,
15 prompt: String,
16 system: Option<String>,
17 response: Option<String>,
18}
19
20impl TranscriptEntry {
21 fn to_json(&self) -> serde_json::Value {
22 serde_json::json!({
23 "query_id": self.query_id,
24 "prompt": self.prompt,
25 "system": self.system,
26 "response": self.response,
27 })
28 }
29}
30
31pub(crate) struct SessionStatus {
85 started_at: Instant,
86 ended_at: Option<Instant>,
87 pub(crate) llm_calls: u64,
88 pauses: u64,
89 rounds: u64,
90 total_prompt_chars: u64,
91 total_response_chars: u64,
92 prompt_tokens: TokenCount,
93 response_tokens: TokenCount,
94 transcript: Vec<TranscriptEntry>,
95 pub(crate) budget: Option<Budget>,
96 pub(crate) progress: Option<ProgressInfo>,
97}
98
99impl SessionStatus {
100 fn new() -> Self {
101 Self {
102 started_at: Instant::now(),
103 ended_at: None,
104 llm_calls: 0,
105 pauses: 0,
106 rounds: 0,
107 total_prompt_chars: 0,
108 total_response_chars: 0,
109 prompt_tokens: TokenCount::new(TokenSource::Estimated),
110 response_tokens: TokenCount::new(TokenSource::Estimated),
111 transcript: Vec::new(),
112 budget: None,
113 progress: None,
114 }
115 }
116
117 fn elapsed_ms(&self) -> u64 {
119 self.ended_at
120 .map(|end| end.duration_since(self.started_at).as_millis() as u64)
121 .unwrap_or_else(|| self.started_at.elapsed().as_millis() as u64)
122 }
123
124 fn to_json(&self) -> serde_json::Value {
125 let total_tokens = TokenCount {
126 tokens: self.prompt_tokens.tokens + self.response_tokens.tokens,
127 source: self
128 .prompt_tokens
129 .source
130 .weaker(self.response_tokens.source),
131 };
132 let mut json = serde_json::json!({
133 "elapsed_ms": self.elapsed_ms(),
134 "llm_calls": self.llm_calls,
135 "pauses": self.pauses,
136 "rounds": self.rounds,
137 "total_prompt_chars": self.total_prompt_chars,
138 "total_response_chars": self.total_response_chars,
139 "prompt_tokens": self.prompt_tokens.to_json(),
140 "response_tokens": self.response_tokens.to_json(),
141 "total_tokens": total_tokens.to_json(),
142 });
143 if let Some(ref b) = self.budget {
144 json["budget"] = b.to_json();
145 }
146 json
147 }
148
149 pub(crate) fn check_budget(&self) -> Result<(), String> {
150 match self.budget {
151 Some(ref b) => b.check(self.llm_calls, self.elapsed_ms()),
152 None => Ok(()),
153 }
154 }
155
156 fn snapshot(&self) -> serde_json::Value {
160 let mut json = serde_json::json!({
161 "elapsed_ms": self.elapsed_ms(),
162 "llm_calls": self.llm_calls,
163 "rounds": self.rounds,
164 });
165
166 if let Some(ref p) = self.progress {
167 json["progress"] = serde_json::json!({
168 "step": p.step,
169 "total": p.total,
170 "message": p.message,
171 });
172 }
173
174 if let Some(ref b) = self.budget {
175 json["budget_remaining"] = b.remaining_json(self.llm_calls, self.elapsed_ms());
176 }
177
178 json
179 }
180
181 pub(crate) fn budget_remaining(&self) -> serde_json::Value {
182 match self.budget {
183 None => serde_json::Value::Null,
184 Some(ref b) => b.remaining_json(self.llm_calls, self.elapsed_ms()),
185 }
186 }
187}
188
189pub struct ExecutionMetrics {
197 auto: Arc<Mutex<SessionStatus>>,
198 custom: Arc<Mutex<CustomMetrics>>,
199}
200
201impl ExecutionMetrics {
202 pub fn new() -> Self {
203 Self {
204 auto: Arc::new(Mutex::new(SessionStatus::new())),
205 custom: Arc::new(Mutex::new(CustomMetrics::new())),
206 }
207 }
208
209 pub fn to_json(&self) -> serde_json::Value {
211 let auto_json = self
212 .auto
213 .lock()
214 .map(|m| m.to_json())
215 .unwrap_or(serde_json::Value::Null);
216
217 let custom_json = self
218 .custom
219 .lock()
220 .map(|m| m.to_json())
221 .unwrap_or(serde_json::Value::Null);
222
223 serde_json::json!({
224 "auto": auto_json,
225 "custom": custom_json,
226 })
227 }
228
229 pub fn transcript_to_json(&self) -> Vec<serde_json::Value> {
231 self.auto
232 .lock()
233 .map(|m| m.transcript.iter().map(|e| e.to_json()).collect())
234 .unwrap_or_default()
235 }
236
237 pub fn custom_metrics_handle(&self) -> CustomMetricsHandle {
239 CustomMetricsHandle::new(Arc::clone(&self.custom))
240 }
241
242 pub fn set_budget(&self, budget: Budget) {
244 if let Ok(mut m) = self.auto.lock() {
245 m.budget = Some(budget);
246 }
247 }
248
249 pub fn budget_handle(&self) -> BudgetHandle {
251 BudgetHandle::new(Arc::clone(&self.auto))
252 }
253
254 pub fn progress_handle(&self) -> ProgressHandle {
256 ProgressHandle::new(Arc::clone(&self.auto))
257 }
258
259 pub fn snapshot(&self) -> serde_json::Value {
262 self.auto
263 .lock()
264 .map(|m| m.snapshot())
265 .unwrap_or(serde_json::Value::Null)
266 }
267
268 pub fn create_observer(&self) -> MetricsObserver {
269 MetricsObserver::new(Arc::clone(&self.auto))
270 }
271}
272
273impl Default for ExecutionMetrics {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279impl serde::Serialize for ExecutionMetrics {
280 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
281 self.to_json().serialize(serializer)
282 }
283}
284
285pub struct MetricsObserver {
287 auto: Arc<Mutex<SessionStatus>>,
288}
289
290impl MetricsObserver {
291 pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
292 Self { auto }
293 }
294}
295
296impl ExecutionObserver for MetricsObserver {
297 fn on_paused(&self, queries: &[LlmQuery]) {
298 if let Ok(mut m) = self.auto.lock() {
299 m.pauses += 1;
300 m.llm_calls += queries.len() as u64;
301 for q in queries {
302 m.total_prompt_chars += q.prompt.len() as u64;
303 m.prompt_tokens
304 .accumulate(estimate_tokens(&q.prompt), TokenSource::Estimated);
305 if let Some(ref sys) = q.system {
306 m.total_prompt_chars += sys.len() as u64;
307 m.prompt_tokens
308 .accumulate(estimate_tokens(sys), TokenSource::Estimated);
309 }
310 m.transcript.push(TranscriptEntry {
311 query_id: q.id.as_str().to_string(),
312 prompt: q.prompt.clone(),
313 system: q.system.clone(),
314 response: None,
315 });
316 }
317 }
318 }
319
320 fn on_response_fed(&self, query_id: &QueryId, response: &str) {
321 if let Ok(mut m) = self.auto.lock() {
322 m.total_response_chars += response.len() as u64;
323 m.response_tokens
324 .accumulate(estimate_tokens(response), TokenSource::Estimated);
325 if let Some(entry) = m
327 .transcript
328 .iter_mut()
329 .rev()
330 .find(|e| e.query_id == query_id.as_str())
331 {
332 entry.response = Some(response.to_string());
333 }
334 }
335 }
336
337 fn on_resumed(&self) {
338 if let Ok(mut m) = self.auto.lock() {
339 m.rounds += 1;
340 }
341 }
342
343 fn on_completed(&self, _result: &serde_json::Value) {
344 if let Ok(mut m) = self.auto.lock() {
345 m.ended_at = Some(Instant::now());
346 }
347 }
348
349 fn on_failed(&self, _error: &str) {
350 if let Ok(mut m) = self.auto.lock() {
351 m.ended_at = Some(Instant::now());
352 }
353 }
354
355 fn on_cancelled(&self) {
356 if let Ok(mut m) = self.auto.lock() {
357 m.ended_at = Some(Instant::now());
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::{LlmQuery, QueryId};
366
367 #[test]
368 fn metrics_to_json_has_auto_and_custom() {
369 let metrics = ExecutionMetrics::new();
370 let json = metrics.to_json();
371 assert!(json.get("auto").is_some());
372 assert!(json.get("custom").is_some());
373 }
374
375 #[test]
376 fn custom_handle_shares_state() {
377 let metrics = ExecutionMetrics::new();
378 let handle = metrics.custom_metrics_handle();
379
380 handle.record("key".into(), serde_json::json!("value"));
381
382 let json = metrics.to_json();
383 let custom = json.get("custom").unwrap();
384 assert_eq!(custom.get("key").unwrap(), "value");
385 }
386
387 #[test]
388 fn observer_updates_auto_metrics() {
389 let metrics = ExecutionMetrics::new();
390 let observer = metrics.create_observer();
391
392 let queries = vec![LlmQuery {
393 id: QueryId::batch(0),
394 prompt: "test".into(),
395 system: None,
396 max_tokens: 100,
397 grounded: false,
398 underspecified: false,
399 }];
400
401 observer.on_paused(&queries);
402 observer.on_completed(&serde_json::json!(null));
403
404 let json = metrics.to_json();
405 let auto = json.get("auto").unwrap();
406 assert_eq!(auto.get("llm_calls").unwrap(), 1);
407 assert_eq!(auto.get("pauses").unwrap(), 1);
408 assert_eq!(auto.get("rounds").unwrap(), 0);
409 assert_eq!(auto.get("total_prompt_chars").unwrap(), 4); assert_eq!(auto.get("total_response_chars").unwrap(), 0);
411 }
412
413 #[test]
414 fn observer_tracks_prompt_and_response_chars() {
415 let metrics = ExecutionMetrics::new();
416 let observer = metrics.create_observer();
417
418 let queries = vec![
419 LlmQuery {
420 id: QueryId::batch(0),
421 prompt: "hello".into(), system: Some("sys".into()), max_tokens: 100,
424 grounded: false,
425 underspecified: false,
426 },
427 LlmQuery {
428 id: QueryId::batch(1),
429 prompt: "world".into(), system: None,
431 max_tokens: 100,
432 grounded: false,
433 underspecified: false,
434 },
435 ];
436
437 observer.on_paused(&queries);
438 observer.on_response_fed(&QueryId::batch(0), &"x".repeat(42));
439 observer.on_response_fed(&QueryId::batch(1), &"y".repeat(58));
440 observer.on_resumed();
441 observer.on_completed(&serde_json::json!(null));
442
443 let json = metrics.to_json();
444 let auto = json.get("auto").unwrap();
445 assert_eq!(auto.get("total_prompt_chars").unwrap(), 13); assert_eq!(auto.get("total_response_chars").unwrap(), 100); assert_eq!(auto.get("rounds").unwrap(), 1);
448 }
449
450 #[test]
451 fn observer_tracks_multiple_rounds() {
452 let metrics = ExecutionMetrics::new();
453 let observer = metrics.create_observer();
454
455 let q = vec![LlmQuery {
456 id: QueryId::single(),
457 prompt: "p".into(),
458 system: None,
459 max_tokens: 10,
460 grounded: false,
461 underspecified: false,
462 }];
463
464 observer.on_paused(&q);
466 observer.on_response_fed(&QueryId::single(), &"x".repeat(10));
467 observer.on_resumed();
468 observer.on_paused(&q);
470 observer.on_response_fed(&QueryId::single(), &"y".repeat(20));
471 observer.on_resumed();
472 observer.on_paused(&q);
474 observer.on_response_fed(&QueryId::single(), &"z".repeat(30));
475 observer.on_resumed();
476
477 observer.on_completed(&serde_json::json!(null));
478
479 let json = metrics.to_json();
480 let auto = json.get("auto").unwrap();
481 assert_eq!(auto.get("rounds").unwrap(), 3);
482 assert_eq!(auto.get("pauses").unwrap(), 3);
483 assert_eq!(auto.get("llm_calls").unwrap(), 3);
484 assert_eq!(auto.get("total_prompt_chars").unwrap(), 3); assert_eq!(auto.get("total_response_chars").unwrap(), 60); }
487
488 #[test]
489 fn transcript_records_prompt_response_pairs() {
490 let metrics = ExecutionMetrics::new();
491 let observer = metrics.create_observer();
492
493 let queries = vec![LlmQuery {
494 id: QueryId::single(),
495 prompt: "What is 2+2?".into(),
496 system: Some("You are a calculator.".into()),
497 max_tokens: 50,
498 grounded: false,
499 underspecified: false,
500 }];
501
502 observer.on_paused(&queries);
503 observer.on_response_fed(&QueryId::single(), "4");
504 observer.on_resumed();
505 observer.on_completed(&serde_json::json!(null));
506
507 let transcript = metrics.transcript_to_json();
508 assert_eq!(transcript.len(), 1);
509 assert_eq!(transcript[0]["query_id"], "q-0");
510 assert_eq!(transcript[0]["prompt"], "What is 2+2?");
511 assert_eq!(transcript[0]["system"], "You are a calculator.");
512 assert_eq!(transcript[0]["response"], "4");
513 }
514
515 #[test]
516 fn transcript_not_in_stats() {
517 let metrics = ExecutionMetrics::new();
518 let observer = metrics.create_observer();
519 observer.on_paused(&[LlmQuery {
520 id: QueryId::single(),
521 prompt: "p".into(),
522 system: None,
523 max_tokens: 10,
524 grounded: false,
525 underspecified: false,
526 }]);
527 observer.on_response_fed(&QueryId::single(), "r");
528 observer.on_resumed();
529 observer.on_completed(&serde_json::json!(null));
530
531 let json = metrics.to_json();
532 assert!(json["auto"].get("transcript").is_none());
533 }
534
535 #[test]
536 fn transcript_multi_round() {
537 let metrics = ExecutionMetrics::new();
538 let observer = metrics.create_observer();
539
540 observer.on_paused(&[LlmQuery {
542 id: QueryId::single(),
543 prompt: "step1".into(),
544 system: None,
545 max_tokens: 100,
546 grounded: false,
547 underspecified: false,
548 }]);
549 observer.on_response_fed(&QueryId::single(), "answer1");
550 observer.on_resumed();
551
552 observer.on_paused(&[LlmQuery {
554 id: QueryId::single(),
555 prompt: "step2".into(),
556 system: Some("expert".into()),
557 max_tokens: 100,
558 grounded: false,
559 underspecified: false,
560 }]);
561 observer.on_response_fed(&QueryId::single(), "answer2");
562 observer.on_resumed();
563
564 observer.on_completed(&serde_json::json!(null));
565
566 let transcript = metrics.transcript_to_json();
567 assert_eq!(transcript.len(), 2);
568
569 assert_eq!(transcript[0]["prompt"], "step1");
570 assert!(transcript[0]["system"].is_null());
571 assert_eq!(transcript[0]["response"], "answer1");
572
573 assert_eq!(transcript[1]["prompt"], "step2");
574 assert_eq!(transcript[1]["system"], "expert");
575 assert_eq!(transcript[1]["response"], "answer2");
576 }
577
578 #[test]
579 fn transcript_batch_queries() {
580 let metrics = ExecutionMetrics::new();
581 let observer = metrics.create_observer();
582
583 let queries = vec![
584 LlmQuery {
585 id: QueryId::batch(0),
586 prompt: "q0".into(),
587 system: None,
588 max_tokens: 50,
589 grounded: false,
590 underspecified: false,
591 },
592 LlmQuery {
593 id: QueryId::batch(1),
594 prompt: "q1".into(),
595 system: None,
596 max_tokens: 50,
597 grounded: false,
598 underspecified: false,
599 },
600 ];
601
602 observer.on_paused(&queries);
603 observer.on_response_fed(&QueryId::batch(0), "r0");
604 observer.on_response_fed(&QueryId::batch(1), "r1");
605 observer.on_resumed();
606 observer.on_completed(&serde_json::json!(null));
607
608 let transcript = metrics.transcript_to_json();
609 assert_eq!(transcript.len(), 2);
610 assert_eq!(transcript[0]["query_id"], "q-0");
611 assert_eq!(transcript[0]["response"], "r0");
612 assert_eq!(transcript[1]["query_id"], "q-1");
613 assert_eq!(transcript[1]["response"], "r1");
614 }
615}