bittensor_rs/
retry.rs

1//! # Retry Logic with Exponential Backoff
2//!
3//! Production-ready retry mechanisms for Bittensor operations with configurable
4//! exponential backoff, jitter, and error-specific retry strategies.
5
6use crate::error::{BittensorError, RetryConfig};
7use std::future::Future;
8use std::time::Duration;
9use tokio::time::sleep;
10use tracing::{debug, info, warn};
11
12/// Exponential backoff calculator with jitter support
13#[derive(Debug, Clone)]
14pub struct ExponentialBackoff {
15    config: RetryConfig,
16    current_attempt: u32,
17}
18
19impl ExponentialBackoff {
20    /// Creates a new exponential backoff instance
21    pub fn new(config: RetryConfig) -> Self {
22        Self {
23            config,
24            current_attempt: 0,
25        }
26    }
27
28    /// Calculates the next delay duration
29    pub fn next_delay(&mut self) -> Option<Duration> {
30        if self.current_attempt >= self.config.max_attempts {
31            return None;
32        }
33
34        let base_delay = self.config.initial_delay.as_millis() as f64;
35        let multiplier = self
36            .config
37            .backoff_multiplier
38            .powi(self.current_attempt as i32);
39        let calculated_delay = Duration::from_millis((base_delay * multiplier) as u64);
40
41        // Cap at max_delay
42        let mut delay = if calculated_delay > self.config.max_delay {
43            self.config.max_delay
44        } else {
45            calculated_delay
46        };
47
48        // Add jitter if enabled
49        if self.config.jitter {
50            delay = Self::add_jitter(delay);
51        }
52
53        self.current_attempt += 1;
54        Some(delay)
55    }
56
57    /// Adds random jitter to prevent thundering herd
58    fn add_jitter(delay: Duration) -> Duration {
59        use rand::Rng;
60        let jitter_ms = rand::thread_rng().gen_range(0..=delay.as_millis() as u64 / 4);
61        delay + Duration::from_millis(jitter_ms)
62    }
63
64    /// Resets the backoff state
65    pub fn reset(&mut self) {
66        self.current_attempt = 0;
67    }
68
69    /// Gets the current attempt number
70    pub fn attempts(&self) -> u32 {
71        self.current_attempt
72    }
73}
74
75/// Retry node with comprehensive error handling
76pub struct RetryNode {
77    total_timeout: Option<Duration>,
78}
79
80impl RetryNode {
81    /// Creates a new retry node
82    pub fn new() -> Self {
83        Self {
84            total_timeout: None,
85        }
86    }
87
88    /// Sets a total timeout for all retry attempts
89    pub fn with_timeout(mut self, timeout: Duration) -> Self {
90        self.total_timeout = Some(timeout);
91        self
92    }
93
94    /// Executes an operation with retry logic based on error types
95    pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, BittensorError>
96    where
97        F: Fn() -> Fut,
98        Fut: Future<Output = Result<T, BittensorError>>,
99    {
100        let start_time = tokio::time::Instant::now();
101
102        // First attempt without delay
103        match operation().await {
104            Ok(result) => Ok(result),
105            Err(error) => {
106                if !error.is_retryable() {
107                    debug!("Error is not retryable: {:?}", error);
108                    return Err(error);
109                }
110
111                let config = match error.retry_config() {
112                    Some(config) => config,
113                    None => {
114                        debug!("No retry config for error: {:?}", error);
115                        return Err(error);
116                    }
117                };
118
119                info!(
120                    "Starting retry for error category: {:?}, max_attempts: {}",
121                    error.category(),
122                    config.max_attempts
123                );
124
125                let mut backoff = ExponentialBackoff::new(config);
126                let mut _last_error = error;
127
128                // Retry loop
129                while let Some(delay) = backoff.next_delay() {
130                    // Check total timeout
131                    if let Some(total_timeout) = self.total_timeout {
132                        if start_time.elapsed() + delay >= total_timeout {
133                            warn!(
134                                "Total timeout reached after {} attempts",
135                                backoff.attempts()
136                            );
137                            return Err(BittensorError::backoff_timeout(start_time.elapsed()));
138                        }
139                    }
140
141                    debug!(
142                        "Retry attempt {} after delay {:?}",
143                        backoff.attempts(),
144                        delay
145                    );
146                    sleep(delay).await;
147
148                    match operation().await {
149                        Ok(result) => {
150                            info!("Operation succeeded after {} attempts", backoff.attempts());
151                            return Ok(result);
152                        }
153                        Err(error) => {
154                            _last_error = error;
155
156                            // If error category changed, stop retrying
157                            if !_last_error.is_retryable() {
158                                debug!("Error became non-retryable: {:?}", _last_error);
159                                return Err(_last_error);
160                            }
161
162                            warn!(
163                                "Retry attempt {} failed: {}",
164                                backoff.attempts(),
165                                _last_error
166                            );
167                        }
168                    }
169                }
170
171                warn!(
172                    "All {} retry attempts exhausted, last error: {}",
173                    backoff.config.max_attempts, _last_error
174                );
175                Err(BittensorError::max_retries_exceeded(
176                    backoff.config.max_attempts,
177                ))
178            }
179        }
180    }
181
182    /// Executes an operation with custom retry configuration
183    pub async fn execute_with_config<F, Fut, T>(
184        &self,
185        operation: F,
186        config: RetryConfig,
187    ) -> Result<T, BittensorError>
188    where
189        F: Fn() -> Fut,
190        Fut: Future<Output = Result<T, BittensorError>>,
191    {
192        let start_time = tokio::time::Instant::now();
193        let mut backoff = ExponentialBackoff::new(config);
194
195        // First attempt
196        match operation().await {
197            Ok(result) => Ok(result),
198            Err(mut _last_error) => {
199                info!(
200                    "Starting custom retry, max_attempts: {}",
201                    backoff.config.max_attempts
202                );
203
204                // Retry loop
205                while let Some(delay) = backoff.next_delay() {
206                    // Check total timeout
207                    if let Some(total_timeout) = self.total_timeout {
208                        if start_time.elapsed() + delay >= total_timeout {
209                            warn!(
210                                "Total timeout reached after {} attempts",
211                                backoff.attempts()
212                            );
213                            return Err(BittensorError::backoff_timeout(start_time.elapsed()));
214                        }
215                    }
216
217                    debug!(
218                        "Custom retry attempt {} after delay {:?}",
219                        backoff.attempts(),
220                        delay
221                    );
222                    sleep(delay).await;
223
224                    match operation().await {
225                        Ok(result) => {
226                            info!(
227                                "Custom retry succeeded after {} attempts",
228                                backoff.attempts()
229                            );
230                            return Ok(result);
231                        }
232                        Err(error) => {
233                            _last_error = error;
234                            warn!(
235                                "Custom retry attempt {} failed: {}",
236                                backoff.attempts(),
237                                _last_error
238                            );
239                        }
240                    }
241                }
242
243                Err(BittensorError::max_retries_exceeded(
244                    backoff.config.max_attempts,
245                ))
246            }
247        }
248    }
249}
250
251impl Default for RetryNode {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// Convenience function for retrying operations with default settings
258pub async fn retry_operation<F, Fut, T>(operation: F) -> Result<T, BittensorError>
259where
260    F: Fn() -> Fut,
261    Fut: Future<Output = Result<T, BittensorError>>,
262{
263    RetryNode::new().execute(operation).await
264}
265
266/// Convenience function for retrying operations with timeout
267pub async fn retry_operation_with_timeout<F, Fut, T>(
268    operation: F,
269    timeout: Duration,
270) -> Result<T, BittensorError>
271where
272    F: Fn() -> Fut,
273    Fut: Future<Output = Result<T, BittensorError>>,
274{
275    RetryNode::new()
276        .with_timeout(timeout)
277        .execute(operation)
278        .await
279}
280
281/// Circuit breaker for preventing cascade failures
282#[derive(Debug, Clone)]
283pub struct CircuitBreaker {
284    failure_threshold: u32,
285    recovery_timeout: Duration,
286    current_failures: u32,
287    state: CircuitState,
288    last_failure_time: Option<tokio::time::Instant>,
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292enum CircuitState {
293    Closed,   // Normal operation
294    Open,     // Failing fast
295    HalfOpen, // Testing recovery
296}
297
298impl CircuitBreaker {
299    /// Creates a new circuit breaker
300    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
301        Self {
302            failure_threshold,
303            recovery_timeout,
304            current_failures: 0,
305            state: CircuitState::Closed,
306            last_failure_time: None,
307        }
308    }
309
310    /// Executes an operation through the circuit breaker
311    pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, BittensorError>
312    where
313        F: Fn() -> Fut,
314        Fut: Future<Output = Result<T, BittensorError>>,
315    {
316        match self.state {
317            CircuitState::Open => {
318                if let Some(last_failure) = self.last_failure_time {
319                    if last_failure.elapsed() >= self.recovery_timeout {
320                        debug!("Circuit breaker transitioning to half-open");
321                        self.state = CircuitState::HalfOpen;
322                    } else {
323                        return Err(BittensorError::ServiceUnavailable {
324                            message: "Circuit breaker is open".to_string(),
325                        });
326                    }
327                } else {
328                    return Err(BittensorError::ServiceUnavailable {
329                        message: "Circuit breaker is open".to_string(),
330                    });
331                }
332            }
333            CircuitState::Closed | CircuitState::HalfOpen => {}
334        }
335
336        match operation().await {
337            Ok(result) => {
338                // Success - reset circuit breaker
339                if self.state == CircuitState::HalfOpen {
340                    debug!("Circuit breaker recovering - closing circuit");
341                    self.state = CircuitState::Closed;
342                }
343                self.current_failures = 0;
344                self.last_failure_time = None;
345                Ok(result)
346            }
347            Err(error) => {
348                // Failure - update circuit breaker state
349                self.current_failures += 1;
350                self.last_failure_time = Some(tokio::time::Instant::now());
351
352                if self.current_failures >= self.failure_threshold {
353                    warn!(
354                        "Circuit breaker opening after {} failures",
355                        self.current_failures
356                    );
357                    self.state = CircuitState::Open;
358                }
359
360                Err(error)
361            }
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use std::sync::atomic::{AtomicU32, Ordering};
370    use std::sync::Arc;
371
372    #[test]
373    fn test_exponential_backoff() {
374        let config = RetryConfig {
375            max_attempts: 3,
376            initial_delay: Duration::from_millis(100),
377            max_delay: Duration::from_secs(5),
378            backoff_multiplier: 2.0,
379            jitter: false,
380        };
381
382        let mut backoff = ExponentialBackoff::new(config);
383
384        // First delay
385        let delay1 = backoff.next_delay().unwrap();
386        assert_eq!(delay1, Duration::from_millis(100));
387        assert_eq!(backoff.attempts(), 1);
388
389        // Second delay
390        let delay2 = backoff.next_delay().unwrap();
391        assert_eq!(delay2, Duration::from_millis(200));
392        assert_eq!(backoff.attempts(), 2);
393
394        // Third delay
395        let delay3 = backoff.next_delay().unwrap();
396        assert_eq!(delay3, Duration::from_millis(400));
397        assert_eq!(backoff.attempts(), 3);
398
399        // Should return None after max attempts
400        assert!(backoff.next_delay().is_none());
401    }
402
403    #[tokio::test]
404    async fn test_retry_node_success_after_failure() {
405        let counter = Arc::new(AtomicU32::new(0));
406        let counter_clone = counter.clone();
407
408        let operation = move || {
409            let counter = counter_clone.clone();
410            async move {
411                let count = counter.fetch_add(1, Ordering::SeqCst);
412                if count < 2 {
413                    Err(BittensorError::RpcConnectionError {
414                        message: "Connection failed".to_string(),
415                    })
416                } else {
417                    Ok("success")
418                }
419            }
420        };
421
422        let node = RetryNode::new();
423        let result: Result<&str, BittensorError> = node.execute(operation).await;
424
425        assert!(result.is_ok());
426        assert_eq!(result.unwrap(), "success");
427        assert_eq!(counter.load(Ordering::SeqCst), 3);
428    }
429
430    #[tokio::test]
431    async fn test_retry_node_non_retryable_error() {
432        let operation = || async {
433            Err(BittensorError::InvalidHotkey {
434                hotkey: "invalid".to_string(),
435            })
436        };
437
438        let node = RetryNode::new();
439        let result: Result<&str, BittensorError> = node.execute(operation).await;
440
441        assert!(result.is_err());
442        match result.unwrap_err() {
443            BittensorError::InvalidHotkey { .. } => {}
444            other => panic!("Expected InvalidHotkey, got {other:?}"),
445        }
446    }
447
448    #[tokio::test]
449    async fn test_circuit_breaker() {
450        let mut circuit_breaker = CircuitBreaker::new(2, Duration::from_millis(100));
451        let counter = Arc::new(AtomicU32::new(0));
452
453        // First failure
454        let counter_clone = counter.clone();
455        let result: Result<(), BittensorError> = circuit_breaker
456            .execute(|| {
457                let counter = counter_clone.clone();
458                async move {
459                    counter.fetch_add(1, Ordering::SeqCst);
460                    Err(BittensorError::RpcConnectionError {
461                        message: "Connection failed".to_string(),
462                    })
463                }
464            })
465            .await;
466        assert!(result.is_err());
467
468        // Second failure - should open circuit
469        let counter_clone = counter.clone();
470        let result: Result<(), BittensorError> = circuit_breaker
471            .execute(|| {
472                let counter = counter_clone.clone();
473                async move {
474                    counter.fetch_add(1, Ordering::SeqCst);
475                    Err(BittensorError::RpcConnectionError {
476                        message: "Connection failed".to_string(),
477                    })
478                }
479            })
480            .await;
481        assert!(result.is_err());
482
483        // Third call should fail fast without calling operation
484        let counter_before = counter.load(Ordering::SeqCst);
485        let result: Result<&str, BittensorError> = circuit_breaker
486            .execute(|| {
487                let counter = counter.clone();
488                async move {
489                    counter.fetch_add(1, Ordering::SeqCst);
490                    Ok("should not reach here")
491                }
492            })
493            .await;
494        assert!(result.is_err());
495        assert_eq!(counter.load(Ordering::SeqCst), counter_before); // No increment
496
497        match result.unwrap_err() {
498            BittensorError::ServiceUnavailable { .. } => {}
499            other => panic!("Expected ServiceUnavailable, got {other:?}"),
500        }
501    }
502}