adk_eval/
scoring.rs

1//! Scoring implementations for evaluation criteria
2//!
3//! Provides various scorers for tool trajectory, response similarity, etc.
4
5#![allow(clippy::needless_range_loop)] // Intentional for DP algorithms
6
7use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11/// Scorer for tool trajectory matching
12pub struct ToolTrajectoryScorer {
13    config: ToolTrajectoryConfig,
14}
15
16impl ToolTrajectoryScorer {
17    /// Create a new scorer with default config
18    pub fn new() -> Self {
19        Self { config: ToolTrajectoryConfig::default() }
20    }
21
22    /// Create with custom config
23    pub fn with_config(config: ToolTrajectoryConfig) -> Self {
24        Self { config }
25    }
26
27    /// Score tool trajectory
28    ///
29    /// Returns a score from 0.0 to 1.0 indicating how well the actual
30    /// tool calls match the expected tool calls.
31    pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
32        if expected.is_empty() && actual.is_empty() {
33            return 1.0;
34        }
35
36        if expected.is_empty() || actual.is_empty() {
37            return 0.0;
38        }
39
40        if self.config.strict_order {
41            self.score_ordered(expected, actual)
42        } else {
43            self.score_unordered(expected, actual)
44        }
45    }
46
47    /// Score with strict ordering
48    fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
49        let mut matches = 0;
50        let mut exp_idx = 0;
51        let mut act_idx = 0;
52
53        while exp_idx < expected.len() && act_idx < actual.len() {
54            if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
55                matches += 1;
56                exp_idx += 1;
57                act_idx += 1;
58            } else {
59                // Try to find the expected tool in remaining actual calls
60                let mut found = false;
61                for i in (act_idx + 1)..actual.len() {
62                    if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
63                        matches += 1;
64                        exp_idx += 1;
65                        act_idx = i + 1;
66                        found = true;
67                        break;
68                    }
69                }
70                if !found {
71                    exp_idx += 1;
72                }
73            }
74        }
75
76        let max_len = expected.len().max(actual.len());
77        matches as f64 / max_len as f64
78    }
79
80    /// Score without strict ordering (set comparison)
81    fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
82        let mut matched_actual: HashSet<usize> = HashSet::new();
83        let mut matches = 0;
84
85        for exp in expected {
86            for (i, act) in actual.iter().enumerate() {
87                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
88                    matches += 1;
89                    matched_actual.insert(i);
90                    break;
91                }
92            }
93        }
94
95        let max_len = expected.len().max(actual.len());
96        matches as f64 / max_len as f64
97    }
98
99    /// Get detailed comparison
100    pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
101        let mut matched = Vec::new();
102        let mut missing = Vec::new();
103        let mut extra = Vec::new();
104        let mut matched_actual: HashSet<usize> = HashSet::new();
105
106        for exp in expected {
107            let mut found = false;
108            for (i, act) in actual.iter().enumerate() {
109                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
110                    matched.push((exp.clone(), act.clone()));
111                    matched_actual.insert(i);
112                    found = true;
113                    break;
114                }
115            }
116            if !found {
117                missing.push(exp.clone());
118            }
119        }
120
121        for (i, act) in actual.iter().enumerate() {
122            if !matched_actual.contains(&i) {
123                extra.push(act.clone());
124            }
125        }
126
127        ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
128    }
129}
130
131impl Default for ToolTrajectoryScorer {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137/// Detailed comparison of tool trajectories
138#[derive(Debug, Clone)]
139pub struct ToolTrajectoryComparison {
140    /// Tools that matched
141    pub matched: Vec<(ToolUse, ToolUse)>,
142    /// Expected tools that weren't called
143    pub missing: Vec<ToolUse>,
144    /// Actual tools that weren't expected
145    pub extra: Vec<ToolUse>,
146    /// Overall score
147    pub score: f64,
148}
149
150/// Scorer for response text similarity
151pub struct ResponseScorer {
152    config: ResponseMatchConfig,
153}
154
155impl ResponseScorer {
156    /// Create a new scorer with default config
157    pub fn new() -> Self {
158        Self { config: ResponseMatchConfig::default() }
159    }
160
161    /// Create with custom config
162    pub fn with_config(config: ResponseMatchConfig) -> Self {
163        Self { config }
164    }
165
166    /// Score response similarity
167    pub fn score(&self, expected: &str, actual: &str) -> f64 {
168        let (expected, actual) = if self.config.normalize {
169            (self.normalize(expected), self.normalize(actual))
170        } else {
171            (expected.to_string(), actual.to_string())
172        };
173
174        match self.config.algorithm {
175            SimilarityAlgorithm::Exact => {
176                if expected == actual {
177                    1.0
178                } else {
179                    0.0
180                }
181            }
182            SimilarityAlgorithm::Contains => {
183                if actual.contains(&expected) || expected.contains(&actual) {
184                    1.0
185                } else {
186                    0.0
187                }
188            }
189            SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
190            SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
191            SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
192            SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
193            SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
194        }
195    }
196
197    /// Normalize text for comparison
198    fn normalize(&self, text: &str) -> String {
199        let mut result = text.to_string();
200
201        if self.config.ignore_case {
202            result = result.to_lowercase();
203        }
204
205        if self.config.ignore_punctuation {
206            result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
207        }
208
209        // Normalize whitespace
210        result.split_whitespace().collect::<Vec<_>>().join(" ")
211    }
212
213    /// Levenshtein distance based similarity
214    fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
215        let distance = self.levenshtein_distance(a, b);
216        let max_len = a.len().max(b.len());
217        if max_len == 0 {
218            1.0
219        } else {
220            1.0 - (distance as f64 / max_len as f64)
221        }
222    }
223
224    /// Calculate Levenshtein distance
225    fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
226        let a_chars: Vec<char> = a.chars().collect();
227        let b_chars: Vec<char> = b.chars().collect();
228        let m = a_chars.len();
229        let n = b_chars.len();
230
231        if m == 0 {
232            return n;
233        }
234        if n == 0 {
235            return m;
236        }
237
238        let mut dp = vec![vec![0; n + 1]; m + 1];
239
240        for i in 0..=m {
241            dp[i][0] = i;
242        }
243        for j in 0..=n {
244            dp[0][j] = j;
245        }
246
247        for i in 1..=m {
248            for j in 1..=n {
249                let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
250                dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
251            }
252        }
253
254        dp[m][n]
255    }
256
257    /// Jaccard similarity (word overlap)
258    fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
259        let a_words: HashSet<&str> = a.split_whitespace().collect();
260        let b_words: HashSet<&str> = b.split_whitespace().collect();
261
262        if a_words.is_empty() && b_words.is_empty() {
263            return 1.0;
264        }
265
266        let intersection = a_words.intersection(&b_words).count();
267        let union = a_words.union(&b_words).count();
268
269        if union == 0 {
270            0.0
271        } else {
272            intersection as f64 / union as f64
273        }
274    }
275
276    /// ROUGE-N score (n-gram overlap)
277    fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
278        let ref_ngrams = self.get_ngrams(reference, n);
279        let cand_ngrams = self.get_ngrams(candidate, n);
280
281        if ref_ngrams.is_empty() {
282            return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
283        }
284
285        let overlap = ref_ngrams.intersection(&cand_ngrams).count();
286        overlap as f64 / ref_ngrams.len() as f64
287    }
288
289    /// Get n-grams from text
290    fn get_ngrams<'a>(&self, text: &'a str, n: usize) -> HashSet<Vec<&'a str>> {
291        let words: Vec<&str> = text.split_whitespace().collect();
292        if words.len() < n {
293            return HashSet::new();
294        }
295
296        words.windows(n).map(|w| w.to_vec()).collect()
297    }
298
299    /// ROUGE-L score (longest common subsequence)
300    fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
301        let ref_words: Vec<&str> = reference.split_whitespace().collect();
302        let cand_words: Vec<&str> = candidate.split_whitespace().collect();
303
304        if ref_words.is_empty() {
305            return if cand_words.is_empty() { 1.0 } else { 0.0 };
306        }
307
308        let lcs_len = self.lcs_length(&ref_words, &cand_words);
309
310        // F1 score of precision and recall
311        let precision =
312            if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
313        let recall = lcs_len as f64 / ref_words.len() as f64;
314
315        if precision + recall == 0.0 {
316            0.0
317        } else {
318            2.0 * precision * recall / (precision + recall)
319        }
320    }
321
322    /// Length of longest common subsequence
323    fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
324        let m = a.len();
325        let n = b.len();
326
327        if m == 0 || n == 0 {
328            return 0;
329        }
330
331        let mut dp = vec![vec![0; n + 1]; m + 1];
332
333        for i in 1..=m {
334            for j in 1..=n {
335                if a[i - 1] == b[j - 1] {
336                    dp[i][j] = dp[i - 1][j - 1] + 1;
337                } else {
338                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
339                }
340            }
341        }
342
343        dp[m][n]
344    }
345}
346
347impl Default for ResponseScorer {
348    fn default() -> Self {
349        Self::new()
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use serde_json::json;
357
358    #[test]
359    fn test_tool_trajectory_exact_match() {
360        let scorer = ToolTrajectoryScorer::new();
361
362        let expected = vec![
363            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
364            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
365        ];
366
367        let actual = vec![
368            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
369            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
370        ];
371
372        assert_eq!(scorer.score(&expected, &actual), 1.0);
373    }
374
375    #[test]
376    fn test_tool_trajectory_partial_match() {
377        let scorer = ToolTrajectoryScorer::new();
378
379        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
380
381        let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
382
383        let score = scorer.score(&expected, &actual);
384        assert!(score > 0.0 && score < 1.0);
385    }
386
387    #[test]
388    fn test_tool_trajectory_unordered() {
389        let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
390            strict_order: false,
391            strict_args: false,
392        });
393
394        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
395
396        let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
397
398        assert_eq!(scorer.score(&expected, &actual), 1.0);
399    }
400
401    #[test]
402    fn test_response_exact_match() {
403        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
404            algorithm: SimilarityAlgorithm::Exact,
405            normalize: true,
406            ignore_case: true,
407            ignore_punctuation: false,
408        });
409
410        assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
411        assert_eq!(scorer.score("Hello", "World"), 0.0);
412    }
413
414    #[test]
415    fn test_response_jaccard() {
416        let scorer = ResponseScorer::new();
417
418        let score = scorer.score("the quick brown fox", "the quick brown dog");
419        assert!(score > 0.5 && score < 1.0);
420    }
421
422    #[test]
423    fn test_response_levenshtein() {
424        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
425            algorithm: SimilarityAlgorithm::Levenshtein,
426            ..Default::default()
427        });
428
429        let score = scorer.score("hello", "hallo");
430        assert!(score > 0.7);
431
432        let score = scorer.score("abc", "xyz");
433        assert!(score < 0.5);
434    }
435
436    #[test]
437    fn test_rouge_l() {
438        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
439            algorithm: SimilarityAlgorithm::RougeL,
440            ..Default::default()
441        });
442
443        let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
444        assert!(score > 0.7);
445    }
446}