Skip to main content

zeph_tools/
utility.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Utility-guided tool dispatch gate (arXiv:2603.19896).
5//!
6//! Computes a scalar utility score for each candidate tool call before execution.
7//! Calls below the configured threshold are skipped (fail-closed on scoring errors).
8
9use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11
12use crate::config::UtilityScoringConfig;
13use crate::executor::ToolCall;
14
15/// Estimated gain for known tool categories.
16///
17/// Keys are exact tool name prefixes or names. Higher value = more expected gain.
18/// Unknown tools default to 0.5 (neutral).
19fn default_gain(tool_name: &str) -> f32 {
20    if tool_name.starts_with("memory") {
21        return 0.8;
22    }
23    if tool_name.starts_with("mcp_") {
24        return 0.5;
25    }
26    match tool_name {
27        "bash" | "shell" => 0.6,
28        "read" | "write" => 0.55,
29        "search_code" | "grep" | "glob" => 0.65,
30        _ => 0.5,
31    }
32}
33
34/// Computed utility components for a candidate tool call.
35#[derive(Debug, Clone)]
36pub struct UtilityScore {
37    /// Estimated information gain from executing the tool.
38    pub gain: f32,
39    /// Normalized token cost: `tokens_consumed / token_budget`.
40    pub cost: f32,
41    /// Redundancy penalty: 1.0 if identical `(tool_name, params_hash)` was seen this turn.
42    pub redundancy: f32,
43    /// Exploration bonus: decreases as turn progresses (`1 - tool_calls_this_turn / max_calls`).
44    pub uncertainty: f32,
45    /// Weighted aggregate.
46    pub total: f32,
47}
48
49impl UtilityScore {
50    /// Returns `true` when the score components are all finite.
51    fn is_valid(&self) -> bool {
52        self.gain.is_finite()
53            && self.cost.is_finite()
54            && self.redundancy.is_finite()
55            && self.uncertainty.is_finite()
56            && self.total.is_finite()
57    }
58}
59
60/// Context required to compute utility — provided by the agent loop.
61#[derive(Debug, Clone)]
62pub struct UtilityContext {
63    /// Number of tool calls already dispatched in the current LLM turn.
64    pub tool_calls_this_turn: usize,
65    /// Tokens consumed so far in this turn.
66    pub tokens_consumed: usize,
67    /// Token budget for the current turn. 0 = budget unknown (cost component treated as 0).
68    pub token_budget: usize,
69    /// True only when the tool was explicitly invoked via a `/tool` slash command.
70    /// Must NOT be set based on tool names found inside user message text or tool outputs.
71    pub user_requested: bool,
72}
73
74/// Hashes `(tool_name, serialized_params)` pre-execution for redundancy detection.
75fn call_hash(call: &ToolCall) -> u64 {
76    let mut h = DefaultHasher::new();
77    call.tool_id.hash(&mut h);
78    // Stable iteration order is not guaranteed for serde_json::Map, but it is insertion-order
79    // in practice for the same LLM output. Using the debug representation is simple and
80    // deterministic within a session (no cross-session persistence of these hashes).
81    format!("{:?}", call.params).hash(&mut h);
82    h.finish()
83}
84
85/// Computes utility scores for tool calls before dispatch.
86///
87/// Not `Send + Sync` — lives on the agent's single-threaded tool loop (same lifecycle as
88/// `ToolResultCache` and `recent_tool_calls`).
89#[derive(Debug)]
90pub struct UtilityScorer {
91    config: UtilityScoringConfig,
92    /// Hashes of `(tool_name, params)` seen in the current LLM turn for redundancy detection.
93    recent_calls: HashMap<u64, u32>,
94}
95
96impl UtilityScorer {
97    /// Create a new scorer from the given config.
98    #[must_use]
99    pub fn new(config: UtilityScoringConfig) -> Self {
100        Self {
101            config,
102            recent_calls: HashMap::new(),
103        }
104    }
105
106    /// Whether utility scoring is enabled.
107    #[must_use]
108    pub fn is_enabled(&self) -> bool {
109        self.config.enabled
110    }
111
112    /// Score a candidate tool call.
113    ///
114    /// Returns `None` when scoring is disabled. When scoring produces a non-finite
115    /// result (misconfigured weights), returns `None` — the caller treats `None` as
116    /// fail-closed (skip the tool call) unless `user_requested` is set.
117    #[must_use]
118    pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
119        if !self.config.enabled {
120            return None;
121        }
122
123        let gain = default_gain(&call.tool_id);
124
125        let cost = if ctx.token_budget > 0 {
126            #[allow(clippy::cast_precision_loss)]
127            (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
128        } else {
129            0.0
130        };
131
132        let hash = call_hash(call);
133        let redundancy = if self.recent_calls.contains_key(&hash) {
134            1.0_f32
135        } else {
136            0.0_f32
137        };
138
139        // Uncertainty decreases as turn progresses. At tool call 0 it equals 1.0;
140        // at tool_calls_this_turn >= 10 it saturates to 0.0.
141        #[allow(clippy::cast_precision_loss)]
142        let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
143
144        let total = self.config.gain_weight * gain
145            - self.config.cost_weight * cost
146            - self.config.redundancy_weight * redundancy
147            + self.config.uncertainty_bonus * uncertainty;
148
149        let score = UtilityScore {
150            gain,
151            cost,
152            redundancy,
153            uncertainty,
154            total,
155        };
156
157        if score.is_valid() { Some(score) } else { None }
158    }
159
160    /// Returns `true` when the tool call should be executed based on its score.
161    ///
162    /// `user_requested` tools bypass the gate unconditionally.
163    /// When `score` is `None` (scoring disabled or produced invalid result) and
164    /// `user_requested` is false, the tool is skipped (fail-closed).
165    #[must_use]
166    pub fn should_execute(&self, score: Option<&UtilityScore>, user_requested: bool) -> bool {
167        if user_requested {
168            return true;
169        }
170        match score {
171            Some(s) => s.total >= self.config.threshold,
172            // Scoring disabled → always execute.
173            // Scoring produced invalid result → fail-closed: skip.
174            None if !self.config.enabled => true,
175            None => false,
176        }
177    }
178
179    /// Record a call as executed for redundancy tracking.
180    ///
181    /// Must be called after `score()` and before the next call to `score()` for the
182    /// same tool in the same turn.
183    pub fn record_call(&mut self, call: &ToolCall) {
184        let hash = call_hash(call);
185        *self.recent_calls.entry(hash).or_insert(0) += 1;
186    }
187
188    /// Reset per-turn state. Call at the start of each LLM tool round.
189    pub fn clear(&mut self) {
190        self.recent_calls.clear();
191    }
192
193    /// The configured threshold.
194    #[must_use]
195    pub fn threshold(&self) -> f32 {
196        self.config.threshold
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use serde_json::json;
204
205    fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
206        ToolCall {
207            tool_id: name.to_owned(),
208            params: if let serde_json::Value::Object(m) = params {
209                m
210            } else {
211                serde_json::Map::new()
212            },
213        }
214    }
215
216    fn default_ctx() -> UtilityContext {
217        UtilityContext {
218            tool_calls_this_turn: 0,
219            tokens_consumed: 0,
220            token_budget: 1000,
221            user_requested: false,
222        }
223    }
224
225    fn default_config() -> UtilityScoringConfig {
226        UtilityScoringConfig {
227            enabled: true,
228            ..UtilityScoringConfig::default()
229        }
230    }
231
232    #[test]
233    fn disabled_returns_none() {
234        let scorer = UtilityScorer::new(UtilityScoringConfig::default());
235        assert!(!scorer.is_enabled());
236        let call = make_call("bash", json!({}));
237        let score = scorer.score(&call, &default_ctx());
238        assert!(score.is_none());
239        // When disabled, should_execute always returns true (never gated).
240        assert!(scorer.should_execute(score.as_ref(), false));
241    }
242
243    #[test]
244    fn first_call_passes_default_threshold() {
245        let scorer = UtilityScorer::new(default_config());
246        let call = make_call("bash", json!({"cmd": "ls"}));
247        let score = scorer.score(&call, &default_ctx());
248        assert!(score.is_some());
249        let s = score.unwrap();
250        assert!(
251            s.total >= 0.1,
252            "first call should exceed threshold: {}",
253            s.total
254        );
255        assert!(scorer.should_execute(Some(&s), false));
256    }
257
258    #[test]
259    fn redundant_call_penalized() {
260        let mut scorer = UtilityScorer::new(default_config());
261        let call = make_call("bash", json!({"cmd": "ls"}));
262        scorer.record_call(&call);
263        let score = scorer.score(&call, &default_ctx()).unwrap();
264        assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
265    }
266
267    #[test]
268    fn clear_resets_redundancy() {
269        let mut scorer = UtilityScorer::new(default_config());
270        let call = make_call("bash", json!({"cmd": "ls"}));
271        scorer.record_call(&call);
272        scorer.clear();
273        let score = scorer.score(&call, &default_ctx()).unwrap();
274        assert!(score.redundancy.abs() < f32::EPSILON);
275    }
276
277    #[test]
278    fn user_requested_always_executes() {
279        let scorer = UtilityScorer::new(default_config());
280        // Simulate a call that would score very low.
281        let score = UtilityScore {
282            gain: 0.0,
283            cost: 1.0,
284            redundancy: 1.0,
285            uncertainty: 0.0,
286            total: -100.0,
287        };
288        assert!(scorer.should_execute(Some(&score), true));
289    }
290
291    #[test]
292    fn none_score_fail_closed_when_enabled() {
293        let scorer = UtilityScorer::new(default_config());
294        // Simulate scoring failure (None with scoring enabled).
295        assert!(!scorer.should_execute(None, false));
296    }
297
298    #[test]
299    fn none_score_executes_when_disabled() {
300        let scorer = UtilityScorer::new(UtilityScoringConfig::default()); // disabled
301        assert!(scorer.should_execute(None, false));
302    }
303
304    #[test]
305    fn cost_increases_with_token_consumption() {
306        let scorer = UtilityScorer::new(default_config());
307        let call = make_call("bash", json!({}));
308        let ctx_low = UtilityContext {
309            tokens_consumed: 100,
310            token_budget: 1000,
311            ..default_ctx()
312        };
313        let ctx_high = UtilityContext {
314            tokens_consumed: 900,
315            token_budget: 1000,
316            ..default_ctx()
317        };
318        let s_low = scorer.score(&call, &ctx_low).unwrap();
319        let s_high = scorer.score(&call, &ctx_high).unwrap();
320        assert!(s_low.cost < s_high.cost);
321        assert!(s_low.total > s_high.total);
322    }
323
324    #[test]
325    fn uncertainty_decreases_with_call_count() {
326        let scorer = UtilityScorer::new(default_config());
327        let call = make_call("bash", json!({}));
328        let ctx_early = UtilityContext {
329            tool_calls_this_turn: 0,
330            ..default_ctx()
331        };
332        let ctx_late = UtilityContext {
333            tool_calls_this_turn: 9,
334            ..default_ctx()
335        };
336        let s_early = scorer.score(&call, &ctx_early).unwrap();
337        let s_late = scorer.score(&call, &ctx_late).unwrap();
338        assert!(s_early.uncertainty > s_late.uncertainty);
339    }
340
341    #[test]
342    fn memory_tool_has_higher_gain_than_scrape() {
343        let scorer = UtilityScorer::new(default_config());
344        let mem_call = make_call("memory_search", json!({}));
345        let web_call = make_call("scrape", json!({}));
346        let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
347        let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
348        assert!(s_mem.gain > s_web.gain);
349    }
350
351    #[test]
352    fn zero_token_budget_zeroes_cost() {
353        let scorer = UtilityScorer::new(default_config());
354        let call = make_call("bash", json!({}));
355        let ctx = UtilityContext {
356            tokens_consumed: 500,
357            token_budget: 0,
358            ..default_ctx()
359        };
360        let s = scorer.score(&call, &ctx).unwrap();
361        assert!(s.cost.abs() < f32::EPSILON);
362    }
363
364    #[test]
365    fn validate_rejects_negative_weights() {
366        let cfg = UtilityScoringConfig {
367            enabled: true,
368            gain_weight: -1.0,
369            ..UtilityScoringConfig::default()
370        };
371        assert!(cfg.validate().is_err());
372    }
373
374    #[test]
375    fn validate_rejects_nan_weights() {
376        let cfg = UtilityScoringConfig {
377            enabled: true,
378            threshold: f32::NAN,
379            ..UtilityScoringConfig::default()
380        };
381        assert!(cfg.validate().is_err());
382    }
383
384    #[test]
385    fn validate_accepts_default() {
386        assert!(UtilityScoringConfig::default().validate().is_ok());
387    }
388
389    #[test]
390    fn threshold_zero_all_calls_pass() {
391        // threshold=0.0: every call with a non-negative total should execute.
392        let scorer = UtilityScorer::new(UtilityScoringConfig {
393            enabled: true,
394            threshold: 0.0,
395            ..UtilityScoringConfig::default()
396        });
397        let call = make_call("bash", json!({}));
398        let score = scorer.score(&call, &default_ctx()).unwrap();
399        // total must be >= 0.0 for a fresh call with default weights.
400        assert!(
401            score.total >= 0.0,
402            "total should be non-negative: {}",
403            score.total
404        );
405        assert!(scorer.should_execute(Some(&score), false));
406    }
407
408    #[test]
409    fn threshold_one_blocks_all_calls() {
410        // threshold=1.0: realistic scores never reach 1.0, so every call is blocked.
411        let scorer = UtilityScorer::new(UtilityScoringConfig {
412            enabled: true,
413            threshold: 1.0,
414            ..UtilityScoringConfig::default()
415        });
416        let call = make_call("bash", json!({}));
417        let score = scorer.score(&call, &default_ctx()).unwrap();
418        assert!(
419            score.total < 1.0,
420            "realistic score should be below 1.0: {}",
421            score.total
422        );
423        assert!(!scorer.should_execute(Some(&score), false));
424    }
425}