1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 ",
38 )
39 .expect("static regex is valid")
40 });
41 RE.is_match(user_message)
42}
43
44fn default_gain(tool_name: &str) -> f32 {
49 if tool_name.starts_with("memory") {
50 return 0.8;
51 }
52 if tool_name.starts_with("mcp_") {
53 return 0.5;
54 }
55 match tool_name {
56 "bash" | "shell" => 0.6,
57 "read" | "write" => 0.55,
58 "search_code" | "grep" | "glob" => 0.65,
59 _ => 0.5,
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct UtilityScore {
66 pub gain: f32,
68 pub cost: f32,
70 pub redundancy: f32,
72 pub uncertainty: f32,
74 pub total: f32,
76}
77
78impl UtilityScore {
79 fn is_valid(&self) -> bool {
81 self.gain.is_finite()
82 && self.cost.is_finite()
83 && self.redundancy.is_finite()
84 && self.uncertainty.is_finite()
85 && self.total.is_finite()
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct UtilityContext {
92 pub tool_calls_this_turn: usize,
94 pub tokens_consumed: usize,
96 pub token_budget: usize,
98 pub user_requested: bool,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum UtilityAction {
107 Respond,
109 Retrieve,
111 ToolCall,
113 Verify,
115 Stop,
117}
118
119fn call_hash(call: &ToolCall) -> u64 {
121 let mut h = DefaultHasher::new();
122 call.tool_id.hash(&mut h);
123 format!("{:?}", call.params).hash(&mut h);
127 h.finish()
128}
129
130#[derive(Debug)]
135pub struct UtilityScorer {
136 config: UtilityScoringConfig,
137 recent_calls: HashMap<u64, u32>,
139}
140
141impl UtilityScorer {
142 #[must_use]
144 pub fn new(config: UtilityScoringConfig) -> Self {
145 Self {
146 config,
147 recent_calls: HashMap::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn is_enabled(&self) -> bool {
154 self.config.enabled
155 }
156
157 #[must_use]
163 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
164 if !self.config.enabled {
165 return None;
166 }
167
168 let gain = default_gain(&call.tool_id);
169
170 let cost = if ctx.token_budget > 0 {
171 #[allow(clippy::cast_precision_loss)]
172 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
173 } else {
174 0.0
175 };
176
177 let hash = call_hash(call);
178 let redundancy = if self.recent_calls.contains_key(&hash) {
179 1.0_f32
180 } else {
181 0.0_f32
182 };
183
184 #[allow(clippy::cast_precision_loss)]
187 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
188
189 let total = self.config.gain_weight * gain
190 - self.config.cost_weight * cost
191 - self.config.redundancy_weight * redundancy
192 + self.config.uncertainty_bonus * uncertainty;
193
194 let score = UtilityScore {
195 gain,
196 cost,
197 redundancy,
198 uncertainty,
199 total,
200 };
201
202 if score.is_valid() { Some(score) } else { None }
203 }
204
205 #[must_use]
219 pub fn recommend_action(
220 &self,
221 score: Option<&UtilityScore>,
222 ctx: &UtilityContext,
223 ) -> UtilityAction {
224 if ctx.user_requested {
226 return UtilityAction::ToolCall;
227 }
228 if !self.config.enabled {
230 return UtilityAction::ToolCall;
231 }
232 let Some(s) = score else {
233 return UtilityAction::Stop;
235 };
236
237 if s.cost > 0.9 {
239 return UtilityAction::Stop;
240 }
241 if s.redundancy >= 1.0 {
243 return UtilityAction::Respond;
244 }
245 if s.gain >= 0.7 && s.total >= self.config.threshold {
247 return UtilityAction::ToolCall;
248 }
249 if s.gain >= 0.5 && s.uncertainty > 0.5 {
251 return UtilityAction::Retrieve;
252 }
253 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
255 return UtilityAction::Verify;
256 }
257 if s.total >= self.config.threshold {
259 return UtilityAction::ToolCall;
260 }
261 UtilityAction::Respond
262 }
263
264 pub fn record_call(&mut self, call: &ToolCall) {
269 let hash = call_hash(call);
270 *self.recent_calls.entry(hash).or_insert(0) += 1;
271 }
272
273 pub fn clear(&mut self) {
275 self.recent_calls.clear();
276 }
277
278 #[must_use]
282 pub fn is_exempt(&self, tool_name: &str) -> bool {
283 let lower = tool_name.to_lowercase();
284 self.config
285 .exempt_tools
286 .iter()
287 .any(|e| e.to_lowercase() == lower)
288 }
289
290 #[must_use]
292 pub fn threshold(&self) -> f32 {
293 self.config.threshold
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use serde_json::json;
301
302 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
303 ToolCall {
304 tool_id: name.to_owned(),
305 params: if let serde_json::Value::Object(m) = params {
306 m
307 } else {
308 serde_json::Map::new()
309 },
310 caller_id: None,
311 }
312 }
313
314 fn default_ctx() -> UtilityContext {
315 UtilityContext {
316 tool_calls_this_turn: 0,
317 tokens_consumed: 0,
318 token_budget: 1000,
319 user_requested: false,
320 }
321 }
322
323 fn default_config() -> UtilityScoringConfig {
324 UtilityScoringConfig {
325 enabled: true,
326 ..UtilityScoringConfig::default()
327 }
328 }
329
330 #[test]
331 fn disabled_returns_none() {
332 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
333 assert!(!scorer.is_enabled());
334 let call = make_call("bash", json!({}));
335 let score = scorer.score(&call, &default_ctx());
336 assert!(score.is_none());
337 assert_eq!(
339 scorer.recommend_action(score.as_ref(), &default_ctx()),
340 UtilityAction::ToolCall
341 );
342 }
343
344 #[test]
345 fn first_call_passes_default_threshold() {
346 let scorer = UtilityScorer::new(default_config());
347 let call = make_call("bash", json!({"cmd": "ls"}));
348 let score = scorer.score(&call, &default_ctx());
349 assert!(score.is_some());
350 let s = score.unwrap();
351 assert!(
352 s.total >= 0.1,
353 "first call should exceed threshold: {}",
354 s.total
355 );
356 let action = scorer.recommend_action(Some(&s), &default_ctx());
359 assert!(
360 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
361 "first call should not be blocked, got {action:?}",
362 );
363 }
364
365 #[test]
366 fn redundant_call_penalized() {
367 let mut scorer = UtilityScorer::new(default_config());
368 let call = make_call("bash", json!({"cmd": "ls"}));
369 scorer.record_call(&call);
370 let score = scorer.score(&call, &default_ctx()).unwrap();
371 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
372 }
373
374 #[test]
375 fn clear_resets_redundancy() {
376 let mut scorer = UtilityScorer::new(default_config());
377 let call = make_call("bash", json!({"cmd": "ls"}));
378 scorer.record_call(&call);
379 scorer.clear();
380 let score = scorer.score(&call, &default_ctx()).unwrap();
381 assert!(score.redundancy.abs() < f32::EPSILON);
382 }
383
384 #[test]
385 fn user_requested_always_executes() {
386 let scorer = UtilityScorer::new(default_config());
387 let score = UtilityScore {
389 gain: 0.0,
390 cost: 1.0,
391 redundancy: 1.0,
392 uncertainty: 0.0,
393 total: -100.0,
394 };
395 let ctx = UtilityContext {
396 user_requested: true,
397 ..default_ctx()
398 };
399 assert_eq!(
400 scorer.recommend_action(Some(&score), &ctx),
401 UtilityAction::ToolCall
402 );
403 }
404
405 #[test]
406 fn none_score_fail_closed_when_enabled() {
407 let scorer = UtilityScorer::new(default_config());
408 assert_eq!(
410 scorer.recommend_action(None, &default_ctx()),
411 UtilityAction::Stop
412 );
413 }
414
415 #[test]
416 fn none_score_executes_when_disabled() {
417 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
419 scorer.recommend_action(None, &default_ctx()),
420 UtilityAction::ToolCall
421 );
422 }
423
424 #[test]
425 fn cost_increases_with_token_consumption() {
426 let scorer = UtilityScorer::new(default_config());
427 let call = make_call("bash", json!({}));
428 let ctx_low = UtilityContext {
429 tokens_consumed: 100,
430 token_budget: 1000,
431 ..default_ctx()
432 };
433 let ctx_high = UtilityContext {
434 tokens_consumed: 900,
435 token_budget: 1000,
436 ..default_ctx()
437 };
438 let s_low = scorer.score(&call, &ctx_low).unwrap();
439 let s_high = scorer.score(&call, &ctx_high).unwrap();
440 assert!(s_low.cost < s_high.cost);
441 assert!(s_low.total > s_high.total);
442 }
443
444 #[test]
445 fn uncertainty_decreases_with_call_count() {
446 let scorer = UtilityScorer::new(default_config());
447 let call = make_call("bash", json!({}));
448 let ctx_early = UtilityContext {
449 tool_calls_this_turn: 0,
450 ..default_ctx()
451 };
452 let ctx_late = UtilityContext {
453 tool_calls_this_turn: 9,
454 ..default_ctx()
455 };
456 let s_early = scorer.score(&call, &ctx_early).unwrap();
457 let s_late = scorer.score(&call, &ctx_late).unwrap();
458 assert!(s_early.uncertainty > s_late.uncertainty);
459 }
460
461 #[test]
462 fn memory_tool_has_higher_gain_than_scrape() {
463 let scorer = UtilityScorer::new(default_config());
464 let mem_call = make_call("memory_search", json!({}));
465 let web_call = make_call("scrape", json!({}));
466 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
467 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
468 assert!(s_mem.gain > s_web.gain);
469 }
470
471 #[test]
472 fn zero_token_budget_zeroes_cost() {
473 let scorer = UtilityScorer::new(default_config());
474 let call = make_call("bash", json!({}));
475 let ctx = UtilityContext {
476 tokens_consumed: 500,
477 token_budget: 0,
478 ..default_ctx()
479 };
480 let s = scorer.score(&call, &ctx).unwrap();
481 assert!(s.cost.abs() < f32::EPSILON);
482 }
483
484 #[test]
485 fn validate_rejects_negative_weights() {
486 let cfg = UtilityScoringConfig {
487 enabled: true,
488 gain_weight: -1.0,
489 ..UtilityScoringConfig::default()
490 };
491 assert!(cfg.validate().is_err());
492 }
493
494 #[test]
495 fn validate_rejects_nan_weights() {
496 let cfg = UtilityScoringConfig {
497 enabled: true,
498 threshold: f32::NAN,
499 ..UtilityScoringConfig::default()
500 };
501 assert!(cfg.validate().is_err());
502 }
503
504 #[test]
505 fn validate_accepts_default() {
506 assert!(UtilityScoringConfig::default().validate().is_ok());
507 }
508
509 #[test]
510 fn threshold_zero_all_calls_pass() {
511 let scorer = UtilityScorer::new(UtilityScoringConfig {
513 enabled: true,
514 threshold: 0.0,
515 ..UtilityScoringConfig::default()
516 });
517 let call = make_call("bash", json!({}));
518 let score = scorer.score(&call, &default_ctx()).unwrap();
519 assert!(
521 score.total >= 0.0,
522 "total should be non-negative: {}",
523 score.total
524 );
525 let action = scorer.recommend_action(Some(&score), &default_ctx());
527 assert!(
528 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
529 "threshold=0 should not block calls, got {action:?}",
530 );
531 }
532
533 #[test]
534 fn threshold_one_blocks_all_calls() {
535 let scorer = UtilityScorer::new(UtilityScoringConfig {
537 enabled: true,
538 threshold: 1.0,
539 ..UtilityScoringConfig::default()
540 });
541 let call = make_call("bash", json!({}));
542 let score = scorer.score(&call, &default_ctx()).unwrap();
543 assert!(
544 score.total < 1.0,
545 "realistic score should be below 1.0: {}",
546 score.total
547 );
548 assert_ne!(
550 scorer.recommend_action(Some(&score), &default_ctx()),
551 UtilityAction::ToolCall
552 );
553 }
554
555 #[test]
558 fn recommend_action_user_requested_always_tool_call() {
559 let scorer = UtilityScorer::new(default_config());
560 let score = UtilityScore {
561 gain: 0.0,
562 cost: 1.0,
563 redundancy: 1.0,
564 uncertainty: 0.0,
565 total: -100.0,
566 };
567 let ctx = UtilityContext {
568 user_requested: true,
569 ..default_ctx()
570 };
571 assert_eq!(
572 scorer.recommend_action(Some(&score), &ctx),
573 UtilityAction::ToolCall
574 );
575 }
576
577 #[test]
578 fn recommend_action_disabled_scorer_always_tool_call() {
579 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
581 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
582 }
583
584 #[test]
585 fn recommend_action_none_score_enabled_stops() {
586 let scorer = UtilityScorer::new(default_config());
587 let ctx = default_ctx();
588 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
589 }
590
591 #[test]
592 fn recommend_action_budget_exhausted_stops() {
593 let scorer = UtilityScorer::new(default_config());
594 let score = UtilityScore {
595 gain: 0.8,
596 cost: 0.95,
597 redundancy: 0.0,
598 uncertainty: 0.5,
599 total: 0.5,
600 };
601 assert_eq!(
602 scorer.recommend_action(Some(&score), &default_ctx()),
603 UtilityAction::Stop
604 );
605 }
606
607 #[test]
608 fn recommend_action_redundant_responds() {
609 let scorer = UtilityScorer::new(default_config());
610 let score = UtilityScore {
611 gain: 0.8,
612 cost: 0.1,
613 redundancy: 1.0,
614 uncertainty: 0.5,
615 total: 0.5,
616 };
617 assert_eq!(
618 scorer.recommend_action(Some(&score), &default_ctx()),
619 UtilityAction::Respond
620 );
621 }
622
623 #[test]
624 fn recommend_action_high_gain_above_threshold_tool_call() {
625 let scorer = UtilityScorer::new(default_config());
626 let score = UtilityScore {
627 gain: 0.8,
628 cost: 0.1,
629 redundancy: 0.0,
630 uncertainty: 0.4,
631 total: 0.6,
632 };
633 assert_eq!(
634 scorer.recommend_action(Some(&score), &default_ctx()),
635 UtilityAction::ToolCall
636 );
637 }
638
639 #[test]
640 fn recommend_action_uncertain_retrieves() {
641 let scorer = UtilityScorer::new(default_config());
642 let score = UtilityScore {
644 gain: 0.6,
645 cost: 0.1,
646 redundancy: 0.0,
647 uncertainty: 0.8,
648 total: 0.4,
649 };
650 assert_eq!(
651 scorer.recommend_action(Some(&score), &default_ctx()),
652 UtilityAction::Retrieve
653 );
654 }
655
656 #[test]
657 fn recommend_action_below_threshold_with_prior_calls_verifies() {
658 let scorer = UtilityScorer::new(default_config());
659 let score = UtilityScore {
660 gain: 0.3,
661 cost: 0.1,
662 redundancy: 0.0,
663 uncertainty: 0.2,
664 total: 0.05, };
666 let ctx = UtilityContext {
667 tool_calls_this_turn: 1,
668 ..default_ctx()
669 };
670 assert_eq!(
671 scorer.recommend_action(Some(&score), &ctx),
672 UtilityAction::Verify
673 );
674 }
675
676 #[test]
677 fn recommend_action_default_responds() {
678 let scorer = UtilityScorer::new(default_config());
679 let score = UtilityScore {
680 gain: 0.3,
681 cost: 0.1,
682 redundancy: 0.0,
683 uncertainty: 0.2,
684 total: 0.05, };
686 let ctx = UtilityContext {
687 tool_calls_this_turn: 0,
688 ..default_ctx()
689 };
690 assert_eq!(
691 scorer.recommend_action(Some(&score), &ctx),
692 UtilityAction::Respond
693 );
694 }
695
696 #[test]
699 fn explicit_request_using_a_tool() {
700 assert!(has_explicit_tool_request(
701 "Please list the files in the current directory using a tool"
702 ));
703 }
704
705 #[test]
706 fn explicit_request_call_the_tool() {
707 assert!(has_explicit_tool_request("call the list_directory tool"));
708 }
709
710 #[test]
711 fn explicit_request_use_the_tool() {
712 assert!(has_explicit_tool_request("use the shell tool to run ls"));
713 }
714
715 #[test]
716 fn explicit_request_run_the_tool() {
717 assert!(has_explicit_tool_request("run the bash tool"));
718 }
719
720 #[test]
721 fn explicit_request_invoke_the_tool() {
722 assert!(has_explicit_tool_request("invoke the search_code tool"));
723 }
724
725 #[test]
726 fn explicit_request_execute_the_tool() {
727 assert!(has_explicit_tool_request("execute the grep tool for me"));
728 }
729
730 #[test]
731 fn explicit_request_case_insensitive() {
732 assert!(has_explicit_tool_request("USING A TOOL to find files"));
733 }
734
735 #[test]
736 fn explicit_request_no_match_plain_message() {
737 assert!(!has_explicit_tool_request("what is the weather today?"));
738 }
739
740 #[test]
741 fn explicit_request_no_match_tool_mentioned_without_invocation() {
742 assert!(!has_explicit_tool_request(
743 "the shell tool is very useful in general"
744 ));
745 }
746
747 #[test]
748 fn is_exempt_matches_case_insensitively() {
749 let scorer = UtilityScorer::new(UtilityScoringConfig {
750 enabled: true,
751 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
752 ..UtilityScoringConfig::default()
753 });
754 assert!(scorer.is_exempt("read"));
755 assert!(scorer.is_exempt("READ"));
756 assert!(scorer.is_exempt("FILE_READ"));
757 assert!(!scorer.is_exempt("write"));
758 assert!(!scorer.is_exempt("bash"));
759 }
760
761 #[test]
762 fn is_exempt_empty_list_returns_false() {
763 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
764 assert!(!scorer.is_exempt("read"));
765 }
766}