use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::LazyLock;
use regex::Regex;
use crate::config::UtilityScoringConfig;
use crate::executor::ToolCall;
#[must_use]
pub fn has_explicit_tool_request(user_message: &str) -> bool {
static RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(?xi)
using\s+a\s+tool
| call\s+(the\s+)?[a-z_]+\s+tool
| use\s+(the\s+)?[a-z_]+\s+tool
| run\s+(the\s+)?[a-z_]+\s+tool
| invoke\s+(the\s+)?[a-z_]+\s+tool
| execute\s+(the\s+)?[a-z_]+\s+tool
| show\s+me\s+the\s+result\s+of\s*:
| run\s*:
| execute\s*:
| what\s+(does|would|is\s+the\s+output\s+of)
",
)
.expect("static regex is valid")
});
static RE_CODE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"`[^`]*[|><$;&][^`]*`").expect("static regex is valid"));
RE.is_match(user_message) || RE_CODE.is_match(user_message)
}
fn default_gain(tool_name: &str) -> f32 {
if tool_name.starts_with("memory") {
return 0.8;
}
if tool_name.starts_with("mcp_") {
return 0.5;
}
match tool_name {
"bash" | "shell" => 0.6,
"read" | "write" => 0.55,
"search_code" | "grep" | "glob" => 0.65,
_ => 0.5,
}
}
#[derive(Debug, Clone)]
pub struct UtilityScore {
pub gain: f32,
pub cost: f32,
pub redundancy: f32,
pub uncertainty: f32,
pub total: f32,
}
impl UtilityScore {
fn is_valid(&self) -> bool {
self.gain.is_finite()
&& self.cost.is_finite()
&& self.redundancy.is_finite()
&& self.uncertainty.is_finite()
&& self.total.is_finite()
}
}
#[derive(Debug, Clone)]
pub struct UtilityContext {
pub tool_calls_this_turn: usize,
pub tokens_consumed: usize,
pub token_budget: usize,
pub user_requested: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UtilityAction {
Respond,
Retrieve,
ToolCall,
Verify,
Stop,
}
fn call_hash(call: &ToolCall) -> u64 {
let mut h = DefaultHasher::new();
call.tool_id.hash(&mut h);
format!("{:?}", call.params).hash(&mut h);
h.finish()
}
#[derive(Debug)]
pub struct UtilityScorer {
config: UtilityScoringConfig,
recent_calls: HashMap<u64, u32>,
}
impl UtilityScorer {
#[must_use]
pub fn new(config: UtilityScoringConfig) -> Self {
Self {
config,
recent_calls: HashMap::new(),
}
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
#[must_use]
pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
if !self.config.enabled {
return None;
}
let gain = default_gain(call.tool_id.as_str());
let cost = if ctx.token_budget > 0 {
#[allow(clippy::cast_precision_loss)]
(ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
} else {
0.0
};
let hash = call_hash(call);
let redundancy = if self.recent_calls.contains_key(&hash) {
1.0_f32
} else {
0.0_f32
};
#[allow(clippy::cast_precision_loss)]
let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
let total = self.config.gain_weight * gain
- self.config.cost_weight * cost
- self.config.redundancy_weight * redundancy
+ self.config.uncertainty_bonus * uncertainty;
let score = UtilityScore {
gain,
cost,
redundancy,
uncertainty,
total,
};
if score.is_valid() { Some(score) } else { None }
}
#[must_use]
pub fn recommend_action(
&self,
score: Option<&UtilityScore>,
ctx: &UtilityContext,
) -> UtilityAction {
if ctx.user_requested {
return UtilityAction::ToolCall;
}
if !self.config.enabled {
return UtilityAction::ToolCall;
}
let Some(s) = score else {
return UtilityAction::Stop;
};
if s.cost > 0.9 {
return UtilityAction::Stop;
}
if s.redundancy >= 1.0 {
return UtilityAction::Respond;
}
if s.gain >= 0.7 && s.total >= self.config.threshold {
return UtilityAction::ToolCall;
}
if s.gain >= 0.5 && s.uncertainty > 0.5 {
return UtilityAction::Retrieve;
}
if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
return UtilityAction::Verify;
}
if s.total >= self.config.threshold {
return UtilityAction::ToolCall;
}
UtilityAction::Respond
}
pub fn record_call(&mut self, call: &ToolCall) {
let hash = call_hash(call);
*self.recent_calls.entry(hash).or_insert(0) += 1;
}
pub fn clear(&mut self) {
self.recent_calls.clear();
}
#[must_use]
pub fn is_exempt(&self, tool_name: &str) -> bool {
let lower = tool_name.to_lowercase();
self.config
.exempt_tools
.iter()
.any(|e| e.to_lowercase() == lower)
}
#[must_use]
pub fn threshold(&self) -> f32 {
self.config.threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ToolName;
use serde_json::json;
fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
ToolCall {
tool_id: ToolName::new(name),
params: if let serde_json::Value::Object(m) = params {
m
} else {
serde_json::Map::new()
},
caller_id: None,
}
}
fn default_ctx() -> UtilityContext {
UtilityContext {
tool_calls_this_turn: 0,
tokens_consumed: 0,
token_budget: 1000,
user_requested: false,
}
}
fn default_config() -> UtilityScoringConfig {
UtilityScoringConfig {
enabled: true,
..UtilityScoringConfig::default()
}
}
#[test]
fn disabled_returns_none() {
let scorer = UtilityScorer::new(UtilityScoringConfig::default());
assert!(!scorer.is_enabled());
let call = make_call("bash", json!({}));
let score = scorer.score(&call, &default_ctx());
assert!(score.is_none());
assert_eq!(
scorer.recommend_action(score.as_ref(), &default_ctx()),
UtilityAction::ToolCall
);
}
#[test]
fn first_call_passes_default_threshold() {
let scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({"cmd": "ls"}));
let score = scorer.score(&call, &default_ctx());
assert!(score.is_some());
let s = score.unwrap();
assert!(
s.total >= 0.1,
"first call should exceed threshold: {}",
s.total
);
let action = scorer.recommend_action(Some(&s), &default_ctx());
assert!(
action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
"first call should not be blocked, got {action:?}",
);
}
#[test]
fn redundant_call_penalized() {
let mut scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({"cmd": "ls"}));
scorer.record_call(&call);
let score = scorer.score(&call, &default_ctx()).unwrap();
assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
}
#[test]
fn clear_resets_redundancy() {
let mut scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({"cmd": "ls"}));
scorer.record_call(&call);
scorer.clear();
let score = scorer.score(&call, &default_ctx()).unwrap();
assert!(score.redundancy.abs() < f32::EPSILON);
}
#[test]
fn user_requested_always_executes() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.0,
cost: 1.0,
redundancy: 1.0,
uncertainty: 0.0,
total: -100.0,
};
let ctx = UtilityContext {
user_requested: true,
..default_ctx()
};
assert_eq!(
scorer.recommend_action(Some(&score), &ctx),
UtilityAction::ToolCall
);
}
#[test]
fn none_score_fail_closed_when_enabled() {
let scorer = UtilityScorer::new(default_config());
assert_eq!(
scorer.recommend_action(None, &default_ctx()),
UtilityAction::Stop
);
}
#[test]
fn none_score_executes_when_disabled() {
let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
scorer.recommend_action(None, &default_ctx()),
UtilityAction::ToolCall
);
}
#[test]
fn cost_increases_with_token_consumption() {
let scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({}));
let ctx_low = UtilityContext {
tokens_consumed: 100,
token_budget: 1000,
..default_ctx()
};
let ctx_high = UtilityContext {
tokens_consumed: 900,
token_budget: 1000,
..default_ctx()
};
let s_low = scorer.score(&call, &ctx_low).unwrap();
let s_high = scorer.score(&call, &ctx_high).unwrap();
assert!(s_low.cost < s_high.cost);
assert!(s_low.total > s_high.total);
}
#[test]
fn uncertainty_decreases_with_call_count() {
let scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({}));
let ctx_early = UtilityContext {
tool_calls_this_turn: 0,
..default_ctx()
};
let ctx_late = UtilityContext {
tool_calls_this_turn: 9,
..default_ctx()
};
let s_early = scorer.score(&call, &ctx_early).unwrap();
let s_late = scorer.score(&call, &ctx_late).unwrap();
assert!(s_early.uncertainty > s_late.uncertainty);
}
#[test]
fn memory_tool_has_higher_gain_than_scrape() {
let scorer = UtilityScorer::new(default_config());
let mem_call = make_call("memory_search", json!({}));
let web_call = make_call("scrape", json!({}));
let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
assert!(s_mem.gain > s_web.gain);
}
#[test]
fn zero_token_budget_zeroes_cost() {
let scorer = UtilityScorer::new(default_config());
let call = make_call("bash", json!({}));
let ctx = UtilityContext {
tokens_consumed: 500,
token_budget: 0,
..default_ctx()
};
let s = scorer.score(&call, &ctx).unwrap();
assert!(s.cost.abs() < f32::EPSILON);
}
#[test]
fn validate_rejects_negative_weights() {
let cfg = UtilityScoringConfig {
enabled: true,
gain_weight: -1.0,
..UtilityScoringConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn validate_rejects_nan_weights() {
let cfg = UtilityScoringConfig {
enabled: true,
threshold: f32::NAN,
..UtilityScoringConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn validate_accepts_default() {
assert!(UtilityScoringConfig::default().validate().is_ok());
}
#[test]
fn threshold_zero_all_calls_pass() {
let scorer = UtilityScorer::new(UtilityScoringConfig {
enabled: true,
threshold: 0.0,
..UtilityScoringConfig::default()
});
let call = make_call("bash", json!({}));
let score = scorer.score(&call, &default_ctx()).unwrap();
assert!(
score.total >= 0.0,
"total should be non-negative: {}",
score.total
);
let action = scorer.recommend_action(Some(&score), &default_ctx());
assert!(
action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
"threshold=0 should not block calls, got {action:?}",
);
}
#[test]
fn threshold_one_blocks_all_calls() {
let scorer = UtilityScorer::new(UtilityScoringConfig {
enabled: true,
threshold: 1.0,
..UtilityScoringConfig::default()
});
let call = make_call("bash", json!({}));
let score = scorer.score(&call, &default_ctx()).unwrap();
assert!(
score.total < 1.0,
"realistic score should be below 1.0: {}",
score.total
);
assert_ne!(
scorer.recommend_action(Some(&score), &default_ctx()),
UtilityAction::ToolCall
);
}
#[test]
fn recommend_action_user_requested_always_tool_call() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.0,
cost: 1.0,
redundancy: 1.0,
uncertainty: 0.0,
total: -100.0,
};
let ctx = UtilityContext {
user_requested: true,
..default_ctx()
};
assert_eq!(
scorer.recommend_action(Some(&score), &ctx),
UtilityAction::ToolCall
);
}
#[test]
fn recommend_action_disabled_scorer_always_tool_call() {
let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
}
#[test]
fn recommend_action_none_score_enabled_stops() {
let scorer = UtilityScorer::new(default_config());
let ctx = default_ctx();
assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
}
#[test]
fn recommend_action_budget_exhausted_stops() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.8,
cost: 0.95,
redundancy: 0.0,
uncertainty: 0.5,
total: 0.5,
};
assert_eq!(
scorer.recommend_action(Some(&score), &default_ctx()),
UtilityAction::Stop
);
}
#[test]
fn recommend_action_redundant_responds() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.8,
cost: 0.1,
redundancy: 1.0,
uncertainty: 0.5,
total: 0.5,
};
assert_eq!(
scorer.recommend_action(Some(&score), &default_ctx()),
UtilityAction::Respond
);
}
#[test]
fn recommend_action_high_gain_above_threshold_tool_call() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.8,
cost: 0.1,
redundancy: 0.0,
uncertainty: 0.4,
total: 0.6,
};
assert_eq!(
scorer.recommend_action(Some(&score), &default_ctx()),
UtilityAction::ToolCall
);
}
#[test]
fn recommend_action_uncertain_retrieves() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.6,
cost: 0.1,
redundancy: 0.0,
uncertainty: 0.8,
total: 0.4,
};
assert_eq!(
scorer.recommend_action(Some(&score), &default_ctx()),
UtilityAction::Retrieve
);
}
#[test]
fn recommend_action_below_threshold_with_prior_calls_verifies() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.3,
cost: 0.1,
redundancy: 0.0,
uncertainty: 0.2,
total: 0.05, };
let ctx = UtilityContext {
tool_calls_this_turn: 1,
..default_ctx()
};
assert_eq!(
scorer.recommend_action(Some(&score), &ctx),
UtilityAction::Verify
);
}
#[test]
fn recommend_action_default_responds() {
let scorer = UtilityScorer::new(default_config());
let score = UtilityScore {
gain: 0.3,
cost: 0.1,
redundancy: 0.0,
uncertainty: 0.2,
total: 0.05, };
let ctx = UtilityContext {
tool_calls_this_turn: 0,
..default_ctx()
};
assert_eq!(
scorer.recommend_action(Some(&score), &ctx),
UtilityAction::Respond
);
}
#[test]
fn explicit_request_using_a_tool() {
assert!(has_explicit_tool_request(
"Please list the files in the current directory using a tool"
));
}
#[test]
fn explicit_request_call_the_tool() {
assert!(has_explicit_tool_request("call the list_directory tool"));
}
#[test]
fn explicit_request_use_the_tool() {
assert!(has_explicit_tool_request("use the shell tool to run ls"));
}
#[test]
fn explicit_request_run_the_tool() {
assert!(has_explicit_tool_request("run the bash tool"));
}
#[test]
fn explicit_request_invoke_the_tool() {
assert!(has_explicit_tool_request("invoke the search_code tool"));
}
#[test]
fn explicit_request_execute_the_tool() {
assert!(has_explicit_tool_request("execute the grep tool for me"));
}
#[test]
fn explicit_request_case_insensitive() {
assert!(has_explicit_tool_request("USING A TOOL to find files"));
}
#[test]
fn explicit_request_no_match_plain_message() {
assert!(!has_explicit_tool_request("what is the weather today?"));
}
#[test]
fn explicit_request_no_match_tool_mentioned_without_invocation() {
assert!(!has_explicit_tool_request(
"the shell tool is very useful in general"
));
}
#[test]
fn explicit_request_show_me_result_of() {
assert!(has_explicit_tool_request(
"show me the result of: echo hello"
));
}
#[test]
fn explicit_request_run_colon() {
assert!(has_explicit_tool_request("run: echo hello"));
}
#[test]
fn explicit_request_execute_colon() {
assert!(has_explicit_tool_request("execute: ls -la"));
}
#[test]
fn explicit_request_what_does() {
assert!(has_explicit_tool_request("what does echo hello output?"));
}
#[test]
fn explicit_request_what_would() {
assert!(has_explicit_tool_request("what would cat /etc/hosts show?"));
}
#[test]
fn explicit_request_what_is_the_output_of() {
assert!(has_explicit_tool_request(
"what is the output of ls | grep foo?"
));
}
#[test]
fn explicit_request_inline_code_pipe() {
assert!(has_explicit_tool_request("try running `ls | grep foo`"));
}
#[test]
fn explicit_request_inline_code_redirect() {
assert!(has_explicit_tool_request("run `echo hello > /tmp/out`"));
}
#[test]
fn explicit_request_inline_code_dollar() {
assert!(has_explicit_tool_request("check `$HOME/bin`"));
}
#[test]
fn explicit_request_inline_code_and() {
assert!(has_explicit_tool_request("try `git fetch && git rebase`"));
}
#[test]
fn no_match_run_the_tests() {
assert!(!has_explicit_tool_request("run the tests please"));
}
#[test]
fn no_match_execute_the_plan() {
assert!(!has_explicit_tool_request("execute the plan we discussed"));
}
#[test]
fn no_match_inline_code_no_shell_syntax() {
assert!(!has_explicit_tool_request(
"the function `process_items` handles it"
));
}
#[test]
fn known_fp_what_does_function_do() {
assert!(has_explicit_tool_request("what does this function do?"));
}
#[test]
fn no_match_show_me_result_without_colon() {
assert!(!has_explicit_tool_request(
"show me the result of running it"
));
}
#[test]
fn is_exempt_matches_case_insensitively() {
let scorer = UtilityScorer::new(UtilityScoringConfig {
enabled: true,
exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
..UtilityScoringConfig::default()
});
assert!(scorer.is_exempt("read"));
assert!(scorer.is_exempt("READ"));
assert!(scorer.is_exempt("FILE_READ"));
assert!(!scorer.is_exempt("write"));
assert!(!scorer.is_exempt("bash"));
}
#[test]
fn is_exempt_empty_list_returns_false() {
let scorer = UtilityScorer::new(UtilityScoringConfig::default());
assert!(!scorer.is_exempt("read"));
}
}