use crate::error::{AgentError, Result};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct GlobalLimits {
pub max_sessions: usize,
pub max_total_context_blocks: usize,
pub max_ops_per_second: f64,
pub operation_timeout: Duration,
}
impl Default for GlobalLimits {
fn default() -> Self {
Self {
max_sessions: 100,
max_total_context_blocks: 100_000,
max_ops_per_second: 1000.0,
operation_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionLimits {
pub max_context_tokens: usize,
pub max_context_blocks: usize,
pub max_expand_depth: usize,
pub max_results_per_operation: usize,
pub max_operations_before_checkpoint: usize,
pub session_timeout: Duration,
pub max_history_size: usize,
pub budget: OperationBudget,
}
impl Default for SessionLimits {
fn default() -> Self {
Self {
max_context_tokens: 8_000,
max_context_blocks: 200,
max_expand_depth: 10,
max_results_per_operation: 100,
max_operations_before_checkpoint: 1000,
session_timeout: Duration::from_secs(30 * 60), max_history_size: 100,
budget: OperationBudget::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct OperationBudget {
pub traversal_operations: usize,
pub search_operations: usize,
pub blocks_read: usize,
}
impl Default for OperationBudget {
fn default() -> Self {
Self {
traversal_operations: 10_000,
search_operations: 100,
blocks_read: 50_000,
}
}
}
#[derive(Debug, Default)]
pub struct BudgetTracker {
pub traversal_ops_used: AtomicUsize,
pub search_ops_used: AtomicUsize,
pub blocks_read_used: AtomicUsize,
}
impl BudgetTracker {
pub fn new() -> Self {
Self::default()
}
pub fn record_traversal(&self) {
self.traversal_ops_used.fetch_add(1, Ordering::Relaxed);
}
pub fn record_search(&self) {
self.search_ops_used.fetch_add(1, Ordering::Relaxed);
}
pub fn record_blocks_read(&self, count: usize) {
self.blocks_read_used.fetch_add(count, Ordering::Relaxed);
}
pub fn check_traversal_budget(&self, budget: &OperationBudget) -> Result<()> {
let used = self.traversal_ops_used.load(Ordering::Relaxed);
if used >= budget.traversal_operations {
return Err(AgentError::BudgetExhausted {
operation_type: "traversal".to_string(),
});
}
Ok(())
}
pub fn check_search_budget(&self, budget: &OperationBudget) -> Result<()> {
let used = self.search_ops_used.load(Ordering::Relaxed);
if used >= budget.search_operations {
return Err(AgentError::BudgetExhausted {
operation_type: "search".to_string(),
});
}
Ok(())
}
pub fn check_blocks_budget(&self, budget: &OperationBudget) -> Result<()> {
let used = self.blocks_read_used.load(Ordering::Relaxed);
if used >= budget.blocks_read {
return Err(AgentError::BudgetExhausted {
operation_type: "blocks_read".to_string(),
});
}
Ok(())
}
pub fn reset(&self) {
self.traversal_ops_used.store(0, Ordering::Relaxed);
self.search_ops_used.store(0, Ordering::Relaxed);
self.blocks_read_used.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: RwLock<CircuitState>,
failure_count: AtomicUsize,
failure_threshold: usize,
recovery_timeout: Duration,
last_failure: RwLock<Option<Instant>>,
success_count_in_half_open: AtomicUsize,
success_threshold: usize,
}
impl CircuitBreaker {
pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
Self {
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicUsize::new(0),
failure_threshold,
recovery_timeout,
last_failure: RwLock::new(None),
success_count_in_half_open: AtomicUsize::new(0),
success_threshold: 3, }
}
pub fn state(&self) -> CircuitState {
*self.state.read().unwrap()
}
pub fn can_proceed(&self) -> Result<()> {
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
let last_failure = self.last_failure.read().unwrap();
if let Some(last) = *last_failure {
if last.elapsed() >= self.recovery_timeout {
drop(last_failure);
*self.state.write().unwrap() = CircuitState::HalfOpen;
self.success_count_in_half_open.store(0, Ordering::Relaxed);
debug!("Circuit breaker transitioning to half-open");
return Ok(());
}
}
Err(AgentError::CircuitOpen {
reason: "Too many failures, circuit is open".to_string(),
})
}
CircuitState::HalfOpen => {
Ok(())
}
}
}
pub fn record_success(&self) {
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
}
CircuitState::HalfOpen => {
let successes = self
.success_count_in_half_open
.fetch_add(1, Ordering::Relaxed)
+ 1;
if successes >= self.success_threshold {
*self.state.write().unwrap() = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
debug!("Circuit breaker closed after successful recovery");
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => {
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if failures >= self.failure_threshold {
*self.state.write().unwrap() = CircuitState::Open;
*self.last_failure.write().unwrap() = Some(Instant::now());
warn!(
"Circuit breaker opened after {} failures",
self.failure_threshold
);
}
}
CircuitState::HalfOpen => {
*self.state.write().unwrap() = CircuitState::Open;
*self.last_failure.write().unwrap() = Some(Instant::now());
self.success_count_in_half_open.store(0, Ordering::Relaxed);
warn!("Circuit breaker re-opened after failure during half-open");
}
CircuitState::Open => {
*self.last_failure.write().unwrap() = Some(Instant::now());
}
}
}
pub fn reset(&self) {
*self.state.write().unwrap() = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
*self.last_failure.write().unwrap() = None;
self.success_count_in_half_open.store(0, Ordering::Relaxed);
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(5, Duration::from_secs(30))
}
}
pub struct DepthGuardHandle<'a> {
guard: &'a DepthGuard,
}
impl<'a> Drop for DepthGuardHandle<'a> {
fn drop(&mut self) {
self.guard.current.fetch_sub(1, Ordering::Relaxed);
}
}
pub struct DepthGuard {
current: AtomicUsize,
max: usize,
}
impl DepthGuard {
pub fn new(max: usize) -> Self {
Self {
current: AtomicUsize::new(0),
max,
}
}
pub fn try_enter(&self) -> Result<DepthGuardHandle<'_>> {
let current = self.current.fetch_add(1, Ordering::Relaxed);
if current >= self.max {
self.current.fetch_sub(1, Ordering::Relaxed);
return Err(AgentError::DepthLimitExceeded {
current: current + 1,
max: self.max,
});
}
Ok(DepthGuardHandle { guard: self })
}
pub fn current_depth(&self) -> usize {
self.current.load(Ordering::Relaxed)
}
pub fn max_depth(&self) -> usize {
self.max
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_tracker() {
let tracker = BudgetTracker::new();
let budget = OperationBudget {
traversal_operations: 3,
search_operations: 2,
blocks_read: 10,
};
tracker.record_traversal();
tracker.record_traversal();
assert!(tracker.check_traversal_budget(&budget).is_ok());
tracker.record_traversal();
assert!(tracker.check_traversal_budget(&budget).is_err());
tracker.reset();
assert!(tracker.check_traversal_budget(&budget).is_ok());
}
#[test]
fn test_circuit_breaker() {
let cb = CircuitBreaker::new(3, Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_proceed().is_ok());
cb.record_failure();
cb.record_failure();
assert!(cb.can_proceed().is_ok());
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(cb.can_proceed().is_err());
std::thread::sleep(Duration::from_millis(150));
assert!(cb.can_proceed().is_ok()); assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
cb.record_success();
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_depth_guard() {
let guard = DepthGuard::new(3);
assert_eq!(guard.current_depth(), 0);
{
let _h1 = guard.try_enter().unwrap();
assert_eq!(guard.current_depth(), 1);
{
let _h2 = guard.try_enter().unwrap();
assert_eq!(guard.current_depth(), 2);
{
let _h3 = guard.try_enter().unwrap();
assert_eq!(guard.current_depth(), 3);
assert!(guard.try_enter().is_err());
}
assert_eq!(guard.current_depth(), 2);
}
assert_eq!(guard.current_depth(), 1);
}
assert_eq!(guard.current_depth(), 0);
}
}