Skip to main content

harn_vm/llm/eval/
tool_call_case.rs

1use std::collections::BTreeSet;
2use std::fmt;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value as JsonValue;
8
9const FLOAT_TOLERANCE: f64 = 1e-6;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub struct ToolCallEvalCase {
13    pub id: String,
14    pub prompt: String,
15    #[serde(default)]
16    pub tools: Vec<ToolDef>,
17    pub expected: ExpectedToolCall,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub baseline_pass_rate: Option<f64>,
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub source: Option<String>,
22    #[serde(default, skip_serializing_if = "Vec::is_empty")]
23    pub tags: Vec<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct ToolDef {
28    pub name: String,
29    #[serde(default)]
30    pub description: String,
31    /// Harn tool parameter schema: a map of parameter name to JSON Schema
32    /// fragments. `llm_call` wraps this into provider-native function schemas.
33    #[serde(default)]
34    pub parameters: JsonValue,
35    #[serde(
36        default,
37        skip_serializing_if = "Option::is_none",
38        rename = "outputSchema"
39    )]
40    pub output_schema: Option<JsonValue>,
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub namespace: Option<String>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub defer_loading: Option<bool>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48#[serde(tag = "kind", rename_all = "snake_case")]
49pub enum ExpectedToolCall {
50    Exact {
51        name: String,
52        args: JsonValue,
53    },
54    Predicate {
55        description: String,
56        judge_prompt: String,
57    },
58    Refusal {
59        reason_must_match: String,
60    },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
64pub struct ObservedToolCall {
65    pub name: String,
66    #[serde(default)]
67    pub args: JsonValue,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub struct ObservedToolCallOutcome {
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub tool_call: Option<ObservedToolCall>,
74    #[serde(default)]
75    pub final_text: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct PredicateJudgeVerdict {
80    pub passed: bool,
81    #[serde(default)]
82    pub reason: String,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ToolCallScore {
87    pub passed: bool,
88    pub reason: String,
89}
90
91#[derive(Debug)]
92pub enum ToolCallEvalDatasetError {
93    Io { path: PathBuf, message: String },
94    Json { path: PathBuf, message: String },
95    Validation { path: PathBuf, message: String },
96}
97
98impl fmt::Display for ToolCallEvalDatasetError {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::Io { path, message } => write!(f, "{}: {message}", path.display()),
102            Self::Json { path, message } => write!(f, "{}: {message}", path.display()),
103            Self::Validation { path, message } => write!(f, "{}: {message}", path.display()),
104        }
105    }
106}
107
108impl std::error::Error for ToolCallEvalDatasetError {}
109
110pub fn load_tool_call_eval_dataset(
111    path: &Path,
112) -> Result<Vec<ToolCallEvalCase>, ToolCallEvalDatasetError> {
113    let mut cases = Vec::new();
114    for file in tool_call_eval_case_files(path)? {
115        let raw = fs::read_to_string(&file).map_err(|error| ToolCallEvalDatasetError::Io {
116            path: file.clone(),
117            message: error.to_string(),
118        })?;
119        let value: JsonValue =
120            serde_json::from_str(&raw).map_err(|error| ToolCallEvalDatasetError::Json {
121                path: file.clone(),
122                message: error.to_string(),
123            })?;
124        let mut loaded = if value.is_array() {
125            serde_json::from_value::<Vec<ToolCallEvalCase>>(value).map_err(|error| {
126                ToolCallEvalDatasetError::Json {
127                    path: file.clone(),
128                    message: error.to_string(),
129                }
130            })?
131        } else {
132            vec![
133                serde_json::from_value::<ToolCallEvalCase>(value).map_err(|error| {
134                    ToolCallEvalDatasetError::Json {
135                        path: file.clone(),
136                        message: error.to_string(),
137                    }
138                })?,
139            ]
140        };
141        for case in &loaded {
142            validate_case(case, &file)?;
143        }
144        cases.append(&mut loaded);
145    }
146    cases.sort_by(|left, right| left.id.cmp(&right.id));
147    validate_unique_case_ids(&cases, path)?;
148    Ok(cases)
149}
150
151fn tool_call_eval_case_files(path: &Path) -> Result<Vec<PathBuf>, ToolCallEvalDatasetError> {
152    if path.is_file() {
153        return Ok(vec![path.to_path_buf()]);
154    }
155    let cases_dir = path.join("cases");
156    let root = if cases_dir.is_dir() {
157        cases_dir
158    } else {
159        path.to_path_buf()
160    };
161    let mut files = Vec::new();
162    collect_json_files(&root, &mut files)?;
163    files.sort();
164    Ok(files)
165}
166
167fn collect_json_files(dir: &Path, out: &mut Vec<PathBuf>) -> Result<(), ToolCallEvalDatasetError> {
168    let entries = fs::read_dir(dir).map_err(|error| ToolCallEvalDatasetError::Io {
169        path: dir.to_path_buf(),
170        message: error.to_string(),
171    })?;
172    for entry in entries {
173        let entry = entry.map_err(|error| ToolCallEvalDatasetError::Io {
174            path: dir.to_path_buf(),
175            message: error.to_string(),
176        })?;
177        let path = entry.path();
178        if path.is_dir() {
179            collect_json_files(&path, out)?;
180        } else if path.extension().is_some_and(|ext| ext == "json") {
181            out.push(path);
182        }
183    }
184    Ok(())
185}
186
187fn validate_case(case: &ToolCallEvalCase, path: &Path) -> Result<(), ToolCallEvalDatasetError> {
188    if case.id.trim().is_empty() {
189        return validation_error(path, "case id must not be empty");
190    }
191    if case.prompt.trim().is_empty() {
192        return validation_error(path, format!("{}: prompt must not be empty", case.id));
193    }
194    let mut names = BTreeSet::new();
195    for tool in &case.tools {
196        if tool.name.trim().is_empty() {
197            return validation_error(path, format!("{}: tool name must not be empty", case.id));
198        }
199        if !names.insert(tool.name.as_str()) {
200            return validation_error(
201                path,
202                format!("{}: duplicate tool name `{}`", case.id, tool.name),
203            );
204        }
205        if !tool.parameters.is_object() {
206            return validation_error(
207                path,
208                format!(
209                    "{}: tool `{}` parameters must be an object",
210                    case.id, tool.name
211                ),
212            );
213        }
214    }
215    if let ExpectedToolCall::Exact { name, .. } = &case.expected {
216        if !names.contains(name.as_str()) {
217            return validation_error(
218                path,
219                format!("{}: expected tool `{name}` is not declared", case.id),
220            );
221        }
222    }
223    if let Some(rate) = case.baseline_pass_rate {
224        if !(0.0..=1.0).contains(&rate) {
225            return validation_error(
226                path,
227                format!("{}: baseline_pass_rate must be in [0, 1]", case.id),
228            );
229        }
230    }
231    Ok(())
232}
233
234fn validation_error<T>(
235    path: &Path,
236    message: impl Into<String>,
237) -> Result<T, ToolCallEvalDatasetError> {
238    Err(ToolCallEvalDatasetError::Validation {
239        path: path.to_path_buf(),
240        message: message.into(),
241    })
242}
243
244fn validate_unique_case_ids(
245    cases: &[ToolCallEvalCase],
246    path: &Path,
247) -> Result<(), ToolCallEvalDatasetError> {
248    let mut seen = BTreeSet::new();
249    for case in cases {
250        if !seen.insert(case.id.as_str()) {
251            return validation_error(path, format!("duplicate case id `{}`", case.id));
252        }
253    }
254    Ok(())
255}
256
257pub fn score_tool_call_case(
258    case: &ToolCallEvalCase,
259    observed: &ObservedToolCallOutcome,
260    predicate_verdict: Option<&PredicateJudgeVerdict>,
261) -> ToolCallScore {
262    match &case.expected {
263        ExpectedToolCall::Exact { name, args } => score_exact(name, args, observed),
264        ExpectedToolCall::Predicate { .. } => match predicate_verdict {
265            Some(verdict) => ToolCallScore {
266                passed: verdict.passed,
267                reason: if verdict.reason.is_empty() {
268                    "predicate judge returned no reason".to_string()
269                } else {
270                    verdict.reason.clone()
271                },
272            },
273            None => ToolCallScore {
274                passed: false,
275                reason: "predicate case was not judged".to_string(),
276            },
277        },
278        ExpectedToolCall::Refusal { reason_must_match } => {
279            score_refusal(reason_must_match, observed)
280        }
281    }
282}
283
284fn score_exact(name: &str, args: &JsonValue, observed: &ObservedToolCallOutcome) -> ToolCallScore {
285    let Some(call) = observed.tool_call.as_ref() else {
286        return ToolCallScore {
287            passed: false,
288            reason: format!("expected `{name}` tool call, observed no tool call"),
289        };
290    };
291    if call.name != name {
292        return ToolCallScore {
293            passed: false,
294            reason: format!("expected tool `{name}`, observed `{}`", call.name),
295        };
296    }
297    if !json_deep_equal_with_numeric_tolerance(args, &call.args) {
298        return ToolCallScore {
299            passed: false,
300            reason: format!("expected args {args}, observed {}", call.args),
301        };
302    }
303    ToolCallScore {
304        passed: true,
305        reason: format!("matched `{name}` and canonical arguments"),
306    }
307}
308
309fn score_refusal(pattern: &str, observed: &ObservedToolCallOutcome) -> ToolCallScore {
310    if let Some(call) = observed.tool_call.as_ref() {
311        return ToolCallScore {
312            passed: false,
313            reason: format!("expected refusal, observed tool `{}`", call.name),
314        };
315    }
316    match regex::Regex::new(pattern) {
317        Ok(regex) if regex.is_match(&observed.final_text) => ToolCallScore {
318            passed: true,
319            reason: "refusal text matched expected reason pattern".to_string(),
320        },
321        Ok(_) => ToolCallScore {
322            passed: false,
323            reason: format!(
324                "refusal text did not match `{pattern}`: {}",
325                observed.final_text
326            ),
327        },
328        Err(error) => ToolCallScore {
329            passed: false,
330            reason: format!("invalid refusal regex `{pattern}`: {error}"),
331        },
332    }
333}
334
335pub fn json_deep_equal_with_numeric_tolerance(left: &JsonValue, right: &JsonValue) -> bool {
336    match (left, right) {
337        (JsonValue::Null, JsonValue::Null) => true,
338        (JsonValue::Bool(left), JsonValue::Bool(right)) => left == right,
339        (JsonValue::String(left), JsonValue::String(right)) => left == right,
340        (JsonValue::Number(left), JsonValue::Number(right)) => {
341            match (left.as_f64(), right.as_f64()) {
342                (Some(left), Some(right)) => (left - right).abs() <= FLOAT_TOLERANCE,
343                _ => left == right,
344            }
345        }
346        (JsonValue::Array(left), JsonValue::Array(right)) => {
347            left.len() == right.len()
348                && left
349                    .iter()
350                    .zip(right)
351                    .all(|(l, r)| json_deep_equal_with_numeric_tolerance(l, r))
352        }
353        (JsonValue::Object(left), JsonValue::Object(right)) => {
354            left.len() == right.len()
355                && left.iter().all(|(key, left_value)| {
356                    right.get(key).is_some_and(|right_value| {
357                        json_deep_equal_with_numeric_tolerance(left_value, right_value)
358                    })
359                })
360        }
361        _ => false,
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use serde_json::json;
369
370    fn exact_case() -> ToolCallEvalCase {
371        ToolCallEvalCase {
372            id: "exact".to_string(),
373            prompt: "Add two numbers".to_string(),
374            tools: vec![ToolDef {
375                name: "add".to_string(),
376                description: String::new(),
377                parameters: json!({
378                    "left": {"type": "integer"},
379                    "right": {"type": "integer"}
380                }),
381                output_schema: None,
382                namespace: None,
383                defer_loading: None,
384            }],
385            expected: ExpectedToolCall::Exact {
386                name: "add".to_string(),
387                args: json!({"left": 2, "right": 3.0}),
388            },
389            baseline_pass_rate: None,
390            source: None,
391            tags: Vec::new(),
392        }
393    }
394
395    #[test]
396    fn exact_scoring_accepts_numeric_tolerance() {
397        let score = score_tool_call_case(
398            &exact_case(),
399            &ObservedToolCallOutcome {
400                tool_call: Some(ObservedToolCall {
401                    name: "add".to_string(),
402                    args: json!({"right": 3.0000001, "left": 2}),
403                }),
404                final_text: String::new(),
405            },
406            None,
407        );
408        assert!(score.passed, "{score:?}");
409    }
410
411    #[test]
412    fn exact_scoring_rejects_extra_args() {
413        let score = score_tool_call_case(
414            &exact_case(),
415            &ObservedToolCallOutcome {
416                tool_call: Some(ObservedToolCall {
417                    name: "add".to_string(),
418                    args: json!({"left": 2, "right": 3, "extra": true}),
419                }),
420                final_text: String::new(),
421            },
422            None,
423        );
424        assert!(!score.passed);
425        assert!(score.reason.contains("expected args"));
426    }
427
428    #[test]
429    fn refusal_requires_no_tool_and_matching_text() {
430        let case = ToolCallEvalCase {
431            id: "refusal".to_string(),
432            prompt: "Tell a joke".to_string(),
433            tools: Vec::new(),
434            expected: ExpectedToolCall::Refusal {
435                reason_must_match: "(?i)not.*available".to_string(),
436            },
437            baseline_pass_rate: None,
438            source: None,
439            tags: Vec::new(),
440        };
441        let score = score_tool_call_case(
442            &case,
443            &ObservedToolCallOutcome {
444                tool_call: None,
445                final_text: "That tool is not available for this request.".to_string(),
446            },
447            None,
448        );
449        assert!(score.passed, "{score:?}");
450    }
451
452    #[test]
453    fn dataset_loader_accepts_arrays() {
454        let tmp = tempfile::tempdir().unwrap();
455        let cases_dir = tmp.path().join("cases");
456        fs::create_dir(&cases_dir).unwrap();
457        fs::write(
458            cases_dir.join("cases.json"),
459            serde_json::to_string(&vec![exact_case()]).unwrap(),
460        )
461        .unwrap();
462        let loaded = load_tool_call_eval_dataset(tmp.path()).unwrap();
463        assert_eq!(loaded.len(), 1);
464        assert_eq!(loaded[0].id, "exact");
465    }
466}