1use crate::error::{BittensorError, RetryConfig};
7use std::future::Future;
8use std::time::Duration;
9use tokio::time::sleep;
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone)]
14pub struct ExponentialBackoff {
15 config: RetryConfig,
16 current_attempt: u32,
17}
18
19impl ExponentialBackoff {
20 pub fn new(config: RetryConfig) -> Self {
22 Self {
23 config,
24 current_attempt: 0,
25 }
26 }
27
28 pub fn next_delay(&mut self) -> Option<Duration> {
30 if self.current_attempt >= self.config.max_attempts {
31 return None;
32 }
33
34 let base_delay = self.config.initial_delay.as_millis() as f64;
35 let multiplier = self
36 .config
37 .backoff_multiplier
38 .powi(self.current_attempt as i32);
39 let calculated_delay = Duration::from_millis((base_delay * multiplier) as u64);
40
41 let mut delay = if calculated_delay > self.config.max_delay {
43 self.config.max_delay
44 } else {
45 calculated_delay
46 };
47
48 if self.config.jitter {
50 delay = Self::add_jitter(delay);
51 }
52
53 self.current_attempt += 1;
54 Some(delay)
55 }
56
57 fn add_jitter(delay: Duration) -> Duration {
59 use rand::Rng;
60 let jitter_ms = rand::thread_rng().gen_range(0..=delay.as_millis() as u64 / 4);
61 delay + Duration::from_millis(jitter_ms)
62 }
63
64 pub fn reset(&mut self) {
66 self.current_attempt = 0;
67 }
68
69 pub fn attempts(&self) -> u32 {
71 self.current_attempt
72 }
73}
74
75pub struct RetryNode {
77 total_timeout: Option<Duration>,
78}
79
80impl RetryNode {
81 pub fn new() -> Self {
83 Self {
84 total_timeout: None,
85 }
86 }
87
88 pub fn with_timeout(mut self, timeout: Duration) -> Self {
90 self.total_timeout = Some(timeout);
91 self
92 }
93
94 pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, BittensorError>
96 where
97 F: Fn() -> Fut,
98 Fut: Future<Output = Result<T, BittensorError>>,
99 {
100 let start_time = tokio::time::Instant::now();
101
102 match operation().await {
104 Ok(result) => Ok(result),
105 Err(error) => {
106 if !error.is_retryable() {
107 debug!("Error is not retryable: {:?}", error);
108 return Err(error);
109 }
110
111 let config = match error.retry_config() {
112 Some(config) => config,
113 None => {
114 debug!("No retry config for error: {:?}", error);
115 return Err(error);
116 }
117 };
118
119 info!(
120 "Starting retry for error category: {:?}, max_attempts: {}",
121 error.category(),
122 config.max_attempts
123 );
124
125 let mut backoff = ExponentialBackoff::new(config);
126 let mut _last_error = error;
127
128 while let Some(delay) = backoff.next_delay() {
130 if let Some(total_timeout) = self.total_timeout {
132 if start_time.elapsed() + delay >= total_timeout {
133 warn!(
134 "Total timeout reached after {} attempts",
135 backoff.attempts()
136 );
137 return Err(BittensorError::backoff_timeout(start_time.elapsed()));
138 }
139 }
140
141 debug!(
142 "Retry attempt {} after delay {:?}",
143 backoff.attempts(),
144 delay
145 );
146 sleep(delay).await;
147
148 match operation().await {
149 Ok(result) => {
150 info!("Operation succeeded after {} attempts", backoff.attempts());
151 return Ok(result);
152 }
153 Err(error) => {
154 _last_error = error;
155
156 if !_last_error.is_retryable() {
158 debug!("Error became non-retryable: {:?}", _last_error);
159 return Err(_last_error);
160 }
161
162 warn!(
163 "Retry attempt {} failed: {}",
164 backoff.attempts(),
165 _last_error
166 );
167 }
168 }
169 }
170
171 warn!(
172 "All {} retry attempts exhausted, last error: {}",
173 backoff.config.max_attempts, _last_error
174 );
175 Err(BittensorError::max_retries_exceeded(
176 backoff.config.max_attempts,
177 ))
178 }
179 }
180 }
181
182 pub async fn execute_with_config<F, Fut, T>(
184 &self,
185 operation: F,
186 config: RetryConfig,
187 ) -> Result<T, BittensorError>
188 where
189 F: Fn() -> Fut,
190 Fut: Future<Output = Result<T, BittensorError>>,
191 {
192 let start_time = tokio::time::Instant::now();
193 let mut backoff = ExponentialBackoff::new(config);
194
195 match operation().await {
197 Ok(result) => Ok(result),
198 Err(mut _last_error) => {
199 info!(
200 "Starting custom retry, max_attempts: {}",
201 backoff.config.max_attempts
202 );
203
204 while let Some(delay) = backoff.next_delay() {
206 if let Some(total_timeout) = self.total_timeout {
208 if start_time.elapsed() + delay >= total_timeout {
209 warn!(
210 "Total timeout reached after {} attempts",
211 backoff.attempts()
212 );
213 return Err(BittensorError::backoff_timeout(start_time.elapsed()));
214 }
215 }
216
217 debug!(
218 "Custom retry attempt {} after delay {:?}",
219 backoff.attempts(),
220 delay
221 );
222 sleep(delay).await;
223
224 match operation().await {
225 Ok(result) => {
226 info!(
227 "Custom retry succeeded after {} attempts",
228 backoff.attempts()
229 );
230 return Ok(result);
231 }
232 Err(error) => {
233 _last_error = error;
234 warn!(
235 "Custom retry attempt {} failed: {}",
236 backoff.attempts(),
237 _last_error
238 );
239 }
240 }
241 }
242
243 Err(BittensorError::max_retries_exceeded(
244 backoff.config.max_attempts,
245 ))
246 }
247 }
248 }
249}
250
251impl Default for RetryNode {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257pub async fn retry_operation<F, Fut, T>(operation: F) -> Result<T, BittensorError>
259where
260 F: Fn() -> Fut,
261 Fut: Future<Output = Result<T, BittensorError>>,
262{
263 RetryNode::new().execute(operation).await
264}
265
266pub async fn retry_operation_with_timeout<F, Fut, T>(
268 operation: F,
269 timeout: Duration,
270) -> Result<T, BittensorError>
271where
272 F: Fn() -> Fut,
273 Fut: Future<Output = Result<T, BittensorError>>,
274{
275 RetryNode::new()
276 .with_timeout(timeout)
277 .execute(operation)
278 .await
279}
280
281#[derive(Debug, Clone)]
283pub struct CircuitBreaker {
284 failure_threshold: u32,
285 recovery_timeout: Duration,
286 current_failures: u32,
287 state: CircuitState,
288 last_failure_time: Option<tokio::time::Instant>,
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292enum CircuitState {
293 Closed, Open, HalfOpen, }
297
298impl CircuitBreaker {
299 pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
301 Self {
302 failure_threshold,
303 recovery_timeout,
304 current_failures: 0,
305 state: CircuitState::Closed,
306 last_failure_time: None,
307 }
308 }
309
310 pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, BittensorError>
312 where
313 F: Fn() -> Fut,
314 Fut: Future<Output = Result<T, BittensorError>>,
315 {
316 match self.state {
317 CircuitState::Open => {
318 if let Some(last_failure) = self.last_failure_time {
319 if last_failure.elapsed() >= self.recovery_timeout {
320 debug!("Circuit breaker transitioning to half-open");
321 self.state = CircuitState::HalfOpen;
322 } else {
323 return Err(BittensorError::ServiceUnavailable {
324 message: "Circuit breaker is open".to_string(),
325 });
326 }
327 } else {
328 return Err(BittensorError::ServiceUnavailable {
329 message: "Circuit breaker is open".to_string(),
330 });
331 }
332 }
333 CircuitState::Closed | CircuitState::HalfOpen => {}
334 }
335
336 match operation().await {
337 Ok(result) => {
338 if self.state == CircuitState::HalfOpen {
340 debug!("Circuit breaker recovering - closing circuit");
341 self.state = CircuitState::Closed;
342 }
343 self.current_failures = 0;
344 self.last_failure_time = None;
345 Ok(result)
346 }
347 Err(error) => {
348 self.current_failures += 1;
350 self.last_failure_time = Some(tokio::time::Instant::now());
351
352 if self.current_failures >= self.failure_threshold {
353 warn!(
354 "Circuit breaker opening after {} failures",
355 self.current_failures
356 );
357 self.state = CircuitState::Open;
358 }
359
360 Err(error)
361 }
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use std::sync::atomic::{AtomicU32, Ordering};
370 use std::sync::Arc;
371
372 #[test]
373 fn test_exponential_backoff() {
374 let config = RetryConfig {
375 max_attempts: 3,
376 initial_delay: Duration::from_millis(100),
377 max_delay: Duration::from_secs(5),
378 backoff_multiplier: 2.0,
379 jitter: false,
380 };
381
382 let mut backoff = ExponentialBackoff::new(config);
383
384 let delay1 = backoff.next_delay().unwrap();
386 assert_eq!(delay1, Duration::from_millis(100));
387 assert_eq!(backoff.attempts(), 1);
388
389 let delay2 = backoff.next_delay().unwrap();
391 assert_eq!(delay2, Duration::from_millis(200));
392 assert_eq!(backoff.attempts(), 2);
393
394 let delay3 = backoff.next_delay().unwrap();
396 assert_eq!(delay3, Duration::from_millis(400));
397 assert_eq!(backoff.attempts(), 3);
398
399 assert!(backoff.next_delay().is_none());
401 }
402
403 #[tokio::test]
404 async fn test_retry_node_success_after_failure() {
405 let counter = Arc::new(AtomicU32::new(0));
406 let counter_clone = counter.clone();
407
408 let operation = move || {
409 let counter = counter_clone.clone();
410 async move {
411 let count = counter.fetch_add(1, Ordering::SeqCst);
412 if count < 2 {
413 Err(BittensorError::RpcConnectionError {
414 message: "Connection failed".to_string(),
415 })
416 } else {
417 Ok("success")
418 }
419 }
420 };
421
422 let node = RetryNode::new();
423 let result: Result<&str, BittensorError> = node.execute(operation).await;
424
425 assert!(result.is_ok());
426 assert_eq!(result.unwrap(), "success");
427 assert_eq!(counter.load(Ordering::SeqCst), 3);
428 }
429
430 #[tokio::test]
431 async fn test_retry_node_non_retryable_error() {
432 let operation = || async {
433 Err(BittensorError::InvalidHotkey {
434 hotkey: "invalid".to_string(),
435 })
436 };
437
438 let node = RetryNode::new();
439 let result: Result<&str, BittensorError> = node.execute(operation).await;
440
441 assert!(result.is_err());
442 match result.unwrap_err() {
443 BittensorError::InvalidHotkey { .. } => {}
444 other => panic!("Expected InvalidHotkey, got {other:?}"),
445 }
446 }
447
448 #[tokio::test]
449 async fn test_circuit_breaker() {
450 let mut circuit_breaker = CircuitBreaker::new(2, Duration::from_millis(100));
451 let counter = Arc::new(AtomicU32::new(0));
452
453 let counter_clone = counter.clone();
455 let result: Result<(), BittensorError> = circuit_breaker
456 .execute(|| {
457 let counter = counter_clone.clone();
458 async move {
459 counter.fetch_add(1, Ordering::SeqCst);
460 Err(BittensorError::RpcConnectionError {
461 message: "Connection failed".to_string(),
462 })
463 }
464 })
465 .await;
466 assert!(result.is_err());
467
468 let counter_clone = counter.clone();
470 let result: Result<(), BittensorError> = circuit_breaker
471 .execute(|| {
472 let counter = counter_clone.clone();
473 async move {
474 counter.fetch_add(1, Ordering::SeqCst);
475 Err(BittensorError::RpcConnectionError {
476 message: "Connection failed".to_string(),
477 })
478 }
479 })
480 .await;
481 assert!(result.is_err());
482
483 let counter_before = counter.load(Ordering::SeqCst);
485 let result: Result<&str, BittensorError> = circuit_breaker
486 .execute(|| {
487 let counter = counter.clone();
488 async move {
489 counter.fetch_add(1, Ordering::SeqCst);
490 Ok("should not reach here")
491 }
492 })
493 .await;
494 assert!(result.is_err());
495 assert_eq!(counter.load(Ordering::SeqCst), counter_before); match result.unwrap_err() {
498 BittensorError::ServiceUnavailable { .. } => {}
499 other => panic!("Expected ServiceUnavailable, got {other:?}"),
500 }
501 }
502}