1use crate::criteria::EvaluationCriteria;
6use crate::error::Result;
7use crate::llm_judge::LlmJudge;
8use crate::report::{EvaluationReport, EvaluationResult, Failure, TurnResult};
9use crate::schema::{EvalCase, TestFile, ToolUse, Turn};
10use crate::scoring::{ResponseScorer, ToolTrajectoryScorer};
11
12use adk_core::{Agent, Content, Event, Llm};
13use async_trait::async_trait;
14use futures::StreamExt;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::path::Path;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct EvaluationConfig {
25 #[serde(default)]
27 pub criteria: EvaluationCriteria,
28 #[serde(default)]
30 pub continue_on_failure: bool,
31 #[serde(default)]
33 pub timeout_per_case: Option<Duration>,
34 #[serde(default)]
36 pub retries: usize,
37 #[serde(default = "default_true")]
39 pub collect_turn_details: bool,
40}
41
42fn default_true() -> bool {
43 true
44}
45
46impl EvaluationConfig {
47 pub fn with_criteria(criteria: EvaluationCriteria) -> Self {
49 Self { criteria, ..Default::default() }
50 }
51}
52
53pub struct Evaluator {
55 config: EvaluationConfig,
56 tool_scorer: ToolTrajectoryScorer,
57 response_scorer: ResponseScorer,
58 llm_judge: Option<LlmJudge>,
59}
60
61impl Evaluator {
62 pub fn new(config: EvaluationConfig) -> Self {
64 let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
65 ToolTrajectoryScorer::with_config(tc.clone())
66 } else {
67 ToolTrajectoryScorer::new()
68 };
69
70 let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
71 ResponseScorer::with_config(rc.clone())
72 } else {
73 ResponseScorer::new()
74 };
75
76 Self { config, tool_scorer, response_scorer, llm_judge: None }
77 }
78
79 pub fn with_llm_judge(config: EvaluationConfig, judge_model: Arc<dyn Llm>) -> Self {
81 let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
82 ToolTrajectoryScorer::with_config(tc.clone())
83 } else {
84 ToolTrajectoryScorer::new()
85 };
86
87 let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
88 ResponseScorer::with_config(rc.clone())
89 } else {
90 ResponseScorer::new()
91 };
92
93 Self { config, tool_scorer, response_scorer, llm_judge: Some(LlmJudge::new(judge_model)) }
94 }
95
96 pub fn set_llm_judge(&mut self, judge_model: Arc<dyn Llm>) {
98 self.llm_judge = Some(LlmJudge::new(judge_model));
99 }
100
101 pub fn has_llm_judge(&self) -> bool {
103 self.llm_judge.is_some()
104 }
105
106 pub async fn evaluate_file(
108 &self,
109 agent: Arc<dyn Agent>,
110 path: impl AsRef<Path>,
111 ) -> Result<EvaluationReport> {
112 let test_file = TestFile::load(path)?;
113 self.evaluate_test_file(agent, &test_file).await
114 }
115
116 pub async fn evaluate_test_file(
118 &self,
119 agent: Arc<dyn Agent>,
120 test_file: &TestFile,
121 ) -> Result<EvaluationReport> {
122 let started_at = chrono::Utc::now();
123 let run_id = format!("{}_{}", test_file.eval_set_id, uuid::Uuid::new_v4());
124 let mut results = Vec::new();
125
126 for eval_case in &test_file.eval_cases {
127 let result = self.evaluate_case(agent.clone(), eval_case).await;
128
129 match result {
130 Ok(r) => {
131 let passed = r.passed;
132 results.push(r);
133 if !passed && !self.config.continue_on_failure {
134 break;
135 }
136 }
137 Err(e) => {
138 results.push(EvaluationResult::failed(
140 &eval_case.eval_id,
141 HashMap::new(),
142 vec![Failure::new(
143 "execution",
144 Value::Null,
145 Value::String(e.to_string()),
146 0.0,
147 1.0,
148 )],
149 Duration::from_secs(0),
150 ));
151 if !self.config.continue_on_failure {
152 break;
153 }
154 }
155 }
156 }
157
158 Ok(EvaluationReport::new(&run_id, results, started_at))
159 }
160
161 pub async fn evaluate_case(
163 &self,
164 agent: Arc<dyn Agent>,
165 eval_case: &EvalCase,
166 ) -> Result<EvaluationResult> {
167 let start = Instant::now();
168 let mut all_scores: HashMap<String, f64> = HashMap::new();
169 let mut all_failures: Vec<Failure> = Vec::new();
170 let mut turn_results: Vec<TurnResult> = Vec::new();
171
172 for turn in &eval_case.conversation {
174 let turn_result = self.execute_turn(agent.clone(), turn).await?;
175
176 let (scores, failures) = self.score_turn(turn, &turn_result).await;
178
179 for (criterion, score) in &scores {
181 all_scores
182 .entry(criterion.clone())
183 .and_modify(|s| *s = (*s + score) / 2.0)
184 .or_insert(*score);
185 }
186 all_failures.extend(failures);
187
188 if self.config.collect_turn_details {
189 turn_results.push(turn_result);
190 }
191 }
192
193 let duration = start.elapsed();
194 let passed = all_failures.is_empty();
195
196 let mut result = if passed {
197 EvaluationResult::passed(&eval_case.eval_id, all_scores, duration)
198 } else {
199 EvaluationResult::failed(&eval_case.eval_id, all_scores, all_failures, duration)
200 };
201
202 if self.config.collect_turn_details {
203 result = result.with_turn_results(turn_results);
204 }
205
206 Ok(result)
207 }
208
209 async fn execute_turn(&self, agent: Arc<dyn Agent>, turn: &Turn) -> Result<TurnResult> {
211 let input_content = turn.user_content.to_adk_content();
213
214 let events = self.run_agent(agent, input_content).await?;
216
217 let (actual_response, actual_tool_calls) = self.extract_from_events(&events);
219
220 let expected_response = turn.final_response.as_ref().map(|c| c.get_text());
222 let expected_tool_calls =
223 turn.intermediate_data.as_ref().map(|d| d.tool_uses.clone()).unwrap_or_default();
224
225 Ok(TurnResult {
226 invocation_id: turn.invocation_id.clone(),
227 actual_response,
228 expected_response,
229 actual_tool_calls,
230 expected_tool_calls,
231 scores: HashMap::new(),
232 })
233 }
234
235 async fn run_agent(&self, agent: Arc<dyn Agent>, input: Content) -> Result<Vec<Event>> {
237 let invocation_id = uuid::Uuid::new_v4().to_string();
239 let ctx = Arc::new(EvalInvocationContext::new(invocation_id, input, agent.clone()));
240
241 let stream = agent.run(ctx).await.map_err(|e| {
243 crate::error::EvalError::ExecutionError(format!("Agent run failed: {}", e))
244 })?;
245
246 let events: Vec<Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
248
249 Ok(events)
250 }
251
252 fn extract_from_events(&self, events: &[Event]) -> (Option<String>, Vec<ToolUse>) {
254 let mut response_text = String::new();
255 let mut tool_calls = Vec::new();
256
257 for event in events {
258 if let Some(content) = event.content() {
260 for part in &content.parts {
261 if let Some(text) = part.text() {
263 response_text.push_str(text);
264 }
265 if let adk_core::Part::FunctionCall { name, args, .. } = part {
267 tool_calls.push(ToolUse {
268 name: name.clone(),
269 args: args.clone(),
270 expected_response: None,
271 });
272 }
273 }
274 }
275 }
276
277 let response = if response_text.is_empty() { None } else { Some(response_text) };
278
279 (response, tool_calls)
280 }
281
282 async fn score_turn(
284 &self,
285 turn: &Turn,
286 result: &TurnResult,
287 ) -> (HashMap<String, f64>, Vec<Failure>) {
288 let mut scores = HashMap::new();
289 let mut failures = Vec::new();
290
291 if let Some(threshold) = self.config.criteria.tool_trajectory_score {
293 let score =
294 self.tool_scorer.score(&result.expected_tool_calls, &result.actual_tool_calls);
295 scores.insert("tool_trajectory".to_string(), score);
296
297 if score < threshold {
298 failures.push(
299 Failure::new(
300 "tool_trajectory",
301 serde_json::to_value(&result.expected_tool_calls).unwrap_or_default(),
302 serde_json::to_value(&result.actual_tool_calls).unwrap_or_default(),
303 score,
304 threshold,
305 )
306 .with_details(&format!(
307 "Expected {} tool calls, got {}",
308 result.expected_tool_calls.len(),
309 result.actual_tool_calls.len()
310 )),
311 );
312 }
313 }
314
315 if let Some(threshold) = self.config.criteria.response_similarity {
317 if let (Some(expected), Some(actual)) =
318 (&result.expected_response, &result.actual_response)
319 {
320 let score = self.response_scorer.score(expected, actual);
321 scores.insert("response_similarity".to_string(), score);
322
323 if score < threshold {
324 failures.push(
325 Failure::new(
326 "response_similarity",
327 Value::String(expected.clone()),
328 Value::String(actual.clone()),
329 score,
330 threshold,
331 )
332 .with_details("Response text differs from expected"),
333 );
334 }
335 } else if result.expected_response.is_some() && result.actual_response.is_none() {
336 scores.insert("response_similarity".to_string(), 0.0);
337 failures.push(
338 Failure::new(
339 "response_similarity",
340 Value::String(result.expected_response.clone().unwrap_or_default()),
341 Value::Null,
342 0.0,
343 threshold,
344 )
345 .with_details("No response received"),
346 );
347 }
348 }
349
350 if let Some(threshold) = self.config.criteria.semantic_match_score {
352 if let Some(judge) = &self.llm_judge {
353 if let (Some(expected), Some(actual)) =
354 (&result.expected_response, &result.actual_response)
355 {
356 match judge
357 .semantic_match(
358 expected,
359 actual,
360 self.config.criteria.semantic_match_config.as_ref(),
361 )
362 .await
363 {
364 Ok(semantic_result) => {
365 scores.insert("semantic_match".to_string(), semantic_result.score);
366 if semantic_result.score < threshold {
367 failures.push(
368 Failure::new(
369 "semantic_match",
370 Value::String(expected.clone()),
371 Value::String(actual.clone()),
372 semantic_result.score,
373 threshold,
374 )
375 .with_details(&semantic_result.reasoning),
376 );
377 }
378 }
379 Err(e) => {
380 failures.push(
382 Failure::new(
383 "semantic_match",
384 Value::String(expected.clone()),
385 Value::String(actual.clone()),
386 0.0,
387 threshold,
388 )
389 .with_details(&format!("LLM judge error: {}", e)),
390 );
391 }
392 }
393 }
394 }
395 }
396
397 if let Some(threshold) = self.config.criteria.rubric_quality_score {
399 if let Some(judge) = &self.llm_judge {
400 if let Some(rubric_config) = &self.config.criteria.rubric_config {
401 if let Some(actual) = &result.actual_response {
402 let context = turn.user_content.get_text();
404 match judge.evaluate_rubrics(actual, &context, rubric_config).await {
405 Ok(rubric_result) => {
406 scores.insert(
407 "rubric_quality".to_string(),
408 rubric_result.overall_score,
409 );
410 for rs in &rubric_result.rubric_scores {
412 scores.insert(format!("rubric_{}", rs.name), rs.score);
413 }
414 if rubric_result.overall_score < threshold {
415 let details = rubric_result
416 .rubric_scores
417 .iter()
418 .map(|rs| {
419 format!(
420 "{}: {:.2} - {}",
421 rs.name, rs.score, rs.reasoning
422 )
423 })
424 .collect::<Vec<_>>()
425 .join("; ");
426 failures.push(
427 Failure::new(
428 "rubric_quality",
429 Value::Number(
430 serde_json::Number::from_f64(threshold)
431 .unwrap_or(serde_json::Number::from(0)),
432 ),
433 Value::Number(
434 serde_json::Number::from_f64(
435 rubric_result.overall_score,
436 )
437 .unwrap_or(serde_json::Number::from(0)),
438 ),
439 rubric_result.overall_score,
440 threshold,
441 )
442 .with_details(&details),
443 );
444 }
445 }
446 Err(e) => {
447 failures.push(
448 Failure::new(
449 "rubric_quality",
450 Value::Null,
451 Value::Null,
452 0.0,
453 threshold,
454 )
455 .with_details(&format!("LLM judge error: {}", e)),
456 );
457 }
458 }
459 }
460 }
461 }
462 }
463
464 if let Some(threshold) = self.config.criteria.safety_score {
466 if let Some(judge) = &self.llm_judge {
467 if let Some(actual) = &result.actual_response {
468 match judge.evaluate_safety(actual).await {
469 Ok(safety_result) => {
470 scores.insert("safety".to_string(), safety_result.score);
471 if safety_result.score < threshold {
472 failures.push(
473 Failure::new(
474 "safety",
475 Value::Number(
476 serde_json::Number::from_f64(threshold)
477 .unwrap_or(serde_json::Number::from(0)),
478 ),
479 Value::Number(
480 serde_json::Number::from_f64(safety_result.score)
481 .unwrap_or(serde_json::Number::from(0)),
482 ),
483 safety_result.score,
484 threshold,
485 )
486 .with_details(&format!(
487 "Safety issues: {}",
488 safety_result.issues.join(", ")
489 )),
490 );
491 }
492 }
493 Err(e) => {
494 failures.push(
495 Failure::new("safety", Value::Null, Value::Null, 0.0, threshold)
496 .with_details(&format!("LLM judge error: {}", e)),
497 );
498 }
499 }
500 }
501 }
502 }
503
504 if let Some(threshold) = self.config.criteria.hallucination_score {
506 if let Some(judge) = &self.llm_judge {
507 if let Some(actual) = &result.actual_response {
508 let context = turn.user_content.get_text();
509 let ground_truth = result.expected_response.as_deref();
510 match judge.detect_hallucinations(actual, &context, ground_truth).await {
511 Ok(hallucination_result) => {
512 scores.insert("hallucination".to_string(), hallucination_result.score);
513 if hallucination_result.score < threshold {
514 failures.push(
515 Failure::new(
516 "hallucination",
517 Value::Number(
518 serde_json::Number::from_f64(threshold)
519 .unwrap_or(serde_json::Number::from(0)),
520 ),
521 Value::Number(
522 serde_json::Number::from_f64(
523 hallucination_result.score,
524 )
525 .unwrap_or(serde_json::Number::from(0)),
526 ),
527 hallucination_result.score,
528 threshold,
529 )
530 .with_details(&format!(
531 "Hallucinations detected: {}",
532 hallucination_result.issues.join(", ")
533 )),
534 );
535 }
536 }
537 Err(e) => {
538 failures.push(
539 Failure::new(
540 "hallucination",
541 Value::Null,
542 Value::Null,
543 0.0,
544 threshold,
545 )
546 .with_details(&format!("LLM judge error: {}", e)),
547 );
548 }
549 }
550 }
551 }
552 }
553
554 (scores, failures)
555 }
556
557 pub async fn evaluate_cases_parallel(
559 &self,
560 agent: Arc<dyn Agent>,
561 cases: &[EvalCase],
562 concurrency: usize,
563 ) -> Vec<Result<EvaluationResult>> {
564 use futures::stream::{self, StreamExt};
565
566 let results: Vec<_> = stream::iter(cases)
567 .map(|case| {
568 let agent = agent.clone();
569 async move { self.evaluate_case(agent, case).await }
570 })
571 .buffer_unordered(concurrency)
572 .collect()
573 .await;
574
575 results
576 }
577
578 pub async fn evaluate_directory(
580 &self,
581 agent: Arc<dyn Agent>,
582 dir: impl AsRef<Path>,
583 ) -> Result<Vec<EvaluationReport>> {
584 let mut reports = Vec::new();
585
586 let entries = std::fs::read_dir(dir)?;
587 for entry in entries {
588 let entry = entry?;
589 let path = entry.path();
590
591 if path.extension().is_some_and(|ext| ext == "json") {
592 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
593 if name.ends_with(".test.json") {
594 let report = self.evaluate_file(agent.clone(), &path).await?;
595 reports.push(report);
596 }
597 }
598 }
599 }
600
601 Ok(reports)
602 }
603}
604
605impl Default for Evaluator {
606 fn default() -> Self {
607 Self::new(EvaluationConfig::default())
608 }
609}
610
611struct EvalInvocationContext {
617 invocation_id: String,
618 user_content: Content,
619 agent: Arc<dyn Agent>,
620 session: EvalSession,
621 run_config: adk_core::RunConfig,
622 ended: std::sync::atomic::AtomicBool,
623}
624
625impl EvalInvocationContext {
626 fn new(invocation_id: String, user_content: Content, agent: Arc<dyn Agent>) -> Self {
627 let session_id = format!("eval-session-{}", uuid::Uuid::new_v4());
628 Self {
629 invocation_id,
630 user_content,
631 agent,
632 session: EvalSession::new(session_id),
633 run_config: adk_core::RunConfig::default(),
634 ended: std::sync::atomic::AtomicBool::new(false),
635 }
636 }
637}
638
639impl adk_core::ReadonlyContext for EvalInvocationContext {
640 fn invocation_id(&self) -> &str {
641 &self.invocation_id
642 }
643
644 fn agent_name(&self) -> &str {
645 self.agent.name()
646 }
647
648 fn user_id(&self) -> &str {
649 "eval_user"
650 }
651
652 fn app_name(&self) -> &str {
653 "eval_app"
654 }
655
656 fn session_id(&self) -> &str {
657 &self.session.id
658 }
659
660 fn branch(&self) -> &str {
661 "main"
662 }
663
664 fn user_content(&self) -> &Content {
665 &self.user_content
666 }
667}
668
669#[async_trait]
670impl adk_core::CallbackContext for EvalInvocationContext {
671 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
672 None
673 }
674}
675
676#[async_trait]
677impl adk_core::InvocationContext for EvalInvocationContext {
678 fn agent(&self) -> Arc<dyn Agent> {
679 self.agent.clone()
680 }
681
682 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
683 None
684 }
685
686 fn session(&self) -> &dyn adk_core::Session {
687 &self.session
688 }
689
690 fn run_config(&self) -> &adk_core::RunConfig {
691 &self.run_config
692 }
693
694 fn end_invocation(&self) {
695 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
696 }
697
698 fn ended(&self) -> bool {
699 self.ended.load(std::sync::atomic::Ordering::SeqCst)
700 }
701}
702
703struct EvalSession {
705 id: String,
706 state: EvalState,
707}
708
709impl EvalSession {
710 fn new(id: String) -> Self {
711 Self { id, state: EvalState::new() }
712 }
713}
714
715impl adk_core::Session for EvalSession {
716 fn id(&self) -> &str {
717 &self.id
718 }
719
720 fn app_name(&self) -> &str {
721 "eval_app"
722 }
723
724 fn user_id(&self) -> &str {
725 "eval_user"
726 }
727
728 fn state(&self) -> &dyn adk_core::State {
729 &self.state
730 }
731
732 fn conversation_history(&self) -> Vec<Content> {
733 vec![]
734 }
735}
736
737struct EvalState {
739 data: std::sync::RwLock<HashMap<String, serde_json::Value>>,
740}
741
742impl EvalState {
743 fn new() -> Self {
744 Self { data: std::sync::RwLock::new(HashMap::new()) }
745 }
746}
747
748impl adk_core::State for EvalState {
749 fn get(&self, key: &str) -> Option<serde_json::Value> {
750 self.data.read().ok()?.get(key).cloned()
751 }
752
753 fn set(&mut self, key: String, value: serde_json::Value) {
754 if let Ok(mut data) = self.data.write() {
755 data.insert(key, value);
756 }
757 }
758
759 fn all(&self) -> HashMap<String, serde_json::Value> {
760 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn test_evaluator_creation() {
770 let config = EvaluationConfig::with_criteria(
771 EvaluationCriteria::exact_tools().with_response_similarity(0.8),
772 );
773 let evaluator = Evaluator::new(config);
774 assert!(evaluator.config.criteria.tool_trajectory_score.is_some());
775 assert!(evaluator.config.criteria.response_similarity.is_some());
776 }
777
778 #[tokio::test]
779 async fn test_turn_scoring() {
780 let config = EvaluationConfig::with_criteria(EvaluationCriteria {
781 tool_trajectory_score: Some(1.0),
782 response_similarity: Some(0.8),
783 ..Default::default()
784 });
785 let evaluator = Evaluator::new(config);
786
787 let turn = Turn {
788 invocation_id: "test".to_string(),
789 user_content: crate::schema::ContentData::text("Hello"),
790 final_response: Some(crate::schema::ContentData::model_response("Hi there!")),
791 intermediate_data: Some(crate::schema::IntermediateData {
792 tool_uses: vec![ToolUse::new("greet")],
793 ..Default::default()
794 }),
795 };
796
797 let result = TurnResult {
798 invocation_id: "test".to_string(),
799 actual_response: Some("Hi there!".to_string()),
800 expected_response: Some("Hi there!".to_string()),
801 actual_tool_calls: vec![ToolUse::new("greet")],
802 expected_tool_calls: vec![ToolUse::new("greet")],
803 scores: HashMap::new(),
804 };
805
806 let (scores, failures) = evaluator.score_turn(&turn, &result).await;
807 assert!(failures.is_empty());
808 assert_eq!(scores.get("tool_trajectory"), Some(&1.0));
809 assert_eq!(scores.get("response_similarity"), Some(&1.0));
810 }
811}