use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
#[derive(Clone, Debug)]
pub struct IterationBudget(Arc<AtomicU32>);
impl IterationBudget {
pub fn new(max: u32) -> Self {
Self(Arc::new(AtomicU32::new(max)))
}
pub fn get(&self) -> u32 {
self.0.load(Ordering::Relaxed)
}
pub fn extend(&self, extra: u32) {
self.0.fetch_add(extra, Ordering::Relaxed);
}
}
pub struct AgentGuard {
max_iterations: u32,
budget: Option<IterationBudget>,
iteration: u32,
consecutive_failures: u32,
failure_threshold: u32,
tool_timeout: Duration,
task_start: Instant,
task_timeout: Duration,
approaching_notified: bool,
approaching_iteration_notified: bool,
circuit_breaker_recoveries: u32,
}
#[derive(Debug, Clone)]
pub enum StopReason {
TaskTimeout { elapsed_secs: u64 },
IterationLimit { count: u32 },
CircuitBreaker { failures: u32 },
Cancelled,
}
impl std::fmt::Display for StopReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskTimeout { elapsed_secs } => write!(f, "Task timeout after {elapsed_secs}s"),
Self::IterationLimit { count } => write!(f, "Iteration limit reached ({count})"),
Self::CircuitBreaker { failures } => {
write!(f, "Circuit breaker: {failures} consecutive tool failures")
}
Self::Cancelled => write!(f, "Cancelled by user"),
}
}
}
impl StopReason {
pub fn is_continuable(&self) -> bool {
matches!(self, Self::TaskTimeout { .. } | Self::IterationLimit { .. })
}
}
#[derive(Debug, Clone)]
pub enum GuardVerdict {
Proceed,
ApproachingTimeout { remaining: Duration },
ApproachingIterationLimit { remaining: u32 },
CircuitBreakerRecovery { failures: u32 },
IterationLimit { count: u32 },
CircuitBreaker { failures: u32 },
TaskTimeout { elapsed: Duration },
}
const APPROACHING_TIMEOUT_SECS: u64 = 60;
const APPROACHING_ITERATION_RATIO: f32 = 0.2;
const MAX_CIRCUIT_BREAKER_RECOVERIES: u32 = 1;
impl AgentGuard {
pub fn new(
max_iterations: u32,
failure_threshold: u32,
tool_timeout_secs: u64,
task_timeout_secs: u64,
) -> Self {
Self {
max_iterations,
budget: None,
iteration: 0,
consecutive_failures: 0,
failure_threshold,
tool_timeout: Duration::from_secs(tool_timeout_secs),
task_start: Instant::now(),
task_timeout: Duration::from_secs(task_timeout_secs),
approaching_notified: false,
approaching_iteration_notified: false,
circuit_breaker_recoveries: 0,
}
}
pub fn with_budget(mut self, budget: IterationBudget) -> Self {
self.budget = Some(budget);
self
}
pub fn budget(&self) -> Option<&IterationBudget> {
self.budget.as_ref()
}
pub fn check(&mut self) -> GuardVerdict {
if let Some(ref budget) = self.budget {
self.max_iterations = budget.get();
}
if self.iteration >= self.max_iterations {
return GuardVerdict::IterationLimit {
count: self.iteration,
};
}
if self.consecutive_failures >= self.failure_threshold {
if self.circuit_breaker_recoveries < MAX_CIRCUIT_BREAKER_RECOVERIES {
self.circuit_breaker_recoveries += 1;
self.consecutive_failures = 0;
return GuardVerdict::CircuitBreakerRecovery {
failures: self.failure_threshold,
};
}
return GuardVerdict::CircuitBreaker {
failures: self.consecutive_failures,
};
}
let elapsed = self.task_start.elapsed();
if elapsed > self.task_timeout {
return GuardVerdict::TaskTimeout { elapsed };
}
let remaining_time = self.task_timeout.saturating_sub(elapsed);
if remaining_time <= Duration::from_secs(APPROACHING_TIMEOUT_SECS)
&& !self.approaching_notified
{
self.approaching_notified = true;
return GuardVerdict::ApproachingTimeout {
remaining: remaining_time,
};
}
let approaching_threshold =
((self.max_iterations as f32) * APPROACHING_ITERATION_RATIO) as u32;
let remaining_iters = self.max_iterations.saturating_sub(self.iteration);
if remaining_iters > 0
&& remaining_iters <= approaching_threshold
&& !self.approaching_iteration_notified
{
self.approaching_iteration_notified = true;
return GuardVerdict::ApproachingIterationLimit {
remaining: remaining_iters,
};
}
GuardVerdict::Proceed
}
pub fn tick(&mut self) {
self.iteration += 1;
}
pub fn record_success(&mut self) {
self.consecutive_failures = 0;
}
pub fn record_failure(&mut self) {
self.consecutive_failures += 1;
}
pub fn reset(&mut self) {
self.iteration = 0;
self.consecutive_failures = 0;
self.task_start = Instant::now();
self.approaching_notified = false;
self.approaching_iteration_notified = false;
self.circuit_breaker_recoveries = 0;
}
pub fn tool_timeout(&self) -> Duration {
self.tool_timeout
}
pub fn iteration(&self) -> u32 {
self.iteration
}
pub fn max_iterations(&self) -> u32 {
self.max_iterations
}
pub fn elapsed(&self) -> Duration {
self.task_start.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_guard_proceed() {
let mut guard = AgentGuard::new(25, 3, 120, 600);
assert!(matches!(guard.check(), GuardVerdict::Proceed));
}
#[test]
fn test_guard_iteration_limit() {
let mut guard = AgentGuard::new(3, 3, 120, 600);
guard.tick();
guard.tick();
guard.tick();
assert!(matches!(
guard.check(),
GuardVerdict::IterationLimit { count: 3 }
));
}
#[test]
fn test_guard_circuit_breaker() {
let mut guard = AgentGuard::new(25, 3, 120, 600);
guard.record_failure();
guard.record_failure();
guard.record_failure();
assert!(matches!(
guard.check(),
GuardVerdict::CircuitBreakerRecovery { .. }
));
guard.record_failure();
guard.record_failure();
guard.record_failure();
assert!(matches!(guard.check(), GuardVerdict::CircuitBreaker { .. }));
}
#[test]
fn test_guard_success_resets_failures() {
let mut guard = AgentGuard::new(25, 3, 120, 600);
guard.record_failure();
guard.record_failure();
guard.record_success();
assert!(matches!(guard.check(), GuardVerdict::Proceed));
}
#[test]
fn test_guard_reset() {
let mut guard = AgentGuard::new(3, 3, 120, 600);
guard.tick();
guard.tick();
guard.tick();
assert!(matches!(guard.check(), GuardVerdict::IterationLimit { .. }));
guard.reset();
assert!(matches!(guard.check(), GuardVerdict::Proceed));
assert_eq!(guard.iteration(), 0);
}
#[test]
fn test_guard_tool_timeout() {
let guard = AgentGuard::new(25, 3, 60, 600);
assert_eq!(guard.tool_timeout(), Duration::from_secs(60));
}
#[test]
fn test_guard_approaching_timeout() {
let mut guard = AgentGuard::new(25, 3, 120, 2);
let verdict = guard.check();
assert!(
matches!(verdict, GuardVerdict::ApproachingTimeout { .. }),
"Expected ApproachingTimeout, got {:?}",
verdict
);
let verdict2 = guard.check();
assert!(
matches!(verdict2, GuardVerdict::Proceed),
"Expected Proceed after notification, got {:?}",
verdict2
);
}
#[test]
fn test_guard_configurable_task_timeout() {
let guard = AgentGuard::new(25, 3, 120, 1200);
assert_eq!(guard.task_timeout, Duration::from_secs(1200));
}
#[test]
fn test_stop_reason_continuable() {
assert!(StopReason::TaskTimeout { elapsed_secs: 600 }.is_continuable());
assert!(StopReason::IterationLimit { count: 50 }.is_continuable());
assert!(!StopReason::CircuitBreaker { failures: 3 }.is_continuable());
assert!(!StopReason::Cancelled.is_continuable());
}
}