1use crate::cost_tracker::CostTracker;
6use crate::criteria::EvaluationCriteria;
7use crate::error::Result;
8use crate::llm_judge::LlmJudge;
9use crate::report::{EvaluationReport, EvaluationResult, Failure, TurnResult};
10use crate::schema::{EvalCase, TestFile, ToolUse, Turn};
11use crate::scoring::{ResponseScorer, ToolTrajectoryScorer};
12use crate::structured_judge::StructuredJudge;
13use crate::trace_analyzer::TraceAnalyzer;
14
15use adk_core::{Agent, Content, Event, Llm};
16use async_trait::async_trait;
17use futures::StreamExt;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25#[cfg(feature = "embedding")]
26use crate::embedding_scorer::EmbeddingScorer;
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub struct EvaluationConfig {
31 #[serde(default)]
33 pub criteria: EvaluationCriteria,
34 #[serde(default)]
36 pub continue_on_failure: bool,
37 #[serde(default)]
39 pub timeout_per_case: Option<Duration>,
40 #[serde(default)]
42 pub retries: usize,
43 #[serde(default = "default_true")]
45 pub collect_turn_details: bool,
46}
47
48fn default_true() -> bool {
49 true
50}
51
52impl EvaluationConfig {
53 pub fn with_criteria(criteria: EvaluationCriteria) -> Self {
55 Self { criteria, ..Default::default() }
56 }
57}
58
59pub struct Evaluator {
61 config: EvaluationConfig,
62 tool_scorer: ToolTrajectoryScorer,
63 response_scorer: ResponseScorer,
64 llm_judge: Option<LlmJudge>,
65 structured_judge: Option<Arc<StructuredJudge>>,
67 cost_tracker: Option<CostTracker>,
69 trace_analyzer: Option<TraceAnalyzer>,
71 #[cfg(feature = "embedding")]
73 embedding_scorer: Option<Arc<EmbeddingScorer>>,
74 conversation_scorer: Option<Arc<crate::conversation_scorer::ConversationScorer>>,
76}
77
78impl Evaluator {
79 pub fn new(config: EvaluationConfig) -> Self {
81 let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
82 ToolTrajectoryScorer::with_config(tc.clone())
83 } else {
84 ToolTrajectoryScorer::new()
85 };
86
87 let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
88 ResponseScorer::with_config(rc.clone())
89 } else {
90 ResponseScorer::new()
91 };
92
93 Self {
94 config,
95 tool_scorer,
96 response_scorer,
97 llm_judge: None,
98 structured_judge: None,
99 cost_tracker: None,
100 trace_analyzer: None,
101 #[cfg(feature = "embedding")]
102 embedding_scorer: None,
103 conversation_scorer: None,
104 }
105 }
106
107 pub fn with_llm_judge(config: EvaluationConfig, judge_model: Arc<dyn Llm>) -> Self {
109 let tool_scorer = if let Some(tc) = &config.criteria.tool_trajectory_config {
110 ToolTrajectoryScorer::with_config(tc.clone())
111 } else {
112 ToolTrajectoryScorer::new()
113 };
114
115 let response_scorer = if let Some(rc) = &config.criteria.response_match_config {
116 ResponseScorer::with_config(rc.clone())
117 } else {
118 ResponseScorer::new()
119 };
120
121 Self {
122 config,
123 tool_scorer,
124 response_scorer,
125 llm_judge: Some(LlmJudge::new(judge_model)),
126 structured_judge: None,
127 cost_tracker: None,
128 trace_analyzer: None,
129 #[cfg(feature = "embedding")]
130 embedding_scorer: None,
131 conversation_scorer: None,
132 }
133 }
134
135 pub fn set_llm_judge(&mut self, judge_model: Arc<dyn Llm>) {
137 self.llm_judge = Some(LlmJudge::new(judge_model));
138 }
139
140 pub fn has_llm_judge(&self) -> bool {
142 self.llm_judge.is_some()
143 }
144
145 pub fn set_structured_judge(&mut self, judge: Arc<StructuredJudge>) {
147 self.structured_judge = Some(judge);
148 }
149
150 pub fn set_cost_tracker(&mut self, tracker: CostTracker) {
152 self.cost_tracker = Some(tracker);
153 }
154
155 pub fn set_trace_analyzer(&mut self, analyzer: TraceAnalyzer) {
157 self.trace_analyzer = Some(analyzer);
158 }
159
160 #[cfg(feature = "embedding")]
162 pub fn set_embedding_scorer(&mut self, scorer: Arc<EmbeddingScorer>) {
163 self.embedding_scorer = Some(scorer);
164 }
165
166 pub fn set_conversation_scorer(
168 &mut self,
169 scorer: Arc<crate::conversation_scorer::ConversationScorer>,
170 ) {
171 self.conversation_scorer = Some(scorer);
172 }
173
174 pub fn has_structured_judge(&self) -> bool {
176 self.structured_judge.is_some()
177 }
178
179 pub fn has_cost_tracker(&self) -> bool {
181 self.cost_tracker.is_some()
182 }
183
184 pub fn has_trace_analyzer(&self) -> bool {
186 self.trace_analyzer.is_some()
187 }
188
189 #[cfg(feature = "embedding")]
191 pub fn has_embedding_scorer(&self) -> bool {
192 self.embedding_scorer.is_some()
193 }
194
195 pub fn has_conversation_scorer(&self) -> bool {
197 self.conversation_scorer.is_some()
198 }
199
200 pub async fn evaluate_file(
202 &self,
203 agent: Arc<dyn Agent>,
204 path: impl AsRef<Path>,
205 ) -> Result<EvaluationReport> {
206 let test_file = TestFile::load(path)?;
207 self.evaluate_test_file(agent, &test_file).await
208 }
209
210 pub async fn evaluate_test_file(
212 &self,
213 agent: Arc<dyn Agent>,
214 test_file: &TestFile,
215 ) -> Result<EvaluationReport> {
216 let started_at = chrono::Utc::now();
217 let run_id = format!("{}_{}", test_file.eval_set_id, uuid::Uuid::new_v4());
218 let mut results = Vec::new();
219
220 for eval_case in &test_file.eval_cases {
221 let result = self.evaluate_case(agent.clone(), eval_case).await;
222
223 match result {
224 Ok(r) => {
225 let passed = r.passed;
226 results.push(r);
227 if !passed && !self.config.continue_on_failure {
228 break;
229 }
230 }
231 Err(e) => {
232 results.push(EvaluationResult::failed(
234 &eval_case.eval_id,
235 HashMap::new(),
236 vec![Failure::new(
237 "execution",
238 Value::Null,
239 Value::String(e.to_string()),
240 0.0,
241 1.0,
242 )],
243 Duration::from_secs(0),
244 ));
245 if !self.config.continue_on_failure {
246 break;
247 }
248 }
249 }
250 }
251
252 Ok(EvaluationReport::new(&run_id, results, started_at))
253 }
254
255 pub async fn evaluate_case(
257 &self,
258 agent: Arc<dyn Agent>,
259 eval_case: &EvalCase,
260 ) -> Result<EvaluationResult> {
261 let start = Instant::now();
262 let mut all_scores: HashMap<String, f64> = HashMap::new();
263 let mut all_failures: Vec<Failure> = Vec::new();
264 let mut turn_results: Vec<TurnResult> = Vec::new();
265 let mut all_events: Vec<Event> = Vec::new();
266
267 for turn in &eval_case.conversation {
269 let turn_result = self.execute_turn(agent.clone(), turn).await?;
270
271 let (scores, failures) = self.score_turn(turn, &turn_result).await;
273
274 for (criterion, score) in &scores {
276 all_scores
277 .entry(criterion.clone())
278 .and_modify(|s| *s = (*s + score) / 2.0)
279 .or_insert(*score);
280 }
281 all_failures.extend(failures);
282
283 if self.config.collect_turn_details {
284 turn_results.push(turn_result);
285 }
286 }
287
288 let case_events = self.collect_case_events(agent.clone(), eval_case).await;
291 if let Ok(events) = case_events {
292 all_events = events;
293 }
294
295 let duration = start.elapsed();
296
297 let cost_metrics = self
299 .cost_tracker
300 .as_ref()
301 .map(|tracker| tracker.extract_metrics(&all_events, duration));
302
303 let trace_analysis =
305 self.trace_analyzer.as_ref().map(|analyzer| analyzer.analyze(&all_events));
306
307 let mut verdicts = Vec::new();
309 if let Some(judge) = &self.structured_judge
310 && let Some(last_turn_result) = turn_results.last()
311 && let (Some(expected), Some(actual)) =
312 (&last_turn_result.expected_response, &last_turn_result.actual_response)
313 {
314 match judge.judge(expected, actual, "overall_quality").await {
315 Ok(verdict) => {
316 all_scores.insert("structured_judge".to_string(), verdict.score);
317 verdicts.push(verdict);
318 }
319 Err(e) => {
320 tracing::warn!("Structured judge failed: {e}");
321 let fallback = crate::structured_judge::StructuredVerdict {
323 score: 0.0,
324 reasoning: format!("Judge error: {e}"),
325 verdict: crate::structured_judge::Verdict::Fail,
326 };
327 verdicts.push(fallback);
328 }
329 }
330 }
331
332 #[cfg(feature = "embedding")]
334 if let Some(scorer) = &self.embedding_scorer
335 && let Some(last_turn_result) = turn_results.last()
336 && let (Some(expected), Some(actual)) =
337 (&last_turn_result.expected_response, &last_turn_result.actual_response)
338 {
339 match scorer.score(expected, actual).await {
340 Ok(score) => {
341 all_scores.insert("embedding_similarity".to_string(), score);
342 }
343 Err(e) => {
344 tracing::warn!("Embedding scorer failed: {e}");
345 }
346 }
347 }
348
349 let passed = all_failures.is_empty();
350
351 let mut result = if passed {
352 EvaluationResult::passed(&eval_case.eval_id, all_scores, duration)
353 } else {
354 EvaluationResult::failed(&eval_case.eval_id, all_scores, all_failures, duration)
355 };
356
357 if self.config.collect_turn_details {
358 result = result.with_turn_results(turn_results);
359 }
360
361 result.cost_metrics = cost_metrics;
363 result.trace_analysis = trace_analysis;
364 result.verdicts = verdicts;
365
366 Ok(result)
367 }
368
369 async fn collect_case_events(
372 &self,
373 agent: Arc<dyn Agent>,
374 eval_case: &EvalCase,
375 ) -> Result<Vec<Event>> {
376 if self.cost_tracker.is_none() && self.trace_analyzer.is_none() {
378 return Ok(Vec::new());
379 }
380
381 if let Some(first_turn) = eval_case.conversation.first() {
383 let input_content = first_turn.user_content.to_adk_content();
384 self.run_agent(agent, input_content).await
385 } else {
386 Ok(Vec::new())
387 }
388 }
389
390 async fn execute_turn(&self, agent: Arc<dyn Agent>, turn: &Turn) -> Result<TurnResult> {
392 let input_content = turn.user_content.to_adk_content();
394
395 let events = self.run_agent(agent, input_content).await?;
397
398 let (actual_response, actual_tool_calls) = self.extract_from_events(&events);
400
401 let expected_response = turn.final_response.as_ref().map(|c| c.get_text());
403 let expected_tool_calls =
404 turn.intermediate_data.as_ref().map(|d| d.tool_uses.clone()).unwrap_or_default();
405
406 Ok(TurnResult {
407 invocation_id: turn.invocation_id.clone(),
408 actual_response,
409 expected_response,
410 actual_tool_calls,
411 expected_tool_calls,
412 scores: HashMap::new(),
413 })
414 }
415
416 async fn run_agent(&self, agent: Arc<dyn Agent>, input: Content) -> Result<Vec<Event>> {
418 let invocation_id = uuid::Uuid::new_v4().to_string();
420 let ctx = Arc::new(EvalInvocationContext::new(invocation_id, input, agent.clone()));
421
422 let stream = agent.run(ctx).await.map_err(|e| {
424 crate::error::EvalError::ExecutionError(format!("Agent run failed: {}", e))
425 })?;
426
427 let events: Vec<Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
429
430 Ok(events)
431 }
432
433 fn extract_from_events(&self, events: &[Event]) -> (Option<String>, Vec<ToolUse>) {
435 let mut response_text = String::new();
436 let mut tool_calls = Vec::new();
437
438 for event in events {
439 if let Some(content) = event.content() {
441 for part in &content.parts {
442 if let Some(text) = part.text() {
444 response_text.push_str(text);
445 }
446 if let adk_core::Part::FunctionCall { name, args, .. } = part {
448 tool_calls.push(ToolUse {
449 name: name.clone(),
450 args: args.clone(),
451 expected_response: None,
452 });
453 }
454 }
455 }
456 }
457
458 let response = if response_text.is_empty() { None } else { Some(response_text) };
459
460 (response, tool_calls)
461 }
462
463 async fn score_turn(
465 &self,
466 turn: &Turn,
467 result: &TurnResult,
468 ) -> (HashMap<String, f64>, Vec<Failure>) {
469 let mut scores = HashMap::new();
470 let mut failures = Vec::new();
471
472 if let Some(threshold) = self.config.criteria.tool_trajectory_score {
474 let score =
475 self.tool_scorer.score(&result.expected_tool_calls, &result.actual_tool_calls);
476 scores.insert("tool_trajectory".to_string(), score);
477
478 if score < threshold {
479 failures.push(
480 Failure::new(
481 "tool_trajectory",
482 serde_json::to_value(&result.expected_tool_calls).unwrap_or_default(),
483 serde_json::to_value(&result.actual_tool_calls).unwrap_or_default(),
484 score,
485 threshold,
486 )
487 .with_details(&format!(
488 "Expected {} tool calls, got {}",
489 result.expected_tool_calls.len(),
490 result.actual_tool_calls.len()
491 )),
492 );
493 }
494 }
495
496 if let Some(threshold) = self.config.criteria.response_similarity {
498 if let (Some(expected), Some(actual)) =
499 (&result.expected_response, &result.actual_response)
500 {
501 let score = self.response_scorer.score(expected, actual);
502 scores.insert("response_similarity".to_string(), score);
503
504 if score < threshold {
505 failures.push(
506 Failure::new(
507 "response_similarity",
508 Value::String(expected.clone()),
509 Value::String(actual.clone()),
510 score,
511 threshold,
512 )
513 .with_details("Response text differs from expected"),
514 );
515 }
516 } else if result.expected_response.is_some() && result.actual_response.is_none() {
517 scores.insert("response_similarity".to_string(), 0.0);
518 failures.push(
519 Failure::new(
520 "response_similarity",
521 Value::String(result.expected_response.clone().unwrap_or_default()),
522 Value::Null,
523 0.0,
524 threshold,
525 )
526 .with_details("No response received"),
527 );
528 }
529 }
530
531 if let Some(threshold) = self.config.criteria.semantic_match_score
533 && let Some(judge) = &self.llm_judge
534 && let (Some(expected), Some(actual)) =
535 (&result.expected_response, &result.actual_response)
536 {
537 match judge
538 .semantic_match(
539 expected,
540 actual,
541 self.config.criteria.semantic_match_config.as_ref(),
542 )
543 .await
544 {
545 Ok(semantic_result) => {
546 scores.insert("semantic_match".to_string(), semantic_result.score);
547 if semantic_result.score < threshold {
548 failures.push(
549 Failure::new(
550 "semantic_match",
551 Value::String(expected.clone()),
552 Value::String(actual.clone()),
553 semantic_result.score,
554 threshold,
555 )
556 .with_details(&semantic_result.reasoning),
557 );
558 }
559 }
560 Err(e) => {
561 failures.push(
563 Failure::new(
564 "semantic_match",
565 Value::String(expected.clone()),
566 Value::String(actual.clone()),
567 0.0,
568 threshold,
569 )
570 .with_details(&format!("LLM judge error: {}", e)),
571 );
572 }
573 }
574 }
575
576 if let Some(threshold) = self.config.criteria.rubric_quality_score
578 && let Some(judge) = &self.llm_judge
579 && let Some(rubric_config) = &self.config.criteria.rubric_config
580 && let Some(actual) = &result.actual_response
581 {
582 let context = turn.user_content.get_text();
584 match judge.evaluate_rubrics(actual, &context, rubric_config).await {
585 Ok(rubric_result) => {
586 scores.insert("rubric_quality".to_string(), rubric_result.overall_score);
587 for rs in &rubric_result.rubric_scores {
589 scores.insert(format!("rubric_{}", rs.name), rs.score);
590 }
591 if rubric_result.overall_score < threshold {
592 let details = rubric_result
593 .rubric_scores
594 .iter()
595 .map(|rs| format!("{}: {:.2} - {}", rs.name, rs.score, rs.reasoning))
596 .collect::<Vec<_>>()
597 .join("; ");
598 failures.push(
599 Failure::new(
600 "rubric_quality",
601 Value::Number(
602 serde_json::Number::from_f64(threshold)
603 .unwrap_or(serde_json::Number::from(0)),
604 ),
605 Value::Number(
606 serde_json::Number::from_f64(rubric_result.overall_score)
607 .unwrap_or(serde_json::Number::from(0)),
608 ),
609 rubric_result.overall_score,
610 threshold,
611 )
612 .with_details(&details),
613 );
614 }
615 }
616 Err(e) => {
617 failures.push(
618 Failure::new("rubric_quality", Value::Null, Value::Null, 0.0, threshold)
619 .with_details(&format!("LLM judge error: {}", e)),
620 );
621 }
622 }
623 }
624
625 if let Some(threshold) = self.config.criteria.safety_score
627 && let Some(judge) = &self.llm_judge
628 && let Some(actual) = &result.actual_response
629 {
630 match judge.evaluate_safety(actual).await {
631 Ok(safety_result) => {
632 scores.insert("safety".to_string(), safety_result.score);
633 if safety_result.score < threshold {
634 failures.push(
635 Failure::new(
636 "safety",
637 Value::Number(
638 serde_json::Number::from_f64(threshold)
639 .unwrap_or(serde_json::Number::from(0)),
640 ),
641 Value::Number(
642 serde_json::Number::from_f64(safety_result.score)
643 .unwrap_or(serde_json::Number::from(0)),
644 ),
645 safety_result.score,
646 threshold,
647 )
648 .with_details(&format!(
649 "Safety issues: {}",
650 safety_result.issues.join(", ")
651 )),
652 );
653 }
654 }
655 Err(e) => {
656 failures.push(
657 Failure::new("safety", Value::Null, Value::Null, 0.0, threshold)
658 .with_details(&format!("LLM judge error: {}", e)),
659 );
660 }
661 }
662 }
663
664 if let Some(threshold) = self.config.criteria.hallucination_score
666 && let Some(judge) = &self.llm_judge
667 && let Some(actual) = &result.actual_response
668 {
669 let context = turn.user_content.get_text();
670 let ground_truth = result.expected_response.as_deref();
671 match judge.detect_hallucinations(actual, &context, ground_truth).await {
672 Ok(hallucination_result) => {
673 scores.insert("hallucination".to_string(), hallucination_result.score);
674 if hallucination_result.score < threshold {
675 failures.push(
676 Failure::new(
677 "hallucination",
678 Value::Number(
679 serde_json::Number::from_f64(threshold)
680 .unwrap_or(serde_json::Number::from(0)),
681 ),
682 Value::Number(
683 serde_json::Number::from_f64(hallucination_result.score)
684 .unwrap_or(serde_json::Number::from(0)),
685 ),
686 hallucination_result.score,
687 threshold,
688 )
689 .with_details(&format!(
690 "Hallucinations detected: {}",
691 hallucination_result.issues.join(", ")
692 )),
693 );
694 }
695 }
696 Err(e) => {
697 failures.push(
698 Failure::new("hallucination", Value::Null, Value::Null, 0.0, threshold)
699 .with_details(&format!("LLM judge error: {}", e)),
700 );
701 }
702 }
703 }
704
705 (scores, failures)
706 }
707
708 pub async fn evaluate_cases_parallel(
710 &self,
711 agent: Arc<dyn Agent>,
712 cases: &[EvalCase],
713 concurrency: usize,
714 ) -> Vec<Result<EvaluationResult>> {
715 use futures::stream::{self, StreamExt};
716
717 let results: Vec<_> = stream::iter(cases)
718 .map(|case| {
719 let agent = agent.clone();
720 async move { self.evaluate_case(agent, case).await }
721 })
722 .buffer_unordered(concurrency)
723 .collect()
724 .await;
725
726 results
727 }
728
729 pub async fn evaluate_directory(
731 &self,
732 agent: Arc<dyn Agent>,
733 dir: impl AsRef<Path>,
734 ) -> Result<Vec<EvaluationReport>> {
735 let mut reports = Vec::new();
736
737 let entries = std::fs::read_dir(dir)?;
738 for entry in entries {
739 let entry = entry?;
740 let path = entry.path();
741
742 if path.extension().is_some_and(|ext| ext == "json")
743 && let Some(name) = path.file_name().and_then(|n| n.to_str())
744 && name.ends_with(".test.json")
745 {
746 let report = self.evaluate_file(agent.clone(), &path).await?;
747 reports.push(report);
748 }
749 }
750
751 Ok(reports)
752 }
753
754 #[cfg(feature = "personas")]
768 pub async fn evaluate_multi_turn(
769 &self,
770 agent: Arc<dyn Agent>,
771 simulator: &crate::personas::UserSimulator,
772 num_turns: usize,
773 ) -> Result<Vec<Content>> {
774 let mut history: Vec<Content> = Vec::new();
775
776 for _turn_idx in 0..num_turns {
777 let user_message = simulator.generate_message(&history).await?;
779 history.push(user_message.clone());
780
781 let events = self.run_agent(agent.clone(), user_message).await?;
783
784 let (response_text, _tool_calls) = self.extract_from_events(&events);
786 if let Some(text) = response_text {
787 history.push(Content::new("model").with_text(text));
788 }
789 }
790
791 Ok(history)
792 }
793}
794
795impl Default for Evaluator {
796 fn default() -> Self {
797 Self::new(EvaluationConfig::default())
798 }
799}
800
801struct EvalInvocationContext {
807 invocation_id: String,
808 user_content: Content,
809 agent: Arc<dyn Agent>,
810 session: EvalSession,
811 run_config: adk_core::RunConfig,
812 ended: std::sync::atomic::AtomicBool,
813}
814
815impl EvalInvocationContext {
816 fn new(invocation_id: String, user_content: Content, agent: Arc<dyn Agent>) -> Self {
817 let session_id = format!("eval-session-{}", uuid::Uuid::new_v4());
818 Self {
819 invocation_id,
820 user_content,
821 agent,
822 session: EvalSession::new(session_id),
823 run_config: adk_core::RunConfig::default(),
824 ended: std::sync::atomic::AtomicBool::new(false),
825 }
826 }
827}
828
829impl adk_core::ReadonlyContext for EvalInvocationContext {
830 fn invocation_id(&self) -> &str {
831 &self.invocation_id
832 }
833
834 fn agent_name(&self) -> &str {
835 self.agent.name()
836 }
837
838 fn user_id(&self) -> &str {
839 "eval_user"
840 }
841
842 fn app_name(&self) -> &str {
843 "eval_app"
844 }
845
846 fn session_id(&self) -> &str {
847 &self.session.id
848 }
849
850 fn branch(&self) -> &str {
851 "main"
852 }
853
854 fn user_content(&self) -> &Content {
855 &self.user_content
856 }
857}
858
859#[async_trait]
860impl adk_core::CallbackContext for EvalInvocationContext {
861 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
862 None
863 }
864}
865
866#[async_trait]
867impl adk_core::InvocationContext for EvalInvocationContext {
868 fn agent(&self) -> Arc<dyn Agent> {
869 self.agent.clone()
870 }
871
872 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
873 None
874 }
875
876 fn session(&self) -> &dyn adk_core::Session {
877 &self.session
878 }
879
880 fn run_config(&self) -> &adk_core::RunConfig {
881 &self.run_config
882 }
883
884 fn end_invocation(&self) {
885 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
886 }
887
888 fn ended(&self) -> bool {
889 self.ended.load(std::sync::atomic::Ordering::SeqCst)
890 }
891}
892
893struct EvalSession {
895 id: String,
896 state: EvalState,
897}
898
899impl EvalSession {
900 fn new(id: String) -> Self {
901 Self { id, state: EvalState::new() }
902 }
903}
904
905impl adk_core::Session for EvalSession {
906 fn id(&self) -> &str {
907 &self.id
908 }
909
910 fn app_name(&self) -> &str {
911 "eval_app"
912 }
913
914 fn user_id(&self) -> &str {
915 "eval_user"
916 }
917
918 fn state(&self) -> &dyn adk_core::State {
919 &self.state
920 }
921
922 fn conversation_history(&self) -> Vec<Content> {
923 vec![]
924 }
925}
926
927struct EvalState {
929 data: std::sync::RwLock<HashMap<String, serde_json::Value>>,
930}
931
932impl EvalState {
933 fn new() -> Self {
934 Self { data: std::sync::RwLock::new(HashMap::new()) }
935 }
936}
937
938impl adk_core::State for EvalState {
939 fn get(&self, key: &str) -> Option<serde_json::Value> {
940 self.data.read().ok()?.get(key).cloned()
941 }
942
943 fn set(&mut self, key: String, value: serde_json::Value) {
944 if let Err(msg) = adk_core::validate_state_key(&key) {
945 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
946 return;
947 }
948 if let Ok(mut data) = self.data.write() {
949 data.insert(key, value);
950 }
951 }
952
953 fn all(&self) -> HashMap<String, serde_json::Value> {
954 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
955 }
956}
957
958#[cfg(test)]
959mod tests {
960 use super::*;
961
962 #[test]
963 fn test_evaluator_creation() {
964 let config = EvaluationConfig::with_criteria(
965 EvaluationCriteria::exact_tools().with_response_similarity(0.8),
966 );
967 let evaluator = Evaluator::new(config);
968 assert!(evaluator.config.criteria.tool_trajectory_score.is_some());
969 assert!(evaluator.config.criteria.response_similarity.is_some());
970 }
971
972 #[tokio::test]
973 async fn test_turn_scoring() {
974 let config = EvaluationConfig::with_criteria(EvaluationCriteria {
975 tool_trajectory_score: Some(1.0),
976 response_similarity: Some(0.8),
977 ..Default::default()
978 });
979 let evaluator = Evaluator::new(config);
980
981 let turn = Turn {
982 invocation_id: "test".to_string(),
983 user_content: crate::schema::ContentData::text("Hello"),
984 final_response: Some(crate::schema::ContentData::model_response("Hi there!")),
985 intermediate_data: Some(crate::schema::IntermediateData {
986 tool_uses: vec![ToolUse::new("greet")],
987 ..Default::default()
988 }),
989 };
990
991 let result = TurnResult {
992 invocation_id: "test".to_string(),
993 actual_response: Some("Hi there!".to_string()),
994 expected_response: Some("Hi there!".to_string()),
995 actual_tool_calls: vec![ToolUse::new("greet")],
996 expected_tool_calls: vec![ToolUse::new("greet")],
997 scores: HashMap::new(),
998 };
999
1000 let (scores, failures) = evaluator.score_turn(&turn, &result).await;
1001 assert!(failures.is_empty());
1002 assert_eq!(scores.get("tool_trajectory"), Some(&1.0));
1003 assert_eq!(scores.get("response_similarity"), Some(&1.0));
1004 }
1005}