Skip to main content

adk_eval/
evaluator.rs

1//! Core evaluator implementation
2//!
3//! The Evaluator orchestrates test execution and applies evaluation criteria.
4
5use crate::cost_tracker::CostTracker;
6use crate::criteria::EvaluationCriteria;
7use crate::error::Result;
8use crate::llm_judge::LlmJudge;
9use crate::report::{EvaluationReport, EvaluationResult, Failure, TurnResult};
10use crate::schema::{EvalCase, TestFile, ToolUse, Turn};
11use crate::scoring::{ResponseScorer, ToolTrajectoryScorer};
12use crate::structured_judge::StructuredJudge;
13use crate::trace_analyzer::TraceAnalyzer;
14
15use adk_core::{Agent, Content, Event, Llm};
16use async_trait::async_trait;
17use futures::StreamExt;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25#[cfg(feature = "embedding")]
26use crate::embedding_scorer::EmbeddingScorer;
27
28/// Configuration for the evaluator
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub struct EvaluationConfig {
31    /// Evaluation criteria to apply
32    #[serde(default)]
33    pub criteria: EvaluationCriteria,
34    /// Whether to continue on failure
35    #[serde(default)]
36    pub continue_on_failure: bool,
37    /// Maximum time per test case
38    #[serde(default)]
39    pub timeout_per_case: Option<Duration>,
40    /// Number of retries for flaky tests
41    #[serde(default)]
42    pub retries: usize,
43    /// Whether to collect detailed turn results
44    #[serde(default = "default_true")]
45    pub collect_turn_details: bool,
46}
47
48fn default_true() -> bool {
49    true
50}
51
52impl EvaluationConfig {
53    /// Create config with specific criteria
54    pub fn with_criteria(criteria: EvaluationCriteria) -> Self {
55        Self { criteria, ..Default::default() }
56    }
57}
58
59/// The main evaluator struct
60pub struct Evaluator {
61    config: EvaluationConfig,
62    tool_scorer: ToolTrajectoryScorer,
63    response_scorer: ResponseScorer,
64    llm_judge: Option<LlmJudge>,
65    /// Optional structured judge for typed verdicts
66    structured_judge: Option<Arc<StructuredJudge>>,
67    /// Optional cost tracker for token usage and latency
68    cost_tracker: Option<CostTracker>,
69    /// Optional trace analyzer for detecting execution inefficiencies
70    trace_analyzer: Option<TraceAnalyzer>,
71    /// Optional embedding scorer for semantic similarity (requires `embedding` feature)
72    #[cfg(feature = "embedding")]
73    embedding_scorer: Option<Arc<EmbeddingScorer>>,
74    /// Optional conversation scorer for multi-turn metrics
75    conversation_scorer: Option<Arc<crate::conversation_scorer::ConversationScorer>>,
76}
77
78impl Evaluator {
79    /// Create a new evaluator with default configuration
80    pub fn new(config: EvaluationConfig) -> Self {
81        let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
82            ToolTrajectoryScorer::with_config(tc.clone())
83        } else {
84            ToolTrajectoryScorer::new()
85        };
86
87        let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
88            ResponseScorer::with_config(rc.clone())
89        } else {
90            ResponseScorer::new()
91        };
92
93        Self {
94            config,
95            tool_scorer,
96            response_scorer,
97            llm_judge: None,
98            structured_judge: None,
99            cost_tracker: None,
100            trace_analyzer: None,
101            #[cfg(feature = "embedding")]
102            embedding_scorer: None,
103            conversation_scorer: None,
104        }
105    }
106
107    /// Create an evaluator with an LLM judge for semantic matching and rubric evaluation
108    pub fn with_llm_judge(config: EvaluationConfig, judge_model: Arc<dyn Llm>) -> Self {
109        let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
110            ToolTrajectoryScorer::with_config(tc.clone())
111        } else {
112            ToolTrajectoryScorer::new()
113        };
114
115        let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
116            ResponseScorer::with_config(rc.clone())
117        } else {
118            ResponseScorer::new()
119        };
120
121        Self {
122            config,
123            tool_scorer,
124            response_scorer,
125            llm_judge: Some(LlmJudge::new(judge_model)),
126            structured_judge: None,
127            cost_tracker: None,
128            trace_analyzer: None,
129            #[cfg(feature = "embedding")]
130            embedding_scorer: None,
131            conversation_scorer: None,
132        }
133    }
134
135    /// Set the LLM judge model
136    pub fn set_llm_judge(&mut self, judge_model: Arc<dyn Llm>) {
137        self.llm_judge = Some(LlmJudge::new(judge_model));
138    }
139
140    /// Check if LLM judge is available
141    pub fn has_llm_judge(&self) -> bool {
142        self.llm_judge.is_some()
143    }
144
145    /// Set the structured judge for typed verdict evaluation
146    pub fn set_structured_judge(&mut self, judge: Arc<StructuredJudge>) {
147        self.structured_judge = Some(judge);
148    }
149
150    /// Set the cost tracker for token usage and latency metrics
151    pub fn set_cost_tracker(&mut self, tracker: CostTracker) {
152        self.cost_tracker = Some(tracker);
153    }
154
155    /// Set the trace analyzer for execution inefficiency detection
156    pub fn set_trace_analyzer(&mut self, analyzer: TraceAnalyzer) {
157        self.trace_analyzer = Some(analyzer);
158    }
159
160    /// Set the embedding scorer for semantic similarity (requires `embedding` feature)
161    #[cfg(feature = "embedding")]
162    pub fn set_embedding_scorer(&mut self, scorer: Arc<EmbeddingScorer>) {
163        self.embedding_scorer = Some(scorer);
164    }
165
166    /// Set the conversation scorer for multi-turn metrics
167    pub fn set_conversation_scorer(
168        &mut self,
169        scorer: Arc<crate::conversation_scorer::ConversationScorer>,
170    ) {
171        self.conversation_scorer = Some(scorer);
172    }
173
174    /// Check if a structured judge is configured
175    pub fn has_structured_judge(&self) -> bool {
176        self.structured_judge.is_some()
177    }
178
179    /// Check if a cost tracker is configured
180    pub fn has_cost_tracker(&self) -> bool {
181        self.cost_tracker.is_some()
182    }
183
184    /// Check if a trace analyzer is configured
185    pub fn has_trace_analyzer(&self) -> bool {
186        self.trace_analyzer.is_some()
187    }
188
189    /// Check if an embedding scorer is configured (requires `embedding` feature)
190    #[cfg(feature = "embedding")]
191    pub fn has_embedding_scorer(&self) -> bool {
192        self.embedding_scorer.is_some()
193    }
194
195    /// Check if a conversation scorer is configured
196    pub fn has_conversation_scorer(&self) -> bool {
197        self.conversation_scorer.is_some()
198    }
199
200    /// Evaluate a test file against an agent
201    pub async fn evaluate_file(
202        &self,
203        agent: Arc<dyn Agent>,
204        path: impl AsRef<Path>,
205    ) -> Result<EvaluationReport> {
206        let test_file = TestFile::load(path)?;
207        self.evaluate_test_file(agent, &test_file).await
208    }
209
210    /// Evaluate a TestFile struct
211    pub async fn evaluate_test_file(
212        &self,
213        agent: Arc<dyn Agent>,
214        test_file: &TestFile,
215    ) -> Result<EvaluationReport> {
216        let started_at = chrono::Utc::now();
217        let run_id = format!("{}_{}", test_file.eval_set_id, uuid::Uuid::new_v4());
218        let mut results = Vec::new();
219
220        for eval_case in &test_file.eval_cases {
221            let result = self.evaluate_case(agent.clone(), eval_case).await;
222
223            match result {
224                Ok(r) => {
225                    let passed = r.passed;
226                    results.push(r);
227                    if !passed && !self.config.continue_on_failure {
228                        break;
229                    }
230                }
231                Err(e) => {
232                    // Create a failed result for the error
233                    results.push(EvaluationResult::failed(
234                        &eval_case.eval_id,
235                        HashMap::new(),
236                        vec![Failure::new(
237                            "execution",
238                            Value::Null,
239                            Value::String(e.to_string()),
240                            0.0,
241                            1.0,
242                        )],
243                        Duration::from_secs(0),
244                    ));
245                    if !self.config.continue_on_failure {
246                        break;
247                    }
248                }
249            }
250        }
251
252        Ok(EvaluationReport::new(&run_id, results, started_at))
253    }
254
255    /// Evaluate a single test case
256    pub async fn evaluate_case(
257        &self,
258        agent: Arc<dyn Agent>,
259        eval_case: &EvalCase,
260    ) -> Result<EvaluationResult> {
261        let start = Instant::now();
262        let mut all_scores: HashMap<String, f64> = HashMap::new();
263        let mut all_failures: Vec<Failure> = Vec::new();
264        let mut turn_results: Vec<TurnResult> = Vec::new();
265        let mut all_events: Vec<Event> = Vec::new();
266
267        // Execute each turn in the conversation
268        for turn in &eval_case.conversation {
269            let turn_result = self.execute_turn(agent.clone(), turn).await?;
270
271            // Score this turn
272            let (scores, failures) = self.score_turn(turn, &turn_result).await;
273
274            // Merge scores
275            for (criterion, score) in &scores {
276                all_scores
277                    .entry(criterion.clone())
278                    .and_modify(|s| *s = (*s + score) / 2.0)
279                    .or_insert(*score);
280            }
281            all_failures.extend(failures);
282
283            if self.config.collect_turn_details {
284                turn_results.push(turn_result);
285            }
286        }
287
288        // Collect events for the full case by re-running (or using last turn's events)
289        // For cost/trace analysis we re-run the agent to get the full event stream
290        let case_events = self.collect_case_events(agent.clone(), eval_case).await;
291        if let Ok(events) = case_events {
292            all_events = events;
293        }
294
295        let duration = start.elapsed();
296
297        // Invoke CostTracker if configured
298        let cost_metrics = self
299            .cost_tracker
300            .as_ref()
301            .map(|tracker| tracker.extract_metrics(&all_events, duration));
302
303        // Invoke TraceAnalyzer if configured
304        let trace_analysis =
305            self.trace_analyzer.as_ref().map(|analyzer| analyzer.analyze(&all_events));
306
307        // Invoke StructuredJudge if configured
308        let mut verdicts = Vec::new();
309        if let Some(judge) = &self.structured_judge
310            && let Some(last_turn_result) = turn_results.last()
311            && let (Some(expected), Some(actual)) =
312                (&last_turn_result.expected_response, &last_turn_result.actual_response)
313        {
314            match judge.judge(expected, actual, "overall_quality").await {
315                Ok(verdict) => {
316                    all_scores.insert("structured_judge".to_string(), verdict.score);
317                    verdicts.push(verdict);
318                }
319                Err(e) => {
320                    tracing::warn!("Structured judge failed: {e}");
321                    // Create a fallback verdict with score 0.0
322                    let fallback = crate::structured_judge::StructuredVerdict {
323                        score: 0.0,
324                        reasoning: format!("Judge error: {e}"),
325                        verdict: crate::structured_judge::Verdict::Fail,
326                    };
327                    verdicts.push(fallback);
328                }
329            }
330        }
331
332        // Invoke EmbeddingScorer if configured
333        #[cfg(feature = "embedding")]
334        if let Some(scorer) = &self.embedding_scorer
335            && let Some(last_turn_result) = turn_results.last()
336            && let (Some(expected), Some(actual)) =
337                (&last_turn_result.expected_response, &last_turn_result.actual_response)
338        {
339            match scorer.score(expected, actual).await {
340                Ok(score) => {
341                    all_scores.insert("embedding_similarity".to_string(), score);
342                }
343                Err(e) => {
344                    tracing::warn!("Embedding scorer failed: {e}");
345                }
346            }
347        }
348
349        let passed = all_failures.is_empty();
350
351        let mut result = if passed {
352            EvaluationResult::passed(&eval_case.eval_id, all_scores, duration)
353        } else {
354            EvaluationResult::failed(&eval_case.eval_id, all_scores, all_failures, duration)
355        };
356
357        if self.config.collect_turn_details {
358            result = result.with_turn_results(turn_results);
359        }
360
361        // Populate extended fields
362        result.cost_metrics = cost_metrics;
363        result.trace_analysis = trace_analysis;
364        result.verdicts = verdicts;
365
366        Ok(result)
367    }
368
369    /// Collect all events for a case by running the agent on the first turn input.
370    /// Used by cost tracker and trace analyzer to analyze the full execution.
371    async fn collect_case_events(
372        &self,
373        agent: Arc<dyn Agent>,
374        eval_case: &EvalCase,
375    ) -> Result<Vec<Event>> {
376        // Only collect events if we have a cost tracker or trace analyzer configured
377        if self.cost_tracker.is_none() && self.trace_analyzer.is_none() {
378            return Ok(Vec::new());
379        }
380
381        // Use events from the first turn as representative
382        if let Some(first_turn) = eval_case.conversation.first() {
383            let input_content = first_turn.user_content.to_adk_content();
384            self.run_agent(agent, input_content).await
385        } else {
386            Ok(Vec::new())
387        }
388    }
389
390    /// Execute a single turn and collect results
391    async fn execute_turn(&self, agent: Arc<dyn Agent>, turn: &Turn) -> Result<TurnResult> {
392        // Create input content
393        let input_content = turn.user_content.to_adk_content();
394
395        // Run the agent
396        let events = self.run_agent(agent, input_content).await?;
397
398        // Extract response and tool calls from events
399        let (actual_response, actual_tool_calls) = self.extract_from_events(&events);
400
401        // Get expected values
402        let expected_response = turn.final_response.as_ref().map(|c| c.get_text());
403        let expected_tool_calls =
404            turn.intermediate_data.as_ref().map(|d| d.tool_uses.clone()).unwrap_or_default();
405
406        Ok(TurnResult {
407            invocation_id: turn.invocation_id.clone(),
408            actual_response,
409            expected_response,
410            actual_tool_calls,
411            expected_tool_calls,
412            scores: HashMap::new(),
413        })
414    }
415
416    /// Run agent and collect events
417    async fn run_agent(&self, agent: Arc<dyn Agent>, input: Content) -> Result<Vec<Event>> {
418        // Create a minimal invocation context for evaluation
419        let invocation_id = uuid::Uuid::new_v4().to_string();
420        let ctx = Arc::new(EvalInvocationContext::new(invocation_id, input, agent.clone()));
421
422        // Run the agent and collect all events
423        let stream = agent.run(ctx).await.map_err(|e| {
424            crate::error::EvalError::ExecutionError(format!("Agent run failed: {}", e))
425        })?;
426
427        // Collect all events from the stream
428        let events: Vec<Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
429
430        Ok(events)
431    }
432
433    /// Extract response text and tool calls from events
434    fn extract_from_events(&self, events: &[Event]) -> (Option<String>, Vec<ToolUse>) {
435        let mut response_text = String::new();
436        let mut tool_calls = Vec::new();
437
438        for event in events {
439            // Extract text content
440            if let Some(content) = event.content() {
441                for part in &content.parts {
442                    // Extract text content
443                    if let Some(text) = part.text() {
444                        response_text.push_str(text);
445                    }
446                    // Extract function calls using pattern matching
447                    if let adk_core::Part::FunctionCall { name, args, .. } = part {
448                        tool_calls.push(ToolUse {
449                            name: name.clone(),
450                            args: args.clone(),
451                            expected_response: None,
452                        });
453                    }
454                }
455            }
456        }
457
458        let response = if response_text.is_empty() { None } else { Some(response_text) };
459
460        (response, tool_calls)
461    }
462
463    /// Score a turn against criteria
464    async fn score_turn(
465        &self,
466        turn: &Turn,
467        result: &TurnResult,
468    ) -> (HashMap<String, f64>, Vec<Failure>) {
469        let mut scores = HashMap::new();
470        let mut failures = Vec::new();
471
472        // Tool trajectory scoring
473        if let Some(threshold) = self.config.criteria.tool_trajectory_score {
474            let score =
475                self.tool_scorer.score(&result.expected_tool_calls, &result.actual_tool_calls);
476            scores.insert("tool_trajectory".to_string(), score);
477
478            if score < threshold {
479                failures.push(
480                    Failure::new(
481                        "tool_trajectory",
482                        serde_json::to_value(&result.expected_tool_calls).unwrap_or_default(),
483                        serde_json::to_value(&result.actual_tool_calls).unwrap_or_default(),
484                        score,
485                        threshold,
486                    )
487                    .with_details(&format!(
488                        "Expected {} tool calls, got {}",
489                        result.expected_tool_calls.len(),
490                        result.actual_tool_calls.len()
491                    )),
492                );
493            }
494        }
495
496        // Response similarity scoring (text-based)
497        if let Some(threshold) = self.config.criteria.response_similarity {
498            if let (Some(expected), Some(actual)) =
499                (&result.expected_response, &result.actual_response)
500            {
501                let score = self.response_scorer.score(expected, actual);
502                scores.insert("response_similarity".to_string(), score);
503
504                if score < threshold {
505                    failures.push(
506                        Failure::new(
507                            "response_similarity",
508                            Value::String(expected.clone()),
509                            Value::String(actual.clone()),
510                            score,
511                            threshold,
512                        )
513                        .with_details("Response text differs from expected"),
514                    );
515                }
516            } else if result.expected_response.is_some() && result.actual_response.is_none() {
517                scores.insert("response_similarity".to_string(), 0.0);
518                failures.push(
519                    Failure::new(
520                        "response_similarity",
521                        Value::String(result.expected_response.clone().unwrap_or_default()),
522                        Value::Null,
523                        0.0,
524                        threshold,
525                    )
526                    .with_details("No response received"),
527                );
528            }
529        }
530
531        // LLM-judged semantic matching
532        if let Some(threshold) = self.config.criteria.semantic_match_score
533            && let Some(judge) = &self.llm_judge
534            && let (Some(expected), Some(actual)) =
535                (&result.expected_response, &result.actual_response)
536        {
537            match judge
538                .semantic_match(
539                    expected,
540                    actual,
541                    self.config.criteria.semantic_match_config.as_ref(),
542                )
543                .await
544            {
545                Ok(semantic_result) => {
546                    scores.insert("semantic_match".to_string(), semantic_result.score);
547                    if semantic_result.score < threshold {
548                        failures.push(
549                            Failure::new(
550                                "semantic_match",
551                                Value::String(expected.clone()),
552                                Value::String(actual.clone()),
553                                semantic_result.score,
554                                threshold,
555                            )
556                            .with_details(&semantic_result.reasoning),
557                        );
558                    }
559                }
560                Err(e) => {
561                    // Record error but don't fail the whole evaluation
562                    failures.push(
563                        Failure::new(
564                            "semantic_match",
565                            Value::String(expected.clone()),
566                            Value::String(actual.clone()),
567                            0.0,
568                            threshold,
569                        )
570                        .with_details(&format!("LLM judge error: {}", e)),
571                    );
572                }
573            }
574        }
575
576        // Rubric-based evaluation
577        if let Some(threshold) = self.config.criteria.rubric_quality_score
578            && let Some(judge) = &self.llm_judge
579            && let Some(rubric_config) = &self.config.criteria.rubric_config
580            && let Some(actual) = &result.actual_response
581        {
582            // Use user input as context for rubric evaluation
583            let context = turn.user_content.get_text();
584            match judge.evaluate_rubrics(actual, &context, rubric_config).await {
585                Ok(rubric_result) => {
586                    scores.insert("rubric_quality".to_string(), rubric_result.overall_score);
587                    // Also store individual rubric scores
588                    for rs in &rubric_result.rubric_scores {
589                        scores.insert(format!("rubric_{}", rs.name), rs.score);
590                    }
591                    if rubric_result.overall_score < threshold {
592                        let details = rubric_result
593                            .rubric_scores
594                            .iter()
595                            .map(|rs| format!("{}: {:.2} - {}", rs.name, rs.score, rs.reasoning))
596                            .collect::<Vec<_>>()
597                            .join("; ");
598                        failures.push(
599                            Failure::new(
600                                "rubric_quality",
601                                Value::Number(
602                                    serde_json::Number::from_f64(threshold)
603                                        .unwrap_or(serde_json::Number::from(0)),
604                                ),
605                                Value::Number(
606                                    serde_json::Number::from_f64(rubric_result.overall_score)
607                                        .unwrap_or(serde_json::Number::from(0)),
608                                ),
609                                rubric_result.overall_score,
610                                threshold,
611                            )
612                            .with_details(&details),
613                        );
614                    }
615                }
616                Err(e) => {
617                    failures.push(
618                        Failure::new("rubric_quality", Value::Null, Value::Null, 0.0, threshold)
619                            .with_details(&format!("LLM judge error: {}", e)),
620                    );
621                }
622            }
623        }
624
625        // Safety evaluation
626        if let Some(threshold) = self.config.criteria.safety_score
627            && let Some(judge) = &self.llm_judge
628            && let Some(actual) = &result.actual_response
629        {
630            match judge.evaluate_safety(actual).await {
631                Ok(safety_result) => {
632                    scores.insert("safety".to_string(), safety_result.score);
633                    if safety_result.score < threshold {
634                        failures.push(
635                            Failure::new(
636                                "safety",
637                                Value::Number(
638                                    serde_json::Number::from_f64(threshold)
639                                        .unwrap_or(serde_json::Number::from(0)),
640                                ),
641                                Value::Number(
642                                    serde_json::Number::from_f64(safety_result.score)
643                                        .unwrap_or(serde_json::Number::from(0)),
644                                ),
645                                safety_result.score,
646                                threshold,
647                            )
648                            .with_details(&format!(
649                                "Safety issues: {}",
650                                safety_result.issues.join(", ")
651                            )),
652                        );
653                    }
654                }
655                Err(e) => {
656                    failures.push(
657                        Failure::new("safety", Value::Null, Value::Null, 0.0, threshold)
658                            .with_details(&format!("LLM judge error: {}", e)),
659                    );
660                }
661            }
662        }
663
664        // Hallucination detection
665        if let Some(threshold) = self.config.criteria.hallucination_score
666            && let Some(judge) = &self.llm_judge
667            && let Some(actual) = &result.actual_response
668        {
669            let context = turn.user_content.get_text();
670            let ground_truth = result.expected_response.as_deref();
671            match judge.detect_hallucinations(actual, &context, ground_truth).await {
672                Ok(hallucination_result) => {
673                    scores.insert("hallucination".to_string(), hallucination_result.score);
674                    if hallucination_result.score < threshold {
675                        failures.push(
676                            Failure::new(
677                                "hallucination",
678                                Value::Number(
679                                    serde_json::Number::from_f64(threshold)
680                                        .unwrap_or(serde_json::Number::from(0)),
681                                ),
682                                Value::Number(
683                                    serde_json::Number::from_f64(hallucination_result.score)
684                                        .unwrap_or(serde_json::Number::from(0)),
685                                ),
686                                hallucination_result.score,
687                                threshold,
688                            )
689                            .with_details(&format!(
690                                "Hallucinations detected: {}",
691                                hallucination_result.issues.join(", ")
692                            )),
693                        );
694                    }
695                }
696                Err(e) => {
697                    failures.push(
698                        Failure::new("hallucination", Value::Null, Value::Null, 0.0, threshold)
699                            .with_details(&format!("LLM judge error: {}", e)),
700                    );
701                }
702            }
703        }
704
705        (scores, failures)
706    }
707
708    /// Evaluate multiple test cases in parallel
709    pub async fn evaluate_cases_parallel(
710        &self,
711        agent: Arc<dyn Agent>,
712        cases: &[EvalCase],
713        concurrency: usize,
714    ) -> Vec<Result<EvaluationResult>> {
715        use futures::stream::{self, StreamExt};
716
717        let results: Vec<_> = stream::iter(cases)
718            .map(|case| {
719                let agent = agent.clone();
720                async move { self.evaluate_case(agent, case).await }
721            })
722            .buffer_unordered(concurrency)
723            .collect()
724            .await;
725
726        results
727    }
728
729    /// Evaluate a directory of test files
730    pub async fn evaluate_directory(
731        &self,
732        agent: Arc<dyn Agent>,
733        dir: impl AsRef<Path>,
734    ) -> Result<Vec<EvaluationReport>> {
735        let mut reports = Vec::new();
736
737        let entries = std::fs::read_dir(dir)?;
738        for entry in entries {
739            let entry = entry?;
740            let path = entry.path();
741
742            if path.extension().is_some_and(|ext| ext == "json")
743                && let Some(name) = path.file_name().and_then(|n| n.to_str())
744                && name.ends_with(".test.json")
745            {
746                let report = self.evaluate_file(agent.clone(), &path).await?;
747                reports.push(report);
748            }
749        }
750
751        Ok(reports)
752    }
753
754    /// Run a multi-turn conversation between a [`UserSimulator`](crate::personas::UserSimulator)
755    /// and the agent under test for a configurable number of turns.
756    ///
757    /// Each turn consists of:
758    /// 1. The `UserSimulator` generates a user message based on the conversation history
759    /// 2. The agent processes the user message and produces a response
760    /// 3. Both messages are appended to the conversation history
761    ///
762    /// Returns the full conversation history as a `Vec<Content>`.
763    ///
764    /// # Errors
765    ///
766    /// Returns an error if the simulator or agent fails during any turn.
767    #[cfg(feature = "personas")]
768    pub async fn evaluate_multi_turn(
769        &self,
770        agent: Arc<dyn Agent>,
771        simulator: &crate::personas::UserSimulator,
772        num_turns: usize,
773    ) -> Result<Vec<Content>> {
774        let mut history: Vec<Content> = Vec::new();
775
776        for _turn_idx in 0..num_turns {
777            // 1. Generate user message from the simulator
778            let user_message = simulator.generate_message(&history).await?;
779            history.push(user_message.clone());
780
781            // 2. Run the agent with the user message
782            let events = self.run_agent(agent.clone(), user_message).await?;
783
784            // 3. Extract the agent's response text
785            let (response_text, _tool_calls) = self.extract_from_events(&events);
786            if let Some(text) = response_text {
787                history.push(Content::new("model").with_text(text));
788            }
789        }
790
791        Ok(history)
792    }
793}
794
795impl Default for Evaluator {
796    fn default() -> Self {
797        Self::new(EvaluationConfig::default())
798    }
799}
800
801// ============================================================================
802// EvalInvocationContext - Minimal context for running agents during evaluation
803// ============================================================================
804
805/// Minimal InvocationContext implementation for evaluation
806struct EvalInvocationContext {
807    invocation_id: String,
808    user_content: Content,
809    agent: Arc<dyn Agent>,
810    session: EvalSession,
811    run_config: adk_core::RunConfig,
812    ended: std::sync::atomic::AtomicBool,
813}
814
815impl EvalInvocationContext {
816    fn new(invocation_id: String, user_content: Content, agent: Arc<dyn Agent>) -> Self {
817        let session_id = format!("eval-session-{}", uuid::Uuid::new_v4());
818        Self {
819            invocation_id,
820            user_content,
821            agent,
822            session: EvalSession::new(session_id),
823            run_config: adk_core::RunConfig::default(),
824            ended: std::sync::atomic::AtomicBool::new(false),
825        }
826    }
827}
828
829impl adk_core::ReadonlyContext for EvalInvocationContext {
830    fn invocation_id(&self) -> &str {
831        &self.invocation_id
832    }
833
834    fn agent_name(&self) -> &str {
835        self.agent.name()
836    }
837
838    fn user_id(&self) -> &str {
839        "eval_user"
840    }
841
842    fn app_name(&self) -> &str {
843        "eval_app"
844    }
845
846    fn session_id(&self) -> &str {
847        &self.session.id
848    }
849
850    fn branch(&self) -> &str {
851        "main"
852    }
853
854    fn user_content(&self) -> &Content {
855        &self.user_content
856    }
857}
858
859#[async_trait]
860impl adk_core::CallbackContext for EvalInvocationContext {
861    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
862        None
863    }
864}
865
866#[async_trait]
867impl adk_core::InvocationContext for EvalInvocationContext {
868    fn agent(&self) -> Arc<dyn Agent> {
869        self.agent.clone()
870    }
871
872    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
873        None
874    }
875
876    fn session(&self) -> &dyn adk_core::Session {
877        &self.session
878    }
879
880    fn run_config(&self) -> &adk_core::RunConfig {
881        &self.run_config
882    }
883
884    fn end_invocation(&self) {
885        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
886    }
887
888    fn ended(&self) -> bool {
889        self.ended.load(std::sync::atomic::Ordering::SeqCst)
890    }
891}
892
893/// Minimal Session implementation for evaluation
894struct EvalSession {
895    id: String,
896    state: EvalState,
897}
898
899impl EvalSession {
900    fn new(id: String) -> Self {
901        Self { id, state: EvalState::new() }
902    }
903}
904
905impl adk_core::Session for EvalSession {
906    fn id(&self) -> &str {
907        &self.id
908    }
909
910    fn app_name(&self) -> &str {
911        "eval_app"
912    }
913
914    fn user_id(&self) -> &str {
915        "eval_user"
916    }
917
918    fn state(&self) -> &dyn adk_core::State {
919        &self.state
920    }
921
922    fn conversation_history(&self) -> Vec<Content> {
923        vec![]
924    }
925}
926
927/// Minimal State implementation for evaluation
928struct EvalState {
929    data: std::sync::RwLock<HashMap<String, serde_json::Value>>,
930}
931
932impl EvalState {
933    fn new() -> Self {
934        Self { data: std::sync::RwLock::new(HashMap::new()) }
935    }
936}
937
938impl adk_core::State for EvalState {
939    fn get(&self, key: &str) -> Option<serde_json::Value> {
940        self.data.read().ok()?.get(key).cloned()
941    }
942
943    fn set(&mut self, key: String, value: serde_json::Value) {
944        if let Err(msg) = adk_core::validate_state_key(&key) {
945            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
946            return;
947        }
948        if let Ok(mut data) = self.data.write() {
949            data.insert(key, value);
950        }
951    }
952
953    fn all(&self) -> HashMap<String, serde_json::Value> {
954        self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
955    }
956}
957
958#[cfg(test)]
959mod tests {
960    use super::*;
961
962    #[test]
963    fn test_evaluator_creation() {
964        let config = EvaluationConfig::with_criteria(
965            EvaluationCriteria::exact_tools().with_response_similarity(0.8),
966        );
967        let evaluator = Evaluator::new(config);
968        assert!(evaluator.config.criteria.tool_trajectory_score.is_some());
969        assert!(evaluator.config.criteria.response_similarity.is_some());
970    }
971
972    #[tokio::test]
973    async fn test_turn_scoring() {
974        let config = EvaluationConfig::with_criteria(EvaluationCriteria {
975            tool_trajectory_score: Some(1.0),
976            response_similarity: Some(0.8),
977            ..Default::default()
978        });
979        let evaluator = Evaluator::new(config);
980
981        let turn = Turn {
982            invocation_id: "test".to_string(),
983            user_content: crate::schema::ContentData::text("Hello"),
984            final_response: Some(crate::schema::ContentData::model_response("Hi there!")),
985            intermediate_data: Some(crate::schema::IntermediateData {
986                tool_uses: vec![ToolUse::new("greet")],
987                ..Default::default()
988            }),
989        };
990
991        let result = TurnResult {
992            invocation_id: "test".to_string(),
993            actual_response: Some("Hi there!".to_string()),
994            expected_response: Some("Hi there!".to_string()),
995            actual_tool_calls: vec![ToolUse::new("greet")],
996            expected_tool_calls: vec![ToolUse::new("greet")],
997            scores: HashMap::new(),
998        };
999
1000        let (scores, failures) = evaluator.score_turn(&turn, &result).await;
1001        assert!(failures.is_empty());
1002        assert_eq!(scores.get("tool_trajectory"), Some(&1.0));
1003        assert_eq!(scores.get("response_similarity"), Some(&1.0));
1004    }
1005}