Skip to main content

batuta/agent/
guard.rs

1//! Loop guard — prevents runaway agent loops (Jidoka pattern).
2//!
3//! Tracks iteration count, tool call hashes (FxHash ping-pong detection),
4//! cost budget, and consecutive `MaxTokens` responses. Verdicts: `Allow`,
5//! `Warn`, `Block`, or `CircuitBreak`. See arXiv:2512.10350 (Tacheny).
6
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9
10use super::result::TokenUsage;
11
12/// Prevents runaway agent loops (Jidoka pattern).
13pub struct LoopGuard {
14    max_iterations: u32,
15    current_iteration: u32,
16    max_tool_calls: u32,
17    total_tool_calls: u32,
18    tool_call_counts: HashMap<u64, u32>, // FxHash(tool,input) → count
19    consecutive_max_tokens: u32,
20    usage: TokenUsage,
21    max_cost_usd: f64, // 0.0 = unlimited (sovereign)
22    accumulated_cost_usd: f64,
23    max_tokens_budget: Option<u64>, // None = unlimited
24}
25
26/// Verdict from the loop guard on whether to proceed.
27#[derive(Debug, Clone, PartialEq)]
28pub enum LoopVerdict {
29    /// Proceed with execution.
30    Allow,
31    /// Proceed but warn (approaching limits).
32    Warn(String),
33    /// Block this specific tool call (repeated pattern).
34    Block(String),
35    /// Hard stop the entire loop.
36    CircuitBreak(String),
37}
38
39/// Configuration for ping-pong detection thresholds.
40const PINGPONG_THRESHOLD: u32 = 3;
41/// Maximum consecutive `MaxTokens` before circuit break.
42const MAX_CONSECUTIVE_TRUNCATION: u32 = 5;
43/// Warn when iteration count reaches this fraction of max.
44const WARN_ITERATION_FRACTION: f64 = 0.8;
45
46impl LoopGuard {
47    /// Create a new guard from resource quotas.
48    pub fn new(max_iterations: u32, max_tool_calls: u32, max_cost_usd: f64) -> Self {
49        Self {
50            max_iterations,
51            current_iteration: 0,
52            max_tool_calls,
53            total_tool_calls: 0,
54            tool_call_counts: HashMap::new(),
55            consecutive_max_tokens: 0,
56            usage: TokenUsage::default(),
57            max_cost_usd,
58            accumulated_cost_usd: 0.0,
59            max_tokens_budget: None,
60        }
61    }
62
63    /// Set the token budget (input+output cumulative limit).
64    pub fn with_token_budget(mut self, budget: Option<u64>) -> Self {
65        self.max_tokens_budget = budget;
66        self
67    }
68
69    /// Check if another iteration is allowed.
70    pub fn check_iteration(&mut self) -> LoopVerdict {
71        self.current_iteration += 1;
72
73        if self.current_iteration > self.max_iterations {
74            return LoopVerdict::CircuitBreak(format!(
75                "max iterations reached ({})",
76                self.max_iterations
77            ));
78        }
79
80        // Precision loss acceptable: max_iterations is small enough for f64
81        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_lossless)]
82        let threshold = (self.max_iterations as f64 * WARN_ITERATION_FRACTION) as u32;
83        if self.current_iteration >= threshold {
84            return LoopVerdict::Warn(format!(
85                "iteration {}/{} ({}% of budget)",
86                self.current_iteration,
87                self.max_iterations,
88                self.current_iteration * 100 / self.max_iterations
89            ));
90        }
91
92        LoopVerdict::Allow
93    }
94
95    /// Check if a tool call is allowed (ping-pong detection).
96    pub fn check_tool_call(&mut self, tool_name: &str, input: &serde_json::Value) -> LoopVerdict {
97        self.total_tool_calls += 1;
98
99        if self.total_tool_calls > self.max_tool_calls {
100            return LoopVerdict::CircuitBreak(format!(
101                "max tool calls reached ({})",
102                self.max_tool_calls
103            ));
104        }
105
106        let hash = fx_hash_tool_call(tool_name, input);
107        let count = self.tool_call_counts.entry(hash).or_insert(0);
108        *count += 1;
109
110        if *count >= PINGPONG_THRESHOLD {
111            return LoopVerdict::Block(format!(
112                "ping-pong detected: tool '{tool_name}' called \
113                 {count} times with same input"
114            ));
115        }
116
117        LoopVerdict::Allow
118    }
119
120    /// Record a `MaxTokens` stop reason. Returns `CircuitBreak` if
121    /// consecutive truncations exceed threshold.
122    pub fn record_max_tokens(&mut self) -> LoopVerdict {
123        self.consecutive_max_tokens += 1;
124        if self.consecutive_max_tokens >= MAX_CONSECUTIVE_TRUNCATION {
125            LoopVerdict::CircuitBreak(format!(
126                "{MAX_CONSECUTIVE_TRUNCATION} consecutive MaxTokens responses"
127            ))
128        } else {
129            LoopVerdict::Allow
130        }
131    }
132
133    /// Reset consecutive `MaxTokens` counter (on `EndTurn` or `ToolUse`).
134    pub fn reset_max_tokens(&mut self) {
135        self.consecutive_max_tokens = 0;
136    }
137
138    /// Record token usage from a completion. Returns verdict
139    /// based on token budget (if configured).
140    pub fn record_usage(&mut self, usage: &TokenUsage) -> LoopVerdict {
141        self.usage.accumulate(usage);
142        self.check_token_budget()
143    }
144
145    /// Check cumulative token usage against budget.
146    fn check_token_budget(&self) -> LoopVerdict {
147        let Some(budget) = self.max_tokens_budget else {
148            return LoopVerdict::Allow;
149        };
150        let total = self.usage.input_tokens + self.usage.output_tokens;
151        if total > budget {
152            return LoopVerdict::CircuitBreak(format!(
153                "token budget exhausted: {total} > {budget}"
154            ));
155        }
156        let threshold = (budget as f64 * WARN_ITERATION_FRACTION) as u64;
157        if total >= threshold {
158            return LoopVerdict::Warn(format!(
159                "token usage {total}/{budget} ({}% of budget)",
160                total * 100 / budget
161            ));
162        }
163        LoopVerdict::Allow
164    }
165
166    /// Record estimated cost and check budget.
167    #[cfg_attr(
168        feature = "agents-contracts",
169        provable_contracts_macros::contract("agent-loop-v1", equation = "guard_budget")
170    )]
171    pub fn record_cost(&mut self, cost_usd: f64) -> LoopVerdict {
172        self.accumulated_cost_usd += cost_usd;
173        if self.max_cost_usd > 0.0 && self.accumulated_cost_usd > self.max_cost_usd {
174            LoopVerdict::CircuitBreak(format!(
175                "cost budget exceeded: ${:.4} > ${:.4}",
176                self.accumulated_cost_usd, self.max_cost_usd
177            ))
178        } else {
179            LoopVerdict::Allow
180        }
181    }
182
183    /// Get accumulated usage.
184    pub fn usage(&self) -> &TokenUsage {
185        &self.usage
186    }
187
188    /// Get current iteration count.
189    pub fn current_iteration(&self) -> u32 {
190        self.current_iteration
191    }
192
193    /// Get total tool calls made.
194    pub fn total_tool_calls(&self) -> u32 {
195        self.total_tool_calls
196    }
197}
198
199/// `FxHash` a tool call for ping-pong detection.
200///
201/// Uses a simple multiplicative hash (non-cryptographic) — we only
202/// need collision resistance across ~50 values, not security.
203fn fx_hash_tool_call(tool_name: &str, input: &serde_json::Value) -> u64 {
204    let mut hasher = FxHasher::default();
205    tool_name.hash(&mut hasher);
206    let input_str = input.to_string();
207    input_str.hash(&mut hasher);
208    hasher.finish()
209}
210
211/// Minimal `FxHash` implementation (no external dependency).
212#[derive(Default)]
213struct FxHasher {
214    hash: u64,
215}
216
217const FX_SEED: u64 = 0x517c_c1b7_2722_0a95;
218
219impl Hasher for FxHasher {
220    fn finish(&self) -> u64 {
221        self.hash
222    }
223
224    fn write(&mut self, bytes: &[u8]) {
225        for &byte in bytes {
226            self.hash = (self.hash.rotate_left(5) ^ u64::from(byte)).wrapping_mul(FX_SEED);
227        }
228    }
229}
230
231#[cfg(test)]
232#[path = "guard_tests.rs"]
233mod tests;