apex_sdk/
error_recovery.rs

1//! Error recovery strategies for Apex SDK
2//!
3//! This module provides automatic retry logic and error recovery strategies
4//! for transient failures in blockchain operations.
5
6use crate::error::{Error, Result};
7use rand;
8use std::time::Duration;
9use tokio::time::sleep;
10
11/// Retry configuration for error recovery
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts
15    pub max_retries: u32,
16    /// Initial backoff duration
17    pub initial_backoff: Duration,
18    /// Maximum backoff duration
19    pub max_backoff: Duration,
20    /// Backoff multiplier
21    pub multiplier: f64,
22    /// Whether to use jitter to avoid thundering herd
23    pub use_jitter: bool,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_retries: 3,
30            initial_backoff: Duration::from_millis(1000),
31            max_backoff: Duration::from_secs(30),
32            multiplier: 2.0,
33            use_jitter: true,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Create a new retry configuration
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Set maximum retries
45    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
46        self.max_retries = max_retries;
47        self
48    }
49
50    /// Set initial backoff
51    pub fn with_initial_backoff(mut self, backoff: Duration) -> Self {
52        self.initial_backoff = backoff;
53        self
54    }
55
56    /// Set maximum backoff
57    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
58        self.max_backoff = max_backoff;
59        self
60    }
61
62    /// Set backoff multiplier
63    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
64        self.multiplier = multiplier;
65        self
66    }
67
68    /// Enable or disable jitter
69    pub fn with_jitter(mut self, use_jitter: bool) -> Self {
70        self.use_jitter = use_jitter;
71        self
72    }
73}
74
75/// Determines if an error is retryable
76pub fn is_retryable(error: &Error) -> bool {
77    match error {
78        Error::Connection(_, _) => true,
79        Error::Transaction(msg, _) => {
80            // Retry on timeout or network errors
81            msg.contains("timeout")
82                || msg.contains("network")
83                || msg.contains("connection")
84                || msg.contains("unavailable")
85        }
86        Error::Substrate(_) => false, // Chain-specific errors are typically not retryable
87        Error::Evm(_) => false,
88        Error::Config(_, _) => false,
89        Error::UnsupportedChain(_) => false,
90        Error::InvalidAddress(_) => false,
91        Error::Serialization(_) => false,
92        Error::Other(msg) => msg.contains("temporary") || msg.contains("timeout"),
93    }
94}
95
96/// Execute an async operation with automatic retry logic
97pub async fn with_retry<F, Fut, T>(config: &RetryConfig, operation: F) -> Result<T>
98where
99    F: Fn() -> Fut,
100    Fut: std::future::Future<Output = Result<T>>,
101{
102    let mut attempt = 0;
103    let mut backoff = config.initial_backoff;
104
105    loop {
106        match operation().await {
107            Ok(result) => return Ok(result),
108            Err(error) => {
109                if !is_retryable(&error) || attempt >= config.max_retries {
110                    return Err(error);
111                }
112
113                attempt += 1;
114
115                // Calculate backoff with jitter
116                let delay = if config.use_jitter {
117                    let jitter = rand::random::<f64>() * 0.3; // +/- 30% jitter
118                    let multiplier = 1.0 + (jitter - 0.15);
119                    Duration::from_millis((backoff.as_millis() as f64 * multiplier) as u64)
120                } else {
121                    backoff
122                };
123
124                let delay = delay.min(config.max_backoff);
125
126                tracing::warn!(
127                    "Operation failed (attempt {}/{}): {}. Retrying in {:?}",
128                    attempt,
129                    config.max_retries,
130                    error,
131                    delay
132                );
133
134                sleep(delay).await;
135
136                // Exponential backoff
137                backoff =
138                    Duration::from_millis((backoff.as_millis() as f64 * config.multiplier) as u64)
139                        .min(config.max_backoff);
140            }
141        }
142    }
143}
144
145/// Circuit breaker for preventing cascading failures
146#[derive(Debug)]
147pub struct CircuitBreaker {
148    failure_threshold: u32,
149    success_threshold: u32,
150    timeout: Duration,
151    state: CircuitState,
152    failure_count: u32,
153    success_count: u32,
154    last_failure_time: Option<std::time::Instant>,
155}
156
157#[derive(Debug, Clone, PartialEq)]
158enum CircuitState {
159    Closed,
160    Open,
161    HalfOpen,
162}
163
164impl CircuitBreaker {
165    /// Create a new circuit breaker
166    pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
167        Self {
168            failure_threshold,
169            success_threshold: 2,
170            timeout,
171            state: CircuitState::Closed,
172            failure_count: 0,
173            success_count: 0,
174            last_failure_time: None,
175        }
176    }
177
178    /// Execute an operation through the circuit breaker
179    pub async fn call<F, Fut, T>(&mut self, operation: F) -> Result<T>
180    where
181        F: FnOnce() -> Fut,
182        Fut: std::future::Future<Output = Result<T>>,
183    {
184        // Check if circuit should transition from Open to HalfOpen
185        if self.state == CircuitState::Open {
186            if let Some(last_failure) = self.last_failure_time {
187                if last_failure.elapsed() > self.timeout {
188                    self.state = CircuitState::HalfOpen;
189                    self.success_count = 0;
190                } else {
191                    return Err(Error::connection("Circuit breaker is open"));
192                }
193            }
194        }
195
196        match operation().await {
197            Ok(result) => {
198                self.on_success();
199                Ok(result)
200            }
201            Err(error) => {
202                self.on_failure();
203                Err(error)
204            }
205        }
206    }
207
208    fn on_success(&mut self) {
209        self.failure_count = 0;
210
211        if self.state == CircuitState::HalfOpen {
212            self.success_count += 1;
213            if self.success_count >= self.success_threshold {
214                self.state = CircuitState::Closed;
215            }
216        }
217    }
218
219    fn on_failure(&mut self) {
220        self.failure_count += 1;
221        self.last_failure_time = Some(std::time::Instant::now());
222
223        if self.failure_count >= self.failure_threshold {
224            self.state = CircuitState::Open;
225        }
226    }
227
228    /// Check if the circuit is open
229    pub fn is_open(&self) -> bool {
230        self.state == CircuitState::Open
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_retry_config_builder() {
240        let config = RetryConfig::new()
241            .with_max_retries(5)
242            .with_initial_backoff(Duration::from_millis(500))
243            .with_multiplier(1.5);
244
245        assert_eq!(config.max_retries, 5);
246        assert_eq!(config.initial_backoff, Duration::from_millis(500));
247        assert_eq!(config.multiplier, 1.5);
248    }
249
250    #[test]
251    fn test_is_retryable() {
252        assert!(is_retryable(&Error::connection("test")));
253        assert!(is_retryable(&Error::transaction("timeout error")));
254        assert!(!is_retryable(&Error::InvalidAddress("test".to_string())));
255        assert!(!is_retryable(&Error::config("test")));
256    }
257
258    #[tokio::test]
259    async fn test_with_retry_success() {
260        let config = RetryConfig::new().with_max_retries(3);
261
262        let result = with_retry(&config, || async { Ok::<_, Error>(42) }).await;
263
264        assert_eq!(result.unwrap(), 42);
265    }
266
267    #[tokio::test]
268    async fn test_with_retry_non_retryable_error() {
269        let config = RetryConfig::new().with_max_retries(3);
270
271        let result = with_retry(&config, || async {
272            Err::<i32, _>(Error::InvalidAddress("test".to_string()))
273        })
274        .await;
275
276        assert!(result.is_err());
277    }
278
279    #[tokio::test]
280    async fn test_circuit_breaker_opens_after_failures() {
281        let mut breaker = CircuitBreaker::new(2, Duration::from_secs(1));
282
283        // First failure
284        let _ = breaker
285            .call(|| async { Err::<(), _>(Error::connection("test")) })
286            .await;
287        assert!(!breaker.is_open());
288
289        // Second failure - circuit should open
290        let _ = breaker
291            .call(|| async { Err::<(), _>(Error::connection("test")) })
292            .await;
293        assert!(breaker.is_open());
294
295        // Subsequent calls should fail immediately
296        let result = breaker.call(|| async { Ok(()) }).await;
297        assert!(result.is_err());
298    }
299
300    #[tokio::test]
301    async fn test_circuit_breaker_success_closes() {
302        let mut breaker = CircuitBreaker::new(1, Duration::from_millis(100));
303
304        // Trigger failure
305        let _ = breaker
306            .call(|| async { Err::<(), _>(Error::connection("test")) })
307            .await;
308        assert!(breaker.is_open());
309
310        // Wait for timeout
311        tokio::time::sleep(Duration::from_millis(150)).await;
312
313        // Circuit should be half-open, success should close it
314        let _ = breaker.call(|| async { Ok::<_, Error>(()) }).await;
315        let _ = breaker.call(|| async { Ok::<_, Error>(()) }).await;
316
317        // Should be closed now
318        assert!(!breaker.is_open());
319    }
320}