adk_eval/
scoring.rs

1//! Scoring implementations for evaluation criteria
2//!
3//! Provides various scorers for tool trajectory, response similarity, etc.
4
5#![allow(clippy::needless_range_loop)] // Intentional for DP algorithms
6
7use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11/// Scorer for tool trajectory matching
12pub struct ToolTrajectoryScorer {
13    config: ToolTrajectoryConfig,
14}
15
16impl ToolTrajectoryScorer {
17    /// Create a new scorer with default config
18    pub fn new() -> Self {
19        Self { config: ToolTrajectoryConfig::default() }
20    }
21
22    /// Create with custom config
23    pub fn with_config(config: ToolTrajectoryConfig) -> Self {
24        Self { config }
25    }
26
27    /// Score tool trajectory
28    ///
29    /// Returns a score from 0.0 to 1.0 indicating how well the actual
30    /// tool calls match the expected tool calls.
31    pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
32        if expected.is_empty() && actual.is_empty() {
33            return 1.0;
34        }
35
36        if expected.is_empty() || actual.is_empty() {
37            return 0.0;
38        }
39
40        if self.config.strict_order {
41            self.score_ordered(expected, actual)
42        } else {
43            self.score_unordered(expected, actual)
44        }
45    }
46
47    /// Score with strict ordering
48    fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
49        let mut matches = 0;
50        let mut exp_idx = 0;
51        let mut act_idx = 0;
52
53        while exp_idx < expected.len() && act_idx < actual.len() {
54            if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
55                matches += 1;
56                exp_idx += 1;
57                act_idx += 1;
58            } else {
59                // Try to find the expected tool in remaining actual calls
60                let mut found = false;
61                for i in (act_idx + 1)..actual.len() {
62                    if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
63                        matches += 1;
64                        exp_idx += 1;
65                        act_idx = i + 1;
66                        found = true;
67                        break;
68                    }
69                }
70                if !found {
71                    exp_idx += 1;
72                }
73            }
74        }
75
76        let max_len = expected.len().max(actual.len());
77        matches as f64 / max_len as f64
78    }
79
80    /// Score without strict ordering (set comparison)
81    fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
82        let mut matched_actual: HashSet<usize> = HashSet::new();
83        let mut matches = 0;
84
85        for exp in expected {
86            for (i, act) in actual.iter().enumerate() {
87                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
88                    matches += 1;
89                    matched_actual.insert(i);
90                    break;
91                }
92            }
93        }
94
95        let max_len = expected.len().max(actual.len());
96        matches as f64 / max_len as f64
97    }
98
99    /// Get detailed comparison
100    pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
101        let mut matched = Vec::new();
102        let mut missing = Vec::new();
103        let mut extra = Vec::new();
104        let mut matched_actual: HashSet<usize> = HashSet::new();
105
106        for exp in expected {
107            let mut found = false;
108            for (i, act) in actual.iter().enumerate() {
109                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
110                    matched.push((exp.clone(), act.clone()));
111                    matched_actual.insert(i);
112                    found = true;
113                    break;
114                }
115            }
116            if !found {
117                missing.push(exp.clone());
118            }
119        }
120
121        for (i, act) in actual.iter().enumerate() {
122            if !matched_actual.contains(&i) {
123                extra.push(act.clone());
124            }
125        }
126
127        ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
128    }
129}
130
131impl Default for ToolTrajectoryScorer {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137/// Detailed comparison of tool trajectories
138#[derive(Debug, Clone)]
139pub struct ToolTrajectoryComparison {
140    /// Tools that matched
141    pub matched: Vec<(ToolUse, ToolUse)>,
142    /// Expected tools that weren't called
143    pub missing: Vec<ToolUse>,
144    /// Actual tools that weren't expected
145    pub extra: Vec<ToolUse>,
146    /// Overall score
147    pub score: f64,
148}
149
150/// Scorer for response text similarity
151pub struct ResponseScorer {
152    config: ResponseMatchConfig,
153}
154
155impl ResponseScorer {
156    /// Create a new scorer with default config
157    pub fn new() -> Self {
158        Self { config: ResponseMatchConfig::default() }
159    }
160
161    /// Create with custom config
162    pub fn with_config(config: ResponseMatchConfig) -> Self {
163        Self { config }
164    }
165
166    /// Score response similarity
167    pub fn score(&self, expected: &str, actual: &str) -> f64 {
168        let (expected, actual) = if self.config.normalize {
169            (self.normalize(expected), self.normalize(actual))
170        } else {
171            (expected.to_string(), actual.to_string())
172        };
173
174        match self.config.algorithm {
175            SimilarityAlgorithm::Exact => {
176                if expected == actual {
177                    1.0
178                } else {
179                    0.0
180                }
181            }
182            SimilarityAlgorithm::Contains => {
183                if actual.contains(&expected) || expected.contains(&actual) { 1.0 } else { 0.0 }
184            }
185            SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
186            SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
187            SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
188            SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
189            SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
190        }
191    }
192
193    /// Normalize text for comparison
194    fn normalize(&self, text: &str) -> String {
195        let mut result = text.to_string();
196
197        if self.config.ignore_case {
198            result = result.to_lowercase();
199        }
200
201        if self.config.ignore_punctuation {
202            result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
203        }
204
205        // Normalize whitespace
206        result.split_whitespace().collect::<Vec<_>>().join(" ")
207    }
208
209    /// Levenshtein distance based similarity
210    fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
211        let distance = self.levenshtein_distance(a, b);
212        let max_len = a.len().max(b.len());
213        if max_len == 0 { 1.0 } else { 1.0 - (distance as f64 / max_len as f64) }
214    }
215
216    /// Calculate Levenshtein distance
217    fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
218        let a_chars: Vec<char> = a.chars().collect();
219        let b_chars: Vec<char> = b.chars().collect();
220        let m = a_chars.len();
221        let n = b_chars.len();
222
223        if m == 0 {
224            return n;
225        }
226        if n == 0 {
227            return m;
228        }
229
230        let mut dp = vec![vec![0; n + 1]; m + 1];
231
232        for i in 0..=m {
233            dp[i][0] = i;
234        }
235        for j in 0..=n {
236            dp[0][j] = j;
237        }
238
239        for i in 1..=m {
240            for j in 1..=n {
241                let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
242                dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
243            }
244        }
245
246        dp[m][n]
247    }
248
249    /// Jaccard similarity (word overlap)
250    fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
251        let a_words: HashSet<&str> = a.split_whitespace().collect();
252        let b_words: HashSet<&str> = b.split_whitespace().collect();
253
254        if a_words.is_empty() && b_words.is_empty() {
255            return 1.0;
256        }
257
258        let intersection = a_words.intersection(&b_words).count();
259        let union = a_words.union(&b_words).count();
260
261        if union == 0 { 0.0 } else { intersection as f64 / union as f64 }
262    }
263
264    /// ROUGE-N score (n-gram overlap)
265    fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
266        let ref_ngrams = self.get_ngrams(reference, n);
267        let cand_ngrams = self.get_ngrams(candidate, n);
268
269        if ref_ngrams.is_empty() {
270            return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
271        }
272
273        let overlap = ref_ngrams.intersection(&cand_ngrams).count();
274        overlap as f64 / ref_ngrams.len() as f64
275    }
276
277    /// Get n-grams from text
278    fn get_ngrams<'a>(&self, text: &'a str, n: usize) -> HashSet<Vec<&'a str>> {
279        let words: Vec<&str> = text.split_whitespace().collect();
280        if words.len() < n {
281            return HashSet::new();
282        }
283
284        words.windows(n).map(|w| w.to_vec()).collect()
285    }
286
287    /// ROUGE-L score (longest common subsequence)
288    fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
289        let ref_words: Vec<&str> = reference.split_whitespace().collect();
290        let cand_words: Vec<&str> = candidate.split_whitespace().collect();
291
292        if ref_words.is_empty() {
293            return if cand_words.is_empty() { 1.0 } else { 0.0 };
294        }
295
296        let lcs_len = self.lcs_length(&ref_words, &cand_words);
297
298        // F1 score of precision and recall
299        let precision =
300            if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
301        let recall = lcs_len as f64 / ref_words.len() as f64;
302
303        if precision + recall == 0.0 {
304            0.0
305        } else {
306            2.0 * precision * recall / (precision + recall)
307        }
308    }
309
310    /// Length of longest common subsequence
311    fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
312        let m = a.len();
313        let n = b.len();
314
315        if m == 0 || n == 0 {
316            return 0;
317        }
318
319        let mut dp = vec![vec![0; n + 1]; m + 1];
320
321        for i in 1..=m {
322            for j in 1..=n {
323                if a[i - 1] == b[j - 1] {
324                    dp[i][j] = dp[i - 1][j - 1] + 1;
325                } else {
326                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
327                }
328            }
329        }
330
331        dp[m][n]
332    }
333}
334
335impl Default for ResponseScorer {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use serde_json::json;
345
346    #[test]
347    fn test_tool_trajectory_exact_match() {
348        let scorer = ToolTrajectoryScorer::new();
349
350        let expected = vec![
351            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
352            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
353        ];
354
355        let actual = vec![
356            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
357            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
358        ];
359
360        assert_eq!(scorer.score(&expected, &actual), 1.0);
361    }
362
363    #[test]
364    fn test_tool_trajectory_partial_match() {
365        let scorer = ToolTrajectoryScorer::new();
366
367        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
368
369        let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
370
371        let score = scorer.score(&expected, &actual);
372        assert!(score > 0.0 && score < 1.0);
373    }
374
375    #[test]
376    fn test_tool_trajectory_unordered() {
377        let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
378            strict_order: false,
379            strict_args: false,
380        });
381
382        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
383
384        let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
385
386        assert_eq!(scorer.score(&expected, &actual), 1.0);
387    }
388
389    #[test]
390    fn test_response_exact_match() {
391        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
392            algorithm: SimilarityAlgorithm::Exact,
393            normalize: true,
394            ignore_case: true,
395            ignore_punctuation: false,
396        });
397
398        assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
399        assert_eq!(scorer.score("Hello", "World"), 0.0);
400    }
401
402    #[test]
403    fn test_response_jaccard() {
404        let scorer = ResponseScorer::new();
405
406        let score = scorer.score("the quick brown fox", "the quick brown dog");
407        assert!(score > 0.5 && score < 1.0);
408    }
409
410    #[test]
411    fn test_response_levenshtein() {
412        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
413            algorithm: SimilarityAlgorithm::Levenshtein,
414            ..Default::default()
415        });
416
417        let score = scorer.score("hello", "hallo");
418        assert!(score > 0.7);
419
420        let score = scorer.score("abc", "xyz");
421        assert!(score < 0.5);
422    }
423
424    #[test]
425    fn test_rouge_l() {
426        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
427            algorithm: SimilarityAlgorithm::RougeL,
428            ..Default::default()
429        });
430
431        let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
432        assert!(score > 0.7);
433    }
434}