adk_eval/
criteria.rs

1//! Evaluation criteria definitions
2//!
3//! Defines the various criteria that can be used to evaluate agent responses.
4
5use serde::{Deserialize, Serialize};
6
7/// Collection of evaluation criteria
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct EvaluationCriteria {
10    /// Tool trajectory matching score threshold (0.0 - 1.0)
11    /// Checks if the agent called the expected tools in the expected order
12    #[serde(default)]
13    pub tool_trajectory_score: Option<f64>,
14
15    /// Tool trajectory configuration
16    #[serde(default)]
17    pub tool_trajectory_config: Option<ToolTrajectoryConfig>,
18
19    /// Response text similarity threshold (0.0 - 1.0)
20    /// Uses text similarity metrics to compare expected vs actual response
21    #[serde(default)]
22    pub response_similarity: Option<f64>,
23
24    /// Response matching configuration
25    #[serde(default)]
26    pub response_match_config: Option<ResponseMatchConfig>,
27
28    /// LLM-judged semantic match threshold (0.0 - 1.0)
29    /// Uses an LLM to judge if responses are semantically equivalent
30    #[serde(default)]
31    pub semantic_match_score: Option<f64>,
32
33    /// Semantic match configuration
34    #[serde(default)]
35    pub semantic_match_config: Option<SemanticMatchConfig>,
36
37    /// Rubric-based quality score threshold (0.0 - 1.0)
38    /// Evaluates response quality against defined rubrics
39    #[serde(default)]
40    pub rubric_quality_score: Option<f64>,
41
42    /// Rubric configuration
43    #[serde(default)]
44    pub rubric_config: Option<RubricConfig>,
45
46    /// Safety score threshold (0.0 - 1.0)
47    /// Checks for unsafe or harmful content
48    #[serde(default)]
49    pub safety_score: Option<f64>,
50
51    /// Hallucination detection threshold (0.0 - 1.0)
52    /// Detects factual inaccuracies or made-up information
53    #[serde(default)]
54    pub hallucination_score: Option<f64>,
55
56    /// Custom criteria for extensibility
57    #[serde(default)]
58    pub custom: Vec<CustomCriterion>,
59}
60
61impl EvaluationCriteria {
62    /// Create criteria requiring exact tool trajectory match
63    pub fn exact_tools() -> Self {
64        Self {
65            tool_trajectory_score: Some(1.0),
66            tool_trajectory_config: Some(ToolTrajectoryConfig {
67                strict_order: true,
68                strict_args: true,
69            }),
70            ..Default::default()
71        }
72    }
73
74    /// Create criteria for semantic response matching
75    pub fn semantic_match(threshold: f64) -> Self {
76        Self { semantic_match_score: Some(threshold), ..Default::default() }
77    }
78
79    /// Create criteria with response similarity
80    pub fn response_similarity(threshold: f64) -> Self {
81        Self { response_similarity: Some(threshold), ..Default::default() }
82    }
83
84    /// Add tool trajectory requirement
85    pub fn with_tool_trajectory(mut self, threshold: f64) -> Self {
86        self.tool_trajectory_score = Some(threshold);
87        self
88    }
89
90    /// Add response similarity requirement
91    pub fn with_response_similarity(mut self, threshold: f64) -> Self {
92        self.response_similarity = Some(threshold);
93        self
94    }
95
96    /// Add semantic match requirement
97    pub fn with_semantic_match(mut self, threshold: f64) -> Self {
98        self.semantic_match_score = Some(threshold);
99        self
100    }
101
102    /// Add rubric-based evaluation
103    pub fn with_rubrics(mut self, threshold: f64, rubrics: Vec<Rubric>) -> Self {
104        self.rubric_quality_score = Some(threshold);
105        self.rubric_config = Some(RubricConfig { rubrics });
106        self
107    }
108
109    /// Check if any criteria are defined
110    pub fn has_criteria(&self) -> bool {
111        self.tool_trajectory_score.is_some()
112            || self.response_similarity.is_some()
113            || self.semantic_match_score.is_some()
114            || self.rubric_quality_score.is_some()
115            || self.safety_score.is_some()
116            || self.hallucination_score.is_some()
117            || !self.custom.is_empty()
118    }
119}
120
121/// Configuration for tool trajectory matching
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ToolTrajectoryConfig {
124    /// Require tools to be called in exact order
125    #[serde(default = "default_true")]
126    pub strict_order: bool,
127    /// Require exact argument match (vs partial)
128    #[serde(default)]
129    pub strict_args: bool,
130}
131
132impl Default for ToolTrajectoryConfig {
133    fn default() -> Self {
134        Self { strict_order: true, strict_args: false }
135    }
136}
137
138/// Configuration for response matching
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ResponseMatchConfig {
141    /// Similarity algorithm to use
142    #[serde(default)]
143    pub algorithm: SimilarityAlgorithm,
144    /// Whether to normalize text before comparison
145    #[serde(default = "default_true")]
146    pub normalize: bool,
147    /// Whether to ignore case
148    #[serde(default = "default_true")]
149    pub ignore_case: bool,
150    /// Whether to ignore punctuation
151    #[serde(default)]
152    pub ignore_punctuation: bool,
153}
154
155impl Default for ResponseMatchConfig {
156    fn default() -> Self {
157        Self {
158            algorithm: SimilarityAlgorithm::default(),
159            normalize: true,
160            ignore_case: true,
161            ignore_punctuation: false,
162        }
163    }
164}
165
166/// Similarity algorithms for text comparison
167#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168#[serde(rename_all = "snake_case")]
169pub enum SimilarityAlgorithm {
170    /// Exact string match
171    Exact,
172    /// Contains check
173    Contains,
174    /// Levenshtein distance based
175    Levenshtein,
176    /// Jaccard similarity (word overlap)
177    #[default]
178    Jaccard,
179    /// ROUGE-1 (unigram overlap)
180    Rouge1,
181    /// ROUGE-2 (bigram overlap)
182    Rouge2,
183    /// ROUGE-L (longest common subsequence)
184    RougeL,
185}
186
187/// Configuration for LLM-judged semantic matching
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct SemanticMatchConfig {
190    /// Model to use for judging
191    #[serde(default = "default_judge_model")]
192    pub judge_model: String,
193    /// Custom prompt for the judge (optional)
194    pub custom_prompt: Option<String>,
195}
196
197impl Default for SemanticMatchConfig {
198    fn default() -> Self {
199        Self { judge_model: default_judge_model(), custom_prompt: None }
200    }
201}
202
203fn default_judge_model() -> String {
204    "gemini-2.0-flash".to_string()
205}
206
207/// Configuration for rubric-based evaluation
208#[derive(Debug, Clone, Default, Serialize, Deserialize)]
209pub struct RubricConfig {
210    /// List of rubrics to evaluate against
211    pub rubrics: Vec<Rubric>,
212}
213
214/// A single rubric for quality assessment
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct Rubric {
217    /// Rubric name
218    pub name: String,
219    /// What this rubric measures
220    pub description: String,
221    /// Weight for this rubric (0.0 - 1.0)
222    #[serde(default = "default_weight")]
223    pub weight: f64,
224    /// Scoring levels (optional)
225    #[serde(default)]
226    pub levels: Vec<RubricLevel>,
227}
228
229impl Rubric {
230    /// Create a new rubric
231    pub fn new(name: &str, description: &str) -> Self {
232        Self {
233            name: name.to_string(),
234            description: description.to_string(),
235            weight: 1.0,
236            levels: vec![],
237        }
238    }
239
240    /// Set weight
241    pub fn with_weight(mut self, weight: f64) -> Self {
242        self.weight = weight;
243        self
244    }
245
246    /// Add scoring levels
247    pub fn with_levels(mut self, levels: Vec<RubricLevel>) -> Self {
248        self.levels = levels;
249        self
250    }
251}
252
253/// A scoring level for a rubric
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct RubricLevel {
256    /// Score for this level (0.0 - 1.0)
257    pub score: f64,
258    /// Description of what qualifies for this level
259    pub description: String,
260}
261
262/// Custom evaluation criterion
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct CustomCriterion {
265    /// Criterion name
266    pub name: String,
267    /// Description of what this measures
268    pub description: String,
269    /// Score threshold (0.0 - 1.0)
270    pub threshold: f64,
271    /// Custom configuration as JSON
272    #[serde(default)]
273    pub config: serde_json::Value,
274}
275
276fn default_true() -> bool {
277    true
278}
279
280fn default_weight() -> f64 {
281    1.0
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_criteria_builder() {
290        let criteria = EvaluationCriteria::exact_tools()
291            .with_response_similarity(0.8)
292            .with_semantic_match(0.9);
293
294        assert_eq!(criteria.tool_trajectory_score, Some(1.0));
295        assert_eq!(criteria.response_similarity, Some(0.8));
296        assert_eq!(criteria.semantic_match_score, Some(0.9));
297        assert!(criteria.has_criteria());
298    }
299
300    #[test]
301    fn test_rubric_creation() {
302        let rubric = Rubric::new("Accuracy", "Response is factually correct")
303            .with_weight(0.7)
304            .with_levels(vec![
305                RubricLevel { score: 1.0, description: "Completely accurate".to_string() },
306                RubricLevel { score: 0.5, description: "Partially accurate".to_string() },
307                RubricLevel { score: 0.0, description: "Inaccurate".to_string() },
308            ]);
309
310        assert_eq!(rubric.name, "Accuracy");
311        assert_eq!(rubric.weight, 0.7);
312        assert_eq!(rubric.levels.len(), 3);
313    }
314
315    #[test]
316    fn test_default_criteria() {
317        let criteria = EvaluationCriteria::default();
318        assert!(!criteria.has_criteria());
319    }
320}