use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time;
#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum OperationType {
ApiCall,
FileOperation,
CodeAnalysis,
ToolExecution,
NetworkRequest,
Processing,
Custom(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TimeoutConfig {
pub timeout_duration: Duration,
pub max_retries: u32,
pub initial_retry_delay: Duration,
pub max_retry_delay: Duration,
pub backoff_multiplier: f64,
pub use_jitter: bool,
pub retry_on_timeout: bool,
pub retry_on_errors: Vec<String>,
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
timeout_duration: Duration::from_secs(30),
max_retries: 3,
initial_retry_delay: Duration::from_millis(100),
max_retry_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
use_jitter: true,
retry_on_timeout: true,
retry_on_errors: vec![
"timeout".to_string(),
"connection".to_string(),
"network".to_string(),
"server_error".to_string(),
],
}
}
}
impl TimeoutConfig {
pub fn api_call() -> Self {
Self {
timeout_duration: Duration::from_secs(60),
max_retries: 5,
initial_retry_delay: Duration::from_millis(200),
max_retry_delay: Duration::from_secs(10),
backoff_multiplier: 1.5,
..Default::default()
}
}
pub fn file_operation() -> Self {
Self {
timeout_duration: Duration::from_secs(10),
max_retries: 2,
initial_retry_delay: Duration::from_millis(50),
max_retry_delay: Duration::from_secs(2),
backoff_multiplier: 2.0,
retry_on_timeout: false, ..Default::default()
}
}
pub fn analysis() -> Self {
Self {
timeout_duration: Duration::from_secs(120),
max_retries: 1,
initial_retry_delay: Duration::from_secs(5),
max_retry_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct TimeoutEvent {
pub operation_id: String,
pub operation_type: OperationType,
pub start_time: Instant,
pub timeout_duration: Duration,
pub retry_count: u32,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct TimeoutStats {
pub total_operations: usize,
pub timed_out_operations: usize,
pub successful_retries: usize,
pub failed_retries: usize,
pub average_timeout_duration: Duration,
pub total_retry_attempts: usize,
}
pub struct TimeoutDetector {
configs: Arc<RwLock<HashMap<OperationType, TimeoutConfig>>>,
active_operations: Arc<RwLock<HashMap<String, TimeoutEvent>>>,
stats: Arc<RwLock<TimeoutStats>>,
}
impl Default for TimeoutDetector {
fn default() -> Self {
Self::new()
}
}
impl TimeoutDetector {
pub fn new() -> Self {
let mut configs = HashMap::new();
configs.insert(OperationType::ApiCall, TimeoutConfig::api_call());
configs.insert(
OperationType::FileOperation,
TimeoutConfig::file_operation(),
);
configs.insert(OperationType::CodeAnalysis, TimeoutConfig::analysis());
configs.insert(OperationType::ToolExecution, TimeoutConfig::default());
configs.insert(OperationType::NetworkRequest, TimeoutConfig::api_call());
configs.insert(OperationType::Processing, TimeoutConfig::analysis());
Self {
configs: Arc::new(RwLock::new(configs)),
active_operations: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(TimeoutStats::default())),
}
}
pub async fn set_config(&self, operation_type: OperationType, config: TimeoutConfig) {
let mut configs = self.configs.write().await;
configs.insert(operation_type, config);
}
pub async fn get_config(&self, operation_type: &OperationType) -> TimeoutConfig {
let configs = self.configs.read().await;
configs.get(operation_type).cloned().unwrap_or_default()
}
pub async fn start_operation(
&self,
operation_id: String,
operation_type: OperationType,
) -> TimeoutHandle {
let config = self.get_config(&operation_type).await;
let event = TimeoutEvent {
operation_id: operation_id.clone(),
operation_type,
start_time: Instant::now(),
timeout_duration: config.timeout_duration,
retry_count: 0,
error_message: None,
};
let mut active_ops = self.active_operations.write().await;
active_ops.insert(operation_id.clone(), event);
let mut stats = self.stats.write().await;
stats.total_operations += 1;
TimeoutHandle {
operation_id,
detector: Arc::new(self.clone()),
}
}
pub async fn check_timeout(&self, operation_id: &str) -> Option<TimeoutEvent> {
let active_ops = self.active_operations.read().await;
if let Some(event) = active_ops.get(operation_id) {
if event.start_time.elapsed() >= event.timeout_duration {
return Some(event.clone());
}
}
None
}
pub async fn record_timeout(&self, operation_id: &str, error_message: Option<String>) {
let mut active_ops = self.active_operations.write().await;
if let Some(event) = active_ops.get_mut(operation_id) {
event.error_message = error_message;
}
let mut stats = self.stats.write().await;
stats.timed_out_operations += 1;
}
pub async fn record_successful_retry(&self, _operation_id: &str) {
let mut stats = self.stats.write().await;
stats.successful_retries += 1;
stats.total_retry_attempts += 1;
}
pub async fn record_failed_retry(&self, _operation_id: &str) {
let mut stats = self.stats.write().await;
stats.failed_retries += 1;
stats.total_retry_attempts += 1;
}
pub async fn end_operation(&self, operation_id: &str) {
let mut active_ops = self.active_operations.write().await;
if let Some(event) = active_ops.remove(operation_id) {
let duration = event.start_time.elapsed();
let mut stats = self.stats.write().await;
if stats.total_operations > 0 {
let total_duration =
stats.average_timeout_duration * (stats.total_operations - 1) as u32;
stats.average_timeout_duration =
(total_duration + duration) / stats.total_operations as u32;
}
}
}
pub async fn get_stats(&self) -> TimeoutStats {
self.stats.read().await.clone()
}
pub async fn calculate_retry_delay(
&self,
operation_type: &OperationType,
attempt: u32,
) -> Duration {
let config = self.get_config(operation_type).await;
let base_delay = config.initial_retry_delay.as_millis() as f64;
let multiplier = config.backoff_multiplier.powi(attempt as i32);
let delay_ms = (base_delay * multiplier) as u64;
let mut delay =
Duration::from_millis(delay_ms.min(config.max_retry_delay.as_millis() as u64));
if config.use_jitter {
use std::time::SystemTime;
let seed = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let jitter_factor = (seed % 100) as f64 / 100.0; let jitter_ms = (delay.as_millis() as f64 * 0.1 * jitter_factor) as u64; delay += Duration::from_millis(jitter_ms);
}
delay
}
pub async fn should_retry(
&self,
operation_type: &OperationType,
error: &anyhow::Error,
attempt: u32,
) -> bool {
let config = self.get_config(operation_type).await;
if attempt >= config.max_retries {
return false;
}
let error_str = error.to_string().to_lowercase();
for retry_error in &config.retry_on_errors {
if error_str.contains(retry_error) {
return true;
}
}
if config.retry_on_timeout && error_str.contains("timeout") {
return true;
}
false
}
pub async fn execute_with_timeout_retry<F, Fut, T>(
&self,
operation_id: String,
operation_type: OperationType,
mut operation: F,
) -> Result<T, anyhow::Error>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, anyhow::Error>>,
{
let config = self.get_config(&operation_type).await;
let mut attempt = 0;
let _last_error: Option<anyhow::Error> = None;
loop {
let handle = self
.start_operation(
format!("{}_{}", operation_id, attempt),
operation_type.clone(),
)
.await;
let result = match time::timeout(config.timeout_duration, operation()).await {
Ok(result) => result,
Err(_) => {
self.record_timeout(
&handle.operation_id,
Some("Operation timed out".to_string()),
)
.await;
Err(anyhow::anyhow!(
"Operation '{}' timed out after {:?}",
operation_id,
config.timeout_duration
))
}
};
handle.end().await;
match result {
Ok(value) => {
if attempt > 0 {
self.record_successful_retry(&format!("{}_{}", operation_id, attempt))
.await;
}
return Ok(value);
}
Err(error) => {
let should_retry_op = self.should_retry(&operation_type, &error, attempt).await;
if !should_retry_op {
if attempt > 0 {
self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
.await;
}
return Err(error);
}
attempt += 1;
self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
.await;
let delay = self.calculate_retry_delay(&operation_type, attempt).await;
eprintln!(
"Operation '{}' failed (attempt {}/{}), retrying in {:?}",
operation_id, attempt, config.max_retries, delay
);
time::sleep(delay).await;
}
}
}
}
}
impl Clone for TimeoutDetector {
fn clone(&self) -> Self {
Self {
configs: Arc::clone(&self.configs),
active_operations: Arc::clone(&self.active_operations),
stats: Arc::clone(&self.stats),
}
}
}
pub struct TimeoutHandle {
operation_id: String,
detector: Arc<TimeoutDetector>,
}
impl TimeoutHandle {
pub async fn end(self) {
self.detector.end_operation(&self.operation_id).await;
}
pub fn operation_id(&self) -> &str {
&self.operation_id
}
}
impl Drop for TimeoutHandle {
fn drop(&mut self) {
let operation_id = self.operation_id.clone();
let detector = Arc::clone(&self.detector);
tokio::spawn(async move {
detector.end_operation(&operation_id).await;
});
}
}
use once_cell::sync::Lazy;
pub static TIMEOUT_DETECTOR: Lazy<TimeoutDetector> = Lazy::new(TimeoutDetector::new);
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn test_timeout_detection() {
let detector = TimeoutDetector::new();
let config = TimeoutConfig {
timeout_duration: Duration::from_millis(10),
max_retries: 0,
..Default::default()
};
detector.set_config(OperationType::ApiCall, config).await;
let result = detector
.execute_with_timeout_retry(
"test_operation".to_string(),
OperationType::ApiCall,
|| async {
sleep(Duration::from_millis(20)).await;
Ok("success")
},
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timed out"));
}
#[tokio::test]
async fn test_successful_retry() {
let detector = TimeoutDetector::new();
let config = TimeoutConfig {
timeout_duration: Duration::from_millis(50),
max_retries: 2,
initial_retry_delay: Duration::from_millis(5),
retry_on_timeout: true,
..Default::default()
};
detector.set_config(OperationType::ApiCall, config).await;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let result = detector
.execute_with_timeout_retry(
"test_retry".to_string(),
OperationType::ApiCall,
move || {
let call_count = call_count_clone.clone();
async move {
let count = call_count.fetch_add(1, Ordering::SeqCst) + 1;
if count == 1 {
sleep(Duration::from_millis(60)).await;
Ok("should not reach here")
} else {
sleep(Duration::from_millis(10)).await;
Ok("success")
}
}
},
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(call_count.load(Ordering::SeqCst), 2);
let stats = detector.get_stats().await;
assert_eq!(stats.successful_retries, 1);
assert_eq!(stats.total_retry_attempts, 1);
}
#[tokio::test]
async fn test_calculate_retry_delay() {
let detector = TimeoutDetector::new();
let delay = detector
.calculate_retry_delay(&OperationType::ApiCall, 0)
.await;
assert!(delay >= Duration::from_millis(200));
let delay2 = detector
.calculate_retry_delay(&OperationType::ApiCall, 1)
.await;
assert!(delay2 > delay); }
}