1use std::collections::BTreeSet;
2use std::fmt;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value as JsonValue;
8
9const FLOAT_TOLERANCE: f64 = 1e-6;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub struct ToolCallEvalCase {
13 pub id: String,
14 pub prompt: String,
15 #[serde(default)]
16 pub tools: Vec<ToolDef>,
17 pub expected: ExpectedToolCall,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
19 pub baseline_pass_rate: Option<f64>,
20 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub source: Option<String>,
22 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tags: Vec<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct ToolDef {
28 pub name: String,
29 #[serde(default)]
30 pub description: String,
31 #[serde(default)]
34 pub parameters: JsonValue,
35 #[serde(
36 default,
37 skip_serializing_if = "Option::is_none",
38 rename = "outputSchema"
39 )]
40 pub output_schema: Option<JsonValue>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub namespace: Option<String>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub defer_loading: Option<bool>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48#[serde(tag = "kind", rename_all = "snake_case")]
49pub enum ExpectedToolCall {
50 Exact {
51 name: String,
52 args: JsonValue,
53 },
54 Predicate {
55 description: String,
56 judge_prompt: String,
57 },
58 Refusal {
59 reason_must_match: String,
60 },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
64pub struct ObservedToolCall {
65 pub name: String,
66 #[serde(default)]
67 pub args: JsonValue,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub struct ObservedToolCallOutcome {
72 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub tool_call: Option<ObservedToolCall>,
74 #[serde(default)]
75 pub final_text: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct PredicateJudgeVerdict {
80 pub passed: bool,
81 #[serde(default)]
82 pub reason: String,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ToolCallScore {
87 pub passed: bool,
88 pub reason: String,
89}
90
91#[derive(Debug)]
92pub enum ToolCallEvalDatasetError {
93 Io { path: PathBuf, message: String },
94 Json { path: PathBuf, message: String },
95 Validation { path: PathBuf, message: String },
96}
97
98impl fmt::Display for ToolCallEvalDatasetError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::Io { path, message } => write!(f, "{}: {message}", path.display()),
102 Self::Json { path, message } => write!(f, "{}: {message}", path.display()),
103 Self::Validation { path, message } => write!(f, "{}: {message}", path.display()),
104 }
105 }
106}
107
108impl std::error::Error for ToolCallEvalDatasetError {}
109
110pub fn load_tool_call_eval_dataset(
111 path: &Path,
112) -> Result<Vec<ToolCallEvalCase>, ToolCallEvalDatasetError> {
113 let mut cases = Vec::new();
114 for file in tool_call_eval_case_files(path)? {
115 let raw = fs::read_to_string(&file).map_err(|error| ToolCallEvalDatasetError::Io {
116 path: file.clone(),
117 message: error.to_string(),
118 })?;
119 let value: JsonValue =
120 serde_json::from_str(&raw).map_err(|error| ToolCallEvalDatasetError::Json {
121 path: file.clone(),
122 message: error.to_string(),
123 })?;
124 let mut loaded = if value.is_array() {
125 serde_json::from_value::<Vec<ToolCallEvalCase>>(value).map_err(|error| {
126 ToolCallEvalDatasetError::Json {
127 path: file.clone(),
128 message: error.to_string(),
129 }
130 })?
131 } else {
132 vec![
133 serde_json::from_value::<ToolCallEvalCase>(value).map_err(|error| {
134 ToolCallEvalDatasetError::Json {
135 path: file.clone(),
136 message: error.to_string(),
137 }
138 })?,
139 ]
140 };
141 for case in &loaded {
142 validate_case(case, &file)?;
143 }
144 cases.append(&mut loaded);
145 }
146 cases.sort_by(|left, right| left.id.cmp(&right.id));
147 validate_unique_case_ids(&cases, path)?;
148 Ok(cases)
149}
150
151fn tool_call_eval_case_files(path: &Path) -> Result<Vec<PathBuf>, ToolCallEvalDatasetError> {
152 if path.is_file() {
153 return Ok(vec![path.to_path_buf()]);
154 }
155 let cases_dir = path.join("cases");
156 let root = if cases_dir.is_dir() {
157 cases_dir
158 } else {
159 path.to_path_buf()
160 };
161 let mut files = Vec::new();
162 collect_json_files(&root, &mut files)?;
163 files.sort();
164 Ok(files)
165}
166
167fn collect_json_files(dir: &Path, out: &mut Vec<PathBuf>) -> Result<(), ToolCallEvalDatasetError> {
168 let entries = fs::read_dir(dir).map_err(|error| ToolCallEvalDatasetError::Io {
169 path: dir.to_path_buf(),
170 message: error.to_string(),
171 })?;
172 for entry in entries {
173 let entry = entry.map_err(|error| ToolCallEvalDatasetError::Io {
174 path: dir.to_path_buf(),
175 message: error.to_string(),
176 })?;
177 let path = entry.path();
178 if path.is_dir() {
179 collect_json_files(&path, out)?;
180 } else if path.extension().is_some_and(|ext| ext == "json") {
181 out.push(path);
182 }
183 }
184 Ok(())
185}
186
187fn validate_case(case: &ToolCallEvalCase, path: &Path) -> Result<(), ToolCallEvalDatasetError> {
188 if case.id.trim().is_empty() {
189 return validation_error(path, "case id must not be empty");
190 }
191 if case.prompt.trim().is_empty() {
192 return validation_error(path, format!("{}: prompt must not be empty", case.id));
193 }
194 let mut names = BTreeSet::new();
195 for tool in &case.tools {
196 if tool.name.trim().is_empty() {
197 return validation_error(path, format!("{}: tool name must not be empty", case.id));
198 }
199 if !names.insert(tool.name.as_str()) {
200 return validation_error(
201 path,
202 format!("{}: duplicate tool name `{}`", case.id, tool.name),
203 );
204 }
205 if !tool.parameters.is_object() {
206 return validation_error(
207 path,
208 format!(
209 "{}: tool `{}` parameters must be an object",
210 case.id, tool.name
211 ),
212 );
213 }
214 }
215 if let ExpectedToolCall::Exact { name, .. } = &case.expected {
216 if !names.contains(name.as_str()) {
217 return validation_error(
218 path,
219 format!("{}: expected tool `{name}` is not declared", case.id),
220 );
221 }
222 }
223 if let Some(rate) = case.baseline_pass_rate {
224 if !(0.0..=1.0).contains(&rate) {
225 return validation_error(
226 path,
227 format!("{}: baseline_pass_rate must be in [0, 1]", case.id),
228 );
229 }
230 }
231 Ok(())
232}
233
234fn validation_error<T>(
235 path: &Path,
236 message: impl Into<String>,
237) -> Result<T, ToolCallEvalDatasetError> {
238 Err(ToolCallEvalDatasetError::Validation {
239 path: path.to_path_buf(),
240 message: message.into(),
241 })
242}
243
244fn validate_unique_case_ids(
245 cases: &[ToolCallEvalCase],
246 path: &Path,
247) -> Result<(), ToolCallEvalDatasetError> {
248 let mut seen = BTreeSet::new();
249 for case in cases {
250 if !seen.insert(case.id.as_str()) {
251 return validation_error(path, format!("duplicate case id `{}`", case.id));
252 }
253 }
254 Ok(())
255}
256
257pub fn score_tool_call_case(
258 case: &ToolCallEvalCase,
259 observed: &ObservedToolCallOutcome,
260 predicate_verdict: Option<&PredicateJudgeVerdict>,
261) -> ToolCallScore {
262 match &case.expected {
263 ExpectedToolCall::Exact { name, args } => score_exact(name, args, observed),
264 ExpectedToolCall::Predicate { .. } => match predicate_verdict {
265 Some(verdict) => ToolCallScore {
266 passed: verdict.passed,
267 reason: if verdict.reason.is_empty() {
268 "predicate judge returned no reason".to_string()
269 } else {
270 verdict.reason.clone()
271 },
272 },
273 None => ToolCallScore {
274 passed: false,
275 reason: "predicate case was not judged".to_string(),
276 },
277 },
278 ExpectedToolCall::Refusal { reason_must_match } => {
279 score_refusal(reason_must_match, observed)
280 }
281 }
282}
283
284fn score_exact(name: &str, args: &JsonValue, observed: &ObservedToolCallOutcome) -> ToolCallScore {
285 let Some(call) = observed.tool_call.as_ref() else {
286 return ToolCallScore {
287 passed: false,
288 reason: format!("expected `{name}` tool call, observed no tool call"),
289 };
290 };
291 if call.name != name {
292 return ToolCallScore {
293 passed: false,
294 reason: format!("expected tool `{name}`, observed `{}`", call.name),
295 };
296 }
297 if !json_deep_equal_with_numeric_tolerance(args, &call.args) {
298 return ToolCallScore {
299 passed: false,
300 reason: format!("expected args {args}, observed {}", call.args),
301 };
302 }
303 ToolCallScore {
304 passed: true,
305 reason: format!("matched `{name}` and canonical arguments"),
306 }
307}
308
309fn score_refusal(pattern: &str, observed: &ObservedToolCallOutcome) -> ToolCallScore {
310 if let Some(call) = observed.tool_call.as_ref() {
311 return ToolCallScore {
312 passed: false,
313 reason: format!("expected refusal, observed tool `{}`", call.name),
314 };
315 }
316 match regex::Regex::new(pattern) {
317 Ok(regex) if regex.is_match(&observed.final_text) => ToolCallScore {
318 passed: true,
319 reason: "refusal text matched expected reason pattern".to_string(),
320 },
321 Ok(_) => ToolCallScore {
322 passed: false,
323 reason: format!(
324 "refusal text did not match `{pattern}`: {}",
325 observed.final_text
326 ),
327 },
328 Err(error) => ToolCallScore {
329 passed: false,
330 reason: format!("invalid refusal regex `{pattern}`: {error}"),
331 },
332 }
333}
334
335pub fn json_deep_equal_with_numeric_tolerance(left: &JsonValue, right: &JsonValue) -> bool {
336 match (left, right) {
337 (JsonValue::Null, JsonValue::Null) => true,
338 (JsonValue::Bool(left), JsonValue::Bool(right)) => left == right,
339 (JsonValue::String(left), JsonValue::String(right)) => left == right,
340 (JsonValue::Number(left), JsonValue::Number(right)) => {
341 match (left.as_f64(), right.as_f64()) {
342 (Some(left), Some(right)) => (left - right).abs() <= FLOAT_TOLERANCE,
343 _ => left == right,
344 }
345 }
346 (JsonValue::Array(left), JsonValue::Array(right)) => {
347 left.len() == right.len()
348 && left
349 .iter()
350 .zip(right)
351 .all(|(l, r)| json_deep_equal_with_numeric_tolerance(l, r))
352 }
353 (JsonValue::Object(left), JsonValue::Object(right)) => {
354 left.len() == right.len()
355 && left.iter().all(|(key, left_value)| {
356 right.get(key).is_some_and(|right_value| {
357 json_deep_equal_with_numeric_tolerance(left_value, right_value)
358 })
359 })
360 }
361 _ => false,
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use serde_json::json;
369
370 fn exact_case() -> ToolCallEvalCase {
371 ToolCallEvalCase {
372 id: "exact".to_string(),
373 prompt: "Add two numbers".to_string(),
374 tools: vec![ToolDef {
375 name: "add".to_string(),
376 description: String::new(),
377 parameters: json!({
378 "left": {"type": "integer"},
379 "right": {"type": "integer"}
380 }),
381 output_schema: None,
382 namespace: None,
383 defer_loading: None,
384 }],
385 expected: ExpectedToolCall::Exact {
386 name: "add".to_string(),
387 args: json!({"left": 2, "right": 3.0}),
388 },
389 baseline_pass_rate: None,
390 source: None,
391 tags: Vec::new(),
392 }
393 }
394
395 #[test]
396 fn exact_scoring_accepts_numeric_tolerance() {
397 let score = score_tool_call_case(
398 &exact_case(),
399 &ObservedToolCallOutcome {
400 tool_call: Some(ObservedToolCall {
401 name: "add".to_string(),
402 args: json!({"right": 3.0000001, "left": 2}),
403 }),
404 final_text: String::new(),
405 },
406 None,
407 );
408 assert!(score.passed, "{score:?}");
409 }
410
411 #[test]
412 fn exact_scoring_rejects_extra_args() {
413 let score = score_tool_call_case(
414 &exact_case(),
415 &ObservedToolCallOutcome {
416 tool_call: Some(ObservedToolCall {
417 name: "add".to_string(),
418 args: json!({"left": 2, "right": 3, "extra": true}),
419 }),
420 final_text: String::new(),
421 },
422 None,
423 );
424 assert!(!score.passed);
425 assert!(score.reason.contains("expected args"));
426 }
427
428 #[test]
429 fn refusal_requires_no_tool_and_matching_text() {
430 let case = ToolCallEvalCase {
431 id: "refusal".to_string(),
432 prompt: "Tell a joke".to_string(),
433 tools: Vec::new(),
434 expected: ExpectedToolCall::Refusal {
435 reason_must_match: "(?i)not.*available".to_string(),
436 },
437 baseline_pass_rate: None,
438 source: None,
439 tags: Vec::new(),
440 };
441 let score = score_tool_call_case(
442 &case,
443 &ObservedToolCallOutcome {
444 tool_call: None,
445 final_text: "That tool is not available for this request.".to_string(),
446 },
447 None,
448 );
449 assert!(score.passed, "{score:?}");
450 }
451
452 #[test]
453 fn dataset_loader_accepts_arrays() {
454 let tmp = tempfile::tempdir().unwrap();
455 let cases_dir = tmp.path().join("cases");
456 fs::create_dir(&cases_dir).unwrap();
457 fs::write(
458 cases_dir.join("cases.json"),
459 serde_json::to_string(&vec![exact_case()]).unwrap(),
460 )
461 .unwrap();
462 let loaded = load_tool_call_eval_dataset(tmp.path()).unwrap();
463 assert_eq!(loaded.len(), 1);
464 assert_eq!(loaded[0].id, "exact");
465 }
466}