1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 | show\s+me\s+the\s+result\s+of\s*:
38 | run\s*:
39 | execute\s*:
40 | what\s+(does|would|is\s+the\s+output\s+of)
41 ",
42 )
43 .expect("static regex is valid")
44 });
45 static RE_CODE: LazyLock<Regex> =
48 LazyLock::new(|| Regex::new(r"`[^`]*[|><$;&][^`]*`").expect("static regex is valid"));
49 RE.is_match(user_message) || RE_CODE.is_match(user_message)
50}
51
52fn default_gain(tool_name: &str) -> f32 {
57 if tool_name.starts_with("memory") {
58 return 0.8;
59 }
60 if tool_name.starts_with("mcp_") {
61 return 0.5;
62 }
63 match tool_name {
64 "bash" | "shell" => 0.6,
65 "read" | "write" => 0.55,
66 "search_code" | "grep" | "glob" => 0.65,
67 _ => 0.5,
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UtilityScore {
74 pub gain: f32,
76 pub cost: f32,
78 pub redundancy: f32,
80 pub uncertainty: f32,
82 pub total: f32,
84}
85
86impl UtilityScore {
87 fn is_valid(&self) -> bool {
89 self.gain.is_finite()
90 && self.cost.is_finite()
91 && self.redundancy.is_finite()
92 && self.uncertainty.is_finite()
93 && self.total.is_finite()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct UtilityContext {
100 pub tool_calls_this_turn: usize,
102 pub tokens_consumed: usize,
104 pub token_budget: usize,
106 pub user_requested: bool,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum UtilityAction {
115 Respond,
117 Retrieve,
119 ToolCall,
121 Verify,
123 Stop,
125}
126
127fn call_hash(call: &ToolCall) -> u64 {
129 let mut h = DefaultHasher::new();
130 call.tool_id.hash(&mut h);
131 format!("{:?}", call.params).hash(&mut h);
135 h.finish()
136}
137
138#[derive(Debug)]
143pub struct UtilityScorer {
144 config: UtilityScoringConfig,
145 recent_calls: HashMap<u64, u32>,
147}
148
149impl UtilityScorer {
150 #[must_use]
152 pub fn new(config: UtilityScoringConfig) -> Self {
153 Self {
154 config,
155 recent_calls: HashMap::new(),
156 }
157 }
158
159 #[must_use]
161 pub fn is_enabled(&self) -> bool {
162 self.config.enabled
163 }
164
165 #[must_use]
171 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
172 if !self.config.enabled {
173 return None;
174 }
175
176 let gain = default_gain(call.tool_id.as_str());
177
178 let cost = if ctx.token_budget > 0 {
179 #[allow(clippy::cast_precision_loss)]
180 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
181 } else {
182 0.0
183 };
184
185 let hash = call_hash(call);
186 let redundancy = if self.recent_calls.contains_key(&hash) {
187 1.0_f32
188 } else {
189 0.0_f32
190 };
191
192 #[allow(clippy::cast_precision_loss)]
195 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
196
197 let total = self.config.gain_weight * gain
198 - self.config.cost_weight * cost
199 - self.config.redundancy_weight * redundancy
200 + self.config.uncertainty_bonus * uncertainty;
201
202 let score = UtilityScore {
203 gain,
204 cost,
205 redundancy,
206 uncertainty,
207 total,
208 };
209
210 if score.is_valid() { Some(score) } else { None }
211 }
212
213 #[must_use]
227 pub fn recommend_action(
228 &self,
229 score: Option<&UtilityScore>,
230 ctx: &UtilityContext,
231 ) -> UtilityAction {
232 if ctx.user_requested {
234 return UtilityAction::ToolCall;
235 }
236 if !self.config.enabled {
238 return UtilityAction::ToolCall;
239 }
240 let Some(s) = score else {
241 return UtilityAction::Stop;
243 };
244
245 if s.cost > 0.9 {
247 return UtilityAction::Stop;
248 }
249 if s.redundancy >= 1.0 {
251 return UtilityAction::Respond;
252 }
253 if s.gain >= 0.7 && s.total >= self.config.threshold {
255 return UtilityAction::ToolCall;
256 }
257 if s.gain >= 0.5 && s.uncertainty > 0.5 {
259 return UtilityAction::Retrieve;
260 }
261 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
263 return UtilityAction::Verify;
264 }
265 if s.total >= self.config.threshold {
267 return UtilityAction::ToolCall;
268 }
269 UtilityAction::Respond
270 }
271
272 pub fn record_call(&mut self, call: &ToolCall) {
277 let hash = call_hash(call);
278 *self.recent_calls.entry(hash).or_insert(0) += 1;
279 }
280
281 pub fn clear(&mut self) {
283 self.recent_calls.clear();
284 }
285
286 #[must_use]
290 pub fn is_exempt(&self, tool_name: &str) -> bool {
291 let lower = tool_name.to_lowercase();
292 self.config
293 .exempt_tools
294 .iter()
295 .any(|e| e.to_lowercase() == lower)
296 }
297
298 #[must_use]
300 pub fn threshold(&self) -> f32 {
301 self.config.threshold
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::ToolName;
309 use serde_json::json;
310
311 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
312 ToolCall {
313 tool_id: ToolName::new(name),
314 params: if let serde_json::Value::Object(m) = params {
315 m
316 } else {
317 serde_json::Map::new()
318 },
319 caller_id: None,
320 }
321 }
322
323 fn default_ctx() -> UtilityContext {
324 UtilityContext {
325 tool_calls_this_turn: 0,
326 tokens_consumed: 0,
327 token_budget: 1000,
328 user_requested: false,
329 }
330 }
331
332 fn default_config() -> UtilityScoringConfig {
333 UtilityScoringConfig {
334 enabled: true,
335 ..UtilityScoringConfig::default()
336 }
337 }
338
339 #[test]
340 fn disabled_returns_none() {
341 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
342 assert!(!scorer.is_enabled());
343 let call = make_call("bash", json!({}));
344 let score = scorer.score(&call, &default_ctx());
345 assert!(score.is_none());
346 assert_eq!(
348 scorer.recommend_action(score.as_ref(), &default_ctx()),
349 UtilityAction::ToolCall
350 );
351 }
352
353 #[test]
354 fn first_call_passes_default_threshold() {
355 let scorer = UtilityScorer::new(default_config());
356 let call = make_call("bash", json!({"cmd": "ls"}));
357 let score = scorer.score(&call, &default_ctx());
358 assert!(score.is_some());
359 let s = score.unwrap();
360 assert!(
361 s.total >= 0.1,
362 "first call should exceed threshold: {}",
363 s.total
364 );
365 let action = scorer.recommend_action(Some(&s), &default_ctx());
368 assert!(
369 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
370 "first call should not be blocked, got {action:?}",
371 );
372 }
373
374 #[test]
375 fn redundant_call_penalized() {
376 let mut scorer = UtilityScorer::new(default_config());
377 let call = make_call("bash", json!({"cmd": "ls"}));
378 scorer.record_call(&call);
379 let score = scorer.score(&call, &default_ctx()).unwrap();
380 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
381 }
382
383 #[test]
384 fn clear_resets_redundancy() {
385 let mut scorer = UtilityScorer::new(default_config());
386 let call = make_call("bash", json!({"cmd": "ls"}));
387 scorer.record_call(&call);
388 scorer.clear();
389 let score = scorer.score(&call, &default_ctx()).unwrap();
390 assert!(score.redundancy.abs() < f32::EPSILON);
391 }
392
393 #[test]
394 fn user_requested_always_executes() {
395 let scorer = UtilityScorer::new(default_config());
396 let score = UtilityScore {
398 gain: 0.0,
399 cost: 1.0,
400 redundancy: 1.0,
401 uncertainty: 0.0,
402 total: -100.0,
403 };
404 let ctx = UtilityContext {
405 user_requested: true,
406 ..default_ctx()
407 };
408 assert_eq!(
409 scorer.recommend_action(Some(&score), &ctx),
410 UtilityAction::ToolCall
411 );
412 }
413
414 #[test]
415 fn none_score_fail_closed_when_enabled() {
416 let scorer = UtilityScorer::new(default_config());
417 assert_eq!(
419 scorer.recommend_action(None, &default_ctx()),
420 UtilityAction::Stop
421 );
422 }
423
424 #[test]
425 fn none_score_executes_when_disabled() {
426 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
428 scorer.recommend_action(None, &default_ctx()),
429 UtilityAction::ToolCall
430 );
431 }
432
433 #[test]
434 fn cost_increases_with_token_consumption() {
435 let scorer = UtilityScorer::new(default_config());
436 let call = make_call("bash", json!({}));
437 let ctx_low = UtilityContext {
438 tokens_consumed: 100,
439 token_budget: 1000,
440 ..default_ctx()
441 };
442 let ctx_high = UtilityContext {
443 tokens_consumed: 900,
444 token_budget: 1000,
445 ..default_ctx()
446 };
447 let s_low = scorer.score(&call, &ctx_low).unwrap();
448 let s_high = scorer.score(&call, &ctx_high).unwrap();
449 assert!(s_low.cost < s_high.cost);
450 assert!(s_low.total > s_high.total);
451 }
452
453 #[test]
454 fn uncertainty_decreases_with_call_count() {
455 let scorer = UtilityScorer::new(default_config());
456 let call = make_call("bash", json!({}));
457 let ctx_early = UtilityContext {
458 tool_calls_this_turn: 0,
459 ..default_ctx()
460 };
461 let ctx_late = UtilityContext {
462 tool_calls_this_turn: 9,
463 ..default_ctx()
464 };
465 let s_early = scorer.score(&call, &ctx_early).unwrap();
466 let s_late = scorer.score(&call, &ctx_late).unwrap();
467 assert!(s_early.uncertainty > s_late.uncertainty);
468 }
469
470 #[test]
471 fn memory_tool_has_higher_gain_than_scrape() {
472 let scorer = UtilityScorer::new(default_config());
473 let mem_call = make_call("memory_search", json!({}));
474 let web_call = make_call("scrape", json!({}));
475 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
476 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
477 assert!(s_mem.gain > s_web.gain);
478 }
479
480 #[test]
481 fn zero_token_budget_zeroes_cost() {
482 let scorer = UtilityScorer::new(default_config());
483 let call = make_call("bash", json!({}));
484 let ctx = UtilityContext {
485 tokens_consumed: 500,
486 token_budget: 0,
487 ..default_ctx()
488 };
489 let s = scorer.score(&call, &ctx).unwrap();
490 assert!(s.cost.abs() < f32::EPSILON);
491 }
492
493 #[test]
494 fn validate_rejects_negative_weights() {
495 let cfg = UtilityScoringConfig {
496 enabled: true,
497 gain_weight: -1.0,
498 ..UtilityScoringConfig::default()
499 };
500 assert!(cfg.validate().is_err());
501 }
502
503 #[test]
504 fn validate_rejects_nan_weights() {
505 let cfg = UtilityScoringConfig {
506 enabled: true,
507 threshold: f32::NAN,
508 ..UtilityScoringConfig::default()
509 };
510 assert!(cfg.validate().is_err());
511 }
512
513 #[test]
514 fn validate_accepts_default() {
515 assert!(UtilityScoringConfig::default().validate().is_ok());
516 }
517
518 #[test]
519 fn threshold_zero_all_calls_pass() {
520 let scorer = UtilityScorer::new(UtilityScoringConfig {
522 enabled: true,
523 threshold: 0.0,
524 ..UtilityScoringConfig::default()
525 });
526 let call = make_call("bash", json!({}));
527 let score = scorer.score(&call, &default_ctx()).unwrap();
528 assert!(
530 score.total >= 0.0,
531 "total should be non-negative: {}",
532 score.total
533 );
534 let action = scorer.recommend_action(Some(&score), &default_ctx());
536 assert!(
537 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
538 "threshold=0 should not block calls, got {action:?}",
539 );
540 }
541
542 #[test]
543 fn threshold_one_blocks_all_calls() {
544 let scorer = UtilityScorer::new(UtilityScoringConfig {
546 enabled: true,
547 threshold: 1.0,
548 ..UtilityScoringConfig::default()
549 });
550 let call = make_call("bash", json!({}));
551 let score = scorer.score(&call, &default_ctx()).unwrap();
552 assert!(
553 score.total < 1.0,
554 "realistic score should be below 1.0: {}",
555 score.total
556 );
557 assert_ne!(
559 scorer.recommend_action(Some(&score), &default_ctx()),
560 UtilityAction::ToolCall
561 );
562 }
563
564 #[test]
567 fn recommend_action_user_requested_always_tool_call() {
568 let scorer = UtilityScorer::new(default_config());
569 let score = UtilityScore {
570 gain: 0.0,
571 cost: 1.0,
572 redundancy: 1.0,
573 uncertainty: 0.0,
574 total: -100.0,
575 };
576 let ctx = UtilityContext {
577 user_requested: true,
578 ..default_ctx()
579 };
580 assert_eq!(
581 scorer.recommend_action(Some(&score), &ctx),
582 UtilityAction::ToolCall
583 );
584 }
585
586 #[test]
587 fn recommend_action_disabled_scorer_always_tool_call() {
588 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
590 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
591 }
592
593 #[test]
594 fn recommend_action_none_score_enabled_stops() {
595 let scorer = UtilityScorer::new(default_config());
596 let ctx = default_ctx();
597 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
598 }
599
600 #[test]
601 fn recommend_action_budget_exhausted_stops() {
602 let scorer = UtilityScorer::new(default_config());
603 let score = UtilityScore {
604 gain: 0.8,
605 cost: 0.95,
606 redundancy: 0.0,
607 uncertainty: 0.5,
608 total: 0.5,
609 };
610 assert_eq!(
611 scorer.recommend_action(Some(&score), &default_ctx()),
612 UtilityAction::Stop
613 );
614 }
615
616 #[test]
617 fn recommend_action_redundant_responds() {
618 let scorer = UtilityScorer::new(default_config());
619 let score = UtilityScore {
620 gain: 0.8,
621 cost: 0.1,
622 redundancy: 1.0,
623 uncertainty: 0.5,
624 total: 0.5,
625 };
626 assert_eq!(
627 scorer.recommend_action(Some(&score), &default_ctx()),
628 UtilityAction::Respond
629 );
630 }
631
632 #[test]
633 fn recommend_action_high_gain_above_threshold_tool_call() {
634 let scorer = UtilityScorer::new(default_config());
635 let score = UtilityScore {
636 gain: 0.8,
637 cost: 0.1,
638 redundancy: 0.0,
639 uncertainty: 0.4,
640 total: 0.6,
641 };
642 assert_eq!(
643 scorer.recommend_action(Some(&score), &default_ctx()),
644 UtilityAction::ToolCall
645 );
646 }
647
648 #[test]
649 fn recommend_action_uncertain_retrieves() {
650 let scorer = UtilityScorer::new(default_config());
651 let score = UtilityScore {
653 gain: 0.6,
654 cost: 0.1,
655 redundancy: 0.0,
656 uncertainty: 0.8,
657 total: 0.4,
658 };
659 assert_eq!(
660 scorer.recommend_action(Some(&score), &default_ctx()),
661 UtilityAction::Retrieve
662 );
663 }
664
665 #[test]
666 fn recommend_action_below_threshold_with_prior_calls_verifies() {
667 let scorer = UtilityScorer::new(default_config());
668 let score = UtilityScore {
669 gain: 0.3,
670 cost: 0.1,
671 redundancy: 0.0,
672 uncertainty: 0.2,
673 total: 0.05, };
675 let ctx = UtilityContext {
676 tool_calls_this_turn: 1,
677 ..default_ctx()
678 };
679 assert_eq!(
680 scorer.recommend_action(Some(&score), &ctx),
681 UtilityAction::Verify
682 );
683 }
684
685 #[test]
686 fn recommend_action_default_responds() {
687 let scorer = UtilityScorer::new(default_config());
688 let score = UtilityScore {
689 gain: 0.3,
690 cost: 0.1,
691 redundancy: 0.0,
692 uncertainty: 0.2,
693 total: 0.05, };
695 let ctx = UtilityContext {
696 tool_calls_this_turn: 0,
697 ..default_ctx()
698 };
699 assert_eq!(
700 scorer.recommend_action(Some(&score), &ctx),
701 UtilityAction::Respond
702 );
703 }
704
705 #[test]
708 fn explicit_request_using_a_tool() {
709 assert!(has_explicit_tool_request(
710 "Please list the files in the current directory using a tool"
711 ));
712 }
713
714 #[test]
715 fn explicit_request_call_the_tool() {
716 assert!(has_explicit_tool_request("call the list_directory tool"));
717 }
718
719 #[test]
720 fn explicit_request_use_the_tool() {
721 assert!(has_explicit_tool_request("use the shell tool to run ls"));
722 }
723
724 #[test]
725 fn explicit_request_run_the_tool() {
726 assert!(has_explicit_tool_request("run the bash tool"));
727 }
728
729 #[test]
730 fn explicit_request_invoke_the_tool() {
731 assert!(has_explicit_tool_request("invoke the search_code tool"));
732 }
733
734 #[test]
735 fn explicit_request_execute_the_tool() {
736 assert!(has_explicit_tool_request("execute the grep tool for me"));
737 }
738
739 #[test]
740 fn explicit_request_case_insensitive() {
741 assert!(has_explicit_tool_request("USING A TOOL to find files"));
742 }
743
744 #[test]
745 fn explicit_request_no_match_plain_message() {
746 assert!(!has_explicit_tool_request("what is the weather today?"));
747 }
748
749 #[test]
750 fn explicit_request_no_match_tool_mentioned_without_invocation() {
751 assert!(!has_explicit_tool_request(
752 "the shell tool is very useful in general"
753 ));
754 }
755
756 #[test]
757 fn explicit_request_show_me_result_of() {
758 assert!(has_explicit_tool_request(
759 "show me the result of: echo hello"
760 ));
761 }
762
763 #[test]
764 fn explicit_request_run_colon() {
765 assert!(has_explicit_tool_request("run: echo hello"));
766 }
767
768 #[test]
769 fn explicit_request_execute_colon() {
770 assert!(has_explicit_tool_request("execute: ls -la"));
771 }
772
773 #[test]
774 fn explicit_request_what_does() {
775 assert!(has_explicit_tool_request("what does echo hello output?"));
776 }
777
778 #[test]
779 fn explicit_request_what_would() {
780 assert!(has_explicit_tool_request("what would cat /etc/hosts show?"));
781 }
782
783 #[test]
784 fn explicit_request_what_is_the_output_of() {
785 assert!(has_explicit_tool_request(
786 "what is the output of ls | grep foo?"
787 ));
788 }
789
790 #[test]
791 fn explicit_request_inline_code_pipe() {
792 assert!(has_explicit_tool_request("try running `ls | grep foo`"));
793 }
794
795 #[test]
796 fn explicit_request_inline_code_redirect() {
797 assert!(has_explicit_tool_request("run `echo hello > /tmp/out`"));
798 }
799
800 #[test]
801 fn explicit_request_inline_code_dollar() {
802 assert!(has_explicit_tool_request("check `$HOME/bin`"));
803 }
804
805 #[test]
806 fn explicit_request_inline_code_and() {
807 assert!(has_explicit_tool_request("try `git fetch && git rebase`"));
808 }
809
810 #[test]
811 fn no_match_run_the_tests() {
812 assert!(!has_explicit_tool_request("run the tests please"));
813 }
814
815 #[test]
816 fn no_match_execute_the_plan() {
817 assert!(!has_explicit_tool_request("execute the plan we discussed"));
818 }
819
820 #[test]
821 fn no_match_inline_code_no_shell_syntax() {
822 assert!(!has_explicit_tool_request(
823 "the function `process_items` handles it"
824 ));
825 }
826
827 #[test]
832 fn known_fp_what_does_function_do() {
833 assert!(has_explicit_tool_request("what does this function do?"));
835 }
836
837 #[test]
838 fn no_match_show_me_result_without_colon() {
839 assert!(!has_explicit_tool_request(
841 "show me the result of running it"
842 ));
843 }
844
845 #[test]
846 fn is_exempt_matches_case_insensitively() {
847 let scorer = UtilityScorer::new(UtilityScoringConfig {
848 enabled: true,
849 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
850 ..UtilityScoringConfig::default()
851 });
852 assert!(scorer.is_exempt("read"));
853 assert!(scorer.is_exempt("READ"));
854 assert!(scorer.is_exempt("FILE_READ"));
855 assert!(!scorer.is_exempt("write"));
856 assert!(!scorer.is_exempt("bash"));
857 }
858
859 #[test]
860 fn is_exempt_empty_list_returns_false() {
861 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
862 assert!(!scorer.is_exempt("read"));
863 }
864}