1#![allow(clippy::needless_range_loop)] use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11pub struct ToolTrajectoryScorer {
13 config: ToolTrajectoryConfig,
14}
15
16impl ToolTrajectoryScorer {
17 pub fn new() -> Self {
19 Self { config: ToolTrajectoryConfig::default() }
20 }
21
22 pub fn with_config(config: ToolTrajectoryConfig) -> Self {
24 Self { config }
25 }
26
27 pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
32 if expected.is_empty() && actual.is_empty() {
33 return 1.0;
34 }
35
36 if expected.is_empty() || actual.is_empty() {
37 return 0.0;
38 }
39
40 if self.config.strict_order {
41 self.score_ordered(expected, actual)
42 } else {
43 self.score_unordered(expected, actual)
44 }
45 }
46
47 fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
49 let mut matches = 0;
50 let mut exp_idx = 0;
51 let mut act_idx = 0;
52
53 while exp_idx < expected.len() && act_idx < actual.len() {
54 if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
55 matches += 1;
56 exp_idx += 1;
57 act_idx += 1;
58 } else {
59 let mut found = false;
61 for i in (act_idx + 1)..actual.len() {
62 if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
63 matches += 1;
64 exp_idx += 1;
65 act_idx = i + 1;
66 found = true;
67 break;
68 }
69 }
70 if !found {
71 exp_idx += 1;
72 }
73 }
74 }
75
76 let max_len = expected.len().max(actual.len());
77 matches as f64 / max_len as f64
78 }
79
80 fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
82 let mut matched_actual: HashSet<usize> = HashSet::new();
83 let mut matches = 0;
84
85 for exp in expected {
86 for (i, act) in actual.iter().enumerate() {
87 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
88 matches += 1;
89 matched_actual.insert(i);
90 break;
91 }
92 }
93 }
94
95 let max_len = expected.len().max(actual.len());
96 matches as f64 / max_len as f64
97 }
98
99 pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
101 let mut matched = Vec::new();
102 let mut missing = Vec::new();
103 let mut extra = Vec::new();
104 let mut matched_actual: HashSet<usize> = HashSet::new();
105
106 for exp in expected {
107 let mut found = false;
108 for (i, act) in actual.iter().enumerate() {
109 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
110 matched.push((exp.clone(), act.clone()));
111 matched_actual.insert(i);
112 found = true;
113 break;
114 }
115 }
116 if !found {
117 missing.push(exp.clone());
118 }
119 }
120
121 for (i, act) in actual.iter().enumerate() {
122 if !matched_actual.contains(&i) {
123 extra.push(act.clone());
124 }
125 }
126
127 ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
128 }
129}
130
131impl Default for ToolTrajectoryScorer {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ToolTrajectoryComparison {
140 pub matched: Vec<(ToolUse, ToolUse)>,
142 pub missing: Vec<ToolUse>,
144 pub extra: Vec<ToolUse>,
146 pub score: f64,
148}
149
150pub struct ResponseScorer {
152 config: ResponseMatchConfig,
153}
154
155impl ResponseScorer {
156 pub fn new() -> Self {
158 Self { config: ResponseMatchConfig::default() }
159 }
160
161 pub fn with_config(config: ResponseMatchConfig) -> Self {
163 Self { config }
164 }
165
166 pub fn score(&self, expected: &str, actual: &str) -> f64 {
168 let (expected, actual) = if self.config.normalize {
169 (self.normalize(expected), self.normalize(actual))
170 } else {
171 (expected.to_string(), actual.to_string())
172 };
173
174 match self.config.algorithm {
175 SimilarityAlgorithm::Exact => {
176 if expected == actual {
177 1.0
178 } else {
179 0.0
180 }
181 }
182 SimilarityAlgorithm::Contains => {
183 if actual.contains(&expected) || expected.contains(&actual) {
184 1.0
185 } else {
186 0.0
187 }
188 }
189 SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
190 SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
191 SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
192 SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
193 SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
194 }
195 }
196
197 fn normalize(&self, text: &str) -> String {
199 let mut result = text.to_string();
200
201 if self.config.ignore_case {
202 result = result.to_lowercase();
203 }
204
205 if self.config.ignore_punctuation {
206 result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
207 }
208
209 result.split_whitespace().collect::<Vec<_>>().join(" ")
211 }
212
213 fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
215 let distance = self.levenshtein_distance(a, b);
216 let max_len = a.len().max(b.len());
217 if max_len == 0 {
218 1.0
219 } else {
220 1.0 - (distance as f64 / max_len as f64)
221 }
222 }
223
224 fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
226 let a_chars: Vec<char> = a.chars().collect();
227 let b_chars: Vec<char> = b.chars().collect();
228 let m = a_chars.len();
229 let n = b_chars.len();
230
231 if m == 0 {
232 return n;
233 }
234 if n == 0 {
235 return m;
236 }
237
238 let mut dp = vec![vec![0; n + 1]; m + 1];
239
240 for i in 0..=m {
241 dp[i][0] = i;
242 }
243 for j in 0..=n {
244 dp[0][j] = j;
245 }
246
247 for i in 1..=m {
248 for j in 1..=n {
249 let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
250 dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
251 }
252 }
253
254 dp[m][n]
255 }
256
257 fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
259 let a_words: HashSet<&str> = a.split_whitespace().collect();
260 let b_words: HashSet<&str> = b.split_whitespace().collect();
261
262 if a_words.is_empty() && b_words.is_empty() {
263 return 1.0;
264 }
265
266 let intersection = a_words.intersection(&b_words).count();
267 let union = a_words.union(&b_words).count();
268
269 if union == 0 {
270 0.0
271 } else {
272 intersection as f64 / union as f64
273 }
274 }
275
276 fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
278 let ref_ngrams = self.get_ngrams(reference, n);
279 let cand_ngrams = self.get_ngrams(candidate, n);
280
281 if ref_ngrams.is_empty() {
282 return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
283 }
284
285 let overlap = ref_ngrams.intersection(&cand_ngrams).count();
286 overlap as f64 / ref_ngrams.len() as f64
287 }
288
289 fn get_ngrams<'a>(&self, text: &'a str, n: usize) -> HashSet<Vec<&'a str>> {
291 let words: Vec<&str> = text.split_whitespace().collect();
292 if words.len() < n {
293 return HashSet::new();
294 }
295
296 words.windows(n).map(|w| w.to_vec()).collect()
297 }
298
299 fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
301 let ref_words: Vec<&str> = reference.split_whitespace().collect();
302 let cand_words: Vec<&str> = candidate.split_whitespace().collect();
303
304 if ref_words.is_empty() {
305 return if cand_words.is_empty() { 1.0 } else { 0.0 };
306 }
307
308 let lcs_len = self.lcs_length(&ref_words, &cand_words);
309
310 let precision =
312 if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
313 let recall = lcs_len as f64 / ref_words.len() as f64;
314
315 if precision + recall == 0.0 {
316 0.0
317 } else {
318 2.0 * precision * recall / (precision + recall)
319 }
320 }
321
322 fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
324 let m = a.len();
325 let n = b.len();
326
327 if m == 0 || n == 0 {
328 return 0;
329 }
330
331 let mut dp = vec![vec![0; n + 1]; m + 1];
332
333 for i in 1..=m {
334 for j in 1..=n {
335 if a[i - 1] == b[j - 1] {
336 dp[i][j] = dp[i - 1][j - 1] + 1;
337 } else {
338 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
339 }
340 }
341 }
342
343 dp[m][n]
344 }
345}
346
347impl Default for ResponseScorer {
348 fn default() -> Self {
349 Self::new()
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use serde_json::json;
357
358 #[test]
359 fn test_tool_trajectory_exact_match() {
360 let scorer = ToolTrajectoryScorer::new();
361
362 let expected = vec![
363 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
364 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
365 ];
366
367 let actual = vec![
368 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
369 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
370 ];
371
372 assert_eq!(scorer.score(&expected, &actual), 1.0);
373 }
374
375 #[test]
376 fn test_tool_trajectory_partial_match() {
377 let scorer = ToolTrajectoryScorer::new();
378
379 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
380
381 let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
382
383 let score = scorer.score(&expected, &actual);
384 assert!(score > 0.0 && score < 1.0);
385 }
386
387 #[test]
388 fn test_tool_trajectory_unordered() {
389 let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
390 strict_order: false,
391 strict_args: false,
392 });
393
394 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
395
396 let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
397
398 assert_eq!(scorer.score(&expected, &actual), 1.0);
399 }
400
401 #[test]
402 fn test_response_exact_match() {
403 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
404 algorithm: SimilarityAlgorithm::Exact,
405 normalize: true,
406 ignore_case: true,
407 ignore_punctuation: false,
408 });
409
410 assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
411 assert_eq!(scorer.score("Hello", "World"), 0.0);
412 }
413
414 #[test]
415 fn test_response_jaccard() {
416 let scorer = ResponseScorer::new();
417
418 let score = scorer.score("the quick brown fox", "the quick brown dog");
419 assert!(score > 0.5 && score < 1.0);
420 }
421
422 #[test]
423 fn test_response_levenshtein() {
424 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
425 algorithm: SimilarityAlgorithm::Levenshtein,
426 ..Default::default()
427 });
428
429 let score = scorer.score("hello", "hallo");
430 assert!(score > 0.7);
431
432 let score = scorer.score("abc", "xyz");
433 assert!(score < 0.5);
434 }
435
436 #[test]
437 fn test_rouge_l() {
438 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
439 algorithm: SimilarityAlgorithm::RougeL,
440 ..Default::default()
441 });
442
443 let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
444 assert!(score > 0.7);
445 }
446}