use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use super::result::TokenUsage;
pub struct LoopGuard {
max_iterations: u32,
current_iteration: u32,
max_tool_calls: u32,
total_tool_calls: u32,
tool_call_counts: HashMap<u64, u32>, consecutive_max_tokens: u32,
usage: TokenUsage,
max_cost_usd: f64, accumulated_cost_usd: f64,
max_tokens_budget: Option<u64>, }
#[derive(Debug, Clone, PartialEq)]
pub enum LoopVerdict {
Allow,
Warn(String),
Block(String),
CircuitBreak(String),
}
const PINGPONG_THRESHOLD: u32 = 3;
const MAX_CONSECUTIVE_TRUNCATION: u32 = 5;
const WARN_ITERATION_FRACTION: f64 = 0.8;
impl LoopGuard {
pub fn new(max_iterations: u32, max_tool_calls: u32, max_cost_usd: f64) -> Self {
Self {
max_iterations,
current_iteration: 0,
max_tool_calls,
total_tool_calls: 0,
tool_call_counts: HashMap::new(),
consecutive_max_tokens: 0,
usage: TokenUsage::default(),
max_cost_usd,
accumulated_cost_usd: 0.0,
max_tokens_budget: None,
}
}
pub fn with_token_budget(mut self, budget: Option<u64>) -> Self {
self.max_tokens_budget = budget;
self
}
pub fn check_iteration(&mut self) -> LoopVerdict {
self.current_iteration += 1;
if self.current_iteration > self.max_iterations {
return LoopVerdict::CircuitBreak(format!(
"max iterations reached ({})",
self.max_iterations
));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_lossless)]
let threshold = (self.max_iterations as f64 * WARN_ITERATION_FRACTION) as u32;
if self.current_iteration >= threshold {
return LoopVerdict::Warn(format!(
"iteration {}/{} ({}% of budget)",
self.current_iteration,
self.max_iterations,
self.current_iteration * 100 / self.max_iterations
));
}
LoopVerdict::Allow
}
pub fn check_tool_call(&mut self, tool_name: &str, input: &serde_json::Value) -> LoopVerdict {
self.total_tool_calls += 1;
if self.total_tool_calls > self.max_tool_calls {
return LoopVerdict::CircuitBreak(format!(
"max tool calls reached ({})",
self.max_tool_calls
));
}
let hash = fx_hash_tool_call(tool_name, input);
let count = self.tool_call_counts.entry(hash).or_insert(0);
*count += 1;
if *count >= PINGPONG_THRESHOLD {
return LoopVerdict::Block(format!(
"ping-pong detected: tool '{tool_name}' called \
{count} times with same input"
));
}
LoopVerdict::Allow
}
pub fn record_max_tokens(&mut self) -> LoopVerdict {
self.consecutive_max_tokens += 1;
if self.consecutive_max_tokens >= MAX_CONSECUTIVE_TRUNCATION {
LoopVerdict::CircuitBreak(format!(
"{MAX_CONSECUTIVE_TRUNCATION} consecutive MaxTokens responses"
))
} else {
LoopVerdict::Allow
}
}
pub fn reset_max_tokens(&mut self) {
self.consecutive_max_tokens = 0;
}
pub fn record_usage(&mut self, usage: &TokenUsage) -> LoopVerdict {
self.usage.accumulate(usage);
self.check_token_budget()
}
fn check_token_budget(&self) -> LoopVerdict {
let Some(budget) = self.max_tokens_budget else {
return LoopVerdict::Allow;
};
let total = self.usage.input_tokens + self.usage.output_tokens;
if total > budget {
return LoopVerdict::CircuitBreak(format!(
"token budget exhausted: {total} > {budget}"
));
}
let threshold = (budget as f64 * WARN_ITERATION_FRACTION) as u64;
if total >= threshold {
return LoopVerdict::Warn(format!(
"token usage {total}/{budget} ({}% of budget)",
total * 100 / budget
));
}
LoopVerdict::Allow
}
#[cfg_attr(
feature = "agents-contracts",
provable_contracts_macros::contract("agent-loop-v1", equation = "guard_budget")
)]
pub fn record_cost(&mut self, cost_usd: f64) -> LoopVerdict {
self.accumulated_cost_usd += cost_usd;
if self.max_cost_usd > 0.0 && self.accumulated_cost_usd > self.max_cost_usd {
LoopVerdict::CircuitBreak(format!(
"cost budget exceeded: ${:.4} > ${:.4}",
self.accumulated_cost_usd, self.max_cost_usd
))
} else {
LoopVerdict::Allow
}
}
pub fn usage(&self) -> &TokenUsage {
&self.usage
}
pub fn current_iteration(&self) -> u32 {
self.current_iteration
}
pub fn total_tool_calls(&self) -> u32 {
self.total_tool_calls
}
}
fn fx_hash_tool_call(tool_name: &str, input: &serde_json::Value) -> u64 {
let mut hasher = FxHasher::default();
tool_name.hash(&mut hasher);
let input_str = input.to_string();
input_str.hash(&mut hasher);
hasher.finish()
}
#[derive(Default)]
struct FxHasher {
hash: u64,
}
const FX_SEED: u64 = 0x517c_c1b7_2722_0a95;
impl Hasher for FxHasher {
fn finish(&self) -> u64 {
self.hash
}
fn write(&mut self, bytes: &[u8]) {
for &byte in bytes {
self.hash = (self.hash.rotate_left(5) ^ u64::from(byte)).wrapping_mul(FX_SEED);
}
}
}
#[cfg(test)]
#[path = "guard_tests.rs"]
mod tests;