Skip to main content

adk_eval/
scoring.rs

1//! Scoring implementations for evaluation criteria
2//!
3//! Provides various scorers for tool trajectory, response similarity, etc.
4
5#![allow(clippy::needless_range_loop)] // Intentional for DP algorithms
6
7use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11/// Unicode-aware text tokenizer for scoring.
12///
13/// For whitespace-delimited languages (English, etc.), splits on whitespace.
14/// For CJK text (Chinese, Japanese, Korean) which has no word-separating
15/// whitespace, emits each character as an individual token.
16/// This enables meaningful Jaccard, ROUGE, and n-gram scoring across all languages.
17fn unicode_tokenize(text: &str) -> impl Iterator<Item = &str> {
18    UnicodeTokenizer { text, pos: 0 }
19}
20
21struct UnicodeTokenizer<'a> {
22    text: &'a str,
23    pos: usize,
24}
25
26impl<'a> Iterator for UnicodeTokenizer<'a> {
27    type Item = &'a str;
28
29    fn next(&mut self) -> Option<Self::Item> {
30        // Skip whitespace
31        while self.pos < self.text.len() {
32            if self.text[self.pos..].starts_with(char::is_whitespace) {
33                self.pos += self.text[self.pos..].chars().next().unwrap().len_utf8();
34            } else {
35                break;
36            }
37        }
38
39        if self.pos >= self.text.len() {
40            return None;
41        }
42
43        let start = self.pos;
44        let c = self.text[start..].chars().next().unwrap();
45
46        // CJK character: emit as a single-character token
47        if is_cjk_char(c) {
48            self.pos += c.len_utf8();
49            return Some(&self.text[start..self.pos]);
50        }
51
52        // Non-CJK: accumulate until whitespace or CJK
53        while self.pos < self.text.len() {
54            let ch = self.text[self.pos..].chars().next().unwrap();
55            if ch.is_whitespace() || is_cjk_char(ch) {
56                break;
57            }
58            self.pos += ch.len_utf8();
59        }
60
61        Some(&self.text[start..self.pos])
62    }
63}
64
65/// Returns true if the character is in a CJK script block.
66fn is_cjk_char(c: char) -> bool {
67    matches!(c,
68        '\u{4e00}'..='\u{9fff}'   // CJK Unified Ideographs
69        | '\u{3400}'..='\u{4dbf}' // CJK Extension A
70        | '\u{f900}'..='\u{faff}' // CJK Compatibility
71        | '\u{3040}'..='\u{309f}' // Hiragana
72        | '\u{30a0}'..='\u{30ff}' // Katakana
73        | '\u{ac00}'..='\u{d7af}' // Hangul Syllables
74    )
75}
76
77/// Scorer for tool trajectory matching
78pub struct ToolTrajectoryScorer {
79    config: ToolTrajectoryConfig,
80}
81
82impl ToolTrajectoryScorer {
83    /// Create a new scorer with default config
84    pub fn new() -> Self {
85        Self { config: ToolTrajectoryConfig::default() }
86    }
87
88    /// Create with custom config
89    pub fn with_config(config: ToolTrajectoryConfig) -> Self {
90        Self { config }
91    }
92
93    /// Score tool trajectory
94    ///
95    /// Returns a score from 0.0 to 1.0 indicating how well the actual
96    /// tool calls match the expected tool calls.
97    pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
98        if expected.is_empty() && actual.is_empty() {
99            return 1.0;
100        }
101
102        if expected.is_empty() || actual.is_empty() {
103            return 0.0;
104        }
105
106        if self.config.strict_order {
107            self.score_ordered(expected, actual)
108        } else {
109            self.score_unordered(expected, actual)
110        }
111    }
112
113    /// Score with strict ordering
114    fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
115        let mut matches = 0;
116        let mut exp_idx = 0;
117        let mut act_idx = 0;
118
119        while exp_idx < expected.len() && act_idx < actual.len() {
120            if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
121                matches += 1;
122                exp_idx += 1;
123                act_idx += 1;
124            } else {
125                // Try to find the expected tool in remaining actual calls
126                let mut found = false;
127                for i in (act_idx + 1)..actual.len() {
128                    if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
129                        matches += 1;
130                        exp_idx += 1;
131                        act_idx = i + 1;
132                        found = true;
133                        break;
134                    }
135                }
136                if !found {
137                    exp_idx += 1;
138                }
139            }
140        }
141
142        let max_len = expected.len().max(actual.len());
143        matches as f64 / max_len as f64
144    }
145
146    /// Score without strict ordering (set comparison)
147    fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
148        let mut matched_actual: HashSet<usize> = HashSet::new();
149        let mut matches = 0;
150
151        for exp in expected {
152            for (i, act) in actual.iter().enumerate() {
153                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
154                    matches += 1;
155                    matched_actual.insert(i);
156                    break;
157                }
158            }
159        }
160
161        let max_len = expected.len().max(actual.len());
162        matches as f64 / max_len as f64
163    }
164
165    /// Get detailed comparison
166    pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
167        let mut matched = Vec::new();
168        let mut missing = Vec::new();
169        let mut extra = Vec::new();
170        let mut matched_actual: HashSet<usize> = HashSet::new();
171
172        for exp in expected {
173            let mut found = false;
174            for (i, act) in actual.iter().enumerate() {
175                if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
176                    matched.push((exp.clone(), act.clone()));
177                    matched_actual.insert(i);
178                    found = true;
179                    break;
180                }
181            }
182            if !found {
183                missing.push(exp.clone());
184            }
185        }
186
187        for (i, act) in actual.iter().enumerate() {
188            if !matched_actual.contains(&i) {
189                extra.push(act.clone());
190            }
191        }
192
193        ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
194    }
195}
196
197impl Default for ToolTrajectoryScorer {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203/// Detailed comparison of tool trajectories
204#[derive(Debug, Clone)]
205pub struct ToolTrajectoryComparison {
206    /// Tools that matched
207    pub matched: Vec<(ToolUse, ToolUse)>,
208    /// Expected tools that weren't called
209    pub missing: Vec<ToolUse>,
210    /// Actual tools that weren't expected
211    pub extra: Vec<ToolUse>,
212    /// Overall score
213    pub score: f64,
214}
215
216/// Scorer for response text similarity
217pub struct ResponseScorer {
218    config: ResponseMatchConfig,
219}
220
221impl ResponseScorer {
222    /// Create a new scorer with default config
223    pub fn new() -> Self {
224        Self { config: ResponseMatchConfig::default() }
225    }
226
227    /// Create with custom config
228    pub fn with_config(config: ResponseMatchConfig) -> Self {
229        Self { config }
230    }
231
232    /// Score response similarity
233    pub fn score(&self, expected: &str, actual: &str) -> f64 {
234        let (expected, actual) = if self.config.normalize {
235            (self.normalize(expected), self.normalize(actual))
236        } else {
237            (expected.to_string(), actual.to_string())
238        };
239
240        match self.config.algorithm {
241            SimilarityAlgorithm::Exact => {
242                if expected == actual {
243                    1.0
244                } else {
245                    0.0
246                }
247            }
248            SimilarityAlgorithm::Contains => {
249                if actual.contains(&expected) || expected.contains(&actual) { 1.0 } else { 0.0 }
250            }
251            SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
252            SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
253            SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
254            SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
255            SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
256        }
257    }
258
259    /// Normalize text for comparison
260    fn normalize(&self, text: &str) -> String {
261        let mut result = text.to_string();
262
263        if self.config.ignore_case {
264            result = result.to_lowercase();
265        }
266
267        if self.config.ignore_punctuation {
268            result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
269        }
270
271        // Normalize whitespace
272        result.split_whitespace().collect::<Vec<_>>().join(" ")
273    }
274
275    /// Levenshtein distance based similarity
276    fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
277        let distance = self.levenshtein_distance(a, b);
278        let max_len = a.chars().count().max(b.chars().count());
279        if max_len == 0 { 1.0 } else { 1.0 - (distance as f64 / max_len as f64) }
280    }
281
282    /// Calculate Levenshtein distance
283    fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
284        let a_chars: Vec<char> = a.chars().collect();
285        let b_chars: Vec<char> = b.chars().collect();
286        let m = a_chars.len();
287        let n = b_chars.len();
288
289        if m == 0 {
290            return n;
291        }
292        if n == 0 {
293            return m;
294        }
295
296        let mut dp = vec![vec![0; n + 1]; m + 1];
297
298        for i in 0..=m {
299            dp[i][0] = i;
300        }
301        for j in 0..=n {
302            dp[0][j] = j;
303        }
304
305        for i in 1..=m {
306            for j in 1..=n {
307                let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
308                dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
309            }
310        }
311
312        dp[m][n]
313    }
314
315    /// Jaccard similarity (word overlap)
316    fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
317        let a_words: HashSet<&str> = unicode_tokenize(a).collect();
318        let b_words: HashSet<&str> = unicode_tokenize(b).collect();
319
320        if a_words.is_empty() && b_words.is_empty() {
321            return 1.0;
322        }
323
324        let intersection = a_words.intersection(&b_words).count();
325        let union = a_words.union(&b_words).count();
326
327        if union == 0 { 0.0 } else { intersection as f64 / union as f64 }
328    }
329
330    /// ROUGE-N score (n-gram overlap)
331    fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
332        let ref_ngrams = self.get_ngrams(reference, n);
333        let cand_ngrams = self.get_ngrams(candidate, n);
334
335        if ref_ngrams.is_empty() {
336            return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
337        }
338
339        let overlap = ref_ngrams.intersection(&cand_ngrams).count();
340        overlap as f64 / ref_ngrams.len() as f64
341    }
342
343    /// Get n-grams from text
344    fn get_ngrams(&self, text: &str, n: usize) -> HashSet<Vec<String>> {
345        let words: Vec<String> = unicode_tokenize(text).map(|s| s.to_string()).collect();
346        if words.len() < n {
347            return HashSet::new();
348        }
349
350        words.windows(n).map(|w| w.to_vec()).collect()
351    }
352
353    /// ROUGE-L score (longest common subsequence)
354    fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
355        let ref_words: Vec<&str> = unicode_tokenize(reference).collect();
356        let cand_words: Vec<&str> = unicode_tokenize(candidate).collect();
357
358        if ref_words.is_empty() {
359            return if cand_words.is_empty() { 1.0 } else { 0.0 };
360        }
361
362        let lcs_len = self.lcs_length(&ref_words, &cand_words);
363
364        // F1 score of precision and recall
365        let precision =
366            if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
367        let recall = lcs_len as f64 / ref_words.len() as f64;
368
369        if precision + recall == 0.0 {
370            0.0
371        } else {
372            2.0 * precision * recall / (precision + recall)
373        }
374    }
375
376    /// Length of longest common subsequence
377    fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
378        let m = a.len();
379        let n = b.len();
380
381        if m == 0 || n == 0 {
382            return 0;
383        }
384
385        let mut dp = vec![vec![0; n + 1]; m + 1];
386
387        for i in 1..=m {
388            for j in 1..=n {
389                if a[i - 1] == b[j - 1] {
390                    dp[i][j] = dp[i - 1][j - 1] + 1;
391                } else {
392                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
393                }
394            }
395        }
396
397        dp[m][n]
398    }
399}
400
401impl Default for ResponseScorer {
402    fn default() -> Self {
403        Self::new()
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use serde_json::json;
411
412    #[test]
413    fn test_tool_trajectory_exact_match() {
414        let scorer = ToolTrajectoryScorer::new();
415
416        let expected = vec![
417            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
418            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
419        ];
420
421        let actual = vec![
422            ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
423            ToolUse::new("get_forecast").with_args(json!({"days": 3})),
424        ];
425
426        assert_eq!(scorer.score(&expected, &actual), 1.0);
427    }
428
429    #[test]
430    fn test_tool_trajectory_partial_match() {
431        let scorer = ToolTrajectoryScorer::new();
432
433        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
434
435        let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
436
437        let score = scorer.score(&expected, &actual);
438        assert!(score > 0.0 && score < 1.0);
439    }
440
441    #[test]
442    fn test_tool_trajectory_unordered() {
443        let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
444            strict_order: false,
445            strict_args: false,
446        });
447
448        let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
449
450        let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
451
452        assert_eq!(scorer.score(&expected, &actual), 1.0);
453    }
454
455    #[test]
456    fn test_response_exact_match() {
457        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
458            algorithm: SimilarityAlgorithm::Exact,
459            normalize: true,
460            ignore_case: true,
461            ignore_punctuation: false,
462        });
463
464        assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
465        assert_eq!(scorer.score("Hello", "World"), 0.0);
466    }
467
468    #[test]
469    fn test_response_jaccard() {
470        let scorer = ResponseScorer::new();
471
472        let score = scorer.score("the quick brown fox", "the quick brown dog");
473        assert!(score > 0.5 && score < 1.0);
474    }
475
476    #[test]
477    fn test_response_levenshtein() {
478        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
479            algorithm: SimilarityAlgorithm::Levenshtein,
480            ..Default::default()
481        });
482
483        let score = scorer.score("hello", "hallo");
484        assert!(score > 0.7);
485
486        let score = scorer.score("abc", "xyz");
487        assert!(score < 0.5);
488    }
489
490    #[test]
491    fn test_rouge_l() {
492        let scorer = ResponseScorer::with_config(ResponseMatchConfig {
493            algorithm: SimilarityAlgorithm::RougeL,
494            ..Default::default()
495        });
496
497        let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
498        assert!(score > 0.7);
499    }
500}