1#![allow(clippy::needless_range_loop)] use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11fn unicode_tokenize(text: &str) -> impl Iterator<Item = &str> {
18 UnicodeTokenizer { text, pos: 0 }
19}
20
21struct UnicodeTokenizer<'a> {
22 text: &'a str,
23 pos: usize,
24}
25
26impl<'a> Iterator for UnicodeTokenizer<'a> {
27 type Item = &'a str;
28
29 fn next(&mut self) -> Option<Self::Item> {
30 while self.pos < self.text.len() {
32 if self.text[self.pos..].starts_with(char::is_whitespace) {
33 self.pos += self.text[self.pos..].chars().next().unwrap().len_utf8();
34 } else {
35 break;
36 }
37 }
38
39 if self.pos >= self.text.len() {
40 return None;
41 }
42
43 let start = self.pos;
44 let c = self.text[start..].chars().next().unwrap();
45
46 if is_cjk_char(c) {
48 self.pos += c.len_utf8();
49 return Some(&self.text[start..self.pos]);
50 }
51
52 while self.pos < self.text.len() {
54 let ch = self.text[self.pos..].chars().next().unwrap();
55 if ch.is_whitespace() || is_cjk_char(ch) {
56 break;
57 }
58 self.pos += ch.len_utf8();
59 }
60
61 Some(&self.text[start..self.pos])
62 }
63}
64
65fn is_cjk_char(c: char) -> bool {
67 matches!(c,
68 '\u{4e00}'..='\u{9fff}' | '\u{3400}'..='\u{4dbf}' | '\u{f900}'..='\u{faff}' | '\u{3040}'..='\u{309f}' | '\u{30a0}'..='\u{30ff}' | '\u{ac00}'..='\u{d7af}' )
75}
76
77pub struct ToolTrajectoryScorer {
79 config: ToolTrajectoryConfig,
80}
81
82impl ToolTrajectoryScorer {
83 pub fn new() -> Self {
85 Self { config: ToolTrajectoryConfig::default() }
86 }
87
88 pub fn with_config(config: ToolTrajectoryConfig) -> Self {
90 Self { config }
91 }
92
93 pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
98 if expected.is_empty() && actual.is_empty() {
99 return 1.0;
100 }
101
102 if expected.is_empty() || actual.is_empty() {
103 return 0.0;
104 }
105
106 if self.config.strict_order {
107 self.score_ordered(expected, actual)
108 } else {
109 self.score_unordered(expected, actual)
110 }
111 }
112
113 fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
115 let mut matches = 0;
116 let mut exp_idx = 0;
117 let mut act_idx = 0;
118
119 while exp_idx < expected.len() && act_idx < actual.len() {
120 if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
121 matches += 1;
122 exp_idx += 1;
123 act_idx += 1;
124 } else {
125 let mut found = false;
127 for i in (act_idx + 1)..actual.len() {
128 if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
129 matches += 1;
130 exp_idx += 1;
131 act_idx = i + 1;
132 found = true;
133 break;
134 }
135 }
136 if !found {
137 exp_idx += 1;
138 }
139 }
140 }
141
142 let max_len = expected.len().max(actual.len());
143 matches as f64 / max_len as f64
144 }
145
146 fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
148 let mut matched_actual: HashSet<usize> = HashSet::new();
149 let mut matches = 0;
150
151 for exp in expected {
152 for (i, act) in actual.iter().enumerate() {
153 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
154 matches += 1;
155 matched_actual.insert(i);
156 break;
157 }
158 }
159 }
160
161 let max_len = expected.len().max(actual.len());
162 matches as f64 / max_len as f64
163 }
164
165 pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
167 let mut matched = Vec::new();
168 let mut missing = Vec::new();
169 let mut extra = Vec::new();
170 let mut matched_actual: HashSet<usize> = HashSet::new();
171
172 for exp in expected {
173 let mut found = false;
174 for (i, act) in actual.iter().enumerate() {
175 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
176 matched.push((exp.clone(), act.clone()));
177 matched_actual.insert(i);
178 found = true;
179 break;
180 }
181 }
182 if !found {
183 missing.push(exp.clone());
184 }
185 }
186
187 for (i, act) in actual.iter().enumerate() {
188 if !matched_actual.contains(&i) {
189 extra.push(act.clone());
190 }
191 }
192
193 ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
194 }
195}
196
197impl Default for ToolTrajectoryScorer {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct ToolTrajectoryComparison {
206 pub matched: Vec<(ToolUse, ToolUse)>,
208 pub missing: Vec<ToolUse>,
210 pub extra: Vec<ToolUse>,
212 pub score: f64,
214}
215
216pub struct ResponseScorer {
218 config: ResponseMatchConfig,
219}
220
221impl ResponseScorer {
222 pub fn new() -> Self {
224 Self { config: ResponseMatchConfig::default() }
225 }
226
227 pub fn with_config(config: ResponseMatchConfig) -> Self {
229 Self { config }
230 }
231
232 pub fn score(&self, expected: &str, actual: &str) -> f64 {
234 let (expected, actual) = if self.config.normalize {
235 (self.normalize(expected), self.normalize(actual))
236 } else {
237 (expected.to_string(), actual.to_string())
238 };
239
240 match self.config.algorithm {
241 SimilarityAlgorithm::Exact => {
242 if expected == actual {
243 1.0
244 } else {
245 0.0
246 }
247 }
248 SimilarityAlgorithm::Contains => {
249 if actual.contains(&expected) || expected.contains(&actual) { 1.0 } else { 0.0 }
250 }
251 SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
252 SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
253 SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
254 SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
255 SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
256 }
257 }
258
259 fn normalize(&self, text: &str) -> String {
261 let mut result = text.to_string();
262
263 if self.config.ignore_case {
264 result = result.to_lowercase();
265 }
266
267 if self.config.ignore_punctuation {
268 result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
269 }
270
271 result.split_whitespace().collect::<Vec<_>>().join(" ")
273 }
274
275 fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
277 let distance = self.levenshtein_distance(a, b);
278 let max_len = a.chars().count().max(b.chars().count());
279 if max_len == 0 { 1.0 } else { 1.0 - (distance as f64 / max_len as f64) }
280 }
281
282 fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
284 let a_chars: Vec<char> = a.chars().collect();
285 let b_chars: Vec<char> = b.chars().collect();
286 let m = a_chars.len();
287 let n = b_chars.len();
288
289 if m == 0 {
290 return n;
291 }
292 if n == 0 {
293 return m;
294 }
295
296 let mut dp = vec![vec![0; n + 1]; m + 1];
297
298 for i in 0..=m {
299 dp[i][0] = i;
300 }
301 for j in 0..=n {
302 dp[0][j] = j;
303 }
304
305 for i in 1..=m {
306 for j in 1..=n {
307 let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
308 dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
309 }
310 }
311
312 dp[m][n]
313 }
314
315 fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
317 let a_words: HashSet<&str> = unicode_tokenize(a).collect();
318 let b_words: HashSet<&str> = unicode_tokenize(b).collect();
319
320 if a_words.is_empty() && b_words.is_empty() {
321 return 1.0;
322 }
323
324 let intersection = a_words.intersection(&b_words).count();
325 let union = a_words.union(&b_words).count();
326
327 if union == 0 { 0.0 } else { intersection as f64 / union as f64 }
328 }
329
330 fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
332 let ref_ngrams = self.get_ngrams(reference, n);
333 let cand_ngrams = self.get_ngrams(candidate, n);
334
335 if ref_ngrams.is_empty() {
336 return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
337 }
338
339 let overlap = ref_ngrams.intersection(&cand_ngrams).count();
340 overlap as f64 / ref_ngrams.len() as f64
341 }
342
343 fn get_ngrams(&self, text: &str, n: usize) -> HashSet<Vec<String>> {
345 let words: Vec<String> = unicode_tokenize(text).map(|s| s.to_string()).collect();
346 if words.len() < n {
347 return HashSet::new();
348 }
349
350 words.windows(n).map(|w| w.to_vec()).collect()
351 }
352
353 fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
355 let ref_words: Vec<&str> = unicode_tokenize(reference).collect();
356 let cand_words: Vec<&str> = unicode_tokenize(candidate).collect();
357
358 if ref_words.is_empty() {
359 return if cand_words.is_empty() { 1.0 } else { 0.0 };
360 }
361
362 let lcs_len = self.lcs_length(&ref_words, &cand_words);
363
364 let precision =
366 if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
367 let recall = lcs_len as f64 / ref_words.len() as f64;
368
369 if precision + recall == 0.0 {
370 0.0
371 } else {
372 2.0 * precision * recall / (precision + recall)
373 }
374 }
375
376 fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
378 let m = a.len();
379 let n = b.len();
380
381 if m == 0 || n == 0 {
382 return 0;
383 }
384
385 let mut dp = vec![vec![0; n + 1]; m + 1];
386
387 for i in 1..=m {
388 for j in 1..=n {
389 if a[i - 1] == b[j - 1] {
390 dp[i][j] = dp[i - 1][j - 1] + 1;
391 } else {
392 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
393 }
394 }
395 }
396
397 dp[m][n]
398 }
399}
400
401impl Default for ResponseScorer {
402 fn default() -> Self {
403 Self::new()
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use serde_json::json;
411
412 #[test]
413 fn test_tool_trajectory_exact_match() {
414 let scorer = ToolTrajectoryScorer::new();
415
416 let expected = vec![
417 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
418 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
419 ];
420
421 let actual = vec![
422 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
423 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
424 ];
425
426 assert_eq!(scorer.score(&expected, &actual), 1.0);
427 }
428
429 #[test]
430 fn test_tool_trajectory_partial_match() {
431 let scorer = ToolTrajectoryScorer::new();
432
433 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
434
435 let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
436
437 let score = scorer.score(&expected, &actual);
438 assert!(score > 0.0 && score < 1.0);
439 }
440
441 #[test]
442 fn test_tool_trajectory_unordered() {
443 let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
444 strict_order: false,
445 strict_args: false,
446 });
447
448 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
449
450 let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
451
452 assert_eq!(scorer.score(&expected, &actual), 1.0);
453 }
454
455 #[test]
456 fn test_response_exact_match() {
457 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
458 algorithm: SimilarityAlgorithm::Exact,
459 normalize: true,
460 ignore_case: true,
461 ignore_punctuation: false,
462 });
463
464 assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
465 assert_eq!(scorer.score("Hello", "World"), 0.0);
466 }
467
468 #[test]
469 fn test_response_jaccard() {
470 let scorer = ResponseScorer::new();
471
472 let score = scorer.score("the quick brown fox", "the quick brown dog");
473 assert!(score > 0.5 && score < 1.0);
474 }
475
476 #[test]
477 fn test_response_levenshtein() {
478 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
479 algorithm: SimilarityAlgorithm::Levenshtein,
480 ..Default::default()
481 });
482
483 let score = scorer.score("hello", "hallo");
484 assert!(score > 0.7);
485
486 let score = scorer.score("abc", "xyz");
487 assert!(score < 0.5);
488 }
489
490 #[test]
491 fn test_rouge_l() {
492 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
493 algorithm: SimilarityAlgorithm::RougeL,
494 ..Default::default()
495 });
496
497 let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
498 assert!(score > 0.7);
499 }
500}