1#![allow(clippy::needless_range_loop)] use crate::criteria::{ResponseMatchConfig, SimilarityAlgorithm, ToolTrajectoryConfig};
8use crate::schema::ToolUse;
9use std::collections::HashSet;
10
11pub struct ToolTrajectoryScorer {
13 config: ToolTrajectoryConfig,
14}
15
16impl ToolTrajectoryScorer {
17 pub fn new() -> Self {
19 Self { config: ToolTrajectoryConfig::default() }
20 }
21
22 pub fn with_config(config: ToolTrajectoryConfig) -> Self {
24 Self { config }
25 }
26
27 pub fn score(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
32 if expected.is_empty() && actual.is_empty() {
33 return 1.0;
34 }
35
36 if expected.is_empty() || actual.is_empty() {
37 return 0.0;
38 }
39
40 if self.config.strict_order {
41 self.score_ordered(expected, actual)
42 } else {
43 self.score_unordered(expected, actual)
44 }
45 }
46
47 fn score_ordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
49 let mut matches = 0;
50 let mut exp_idx = 0;
51 let mut act_idx = 0;
52
53 while exp_idx < expected.len() && act_idx < actual.len() {
54 if expected[exp_idx].matches(&actual[act_idx], self.config.strict_args) {
55 matches += 1;
56 exp_idx += 1;
57 act_idx += 1;
58 } else {
59 let mut found = false;
61 for i in (act_idx + 1)..actual.len() {
62 if expected[exp_idx].matches(&actual[i], self.config.strict_args) {
63 matches += 1;
64 exp_idx += 1;
65 act_idx = i + 1;
66 found = true;
67 break;
68 }
69 }
70 if !found {
71 exp_idx += 1;
72 }
73 }
74 }
75
76 let max_len = expected.len().max(actual.len());
77 matches as f64 / max_len as f64
78 }
79
80 fn score_unordered(&self, expected: &[ToolUse], actual: &[ToolUse]) -> f64 {
82 let mut matched_actual: HashSet<usize> = HashSet::new();
83 let mut matches = 0;
84
85 for exp in expected {
86 for (i, act) in actual.iter().enumerate() {
87 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
88 matches += 1;
89 matched_actual.insert(i);
90 break;
91 }
92 }
93 }
94
95 let max_len = expected.len().max(actual.len());
96 matches as f64 / max_len as f64
97 }
98
99 pub fn compare(&self, expected: &[ToolUse], actual: &[ToolUse]) -> ToolTrajectoryComparison {
101 let mut matched = Vec::new();
102 let mut missing = Vec::new();
103 let mut extra = Vec::new();
104 let mut matched_actual: HashSet<usize> = HashSet::new();
105
106 for exp in expected {
107 let mut found = false;
108 for (i, act) in actual.iter().enumerate() {
109 if !matched_actual.contains(&i) && exp.matches(act, self.config.strict_args) {
110 matched.push((exp.clone(), act.clone()));
111 matched_actual.insert(i);
112 found = true;
113 break;
114 }
115 }
116 if !found {
117 missing.push(exp.clone());
118 }
119 }
120
121 for (i, act) in actual.iter().enumerate() {
122 if !matched_actual.contains(&i) {
123 extra.push(act.clone());
124 }
125 }
126
127 ToolTrajectoryComparison { matched, missing, extra, score: self.score(expected, actual) }
128 }
129}
130
131impl Default for ToolTrajectoryScorer {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ToolTrajectoryComparison {
140 pub matched: Vec<(ToolUse, ToolUse)>,
142 pub missing: Vec<ToolUse>,
144 pub extra: Vec<ToolUse>,
146 pub score: f64,
148}
149
150pub struct ResponseScorer {
152 config: ResponseMatchConfig,
153}
154
155impl ResponseScorer {
156 pub fn new() -> Self {
158 Self { config: ResponseMatchConfig::default() }
159 }
160
161 pub fn with_config(config: ResponseMatchConfig) -> Self {
163 Self { config }
164 }
165
166 pub fn score(&self, expected: &str, actual: &str) -> f64 {
168 let (expected, actual) = if self.config.normalize {
169 (self.normalize(expected), self.normalize(actual))
170 } else {
171 (expected.to_string(), actual.to_string())
172 };
173
174 match self.config.algorithm {
175 SimilarityAlgorithm::Exact => {
176 if expected == actual {
177 1.0
178 } else {
179 0.0
180 }
181 }
182 SimilarityAlgorithm::Contains => {
183 if actual.contains(&expected) || expected.contains(&actual) { 1.0 } else { 0.0 }
184 }
185 SimilarityAlgorithm::Levenshtein => self.levenshtein_similarity(&expected, &actual),
186 SimilarityAlgorithm::Jaccard => self.jaccard_similarity(&expected, &actual),
187 SimilarityAlgorithm::Rouge1 => self.rouge_n(&expected, &actual, 1),
188 SimilarityAlgorithm::Rouge2 => self.rouge_n(&expected, &actual, 2),
189 SimilarityAlgorithm::RougeL => self.rouge_l(&expected, &actual),
190 }
191 }
192
193 fn normalize(&self, text: &str) -> String {
195 let mut result = text.to_string();
196
197 if self.config.ignore_case {
198 result = result.to_lowercase();
199 }
200
201 if self.config.ignore_punctuation {
202 result = result.chars().filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect();
203 }
204
205 result.split_whitespace().collect::<Vec<_>>().join(" ")
207 }
208
209 fn levenshtein_similarity(&self, a: &str, b: &str) -> f64 {
211 let distance = self.levenshtein_distance(a, b);
212 let max_len = a.len().max(b.len());
213 if max_len == 0 { 1.0 } else { 1.0 - (distance as f64 / max_len as f64) }
214 }
215
216 fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
218 let a_chars: Vec<char> = a.chars().collect();
219 let b_chars: Vec<char> = b.chars().collect();
220 let m = a_chars.len();
221 let n = b_chars.len();
222
223 if m == 0 {
224 return n;
225 }
226 if n == 0 {
227 return m;
228 }
229
230 let mut dp = vec![vec![0; n + 1]; m + 1];
231
232 for i in 0..=m {
233 dp[i][0] = i;
234 }
235 for j in 0..=n {
236 dp[0][j] = j;
237 }
238
239 for i in 1..=m {
240 for j in 1..=n {
241 let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
242 dp[i][j] = (dp[i - 1][j] + 1).min(dp[i][j - 1] + 1).min(dp[i - 1][j - 1] + cost);
243 }
244 }
245
246 dp[m][n]
247 }
248
249 fn jaccard_similarity(&self, a: &str, b: &str) -> f64 {
251 let a_words: HashSet<&str> = a.split_whitespace().collect();
252 let b_words: HashSet<&str> = b.split_whitespace().collect();
253
254 if a_words.is_empty() && b_words.is_empty() {
255 return 1.0;
256 }
257
258 let intersection = a_words.intersection(&b_words).count();
259 let union = a_words.union(&b_words).count();
260
261 if union == 0 { 0.0 } else { intersection as f64 / union as f64 }
262 }
263
264 fn rouge_n(&self, reference: &str, candidate: &str, n: usize) -> f64 {
266 let ref_ngrams = self.get_ngrams(reference, n);
267 let cand_ngrams = self.get_ngrams(candidate, n);
268
269 if ref_ngrams.is_empty() {
270 return if cand_ngrams.is_empty() { 1.0 } else { 0.0 };
271 }
272
273 let overlap = ref_ngrams.intersection(&cand_ngrams).count();
274 overlap as f64 / ref_ngrams.len() as f64
275 }
276
277 fn get_ngrams<'a>(&self, text: &'a str, n: usize) -> HashSet<Vec<&'a str>> {
279 let words: Vec<&str> = text.split_whitespace().collect();
280 if words.len() < n {
281 return HashSet::new();
282 }
283
284 words.windows(n).map(|w| w.to_vec()).collect()
285 }
286
287 fn rouge_l(&self, reference: &str, candidate: &str) -> f64 {
289 let ref_words: Vec<&str> = reference.split_whitespace().collect();
290 let cand_words: Vec<&str> = candidate.split_whitespace().collect();
291
292 if ref_words.is_empty() {
293 return if cand_words.is_empty() { 1.0 } else { 0.0 };
294 }
295
296 let lcs_len = self.lcs_length(&ref_words, &cand_words);
297
298 let precision =
300 if cand_words.is_empty() { 0.0 } else { lcs_len as f64 / cand_words.len() as f64 };
301 let recall = lcs_len as f64 / ref_words.len() as f64;
302
303 if precision + recall == 0.0 {
304 0.0
305 } else {
306 2.0 * precision * recall / (precision + recall)
307 }
308 }
309
310 fn lcs_length(&self, a: &[&str], b: &[&str]) -> usize {
312 let m = a.len();
313 let n = b.len();
314
315 if m == 0 || n == 0 {
316 return 0;
317 }
318
319 let mut dp = vec![vec![0; n + 1]; m + 1];
320
321 for i in 1..=m {
322 for j in 1..=n {
323 if a[i - 1] == b[j - 1] {
324 dp[i][j] = dp[i - 1][j - 1] + 1;
325 } else {
326 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
327 }
328 }
329 }
330
331 dp[m][n]
332 }
333}
334
335impl Default for ResponseScorer {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use serde_json::json;
345
346 #[test]
347 fn test_tool_trajectory_exact_match() {
348 let scorer = ToolTrajectoryScorer::new();
349
350 let expected = vec![
351 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
352 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
353 ];
354
355 let actual = vec![
356 ToolUse::new("get_weather").with_args(json!({"location": "NYC"})),
357 ToolUse::new("get_forecast").with_args(json!({"days": 3})),
358 ];
359
360 assert_eq!(scorer.score(&expected, &actual), 1.0);
361 }
362
363 #[test]
364 fn test_tool_trajectory_partial_match() {
365 let scorer = ToolTrajectoryScorer::new();
366
367 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
368
369 let actual = vec![ToolUse::new("tool_a"), ToolUse::new("tool_c")];
370
371 let score = scorer.score(&expected, &actual);
372 assert!(score > 0.0 && score < 1.0);
373 }
374
375 #[test]
376 fn test_tool_trajectory_unordered() {
377 let scorer = ToolTrajectoryScorer::with_config(ToolTrajectoryConfig {
378 strict_order: false,
379 strict_args: false,
380 });
381
382 let expected = vec![ToolUse::new("tool_a"), ToolUse::new("tool_b")];
383
384 let actual = vec![ToolUse::new("tool_b"), ToolUse::new("tool_a")];
385
386 assert_eq!(scorer.score(&expected, &actual), 1.0);
387 }
388
389 #[test]
390 fn test_response_exact_match() {
391 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
392 algorithm: SimilarityAlgorithm::Exact,
393 normalize: true,
394 ignore_case: true,
395 ignore_punctuation: false,
396 });
397
398 assert_eq!(scorer.score("Hello World", "hello world"), 1.0);
399 assert_eq!(scorer.score("Hello", "World"), 0.0);
400 }
401
402 #[test]
403 fn test_response_jaccard() {
404 let scorer = ResponseScorer::new();
405
406 let score = scorer.score("the quick brown fox", "the quick brown dog");
407 assert!(score > 0.5 && score < 1.0);
408 }
409
410 #[test]
411 fn test_response_levenshtein() {
412 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
413 algorithm: SimilarityAlgorithm::Levenshtein,
414 ..Default::default()
415 });
416
417 let score = scorer.score("hello", "hallo");
418 assert!(score > 0.7);
419
420 let score = scorer.score("abc", "xyz");
421 assert!(score < 0.5);
422 }
423
424 #[test]
425 fn test_rouge_l() {
426 let scorer = ResponseScorer::with_config(ResponseMatchConfig {
427 algorithm: SimilarityAlgorithm::RougeL,
428 ..Default::default()
429 });
430
431 let score = scorer.score("the cat sat on the mat", "the cat was on the mat");
432 assert!(score > 0.7);
433 }
434}