1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::scenario_results::{EvalMetrics, ScenarioEvalResults, ScenarioResult};
4use crate::evaluate::types::{EvalResults, EvaluationConfig};
5use crate::genai::{evaluate_genai_dataset, EvalDataset};
6use crate::scenario::EvalScenarios;
7use pyo3::prelude::*;
8use scouter_state::app_state;
9use scouter_types::genai::EvalScenario;
10use scouter_types::genai::{GenAIEvalConfig, GenAIEvalProfile};
11use scouter_types::trace::build_trace_spans;
12use scouter_types::trace::sql::TraceSpan;
13use scouter_types::EvalRecord;
14use serde_json::json;
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::sync::Arc;
17use tracing::{debug, error};
18
19struct AliasData {
20 records: Vec<EvalRecord>,
21 profile: Option<Arc<GenAIEvalProfile>>,
22 spans: Vec<TraceSpan>,
23}
24
25#[pyclass]
33#[derive(Debug)]
34pub struct EvalRunner {
35 profiles: HashMap<String, Arc<GenAIEvalProfile>>,
36 scenarios: EvalScenarios,
37}
38
39#[pymethods]
40impl EvalRunner {
41 #[new]
42 #[pyo3(signature = (scenarios, profiles))]
43 pub fn new(scenarios: EvalScenarios, profiles: HashMap<String, GenAIEvalProfile>) -> Self {
44 let arc_profiles: HashMap<String, Arc<GenAIEvalProfile>> = profiles
45 .into_iter()
46 .map(|(k, v)| (k, Arc::new(v)))
47 .collect();
48 Self {
49 profiles: arc_profiles,
50 scenarios,
51 }
52 }
53
54 #[getter]
55 pub fn scenarios(&self) -> EvalScenarios {
56 self.scenarios.clone()
57 }
58
59 #[pyo3(signature = (config=None))]
68 pub fn evaluate(
69 &mut self,
70 config: Option<EvaluationConfig>,
71 ) -> Result<ScenarioEvalResults, EvaluationError> {
72 let config = Arc::new(config.unwrap_or_default());
73
74 if tokio::runtime::Handle::try_current().is_ok() {
75 return Err(EvaluationError::GenAIEvaluatorError(
76 "EvalRunner.evaluate() cannot be called from within an async context. \
77 Use evaluate_async() or call from a synchronous Python context."
78 .to_string(),
79 ));
80 }
81
82 app_state()
83 .handle()
84 .block_on(async { self.evaluate_async(&config).await })
85 }
86
87 #[pyo3(signature = (records, response, scenario))]
94 pub fn collect_scenario_data(
95 &mut self,
96 records: HashMap<String, Vec<EvalRecord>>,
97 response: String,
98 scenario: &EvalScenario,
99 ) -> Result<(), EvaluationError> {
100 let mut alias_datasets: HashMap<String, EvalDataset> = HashMap::new();
101 let scenario_id = scenario.id.clone();
102
103 for (alias, mut alias_records) in records {
104 let scenario_tag = format!("scouter.eval.scenario_id={}", scenario_id);
106 for record in &mut alias_records {
107 if !record.tags.contains(&scenario_tag) {
108 record.tags.push(scenario_tag.clone());
109 }
110 }
111
112 let profile = self.profiles.get(&alias).ok_or_else(|| {
113 EvaluationError::MissingKeyError(format!(
114 "No profile found for alias '{}' in scenario '{}'",
115 alias, scenario_id
116 ))
117 })?;
118
119 alias_datasets.insert(
120 alias,
121 EvalDataset {
122 records: Arc::new(alias_records),
123 profile: Arc::clone(profile),
124 spans: Arc::new(vec![]),
125 },
126 );
127 }
128
129 if self.scenarios.scenario_datasets.contains_key(&scenario_id) {
130 return Err(EvaluationError::MissingKeyError(format!(
131 "Scenario '{}' already has data — collect_scenario_data called twice",
132 scenario_id
133 )));
134 }
135
136 self.scenarios
137 .scenario_datasets
138 .insert(scenario_id.clone(), alias_datasets);
139
140 let context = json!({
142 "response": response,
143 "expected_outcome": scenario.expected_outcome,
144 "metadata": scenario.metadata,
145 });
146 self.scenarios
147 .scenario_contexts
148 .insert(scenario_id, context);
149
150 Ok(())
151 }
152}
153
154impl EvalRunner {
155 async fn evaluate_async(
156 &mut self,
157 config: &Arc<EvaluationConfig>,
158 ) -> Result<ScenarioEvalResults, EvaluationError> {
159 let all_trace_ids = self.collect_all_trace_ids();
161
162 let captured_spans = if !all_trace_ids.is_empty() {
164 scouter_types::span_capture::get_captured_spans_by_trace_ids(&all_trace_ids)
165 } else {
166 vec![]
167 };
168
169 let trace_spans = build_trace_spans(captured_spans);
171
172 self.set_dataset_spans(&trace_spans);
174 let dataset_results = self.evaluate_datasets(config).await?;
175
176 let scenario_results = self.evaluate_scenarios(&trace_spans).await?;
178
179 let metrics = compute_metrics(&dataset_results, &scenario_results);
181
182 self.scenarios.dataset_results = dataset_results.clone();
184 self.scenarios.scenario_results = scenario_results.clone();
185 self.scenarios.metrics = Some(metrics.clone());
186
187 Ok(ScenarioEvalResults {
188 dataset_results,
189 scenario_results,
190 metrics,
191 })
192 }
193
194 fn set_dataset_spans(&mut self, trace_spans: &[scouter_types::trace::sql::TraceSpan]) {
196 for datasets in self.scenarios.scenario_datasets.values_mut() {
197 for dataset in datasets.values_mut() {
198 let trace_ids: HashSet<String> = dataset
199 .records
200 .iter()
201 .filter_map(|r| r.trace_id.as_ref().map(|tid| tid.to_hex()))
202 .collect();
203
204 if trace_ids.is_empty() {
205 continue;
206 }
207
208 let matching: Vec<_> = trace_spans
209 .iter()
210 .filter(|s| trace_ids.contains(&s.trace_id))
211 .cloned()
212 .collect();
213
214 if !matching.is_empty() {
215 dataset.spans = Arc::new(matching);
216 }
217 }
218 }
219 }
220
221 fn collect_all_trace_ids(&self) -> HashSet<scouter_types::TraceId> {
223 self.scenarios
224 .scenario_datasets
225 .values()
226 .flat_map(|datasets| datasets.values())
227 .flat_map(|dataset| dataset.records.iter())
228 .filter_map(|r| r.trace_id)
229 .collect()
230 }
231
232 async fn evaluate_datasets(
235 &self,
236 config: &Arc<EvaluationConfig>,
237 ) -> Result<HashMap<String, EvalResults>, EvaluationError> {
238 let mut alias_data: HashMap<String, AliasData> = HashMap::new();
240
241 for datasets in self.scenarios.scenario_datasets.values() {
242 for (alias, dataset) in datasets {
243 let entry = alias_data
244 .entry(alias.clone())
245 .or_insert_with(|| AliasData {
246 records: Vec::new(),
247 profile: None,
248 spans: Vec::new(),
249 });
250 entry.records.extend(dataset.records.iter().cloned());
251 if entry.profile.is_none() {
252 entry.profile = Some(Arc::clone(&dataset.profile));
253 } else {
254 debug!(
257 "Alias '{}': profile already set, ignoring profile from another scenario",
258 alias
259 );
260 }
261 entry.spans.extend(dataset.spans.iter().cloned());
262 }
263 }
264
265 let mut results = HashMap::new();
266
267 for (
268 alias,
269 AliasData {
270 records,
271 profile,
272 spans,
273 },
274 ) in alias_data
275 {
276 if records.is_empty() {
277 continue;
278 }
279
280 let profile = match profile {
281 Some(p) => p,
282 None => continue,
283 };
284
285 let dataset = EvalDataset {
286 records: Arc::new(records),
287 profile,
288 spans: Arc::new(spans),
289 };
290
291 debug!("Evaluating sub-agent dataset for alias '{}'", alias);
292 match evaluate_genai_dataset(&dataset, config).await {
293 Ok(eval_results) => {
294 results.insert(alias, eval_results);
295 }
296 Err(e) => {
297 error!("Failed to evaluate dataset for alias '{}': {:?}", alias, e);
298 return Err(e);
299 }
300 }
301 }
302
303 Ok(results)
304 }
305
306 async fn evaluate_scenarios(
309 &self,
310 trace_spans: &[scouter_types::trace::sql::TraceSpan],
311 ) -> Result<Vec<ScenarioResult>, EvaluationError> {
312 let mut results = Vec::new();
313
314 let scenario_trace_ids = self.collect_scenario_trace_ids();
316
317 for scenario in &self.scenarios.scenarios {
318 if !scenario.has_tasks() {
319 continue;
320 }
321
322 let context = self
323 .scenarios
324 .scenario_contexts
325 .get(&scenario.id)
326 .cloned()
327 .ok_or_else(|| {
328 EvaluationError::MissingKeyError(format!(
329 "Scenario '{}' has tasks but no context — call collect_scenario_data() first",
330 scenario.id
331 ))
332 })?;
333
334 let record = EvalRecord {
336 context,
337 record_id: scenario.id.clone(),
338 tags: vec![format!("scouter.eval.scenario_id={}", scenario.id)],
339 ..Default::default()
340 };
341
342 let profile = GenAIEvalProfile::build_from_parts_async(
344 GenAIEvalConfig::default(),
345 scenario.tasks.clone(),
346 None,
347 )
348 .await?;
349 let profile = Arc::new(profile);
350
351 let filtered_spans = if let Some(trace_ids) = scenario_trace_ids.get(&scenario.id) {
353 trace_spans
354 .iter()
355 .filter(|s| trace_ids.contains(&s.trace_id))
356 .cloned()
357 .collect::<Vec<_>>()
358 } else {
359 Vec::new()
360 };
361 let spans_arc = Arc::new(filtered_spans);
362
363 match GenAIEvaluator::process_event_record(&record, profile, spans_arc).await {
365 Ok(eval_set) => {
366 let mut eval_results = EvalResults::new();
367 eval_results.add_success(&record, eval_set, BTreeMap::new());
368
369 let (passed, pass_rate) = compute_pass_rate(&eval_results);
370
371 results.push(ScenarioResult {
372 scenario_id: scenario.id.clone(),
373 initial_query: scenario.initial_query.clone(),
374 eval_results,
375 passed,
376 pass_rate,
377 });
378 }
379 Err(e) => {
380 error!("Failed to evaluate scenario '{}': {:?}", scenario.id, e);
381 let mut eval_results = EvalResults::new();
382 eval_results.add_failure(&record, e.to_string());
383
384 results.push(ScenarioResult {
385 scenario_id: scenario.id.clone(),
386 initial_query: scenario.initial_query.clone(),
387 eval_results,
388 passed: false,
389 pass_rate: 0.0,
390 });
391 }
392 }
393 }
394
395 Ok(results)
396 }
397
398 fn collect_scenario_trace_ids(&self) -> HashMap<String, HashSet<String>> {
400 let mut result: HashMap<String, HashSet<String>> = HashMap::new();
401
402 for (scenario_id, datasets) in &self.scenarios.scenario_datasets {
403 let mut trace_ids = HashSet::new();
404 for dataset in datasets.values() {
405 for record in dataset.records.iter() {
406 if let Some(ref tid) = record.trace_id {
407 trace_ids.insert(tid.to_hex());
408 }
409 }
410 }
411 if !trace_ids.is_empty() {
412 result.insert(scenario_id.clone(), trace_ids);
413 }
414 }
415
416 result
417 }
418}
419
420fn compute_metrics(
422 dataset_results: &HashMap<String, EvalResults>,
423 scenario_results: &[ScenarioResult],
424) -> EvalMetrics {
425 let mut dataset_pass_rates: HashMap<String, f64> = HashMap::new();
426 for (alias, results) in dataset_results {
427 let (_, pass_rate) = compute_pass_rate(results);
428 dataset_pass_rates.insert(alias.clone(), pass_rate);
429 }
430
431 let total_scenarios = scenario_results.len();
432 let passed_scenarios = scenario_results.iter().filter(|s| s.passed).count();
433 let scenario_pass_rate = if total_scenarios > 0 {
434 passed_scenarios as f64 / total_scenarios as f64
435 } else {
436 0.0
437 };
438
439 let mut all_rates: Vec<f64> = dataset_pass_rates.values().copied().collect();
440 if total_scenarios > 0 {
441 all_rates.push(scenario_pass_rate);
442 }
443 let overall_pass_rate = if all_rates.is_empty() {
444 0.0
445 } else {
446 all_rates.iter().sum::<f64>() / all_rates.len() as f64
447 };
448
449 EvalMetrics {
450 overall_pass_rate,
451 dataset_pass_rates,
452 scenario_pass_rate,
453 total_scenarios,
454 passed_scenarios,
455 }
456}
457
458fn compute_pass_rate(results: &EvalResults) -> (bool, f64) {
460 if results.aligned_results.is_empty() {
461 return (false, 0.0);
462 }
463
464 let mut total_tasks = 0;
465 let mut passed_tasks = 0;
466
467 for aligned in &results.aligned_results {
468 for task_result in &aligned.eval_set.records {
469 total_tasks += 1;
470 if task_result.passed {
471 passed_tasks += 1;
472 }
473 }
474 }
475
476 if total_tasks == 0 {
477 return (false, 0.0);
478 }
479
480 let pass_rate = passed_tasks as f64 / total_tasks as f64;
481 let passed = passed_tasks == total_tasks;
482 (passed, pass_rate)
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use scouter_types::genai::utils::AssertionTasks;
489 use scouter_types::genai::EvalScenario;
490
491 fn empty_tasks() -> AssertionTasks {
492 AssertionTasks {
493 assertion: vec![],
494 judge: vec![],
495 trace: vec![],
496 agent: vec![],
497 }
498 }
499
500 fn make_scenario(id: &str, query: &str) -> EvalScenario {
501 EvalScenario {
502 id: id.to_string(),
503 initial_query: query.to_string(),
504 predefined_turns: vec![],
505 simulated_user_persona: None,
506 termination_signal: None,
507 max_turns: 10,
508 expected_outcome: Some("Expected output".to_string()),
509 tasks: empty_tasks(),
510 metadata: None,
511 }
512 }
513
514 fn make_scenario_with_tasks(id: &str, query: &str) -> EvalScenario {
515 use scouter_types::genai::{AssertionTask, ComparisonOperator, EvaluationTaskType};
516
517 let task = AssertionTask {
518 id: "check_response".to_string(),
519 context_path: Some("response".to_string()),
520 item_context_path: None,
521 operator: ComparisonOperator::Contains,
522 expected_value: serde_json::Value::String("hello".to_string()),
523 description: None,
524 depends_on: vec![],
525 task_type: EvaluationTaskType::Assertion,
526 result: None,
527 condition: false,
528 };
529
530 EvalScenario {
531 id: id.to_string(),
532 initial_query: query.to_string(),
533 predefined_turns: vec![],
534 simulated_user_persona: None,
535 termination_signal: None,
536 max_turns: 10,
537 expected_outcome: Some("Response contains hello".to_string()),
538 tasks: AssertionTasks {
539 assertion: vec![task],
540 judge: vec![],
541 trace: vec![],
542 agent: vec![],
543 },
544 metadata: None,
545 }
546 }
547
548 fn make_default_profiles() -> HashMap<String, GenAIEvalProfile> {
549 let mut profiles = HashMap::new();
550 profiles.insert("agent_a".to_string(), GenAIEvalProfile::default());
551 profiles
552 }
553
554 #[test]
555 fn collect_scenario_data_stores_datasets_and_contexts() {
556 let mut runner = EvalRunner::new(
557 EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
558 make_default_profiles(),
559 );
560
561 let mut records = HashMap::new();
562 let record = EvalRecord::default();
563 records.insert("agent_a".to_string(), vec![record]);
564
565 let scenario = runner.scenarios.scenarios[0].clone();
566
567 runner
568 .collect_scenario_data(records, "Agent response".to_string(), &scenario)
569 .unwrap();
570
571 assert!(runner.scenarios.scenario_datasets.contains_key("s1"));
572 let datasets = &runner.scenarios.scenario_datasets["s1"];
573 assert!(datasets.contains_key("agent_a"));
574 assert_eq!(datasets["agent_a"].records.len(), 1);
575
576 assert!(datasets["agent_a"].records[0]
577 .tags
578 .contains(&"scouter.eval.scenario_id=s1".to_string()));
579
580 assert!(runner.scenarios.scenario_contexts.contains_key("s1"));
581 let ctx = &runner.scenarios.scenario_contexts["s1"];
582 assert_eq!(ctx["response"], "Agent response");
583 }
584
585 #[test]
586 fn collect_scenario_data_missing_profile_errors() {
587 let mut runner = EvalRunner::new(
588 EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
589 HashMap::new(),
590 );
591
592 let mut records = HashMap::new();
593 records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
594
595 let scenario = runner.scenarios.scenarios[0].clone();
596
597 let result = runner.collect_scenario_data(records, "Response".to_string(), &scenario);
598
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn collect_scenario_data_multiple_aliases() {
604 let mut profiles = make_default_profiles();
605 profiles.insert("agent_b".to_string(), GenAIEvalProfile::default());
606
607 let mut runner = EvalRunner::new(
608 EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
609 profiles,
610 );
611
612 let mut records = HashMap::new();
613 records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
614 records.insert(
615 "agent_b".to_string(),
616 vec![EvalRecord::default(), EvalRecord::default()],
617 );
618
619 let scenario = runner.scenarios.scenarios[0].clone();
620 runner
621 .collect_scenario_data(records, "Response".to_string(), &scenario)
622 .unwrap();
623
624 let datasets = &runner.scenarios.scenario_datasets["s1"];
625 assert_eq!(datasets["agent_a"].records.len(), 1);
626 assert_eq!(datasets["agent_b"].records.len(), 2);
627 }
628
629 #[test]
630 fn evaluate_no_tasks_only_datasets() {
631 let mut runner = EvalRunner::new(
632 EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
633 make_default_profiles(),
634 );
635
636 let mut records = HashMap::new();
637 records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
638
639 let scenario = runner.scenarios.scenarios[0].clone();
640 runner
641 .collect_scenario_data(records, "Response".to_string(), &scenario)
642 .unwrap();
643
644 let result = runner.evaluate(None).unwrap();
645
646 assert!(result.dataset_results.contains_key("agent_a"));
647 assert!(result.scenario_results.is_empty());
648 assert!(result.metrics.dataset_pass_rates.contains_key("agent_a"));
649 assert_eq!(result.metrics.total_scenarios, 0);
650 }
651
652 #[test]
653 fn evaluate_with_assertion_tasks() {
654 let scenario = make_scenario_with_tasks("s1", "Say hello");
655 let mut runner =
656 EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
657
658 let context = json!({
659 "response": "hello world",
660 "expected_outcome": "Response contains hello",
661 "metadata": null,
662 });
663 runner
664 .scenarios
665 .scenario_contexts
666 .insert("s1".to_string(), context);
667
668 let result = runner.evaluate(None).unwrap();
669
670 assert_eq!(result.scenario_results.len(), 1);
671 assert_eq!(result.scenario_results[0].scenario_id, "s1");
672 assert!(result.scenario_results[0].passed);
673 assert_eq!(result.scenario_results[0].pass_rate, 1.0);
674 assert_eq!(result.metrics.total_scenarios, 1);
675 assert_eq!(result.metrics.passed_scenarios, 1);
676 assert_eq!(result.metrics.scenario_pass_rate, 1.0);
677 }
678
679 #[test]
680 fn evaluate_with_failing_assertion() {
681 let scenario = make_scenario_with_tasks("s1", "Say hello");
682 let mut runner =
683 EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
684
685 let context = json!({
686 "response": "goodbye world",
687 "expected_outcome": "Response contains hello",
688 "metadata": null,
689 });
690 runner
691 .scenarios
692 .scenario_contexts
693 .insert("s1".to_string(), context);
694
695 let result = runner.evaluate(None).unwrap();
696
697 assert_eq!(result.scenario_results.len(), 1);
698 assert!(!result.scenario_results[0].passed);
699 assert_eq!(result.scenario_results[0].pass_rate, 0.0);
700 assert_eq!(result.metrics.passed_scenarios, 0);
701 }
702
703 #[test]
704 fn evaluate_scenario_with_tasks_but_no_context_errors() {
705 let scenario = make_scenario_with_tasks("s1", "Say hello");
706 let mut runner = EvalRunner::new(EvalScenarios::new(vec![scenario]), HashMap::new());
707
708 let result = runner.evaluate(None);
709 assert!(result.is_err());
710 let err_msg = result.unwrap_err().to_string();
711 assert!(err_msg.contains("no context"));
712 }
713
714 #[test]
715 fn compute_pass_rate_empty_results() {
716 let results = EvalResults::new();
717 let (passed, rate) = compute_pass_rate(&results);
718 assert!(!passed);
719 assert_eq!(rate, 0.0);
720 }
721
722 #[test]
723 fn compute_pass_rate_zero_tasks() {
724 let mut results = EvalResults::new();
725 let record = EvalRecord::default();
726 let eval_set = scouter_types::genai::EvalSet::new(vec![], Default::default());
727 results.add_success(&record, eval_set, BTreeMap::new());
728
729 let (passed, rate) = compute_pass_rate(&results);
730 assert!(!passed);
731 assert_eq!(rate, 0.0);
732 }
733
734 #[test]
735 fn evaluate_multiple_scenarios_mixed_results() {
736 let s_pass = make_scenario_with_tasks("s_pass", "Say hello");
737 let s_fail = make_scenario_with_tasks("s_fail", "Say hello");
738 let mut runner = EvalRunner::new(EvalScenarios::new(vec![s_pass, s_fail]), HashMap::new());
739
740 runner.scenarios.scenario_contexts.insert(
741 "s_pass".to_string(),
742 json!({"response": "hello world", "expected_outcome": null, "metadata": null}),
743 );
744 runner.scenarios.scenario_contexts.insert(
745 "s_fail".to_string(),
746 json!({"response": "goodbye", "expected_outcome": null, "metadata": null}),
747 );
748
749 let result = runner.evaluate(None).unwrap();
750 assert_eq!(result.scenario_results.len(), 2);
751 assert_eq!(result.metrics.total_scenarios, 2);
752 assert_eq!(result.metrics.passed_scenarios, 1);
753 assert_eq!(result.metrics.scenario_pass_rate, 0.5);
754 }
755
756 #[test]
757 fn compute_metrics_empty() {
758 let metrics = compute_metrics(&HashMap::new(), &[]);
759
760 assert_eq!(metrics.overall_pass_rate, 0.0);
761 assert_eq!(metrics.scenario_pass_rate, 0.0);
762 assert_eq!(metrics.total_scenarios, 0);
763 assert_eq!(metrics.passed_scenarios, 0);
764 }
765
766 #[test]
767 fn compute_metrics_with_scenario_results() {
768 let scenario_results = vec![
769 ScenarioResult {
770 scenario_id: "s1".to_string(),
771 initial_query: "Q1".to_string(),
772 eval_results: EvalResults::new(),
773 passed: true,
774 pass_rate: 1.0,
775 },
776 ScenarioResult {
777 scenario_id: "s2".to_string(),
778 initial_query: "Q2".to_string(),
779 eval_results: EvalResults::new(),
780 passed: false,
781 pass_rate: 0.5,
782 },
783 ];
784
785 let metrics = compute_metrics(&HashMap::new(), &scenario_results);
786
787 assert_eq!(metrics.total_scenarios, 2);
788 assert_eq!(metrics.passed_scenarios, 1);
789 assert_eq!(metrics.scenario_pass_rate, 0.5);
790 assert_eq!(metrics.overall_pass_rate, 0.5);
791 }
792}