use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use serde_json::Value;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryStrategy {
Fixed,
Exponential,
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub strategy: RetryStrategy,
pub backoff_factor: f64,
pub only_transient: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 0,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
strategy: RetryStrategy::Exponential,
backoff_factor: 2.0,
only_transient: false,
}
}
}
impl RetryConfig {
pub fn disabled() -> Self {
Self::default()
}
pub fn exponential(max_retries: u32) -> Self {
Self {
max_retries,
strategy: RetryStrategy::Exponential,
..Default::default()
}
}
pub fn delay_for(&self, attempt: u32) -> Duration {
match self.strategy {
RetryStrategy::Fixed => self.initial_delay,
RetryStrategy::Exponential => {
let factor = self.backoff_factor.powi(attempt as i32);
let ms = (self.initial_delay.as_millis() as f64 * factor) as u64;
let delay = Duration::from_millis(ms);
delay.min(self.max_delay)
}
}
}
pub fn should_retry(&self, output: &Value) -> bool {
if !self.only_transient {
return true;
}
output
.get("error")
.and_then(|e| e.get("transient"))
.and_then(|t| t.as_bool())
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Default)]
pub struct TimeoutConfig {
pub default_timeout: Option<Duration>,
pub tool_timeouts: HashMap<String, Duration>,
}
impl TimeoutConfig {
pub fn disabled() -> Self {
Self::default()
}
pub fn default_timeout(timeout: Duration) -> Self {
Self {
default_timeout: Some(timeout),
tool_timeouts: HashMap::new(),
}
}
pub fn with_tool_timeout(mut self, tool_name: impl Into<String>, timeout: Duration) -> Self {
self.tool_timeouts.insert(tool_name.into(), timeout);
self
}
pub fn get_timeout(&self, tool_name: &str) -> Option<Duration> {
self.tool_timeouts
.get(tool_name)
.copied()
.or(self.default_timeout)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub enabled: bool,
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub half_open_max_calls: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
enabled: false,
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
}
}
}
#[derive(Debug)]
pub struct CircuitBreakerState {
pub state: CircuitState,
pub consecutive_failures: u32,
pub last_opened_at: Option<Instant>,
pub half_open_calls: u32,
}
impl Default for CircuitBreakerState {
fn default() -> Self {
Self {
state: CircuitState::Closed,
consecutive_failures: 0,
last_opened_at: None,
half_open_calls: 0,
}
}
}
impl CircuitBreakerState {
pub fn can_pass(&mut self, config: &CircuitBreakerConfig) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened_at) = self.last_opened_at
&& opened_at.elapsed() >= config.recovery_timeout
{
self.state = CircuitState::HalfOpen;
self.half_open_calls = 0;
return true;
}
false
}
CircuitState::HalfOpen => {
if self.half_open_calls < config.half_open_max_calls {
self.half_open_calls += 1;
true
} else {
false
}
}
}
}
pub fn record_success(&mut self) {
self.consecutive_failures = 0;
self.state = CircuitState::Closed;
self.half_open_calls = 0;
}
pub fn record_failure(&mut self, config: &CircuitBreakerConfig) {
self.consecutive_failures += 1;
if self.consecutive_failures >= config.failure_threshold {
self.state = CircuitState::Open;
self.last_opened_at = Some(Instant::now());
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CircuitBreakerRegistry {
states: Arc<Mutex<HashMap<String, CircuitBreakerState>>>,
}
impl CircuitBreakerRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn can_pass(&self, tool_name: &str, config: &CircuitBreakerConfig) -> bool {
if !config.enabled {
return true;
}
let mut states = self.states.lock().await;
let state = states
.entry(tool_name.to_string())
.or_insert_with(CircuitBreakerState::default);
state.can_pass(config)
}
pub async fn get_state(&self, tool_name: &str) -> CircuitState {
let states = self.states.lock().await;
states
.get(tool_name)
.map_or(CircuitState::Closed, |s| s.state)
}
pub async fn record_success(&self, tool_name: &str) {
let mut states = self.states.lock().await;
if let Some(state) = states.get_mut(tool_name) {
state.record_success();
}
}
pub async fn record_failure(&self, tool_name: &str, config: &CircuitBreakerConfig) {
let mut states = self.states.lock().await;
let state = states
.entry(tool_name.to_string())
.or_insert_with(CircuitBreakerState::default);
state.record_failure(config);
}
}
#[derive(Debug, Clone, Default)]
pub struct ConcurrencyConfig {
pub default_concurrency: Option<usize>,
pub tool_concurrency: HashMap<String, usize>,
}
impl ConcurrencyConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_tool_concurrency(
mut self,
tool_name: impl Into<String>,
concurrency: usize,
) -> Self {
self.tool_concurrency
.insert(tool_name.into(), concurrency.max(1));
self
}
pub fn with_default_concurrency(mut self, concurrency: usize) -> Self {
self.default_concurrency = Some(concurrency.max(1));
self
}
pub fn get_concurrency(&self, tool_name: &str, global_max: usize) -> usize {
self.tool_concurrency
.get(tool_name)
.copied()
.or(self.default_concurrency)
.unwrap_or(global_max)
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
value: Value,
cached_at: Instant,
ttl: Duration,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.cached_at.elapsed() > self.ttl
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enabled: bool,
pub default_ttl: Duration,
pub tool_ttls: HashMap<String, Duration>,
pub max_entries: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: false,
default_ttl: Duration::from_secs(300),
tool_ttls: HashMap::new(),
max_entries: 1000,
}
}
}
impl CacheConfig {
pub fn get_ttl(&self, tool_name: &str) -> Duration {
self.tool_ttls
.get(tool_name)
.copied()
.unwrap_or(self.default_ttl)
}
pub fn with_tool_ttl(mut self, tool_name: impl Into<String>, ttl: Duration) -> Self {
self.tool_ttls.insert(tool_name.into(), ttl);
self
}
}
#[derive(Debug, Default, Clone)]
pub struct ToolResultCache {
entries: Arc<Mutex<HashMap<String, CacheEntry>>>,
}
impl ToolResultCache {
pub fn new() -> Self {
Self::default()
}
fn make_key(tool_name: &str, input: &Value) -> String {
format!("{tool_name}:{input}")
}
pub async fn get(&self, tool_name: &str, input: &Value) -> Option<Value> {
let key = Self::make_key(tool_name, input);
let mut entries = self.entries.lock().await;
let entry = entries.get(&key)?;
if entry.is_expired() {
entries.remove(&key);
return None;
}
Some(entry.value.clone())
}
pub async fn set(
&self,
tool_name: &str,
input: &Value,
value: Value,
ttl: Duration,
max_entries: usize,
) {
let key = Self::make_key(tool_name, input);
let mut entries = self.entries.lock().await;
entries.retain(|_, v| !v.is_expired());
while entries.len() >= max_entries {
if let Some(oldest_key) = entries.keys().next().cloned() {
entries.remove(&oldest_key);
} else {
break;
}
}
entries.insert(
key,
CacheEntry {
value,
cached_at: Instant::now(),
ttl,
},
);
}
pub async fn clear(&self) {
let mut entries = self.entries.lock().await;
entries.clear();
}
pub async fn clear_tool(&self, tool_name: &str) {
let prefix = format!("{tool_name}:");
let mut entries = self.entries.lock().await;
entries.retain(|k, _| !k.starts_with(&prefix));
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolCallEnhancedConfig {
pub retry: RetryConfig,
pub timeout: TimeoutConfig,
pub circuit_breaker: CircuitBreakerConfig,
pub concurrency: ConcurrencyConfig,
pub cache: CacheConfig,
}
impl ToolCallEnhancedConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_retry(mut self, retry: RetryConfig) -> Self {
self.retry = retry;
self
}
pub fn with_timeout(mut self, timeout: TimeoutConfig) -> Self {
self.timeout = timeout;
self
}
pub fn with_circuit_breaker(mut self, cb: CircuitBreakerConfig) -> Self {
self.circuit_breaker = cb;
self
}
pub fn with_concurrency(mut self, concurrency: ConcurrencyConfig) -> Self {
self.concurrency = concurrency;
self
}
pub fn with_cache(mut self, cache: CacheConfig) -> Self {
self.cache = cache;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolCallEnhancedRuntime {
pub circuit_breaker: CircuitBreakerRegistry,
pub cache: ToolResultCache,
}
impl ToolCallEnhancedRuntime {
pub fn new() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_delay_exponential() {
let cfg = RetryConfig::exponential(3);
assert_eq!(cfg.delay_for(0), Duration::from_millis(100));
assert_eq!(cfg.delay_for(1), Duration::from_millis(200));
assert_eq!(cfg.delay_for(2), Duration::from_millis(400));
}
#[test]
fn test_retry_delay_max() {
let cfg = RetryConfig {
max_retries: 10,
initial_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(5),
strategy: RetryStrategy::Exponential,
backoff_factor: 2.0,
only_transient: false,
};
assert!(cfg.delay_for(10) <= Duration::from_secs(5));
}
#[test]
fn test_timeout_config() {
let cfg = TimeoutConfig::default_timeout(Duration::from_secs(30))
.with_tool_timeout("http_request", Duration::from_secs(60));
assert_eq!(
cfg.get_timeout("http_request"),
Some(Duration::from_secs(60))
);
assert_eq!(cfg.get_timeout("shell"), Some(Duration::from_secs(30)));
assert_eq!(TimeoutConfig::disabled().get_timeout("any"), None);
}
#[test]
fn test_concurrency_config() {
let cfg = ConcurrencyConfig::new()
.with_tool_concurrency("http_request", 5)
.with_tool_concurrency("shell", 1);
assert_eq!(cfg.get_concurrency("http_request", 3), 5);
assert_eq!(cfg.get_concurrency("shell", 3), 1);
assert_eq!(cfg.get_concurrency("other", 3), 3); }
#[test]
fn test_circuit_breaker_state() {
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 3,
recovery_timeout: Duration::from_millis(50),
half_open_max_calls: 1,
};
let mut state = CircuitBreakerState::default();
assert!(state.can_pass(&config));
state.record_failure(&config);
state.record_failure(&config);
state.record_failure(&config);
assert_eq!(state.state, CircuitState::Open);
assert!(!state.can_pass(&config));
}
#[tokio::test]
async fn test_tool_result_cache() {
let cache = ToolResultCache::new();
let input = serde_json::json!({"path": "/tmp/test.txt"});
let value = serde_json::json!({"content": "hello"});
assert!(cache.get("file_read", &input).await.is_none());
cache
.set(
"file_read",
&input,
value.clone(),
Duration::from_secs(60),
1000,
)
.await;
let cached = cache.get("file_read", &input).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap(), value);
}
#[tokio::test]
async fn test_cache_expiry() {
let cache = ToolResultCache::new();
let input = serde_json::json!({"key": "val"});
let value = serde_json::json!({"result": "ok"});
cache
.set("tool", &input, value, Duration::from_millis(1), 1000)
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(cache.get("tool", &input).await.is_none());
}
}