#![forbid(unsafe_code)]
use std::fmt::Write as _;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use crate::evaluator::Evaluator;
use crate::judge::{JudgeClient, JudgeError, JudgeVerdict};
use crate::score::Score;
use crate::types::{EvalCase, EvalMetricResult, Invocation, RecordedToolCall};
const DEFAULT_TIMEOUT: Duration = Duration::from_mins(5);
pub struct SemanticToolSelectionEvaluator {
judge: Arc<dyn JudgeClient>,
timeout: Duration,
}
impl SemanticToolSelectionEvaluator {
#[must_use]
pub fn new(judge: Arc<dyn JudgeClient>) -> Self {
Self {
judge,
timeout: DEFAULT_TIMEOUT,
}
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
}
impl Evaluator for SemanticToolSelectionEvaluator {
fn name(&self) -> &'static str {
"semantic_tool_selection"
}
fn evaluate(&self, case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
if !case.semantic_tool_selection {
return None;
}
let calls: Vec<(usize, &RecordedToolCall)> = invocation
.turns
.iter()
.flat_map(|turn| turn.tool_calls.iter().map(move |tc| (turn.turn_index, tc)))
.collect();
if calls.is_empty() {
return None;
}
let goal = goal_from_case(case);
let tool_menu = available_tool_menu(invocation);
let outcomes: Vec<CallOutcome> = drive_judge_calls(|| async {
let mut results = Vec::with_capacity(calls.len());
let mut history = String::new();
for (turn_index, call) in &calls {
let prompt = build_prompt(&goal, &tool_menu, &history, *turn_index, call);
let outcome = match timeout(self.timeout, self.judge.judge(&prompt)).await {
Ok(Ok(verdict)) => CallOutcome::Verdict {
tool: call.name.clone(),
verdict,
},
Ok(Err(err)) => CallOutcome::JudgeError {
tool: call.name.clone(),
error: err,
},
Err(_elapsed) => CallOutcome::OuterTimeout {
tool: call.name.clone(),
limit: self.timeout,
},
};
append_history(&mut history, *turn_index, call);
results.push(outcome);
}
results
});
Some(aggregate(&outcomes))
}
}
fn drive_judge_calls<F, Fut, T>(make_future: F) -> T
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = T>,
{
use tokio::runtime::{Handle, RuntimeFlavor};
if let Ok(handle) = Handle::try_current()
&& handle.runtime_flavor() == RuntimeFlavor::MultiThread
{
return tokio::task::block_in_place(|| handle.block_on(make_future()));
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("build current-thread runtime for judge calls");
rt.block_on(make_future())
}
enum CallOutcome {
Verdict { tool: String, verdict: JudgeVerdict },
JudgeError { tool: String, error: JudgeError },
OuterTimeout { tool: String, limit: Duration },
}
fn aggregate(outcomes: &[CallOutcome]) -> EvalMetricResult {
let mut verdict_scores: Vec<f64> = Vec::new();
let mut verdict_reasons: Vec<String> = Vec::new();
let mut had_failure = false;
let mut failure_details: Vec<String> = Vec::new();
for outcome in outcomes {
match outcome {
CallOutcome::Verdict { tool, verdict } => {
verdict_scores.push(verdict.score.clamp(0.0, 1.0));
let reason = verdict
.reason
.clone()
.unwrap_or_else(|| "no reason".to_string());
verdict_reasons.push(format!(
"{tool}: {status} ({reason})",
status = if verdict.pass { "pass" } else { "fail" }
));
if !verdict.pass {
had_failure = true;
}
}
CallOutcome::JudgeError { tool, error } => {
had_failure = true;
failure_details.push(format!(
"{tool}: judge error — {variant}: {error}",
variant = judge_error_variant(error),
));
}
CallOutcome::OuterTimeout { tool, limit } => {
had_failure = true;
failure_details.push(format!("{tool}: judge call exceeded {limit:?}"));
}
}
}
let score = if had_failure {
Score::fail()
} else {
let mean = if verdict_scores.is_empty() {
0.0
} else {
let total: f64 = verdict_scores.iter().sum();
#[allow(clippy::cast_precision_loss)]
let len_f = verdict_scores.len() as f64;
total / len_f
};
Score::new(mean, 0.5)
};
let mut details: Vec<String> = Vec::new();
if !verdict_reasons.is_empty() {
details.push(verdict_reasons.join("; "));
}
if !failure_details.is_empty() {
details.push(failure_details.join("; "));
}
EvalMetricResult {
evaluator_name: "semantic_tool_selection".to_string(),
score,
details: if details.is_empty() {
None
} else {
Some(details.join(" | "))
},
}
}
const fn judge_error_variant(err: &JudgeError) -> &'static str {
match err {
JudgeError::Transport(_) => "Transport",
JudgeError::Timeout => "Timeout",
JudgeError::MalformedResponse(_) => "MalformedResponse",
JudgeError::Other(_) => "Other",
}
}
fn goal_from_case(case: &EvalCase) -> String {
if case.user_messages.is_empty() {
"(no user goal provided)".to_string()
} else {
case.user_messages.join("\n")
}
}
fn available_tool_menu(invocation: &Invocation) -> String {
let mut seen: Vec<&str> = Vec::new();
for turn in &invocation.turns {
for call in &turn.tool_calls {
if !seen.contains(&call.name.as_str()) {
seen.push(call.name.as_str());
}
}
}
if seen.is_empty() {
"(none)".to_string()
} else {
seen.join(", ")
}
}
fn append_history(buf: &mut String, turn_index: usize, call: &RecordedToolCall) {
let args = serde_json::to_string(&call.arguments).unwrap_or_else(|_| "<unserializable>".into());
let name = &call.name;
let _ = writeln!(buf, "- turn {turn_index}: {name}({args})");
}
fn build_prompt(
goal: &str,
tool_menu: &str,
history: &str,
turn_index: usize,
call: &RecordedToolCall,
) -> String {
let args = serde_json::to_string(&call.arguments).unwrap_or_else(|_| "<unserializable>".into());
let history_section = if history.is_empty() {
"(no prior tool calls)".to_string()
} else {
history.to_string()
};
let name = &call.name;
format!(
"You are judging whether an agent's tool-selection decision was \
semantically appropriate.\n\n\
User goal:\n{goal}\n\n\
Tools the agent has been observed using on this run:\n{tool_menu}\n\n\
Session history so far:\n{history_section}\n\
Current tool call under review (turn {turn_index}):\n {name}({args})\n\n\
Decide whether the chosen tool is an appropriate selection for advancing the \
user goal given the history. Respond with a Pass/Fail verdict and a short \
reason.",
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration as StdDuration;
use swink_agent::{AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage};
use crate::testing::MockJudge;
use crate::types::{EvalCase, Invocation, TurnRecord};
fn simple_case() -> EvalCase {
EvalCase {
id: "c1".into(),
name: "C1".into(),
description: None,
system_prompt: "be helpful".into(),
user_messages: vec!["read the config".into()],
expected_trajectory: None,
expected_response: None,
expected_assertion: None,
expected_interactions: None,
few_shot_examples: vec![],
budget: None,
evaluators: vec![],
metadata: serde_json::Value::Null,
attachments: vec![],
session_id: None,
expected_environment_state: None,
expected_tool_intent: None,
semantic_tool_selection: true,
state_capture: None,
}
}
fn invocation_with_calls(names: &[&str]) -> Invocation {
let tool_calls: Vec<RecordedToolCall> = names
.iter()
.enumerate()
.map(|(i, n)| RecordedToolCall {
id: format!("id{i}"),
name: (*n).to_string(),
arguments: serde_json::json!({"k": i}),
})
.collect();
Invocation {
turns: vec![TurnRecord {
turn_index: 0,
assistant_message: AssistantMessage {
content: vec![ContentBlock::Text { text: "ok".into() }],
provider: "p".into(),
model_id: "m".into(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::Stop,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
},
tool_calls,
tool_results: vec![],
duration: StdDuration::from_millis(1),
}],
total_usage: Usage::default(),
total_cost: Cost::default(),
total_duration: StdDuration::from_millis(1),
final_response: Some("done".into()),
stop_reason: StopReason::Stop,
model: ModelSpec::new("p", "m"),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn returns_none_when_flag_disabled() {
let mut case = simple_case();
case.semantic_tool_selection = false;
let invocation = invocation_with_calls(&["read_file"]);
let judge: Arc<dyn JudgeClient> = Arc::new(MockJudge::always_pass());
let evaluator = SemanticToolSelectionEvaluator::new(judge);
assert!(evaluator.evaluate(&case, &invocation).is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn returns_none_when_trajectory_empty() {
let case = simple_case();
let mut invocation = invocation_with_calls(&[]);
invocation.turns[0].tool_calls.clear();
let judge: Arc<dyn JudgeClient> = Arc::new(MockJudge::always_pass());
let evaluator = SemanticToolSelectionEvaluator::new(judge);
assert!(evaluator.evaluate(&case, &invocation).is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn default_timeout_is_five_minutes() {
let judge: Arc<dyn JudgeClient> = Arc::new(MockJudge::always_pass());
let evaluator = SemanticToolSelectionEvaluator::new(judge);
assert_eq!(evaluator.timeout, Duration::from_mins(5));
}
}