1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 ",
38 )
39 .expect("static regex is valid")
40 });
41 RE.is_match(user_message)
42}
43
44fn default_gain(tool_name: &str) -> f32 {
49 if tool_name.starts_with("memory") {
50 return 0.8;
51 }
52 if tool_name.starts_with("mcp_") {
53 return 0.5;
54 }
55 match tool_name {
56 "bash" | "shell" => 0.6,
57 "read" | "write" => 0.55,
58 "search_code" | "grep" | "glob" => 0.65,
59 _ => 0.5,
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct UtilityScore {
66 pub gain: f32,
68 pub cost: f32,
70 pub redundancy: f32,
72 pub uncertainty: f32,
74 pub total: f32,
76}
77
78impl UtilityScore {
79 fn is_valid(&self) -> bool {
81 self.gain.is_finite()
82 && self.cost.is_finite()
83 && self.redundancy.is_finite()
84 && self.uncertainty.is_finite()
85 && self.total.is_finite()
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct UtilityContext {
92 pub tool_calls_this_turn: usize,
94 pub tokens_consumed: usize,
96 pub token_budget: usize,
98 pub user_requested: bool,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum UtilityAction {
107 Respond,
109 Retrieve,
111 ToolCall,
113 Verify,
115 Stop,
117}
118
119fn call_hash(call: &ToolCall) -> u64 {
121 let mut h = DefaultHasher::new();
122 call.tool_id.hash(&mut h);
123 format!("{:?}", call.params).hash(&mut h);
127 h.finish()
128}
129
130#[derive(Debug)]
135pub struct UtilityScorer {
136 config: UtilityScoringConfig,
137 recent_calls: HashMap<u64, u32>,
139}
140
141impl UtilityScorer {
142 #[must_use]
144 pub fn new(config: UtilityScoringConfig) -> Self {
145 Self {
146 config,
147 recent_calls: HashMap::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn is_enabled(&self) -> bool {
154 self.config.enabled
155 }
156
157 #[must_use]
163 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
164 if !self.config.enabled {
165 return None;
166 }
167
168 let gain = default_gain(call.tool_id.as_str());
169
170 let cost = if ctx.token_budget > 0 {
171 #[allow(clippy::cast_precision_loss)]
172 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
173 } else {
174 0.0
175 };
176
177 let hash = call_hash(call);
178 let redundancy = if self.recent_calls.contains_key(&hash) {
179 1.0_f32
180 } else {
181 0.0_f32
182 };
183
184 #[allow(clippy::cast_precision_loss)]
187 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
188
189 let total = self.config.gain_weight * gain
190 - self.config.cost_weight * cost
191 - self.config.redundancy_weight * redundancy
192 + self.config.uncertainty_bonus * uncertainty;
193
194 let score = UtilityScore {
195 gain,
196 cost,
197 redundancy,
198 uncertainty,
199 total,
200 };
201
202 if score.is_valid() { Some(score) } else { None }
203 }
204
205 #[must_use]
219 pub fn recommend_action(
220 &self,
221 score: Option<&UtilityScore>,
222 ctx: &UtilityContext,
223 ) -> UtilityAction {
224 if ctx.user_requested {
226 return UtilityAction::ToolCall;
227 }
228 if !self.config.enabled {
230 return UtilityAction::ToolCall;
231 }
232 let Some(s) = score else {
233 return UtilityAction::Stop;
235 };
236
237 if s.cost > 0.9 {
239 return UtilityAction::Stop;
240 }
241 if s.redundancy >= 1.0 {
243 return UtilityAction::Respond;
244 }
245 if s.gain >= 0.7 && s.total >= self.config.threshold {
247 return UtilityAction::ToolCall;
248 }
249 if s.gain >= 0.5 && s.uncertainty > 0.5 {
251 return UtilityAction::Retrieve;
252 }
253 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
255 return UtilityAction::Verify;
256 }
257 if s.total >= self.config.threshold {
259 return UtilityAction::ToolCall;
260 }
261 UtilityAction::Respond
262 }
263
264 pub fn record_call(&mut self, call: &ToolCall) {
269 let hash = call_hash(call);
270 *self.recent_calls.entry(hash).or_insert(0) += 1;
271 }
272
273 pub fn clear(&mut self) {
275 self.recent_calls.clear();
276 }
277
278 #[must_use]
282 pub fn is_exempt(&self, tool_name: &str) -> bool {
283 let lower = tool_name.to_lowercase();
284 self.config
285 .exempt_tools
286 .iter()
287 .any(|e| e.to_lowercase() == lower)
288 }
289
290 #[must_use]
292 pub fn threshold(&self) -> f32 {
293 self.config.threshold
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::ToolName;
301 use serde_json::json;
302
303 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
304 ToolCall {
305 tool_id: ToolName::new(name),
306 params: if let serde_json::Value::Object(m) = params {
307 m
308 } else {
309 serde_json::Map::new()
310 },
311 caller_id: None,
312 }
313 }
314
315 fn default_ctx() -> UtilityContext {
316 UtilityContext {
317 tool_calls_this_turn: 0,
318 tokens_consumed: 0,
319 token_budget: 1000,
320 user_requested: false,
321 }
322 }
323
324 fn default_config() -> UtilityScoringConfig {
325 UtilityScoringConfig {
326 enabled: true,
327 ..UtilityScoringConfig::default()
328 }
329 }
330
331 #[test]
332 fn disabled_returns_none() {
333 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
334 assert!(!scorer.is_enabled());
335 let call = make_call("bash", json!({}));
336 let score = scorer.score(&call, &default_ctx());
337 assert!(score.is_none());
338 assert_eq!(
340 scorer.recommend_action(score.as_ref(), &default_ctx()),
341 UtilityAction::ToolCall
342 );
343 }
344
345 #[test]
346 fn first_call_passes_default_threshold() {
347 let scorer = UtilityScorer::new(default_config());
348 let call = make_call("bash", json!({"cmd": "ls"}));
349 let score = scorer.score(&call, &default_ctx());
350 assert!(score.is_some());
351 let s = score.unwrap();
352 assert!(
353 s.total >= 0.1,
354 "first call should exceed threshold: {}",
355 s.total
356 );
357 let action = scorer.recommend_action(Some(&s), &default_ctx());
360 assert!(
361 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
362 "first call should not be blocked, got {action:?}",
363 );
364 }
365
366 #[test]
367 fn redundant_call_penalized() {
368 let mut scorer = UtilityScorer::new(default_config());
369 let call = make_call("bash", json!({"cmd": "ls"}));
370 scorer.record_call(&call);
371 let score = scorer.score(&call, &default_ctx()).unwrap();
372 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
373 }
374
375 #[test]
376 fn clear_resets_redundancy() {
377 let mut scorer = UtilityScorer::new(default_config());
378 let call = make_call("bash", json!({"cmd": "ls"}));
379 scorer.record_call(&call);
380 scorer.clear();
381 let score = scorer.score(&call, &default_ctx()).unwrap();
382 assert!(score.redundancy.abs() < f32::EPSILON);
383 }
384
385 #[test]
386 fn user_requested_always_executes() {
387 let scorer = UtilityScorer::new(default_config());
388 let score = UtilityScore {
390 gain: 0.0,
391 cost: 1.0,
392 redundancy: 1.0,
393 uncertainty: 0.0,
394 total: -100.0,
395 };
396 let ctx = UtilityContext {
397 user_requested: true,
398 ..default_ctx()
399 };
400 assert_eq!(
401 scorer.recommend_action(Some(&score), &ctx),
402 UtilityAction::ToolCall
403 );
404 }
405
406 #[test]
407 fn none_score_fail_closed_when_enabled() {
408 let scorer = UtilityScorer::new(default_config());
409 assert_eq!(
411 scorer.recommend_action(None, &default_ctx()),
412 UtilityAction::Stop
413 );
414 }
415
416 #[test]
417 fn none_score_executes_when_disabled() {
418 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
420 scorer.recommend_action(None, &default_ctx()),
421 UtilityAction::ToolCall
422 );
423 }
424
425 #[test]
426 fn cost_increases_with_token_consumption() {
427 let scorer = UtilityScorer::new(default_config());
428 let call = make_call("bash", json!({}));
429 let ctx_low = UtilityContext {
430 tokens_consumed: 100,
431 token_budget: 1000,
432 ..default_ctx()
433 };
434 let ctx_high = UtilityContext {
435 tokens_consumed: 900,
436 token_budget: 1000,
437 ..default_ctx()
438 };
439 let s_low = scorer.score(&call, &ctx_low).unwrap();
440 let s_high = scorer.score(&call, &ctx_high).unwrap();
441 assert!(s_low.cost < s_high.cost);
442 assert!(s_low.total > s_high.total);
443 }
444
445 #[test]
446 fn uncertainty_decreases_with_call_count() {
447 let scorer = UtilityScorer::new(default_config());
448 let call = make_call("bash", json!({}));
449 let ctx_early = UtilityContext {
450 tool_calls_this_turn: 0,
451 ..default_ctx()
452 };
453 let ctx_late = UtilityContext {
454 tool_calls_this_turn: 9,
455 ..default_ctx()
456 };
457 let s_early = scorer.score(&call, &ctx_early).unwrap();
458 let s_late = scorer.score(&call, &ctx_late).unwrap();
459 assert!(s_early.uncertainty > s_late.uncertainty);
460 }
461
462 #[test]
463 fn memory_tool_has_higher_gain_than_scrape() {
464 let scorer = UtilityScorer::new(default_config());
465 let mem_call = make_call("memory_search", json!({}));
466 let web_call = make_call("scrape", json!({}));
467 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
468 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
469 assert!(s_mem.gain > s_web.gain);
470 }
471
472 #[test]
473 fn zero_token_budget_zeroes_cost() {
474 let scorer = UtilityScorer::new(default_config());
475 let call = make_call("bash", json!({}));
476 let ctx = UtilityContext {
477 tokens_consumed: 500,
478 token_budget: 0,
479 ..default_ctx()
480 };
481 let s = scorer.score(&call, &ctx).unwrap();
482 assert!(s.cost.abs() < f32::EPSILON);
483 }
484
485 #[test]
486 fn validate_rejects_negative_weights() {
487 let cfg = UtilityScoringConfig {
488 enabled: true,
489 gain_weight: -1.0,
490 ..UtilityScoringConfig::default()
491 };
492 assert!(cfg.validate().is_err());
493 }
494
495 #[test]
496 fn validate_rejects_nan_weights() {
497 let cfg = UtilityScoringConfig {
498 enabled: true,
499 threshold: f32::NAN,
500 ..UtilityScoringConfig::default()
501 };
502 assert!(cfg.validate().is_err());
503 }
504
505 #[test]
506 fn validate_accepts_default() {
507 assert!(UtilityScoringConfig::default().validate().is_ok());
508 }
509
510 #[test]
511 fn threshold_zero_all_calls_pass() {
512 let scorer = UtilityScorer::new(UtilityScoringConfig {
514 enabled: true,
515 threshold: 0.0,
516 ..UtilityScoringConfig::default()
517 });
518 let call = make_call("bash", json!({}));
519 let score = scorer.score(&call, &default_ctx()).unwrap();
520 assert!(
522 score.total >= 0.0,
523 "total should be non-negative: {}",
524 score.total
525 );
526 let action = scorer.recommend_action(Some(&score), &default_ctx());
528 assert!(
529 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
530 "threshold=0 should not block calls, got {action:?}",
531 );
532 }
533
534 #[test]
535 fn threshold_one_blocks_all_calls() {
536 let scorer = UtilityScorer::new(UtilityScoringConfig {
538 enabled: true,
539 threshold: 1.0,
540 ..UtilityScoringConfig::default()
541 });
542 let call = make_call("bash", json!({}));
543 let score = scorer.score(&call, &default_ctx()).unwrap();
544 assert!(
545 score.total < 1.0,
546 "realistic score should be below 1.0: {}",
547 score.total
548 );
549 assert_ne!(
551 scorer.recommend_action(Some(&score), &default_ctx()),
552 UtilityAction::ToolCall
553 );
554 }
555
556 #[test]
559 fn recommend_action_user_requested_always_tool_call() {
560 let scorer = UtilityScorer::new(default_config());
561 let score = UtilityScore {
562 gain: 0.0,
563 cost: 1.0,
564 redundancy: 1.0,
565 uncertainty: 0.0,
566 total: -100.0,
567 };
568 let ctx = UtilityContext {
569 user_requested: true,
570 ..default_ctx()
571 };
572 assert_eq!(
573 scorer.recommend_action(Some(&score), &ctx),
574 UtilityAction::ToolCall
575 );
576 }
577
578 #[test]
579 fn recommend_action_disabled_scorer_always_tool_call() {
580 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
582 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
583 }
584
585 #[test]
586 fn recommend_action_none_score_enabled_stops() {
587 let scorer = UtilityScorer::new(default_config());
588 let ctx = default_ctx();
589 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
590 }
591
592 #[test]
593 fn recommend_action_budget_exhausted_stops() {
594 let scorer = UtilityScorer::new(default_config());
595 let score = UtilityScore {
596 gain: 0.8,
597 cost: 0.95,
598 redundancy: 0.0,
599 uncertainty: 0.5,
600 total: 0.5,
601 };
602 assert_eq!(
603 scorer.recommend_action(Some(&score), &default_ctx()),
604 UtilityAction::Stop
605 );
606 }
607
608 #[test]
609 fn recommend_action_redundant_responds() {
610 let scorer = UtilityScorer::new(default_config());
611 let score = UtilityScore {
612 gain: 0.8,
613 cost: 0.1,
614 redundancy: 1.0,
615 uncertainty: 0.5,
616 total: 0.5,
617 };
618 assert_eq!(
619 scorer.recommend_action(Some(&score), &default_ctx()),
620 UtilityAction::Respond
621 );
622 }
623
624 #[test]
625 fn recommend_action_high_gain_above_threshold_tool_call() {
626 let scorer = UtilityScorer::new(default_config());
627 let score = UtilityScore {
628 gain: 0.8,
629 cost: 0.1,
630 redundancy: 0.0,
631 uncertainty: 0.4,
632 total: 0.6,
633 };
634 assert_eq!(
635 scorer.recommend_action(Some(&score), &default_ctx()),
636 UtilityAction::ToolCall
637 );
638 }
639
640 #[test]
641 fn recommend_action_uncertain_retrieves() {
642 let scorer = UtilityScorer::new(default_config());
643 let score = UtilityScore {
645 gain: 0.6,
646 cost: 0.1,
647 redundancy: 0.0,
648 uncertainty: 0.8,
649 total: 0.4,
650 };
651 assert_eq!(
652 scorer.recommend_action(Some(&score), &default_ctx()),
653 UtilityAction::Retrieve
654 );
655 }
656
657 #[test]
658 fn recommend_action_below_threshold_with_prior_calls_verifies() {
659 let scorer = UtilityScorer::new(default_config());
660 let score = UtilityScore {
661 gain: 0.3,
662 cost: 0.1,
663 redundancy: 0.0,
664 uncertainty: 0.2,
665 total: 0.05, };
667 let ctx = UtilityContext {
668 tool_calls_this_turn: 1,
669 ..default_ctx()
670 };
671 assert_eq!(
672 scorer.recommend_action(Some(&score), &ctx),
673 UtilityAction::Verify
674 );
675 }
676
677 #[test]
678 fn recommend_action_default_responds() {
679 let scorer = UtilityScorer::new(default_config());
680 let score = UtilityScore {
681 gain: 0.3,
682 cost: 0.1,
683 redundancy: 0.0,
684 uncertainty: 0.2,
685 total: 0.05, };
687 let ctx = UtilityContext {
688 tool_calls_this_turn: 0,
689 ..default_ctx()
690 };
691 assert_eq!(
692 scorer.recommend_action(Some(&score), &ctx),
693 UtilityAction::Respond
694 );
695 }
696
697 #[test]
700 fn explicit_request_using_a_tool() {
701 assert!(has_explicit_tool_request(
702 "Please list the files in the current directory using a tool"
703 ));
704 }
705
706 #[test]
707 fn explicit_request_call_the_tool() {
708 assert!(has_explicit_tool_request("call the list_directory tool"));
709 }
710
711 #[test]
712 fn explicit_request_use_the_tool() {
713 assert!(has_explicit_tool_request("use the shell tool to run ls"));
714 }
715
716 #[test]
717 fn explicit_request_run_the_tool() {
718 assert!(has_explicit_tool_request("run the bash tool"));
719 }
720
721 #[test]
722 fn explicit_request_invoke_the_tool() {
723 assert!(has_explicit_tool_request("invoke the search_code tool"));
724 }
725
726 #[test]
727 fn explicit_request_execute_the_tool() {
728 assert!(has_explicit_tool_request("execute the grep tool for me"));
729 }
730
731 #[test]
732 fn explicit_request_case_insensitive() {
733 assert!(has_explicit_tool_request("USING A TOOL to find files"));
734 }
735
736 #[test]
737 fn explicit_request_no_match_plain_message() {
738 assert!(!has_explicit_tool_request("what is the weather today?"));
739 }
740
741 #[test]
742 fn explicit_request_no_match_tool_mentioned_without_invocation() {
743 assert!(!has_explicit_tool_request(
744 "the shell tool is very useful in general"
745 ));
746 }
747
748 #[test]
749 fn is_exempt_matches_case_insensitively() {
750 let scorer = UtilityScorer::new(UtilityScoringConfig {
751 enabled: true,
752 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
753 ..UtilityScoringConfig::default()
754 });
755 assert!(scorer.is_exempt("read"));
756 assert!(scorer.is_exempt("READ"));
757 assert!(scorer.is_exempt("FILE_READ"));
758 assert!(!scorer.is_exempt("write"));
759 assert!(!scorer.is_exempt("bash"));
760 }
761
762 #[test]
763 fn is_exempt_empty_list_returns_false() {
764 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
765 assert!(!scorer.is_exempt("read"));
766 }
767}