1use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9
10use super::result::TokenUsage;
11
12pub struct LoopGuard {
14 max_iterations: u32,
15 current_iteration: u32,
16 max_tool_calls: u32,
17 total_tool_calls: u32,
18 tool_call_counts: HashMap<u64, u32>, consecutive_max_tokens: u32,
20 usage: TokenUsage,
21 max_cost_usd: f64, accumulated_cost_usd: f64,
23 max_tokens_budget: Option<u64>, }
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum LoopVerdict {
29 Allow,
31 Warn(String),
33 Block(String),
35 CircuitBreak(String),
37}
38
39const PINGPONG_THRESHOLD: u32 = 3;
41const MAX_CONSECUTIVE_TRUNCATION: u32 = 5;
43const WARN_ITERATION_FRACTION: f64 = 0.8;
45
46impl LoopGuard {
47 pub fn new(max_iterations: u32, max_tool_calls: u32, max_cost_usd: f64) -> Self {
49 Self {
50 max_iterations,
51 current_iteration: 0,
52 max_tool_calls,
53 total_tool_calls: 0,
54 tool_call_counts: HashMap::new(),
55 consecutive_max_tokens: 0,
56 usage: TokenUsage::default(),
57 max_cost_usd,
58 accumulated_cost_usd: 0.0,
59 max_tokens_budget: None,
60 }
61 }
62
63 pub fn with_token_budget(mut self, budget: Option<u64>) -> Self {
65 self.max_tokens_budget = budget;
66 self
67 }
68
69 pub fn check_iteration(&mut self) -> LoopVerdict {
71 self.current_iteration += 1;
72
73 if self.current_iteration > self.max_iterations {
74 return LoopVerdict::CircuitBreak(format!(
75 "max iterations reached ({})",
76 self.max_iterations
77 ));
78 }
79
80 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_lossless)]
82 let threshold = (self.max_iterations as f64 * WARN_ITERATION_FRACTION) as u32;
83 if self.current_iteration >= threshold {
84 return LoopVerdict::Warn(format!(
85 "iteration {}/{} ({}% of budget)",
86 self.current_iteration,
87 self.max_iterations,
88 self.current_iteration * 100 / self.max_iterations
89 ));
90 }
91
92 LoopVerdict::Allow
93 }
94
95 pub fn check_tool_call(&mut self, tool_name: &str, input: &serde_json::Value) -> LoopVerdict {
97 self.total_tool_calls += 1;
98
99 if self.total_tool_calls > self.max_tool_calls {
100 return LoopVerdict::CircuitBreak(format!(
101 "max tool calls reached ({})",
102 self.max_tool_calls
103 ));
104 }
105
106 let hash = fx_hash_tool_call(tool_name, input);
107 let count = self.tool_call_counts.entry(hash).or_insert(0);
108 *count += 1;
109
110 if *count >= PINGPONG_THRESHOLD {
111 return LoopVerdict::Block(format!(
112 "ping-pong detected: tool '{tool_name}' called \
113 {count} times with same input"
114 ));
115 }
116
117 LoopVerdict::Allow
118 }
119
120 pub fn record_max_tokens(&mut self) -> LoopVerdict {
123 self.consecutive_max_tokens += 1;
124 if self.consecutive_max_tokens >= MAX_CONSECUTIVE_TRUNCATION {
125 LoopVerdict::CircuitBreak(format!(
126 "{MAX_CONSECUTIVE_TRUNCATION} consecutive MaxTokens responses"
127 ))
128 } else {
129 LoopVerdict::Allow
130 }
131 }
132
133 pub fn reset_max_tokens(&mut self) {
135 self.consecutive_max_tokens = 0;
136 }
137
138 pub fn record_usage(&mut self, usage: &TokenUsage) -> LoopVerdict {
141 self.usage.accumulate(usage);
142 self.check_token_budget()
143 }
144
145 fn check_token_budget(&self) -> LoopVerdict {
147 let Some(budget) = self.max_tokens_budget else {
148 return LoopVerdict::Allow;
149 };
150 let total = self.usage.input_tokens + self.usage.output_tokens;
151 if total > budget {
152 return LoopVerdict::CircuitBreak(format!(
153 "token budget exhausted: {total} > {budget}"
154 ));
155 }
156 let threshold = (budget as f64 * WARN_ITERATION_FRACTION) as u64;
157 if total >= threshold {
158 return LoopVerdict::Warn(format!(
159 "token usage {total}/{budget} ({}% of budget)",
160 total * 100 / budget
161 ));
162 }
163 LoopVerdict::Allow
164 }
165
166 #[cfg_attr(
168 feature = "agents-contracts",
169 provable_contracts_macros::contract("agent-loop-v1", equation = "guard_budget")
170 )]
171 pub fn record_cost(&mut self, cost_usd: f64) -> LoopVerdict {
172 self.accumulated_cost_usd += cost_usd;
173 if self.max_cost_usd > 0.0 && self.accumulated_cost_usd > self.max_cost_usd {
174 LoopVerdict::CircuitBreak(format!(
175 "cost budget exceeded: ${:.4} > ${:.4}",
176 self.accumulated_cost_usd, self.max_cost_usd
177 ))
178 } else {
179 LoopVerdict::Allow
180 }
181 }
182
183 pub fn usage(&self) -> &TokenUsage {
185 &self.usage
186 }
187
188 pub fn current_iteration(&self) -> u32 {
190 self.current_iteration
191 }
192
193 pub fn total_tool_calls(&self) -> u32 {
195 self.total_tool_calls
196 }
197}
198
199fn fx_hash_tool_call(tool_name: &str, input: &serde_json::Value) -> u64 {
204 let mut hasher = FxHasher::default();
205 tool_name.hash(&mut hasher);
206 let input_str = input.to_string();
207 input_str.hash(&mut hasher);
208 hasher.finish()
209}
210
211#[derive(Default)]
213struct FxHasher {
214 hash: u64,
215}
216
217const FX_SEED: u64 = 0x517c_c1b7_2722_0a95;
218
219impl Hasher for FxHasher {
220 fn finish(&self) -> u64 {
221 self.hash
222 }
223
224 fn write(&mut self, bytes: &[u8]) {
225 for &byte in bytes {
226 self.hash = (self.hash.rotate_left(5) ^ u64::from(byte)).wrapping_mul(FX_SEED);
227 }
228 }
229}
230
231#[cfg(test)]
232#[path = "guard_tests.rs"]
233mod tests;