Skip to main content

aster/agents/error_handling/
retry_handler.rs

1//! Retry Handler
2//!
3//! Provides configurable retry behavior for transient failures.
4//! Supports multiple retry strategies and backoff algorithms.
5//!
6//! **Validates: Requirements 15.4**
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14
15/// Retry strategy types
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "snake_case")]
18pub enum RetryStrategy {
19    /// Fixed delay between retries
20    Fixed,
21    /// Linear backoff (delay * attempt)
22    Linear,
23    /// Exponential backoff (delay * 2^attempt)
24    #[default]
25    Exponential,
26    /// Exponential backoff with jitter
27    ExponentialWithJitter,
28}
29
30impl std::fmt::Display for RetryStrategy {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            RetryStrategy::Fixed => write!(f, "fixed"),
34            RetryStrategy::Linear => write!(f, "linear"),
35            RetryStrategy::Exponential => write!(f, "exponential"),
36            RetryStrategy::ExponentialWithJitter => write!(f, "exponential_with_jitter"),
37        }
38    }
39}
40
41/// Retry configuration
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct RetryConfig {
45    /// Maximum number of retry attempts
46    pub max_retries: u32,
47    /// Base delay between retries
48    pub base_delay: Duration,
49    /// Maximum delay between retries
50    pub max_delay: Duration,
51    /// Retry strategy
52    pub strategy: RetryStrategy,
53    /// Jitter factor for exponential with jitter (0.0 - 1.0)
54    pub jitter_factor: f64,
55    /// Whether to retry on timeout errors
56    pub retry_on_timeout: bool,
57    /// Error types that should be retried
58    pub retryable_errors: Vec<String>,
59}
60
61impl Default for RetryConfig {
62    fn default() -> Self {
63        Self {
64            max_retries: 3,
65            base_delay: Duration::from_millis(1000),
66            max_delay: Duration::from_secs(30),
67            strategy: RetryStrategy::Exponential,
68            jitter_factor: 0.1,
69            retry_on_timeout: true,
70            retryable_errors: vec![
71                "network".to_string(),
72                "timeout".to_string(),
73                "rate_limit".to_string(),
74                "temporary".to_string(),
75            ],
76        }
77    }
78}
79
80impl RetryConfig {
81    /// Create a new retry config
82    pub fn new(max_retries: u32, base_delay: Duration) -> Self {
83        Self {
84            max_retries,
85            base_delay,
86            ..Default::default()
87        }
88    }
89
90    /// Set the strategy
91    pub fn with_strategy(mut self, strategy: RetryStrategy) -> Self {
92        self.strategy = strategy;
93        self
94    }
95
96    /// Set the maximum delay
97    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
98        self.max_delay = max_delay;
99        self
100    }
101
102    /// Set the jitter factor
103    pub fn with_jitter_factor(mut self, factor: f64) -> Self {
104        self.jitter_factor = factor.clamp(0.0, 1.0);
105        self
106    }
107
108    /// Set whether to retry on timeout
109    pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
110        self.retry_on_timeout = retry;
111        self
112    }
113
114    /// Add a retryable error type
115    pub fn with_retryable_error(mut self, error_type: impl Into<String>) -> Self {
116        self.retryable_errors.push(error_type.into());
117        self
118    }
119
120    /// Calculate delay for a given attempt
121    pub fn calculate_delay(&self, attempt: u32) -> Duration {
122        let base_ms = self.base_delay.as_millis() as f64;
123        let max_ms = self.max_delay.as_millis() as f64;
124
125        let delay_ms = match self.strategy {
126            RetryStrategy::Fixed => base_ms,
127            RetryStrategy::Linear => base_ms * (attempt as f64 + 1.0),
128            RetryStrategy::Exponential => base_ms * 2.0_f64.powi(attempt as i32),
129            RetryStrategy::ExponentialWithJitter => {
130                let exp_delay = base_ms * 2.0_f64.powi(attempt as i32);
131                let jitter = exp_delay * self.jitter_factor * rand_jitter();
132                exp_delay + jitter
133            }
134        };
135
136        Duration::from_millis(delay_ms.min(max_ms) as u64)
137    }
138
139    /// Check if an error type is retryable
140    pub fn is_retryable(&self, error_type: &str) -> bool {
141        self.retryable_errors
142            .iter()
143            .any(|e| error_type.to_lowercase().contains(&e.to_lowercase()))
144    }
145
146    /// Validate the configuration
147    pub fn validate(&self) -> Result<(), String> {
148        if self.max_retries == 0 {
149            return Err("max_retries must be greater than 0".to_string());
150        }
151        if self.base_delay.is_zero() {
152            return Err("base_delay must be greater than 0".to_string());
153        }
154        if self.max_delay < self.base_delay {
155            return Err("max_delay must be >= base_delay".to_string());
156        }
157        Ok(())
158    }
159}
160
161/// Generate a random jitter value between -1.0 and 1.0
162fn rand_jitter() -> f64 {
163    use std::time::SystemTime;
164    let nanos = SystemTime::now()
165        .duration_since(SystemTime::UNIX_EPOCH)
166        .map(|d| d.subsec_nanos())
167        .unwrap_or(0);
168    // Simple pseudo-random based on nanoseconds
169    ((nanos % 2000) as f64 / 1000.0) - 1.0
170}
171
172/// Result of a retry operation
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "snake_case")]
175pub enum RetryResult {
176    /// Operation succeeded
177    Success,
178    /// Operation should be retried
179    Retry,
180    /// Maximum retries exceeded
181    MaxRetriesExceeded,
182    /// Error is not retryable
183    NotRetryable,
184    /// Retry was skipped (no config)
185    Skipped,
186}
187
188impl std::fmt::Display for RetryResult {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match self {
191            RetryResult::Success => write!(f, "success"),
192            RetryResult::Retry => write!(f, "retry"),
193            RetryResult::MaxRetriesExceeded => write!(f, "max_retries_exceeded"),
194            RetryResult::NotRetryable => write!(f, "not_retryable"),
195            RetryResult::Skipped => write!(f, "skipped"),
196        }
197    }
198}
199
200/// Retry state for tracking retry attempts
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct RetryState {
204    /// Operation ID
205    pub operation_id: String,
206    /// Current attempt number (0-based)
207    pub attempt: u32,
208    /// Configuration used
209    pub config: RetryConfig,
210    /// Start time of first attempt
211    pub started_at: DateTime<Utc>,
212    /// Last attempt time
213    pub last_attempt_at: Option<DateTime<Utc>>,
214    /// Last error message
215    pub last_error: Option<String>,
216    /// Total delay accumulated
217    pub total_delay: Duration,
218    /// Whether the operation succeeded
219    pub succeeded: bool,
220}
221
222impl RetryState {
223    /// Create a new retry state
224    pub fn new(operation_id: impl Into<String>, config: RetryConfig) -> Self {
225        Self {
226            operation_id: operation_id.into(),
227            attempt: 0,
228            config,
229            started_at: Utc::now(),
230            last_attempt_at: None,
231            last_error: None,
232            total_delay: Duration::ZERO,
233            succeeded: false,
234        }
235    }
236
237    /// Check if more retries are available
238    pub fn can_retry(&self) -> bool {
239        self.attempt < self.config.max_retries
240    }
241
242    /// Get the next delay
243    pub fn next_delay(&self) -> Duration {
244        self.config.calculate_delay(self.attempt)
245    }
246
247    /// Record an attempt
248    pub fn record_attempt(&mut self, error: Option<String>) {
249        self.attempt += 1;
250        self.last_attempt_at = Some(Utc::now());
251        self.last_error = error;
252    }
253
254    /// Record success
255    pub fn record_success(&mut self) {
256        self.succeeded = true;
257        self.last_attempt_at = Some(Utc::now());
258    }
259
260    /// Add delay to total
261    pub fn add_delay(&mut self, delay: Duration) {
262        self.total_delay += delay;
263    }
264
265    /// Get total elapsed time
266    pub fn elapsed(&self) -> Duration {
267        let elapsed = Utc::now().signed_duration_since(self.started_at);
268        elapsed.to_std().unwrap_or(Duration::ZERO)
269    }
270}
271
272/// Retry handler for managing retry operations
273#[derive(Debug)]
274pub struct RetryHandler {
275    /// Active retry states indexed by operation ID
276    states: HashMap<String, RetryState>,
277    /// Default configuration
278    default_config: RetryConfig,
279}
280
281impl Default for RetryHandler {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287impl RetryHandler {
288    /// Create a new retry handler
289    pub fn new() -> Self {
290        Self {
291            states: HashMap::new(),
292            default_config: RetryConfig::default(),
293        }
294    }
295
296    /// Create with custom default configuration
297    pub fn with_default_config(config: RetryConfig) -> Self {
298        Self {
299            states: HashMap::new(),
300            default_config: config,
301        }
302    }
303
304    /// Start tracking a retry operation with default config
305    pub fn start(&mut self, operation_id: &str) -> &RetryState {
306        self.start_with_config(operation_id, self.default_config.clone())
307    }
308
309    /// Start tracking a retry operation with custom config
310    pub fn start_with_config(&mut self, operation_id: &str, config: RetryConfig) -> &RetryState {
311        let state = RetryState::new(operation_id, config);
312        self.states.insert(operation_id.to_string(), state);
313        self.states.get(operation_id).unwrap()
314    }
315
316    /// Handle a failure and determine if retry should occur
317    pub fn handle_failure(
318        &mut self,
319        operation_id: &str,
320        error_type: &str,
321        error_message: &str,
322    ) -> RetryResult {
323        let state = match self.states.get_mut(operation_id) {
324            Some(s) => s,
325            None => return RetryResult::Skipped,
326        };
327
328        // Check if error is retryable
329        if !state.config.is_retryable(error_type) {
330            return RetryResult::NotRetryable;
331        }
332
333        // Check if we have retries left
334        if !state.can_retry() {
335            return RetryResult::MaxRetriesExceeded;
336        }
337
338        // Record the attempt
339        state.record_attempt(Some(error_message.to_string()));
340
341        RetryResult::Retry
342    }
343
344    /// Get the delay before next retry
345    pub fn get_retry_delay(&self, operation_id: &str) -> Option<Duration> {
346        self.states.get(operation_id).map(|s| s.next_delay())
347    }
348
349    /// Record that a delay was applied
350    pub fn record_delay(&mut self, operation_id: &str, delay: Duration) {
351        if let Some(state) = self.states.get_mut(operation_id) {
352            state.add_delay(delay);
353        }
354    }
355
356    /// Record success for an operation
357    pub fn record_success(&mut self, operation_id: &str) {
358        if let Some(state) = self.states.get_mut(operation_id) {
359            state.record_success();
360        }
361    }
362
363    /// Get the current state for an operation
364    pub fn get_state(&self, operation_id: &str) -> Option<&RetryState> {
365        self.states.get(operation_id)
366    }
367
368    /// Get the current attempt number
369    pub fn get_attempt(&self, operation_id: &str) -> Option<u32> {
370        self.states.get(operation_id).map(|s| s.attempt)
371    }
372
373    /// Check if an operation can retry
374    pub fn can_retry(&self, operation_id: &str) -> bool {
375        self.states
376            .get(operation_id)
377            .map(|s| s.can_retry())
378            .unwrap_or(false)
379    }
380
381    /// Remove a completed operation
382    pub fn complete(&mut self, operation_id: &str) -> Option<RetryState> {
383        self.states.remove(operation_id)
384    }
385
386    /// Clear all states
387    pub fn clear(&mut self) {
388        self.states.clear();
389    }
390
391    /// Get the number of active operations
392    pub fn active_count(&self) -> usize {
393        self.states.len()
394    }
395
396    /// Set default configuration
397    pub fn set_default_config(&mut self, config: RetryConfig) {
398        self.default_config = config;
399    }
400
401    /// Get default configuration
402    pub fn default_config(&self) -> &RetryConfig {
403        &self.default_config
404    }
405
406    /// Execute an async operation with retry
407    pub async fn execute_with_retry<F, Fut, T, E>(
408        &mut self,
409        operation_id: &str,
410        mut operation: F,
411    ) -> Result<T, E>
412    where
413        F: FnMut() -> Fut,
414        Fut: std::future::Future<Output = Result<T, E>>,
415        E: std::fmt::Display,
416    {
417        self.start(operation_id);
418
419        loop {
420            match operation().await {
421                Ok(result) => {
422                    self.record_success(operation_id);
423                    return Ok(result);
424                }
425                Err(e) => {
426                    let error_msg = e.to_string();
427                    let result = self.handle_failure(operation_id, "general", &error_msg);
428
429                    match result {
430                        RetryResult::Retry => {
431                            if let Some(delay) = self.get_retry_delay(operation_id) {
432                                tokio::time::sleep(delay).await;
433                                self.record_delay(operation_id, delay);
434                            }
435                        }
436                        _ => return Err(e),
437                    }
438                }
439            }
440        }
441    }
442}
443
444/// Thread-safe retry handler wrapper
445#[allow(dead_code)]
446pub type SharedRetryHandler = Arc<RwLock<RetryHandler>>;
447
448/// Create a new shared retry handler
449#[allow(dead_code)]
450pub fn new_shared_retry_handler() -> SharedRetryHandler {
451    Arc::new(RwLock::new(RetryHandler::new()))
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_retry_config_default() {
460        let config = RetryConfig::default();
461        assert_eq!(config.max_retries, 3);
462        assert_eq!(config.base_delay, Duration::from_millis(1000));
463        assert_eq!(config.strategy, RetryStrategy::Exponential);
464    }
465
466    #[test]
467    fn test_retry_config_calculate_delay_fixed() {
468        let config =
469            RetryConfig::new(3, Duration::from_millis(100)).with_strategy(RetryStrategy::Fixed);
470
471        assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
472        assert_eq!(config.calculate_delay(1), Duration::from_millis(100));
473        assert_eq!(config.calculate_delay(2), Duration::from_millis(100));
474    }
475
476    #[test]
477    fn test_retry_config_calculate_delay_linear() {
478        let config =
479            RetryConfig::new(3, Duration::from_millis(100)).with_strategy(RetryStrategy::Linear);
480
481        assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
482        assert_eq!(config.calculate_delay(1), Duration::from_millis(200));
483        assert_eq!(config.calculate_delay(2), Duration::from_millis(300));
484    }
485
486    #[test]
487    fn test_retry_config_calculate_delay_exponential() {
488        let config = RetryConfig::new(3, Duration::from_millis(100))
489            .with_strategy(RetryStrategy::Exponential);
490
491        assert_eq!(config.calculate_delay(0), Duration::from_millis(100));
492        assert_eq!(config.calculate_delay(1), Duration::from_millis(200));
493        assert_eq!(config.calculate_delay(2), Duration::from_millis(400));
494    }
495
496    #[test]
497    fn test_retry_config_max_delay() {
498        let config = RetryConfig::new(10, Duration::from_millis(100))
499            .with_strategy(RetryStrategy::Exponential)
500            .with_max_delay(Duration::from_millis(500));
501
502        // 100 * 2^5 = 3200, but should be capped at 500
503        assert_eq!(config.calculate_delay(5), Duration::from_millis(500));
504    }
505
506    #[test]
507    fn test_retry_config_is_retryable() {
508        let config = RetryConfig::default();
509
510        assert!(config.is_retryable("network_error"));
511        assert!(config.is_retryable("timeout"));
512        assert!(config.is_retryable("rate_limit_exceeded"));
513        assert!(!config.is_retryable("invalid_input"));
514    }
515
516    #[test]
517    fn test_retry_config_validate() {
518        let valid = RetryConfig::default();
519        assert!(valid.validate().is_ok());
520
521        let invalid_retries = RetryConfig {
522            max_retries: 0,
523            ..Default::default()
524        };
525        assert!(invalid_retries.validate().is_err());
526
527        let invalid_delay = RetryConfig {
528            base_delay: Duration::ZERO,
529            ..Default::default()
530        };
531        assert!(invalid_delay.validate().is_err());
532    }
533
534    #[test]
535    fn test_retry_state_creation() {
536        let config = RetryConfig::default();
537        let state = RetryState::new("op-1", config);
538
539        assert_eq!(state.operation_id, "op-1");
540        assert_eq!(state.attempt, 0);
541        assert!(!state.succeeded);
542        assert!(state.can_retry());
543    }
544
545    #[test]
546    fn test_retry_state_record_attempt() {
547        let config = RetryConfig::new(3, Duration::from_millis(100));
548        let mut state = RetryState::new("op-1", config);
549
550        state.record_attempt(Some("Error 1".to_string()));
551        assert_eq!(state.attempt, 1);
552        assert_eq!(state.last_error, Some("Error 1".to_string()));
553        assert!(state.can_retry());
554
555        state.record_attempt(Some("Error 2".to_string()));
556        state.record_attempt(Some("Error 3".to_string()));
557        assert_eq!(state.attempt, 3);
558        assert!(!state.can_retry());
559    }
560
561    #[test]
562    fn test_retry_handler_start() {
563        let mut handler = RetryHandler::new();
564        handler.start("op-1");
565
566        assert_eq!(handler.active_count(), 1);
567        assert!(handler.get_state("op-1").is_some());
568    }
569
570    #[test]
571    fn test_retry_handler_handle_failure() {
572        let mut handler = RetryHandler::new();
573        handler.start("op-1");
574
575        let result = handler.handle_failure("op-1", "network", "Connection failed");
576        assert_eq!(result, RetryResult::Retry);
577        assert_eq!(handler.get_attempt("op-1"), Some(1));
578    }
579
580    #[test]
581    fn test_retry_handler_handle_failure_not_retryable() {
582        let mut handler = RetryHandler::new();
583        handler.start("op-1");
584
585        let result = handler.handle_failure("op-1", "invalid_input", "Bad request");
586        assert_eq!(result, RetryResult::NotRetryable);
587    }
588
589    #[test]
590    fn test_retry_handler_handle_failure_max_exceeded() {
591        let config = RetryConfig::new(2, Duration::from_millis(100));
592        let mut handler = RetryHandler::with_default_config(config);
593        handler.start("op-1");
594
595        handler.handle_failure("op-1", "network", "Error 1");
596        handler.handle_failure("op-1", "network", "Error 2");
597        let result = handler.handle_failure("op-1", "network", "Error 3");
598
599        assert_eq!(result, RetryResult::MaxRetriesExceeded);
600    }
601
602    #[test]
603    fn test_retry_handler_record_success() {
604        let mut handler = RetryHandler::new();
605        handler.start("op-1");
606        handler.record_success("op-1");
607
608        let state = handler.get_state("op-1").unwrap();
609        assert!(state.succeeded);
610    }
611
612    #[test]
613    fn test_retry_handler_complete() {
614        let mut handler = RetryHandler::new();
615        handler.start("op-1");
616
617        let state = handler.complete("op-1");
618        assert!(state.is_some());
619        assert_eq!(handler.active_count(), 0);
620    }
621
622    #[test]
623    fn test_retry_result_display() {
624        assert_eq!(format!("{}", RetryResult::Success), "success");
625        assert_eq!(format!("{}", RetryResult::Retry), "retry");
626        assert_eq!(
627            format!("{}", RetryResult::MaxRetriesExceeded),
628            "max_retries_exceeded"
629        );
630    }
631}