use crate::error::{BittensorError, RetryConfig};
use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
config: RetryConfig,
current_attempt: u32,
}
impl ExponentialBackoff {
pub fn new(config: RetryConfig) -> Self {
Self {
config,
current_attempt: 0,
}
}
pub fn next_delay(&mut self) -> Option<Duration> {
if self.current_attempt >= self.config.max_attempts {
return None;
}
let base_delay = self.config.initial_delay.as_millis() as f64;
let multiplier = self
.config
.backoff_multiplier
.powi(self.current_attempt as i32);
let calculated_delay = Duration::from_millis((base_delay * multiplier) as u64);
let mut delay = if calculated_delay > self.config.max_delay {
self.config.max_delay
} else {
calculated_delay
};
if self.config.jitter {
delay = Self::add_jitter(delay);
}
self.current_attempt += 1;
Some(delay)
}
fn add_jitter(delay: Duration) -> Duration {
use rand::Rng;
let jitter_ms = rand::thread_rng().gen_range(0..=delay.as_millis() as u64 / 4);
delay + Duration::from_millis(jitter_ms)
}
pub fn reset(&mut self) {
self.current_attempt = 0;
}
pub fn attempts(&self) -> u32 {
self.current_attempt
}
}
pub struct RetryNode {
total_timeout: Option<Duration>,
}
impl RetryNode {
pub fn new() -> Self {
Self {
total_timeout: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.total_timeout = Some(timeout);
self
}
pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, BittensorError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, BittensorError>>,
{
let start_time = tokio::time::Instant::now();
match operation().await {
Ok(result) => Ok(result),
Err(error) => {
if !error.is_retryable() {
debug!("Error is not retryable: {:?}", error);
return Err(error);
}
let config = match error.retry_config() {
Some(config) => config,
None => {
debug!("No retry config for error: {:?}", error);
return Err(error);
}
};
info!(
"Starting retry for error category: {:?}, max_attempts: {}",
error.category(),
config.max_attempts
);
let mut backoff = ExponentialBackoff::new(config);
let mut _last_error = error;
while let Some(delay) = backoff.next_delay() {
if let Some(total_timeout) = self.total_timeout {
if start_time.elapsed() + delay >= total_timeout {
warn!(
"Total timeout reached after {} attempts",
backoff.attempts()
);
return Err(BittensorError::backoff_timeout(start_time.elapsed()));
}
}
debug!(
"Retry attempt {} after delay {:?}",
backoff.attempts(),
delay
);
sleep(delay).await;
match operation().await {
Ok(result) => {
info!("Operation succeeded after {} attempts", backoff.attempts());
return Ok(result);
}
Err(error) => {
_last_error = error;
if !_last_error.is_retryable() {
debug!("Error became non-retryable: {:?}", _last_error);
return Err(_last_error);
}
warn!(
"Retry attempt {} failed: {}",
backoff.attempts(),
_last_error
);
}
}
}
warn!(
"All {} retry attempts exhausted, last error: {}",
backoff.config.max_attempts, _last_error
);
Err(BittensorError::max_retries_exceeded(
backoff.config.max_attempts,
))
}
}
}
pub async fn execute_with_config<F, Fut, T>(
&self,
operation: F,
config: RetryConfig,
) -> Result<T, BittensorError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, BittensorError>>,
{
let start_time = tokio::time::Instant::now();
let mut backoff = ExponentialBackoff::new(config);
match operation().await {
Ok(result) => Ok(result),
Err(mut _last_error) => {
info!(
"Starting custom retry, max_attempts: {}",
backoff.config.max_attempts
);
while let Some(delay) = backoff.next_delay() {
if let Some(total_timeout) = self.total_timeout {
if start_time.elapsed() + delay >= total_timeout {
warn!(
"Total timeout reached after {} attempts",
backoff.attempts()
);
return Err(BittensorError::backoff_timeout(start_time.elapsed()));
}
}
debug!(
"Custom retry attempt {} after delay {:?}",
backoff.attempts(),
delay
);
sleep(delay).await;
match operation().await {
Ok(result) => {
info!(
"Custom retry succeeded after {} attempts",
backoff.attempts()
);
return Ok(result);
}
Err(error) => {
_last_error = error;
warn!(
"Custom retry attempt {} failed: {}",
backoff.attempts(),
_last_error
);
}
}
}
Err(BittensorError::max_retries_exceeded(
backoff.config.max_attempts,
))
}
}
}
}
impl Default for RetryNode {
fn default() -> Self {
Self::new()
}
}
pub async fn retry_operation<F, Fut, T>(operation: F) -> Result<T, BittensorError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, BittensorError>>,
{
RetryNode::new().execute(operation).await
}
pub async fn retry_operation_with_timeout<F, Fut, T>(
operation: F,
timeout: Duration,
) -> Result<T, BittensorError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, BittensorError>>,
{
RetryNode::new()
.with_timeout(timeout)
.execute(operation)
.await
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
failure_threshold: u32,
recovery_timeout: Duration,
current_failures: u32,
state: CircuitState,
last_failure_time: Option<tokio::time::Instant>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CircuitState {
Closed, Open, HalfOpen, }
impl CircuitBreaker {
pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
Self {
failure_threshold,
recovery_timeout,
current_failures: 0,
state: CircuitState::Closed,
last_failure_time: None,
}
}
pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, BittensorError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, BittensorError>>,
{
match self.state {
CircuitState::Open => {
if let Some(last_failure) = self.last_failure_time {
if last_failure.elapsed() >= self.recovery_timeout {
debug!("Circuit breaker transitioning to half-open");
self.state = CircuitState::HalfOpen;
} else {
return Err(BittensorError::ServiceUnavailable {
message: "Circuit breaker is open".to_string(),
});
}
} else {
return Err(BittensorError::ServiceUnavailable {
message: "Circuit breaker is open".to_string(),
});
}
}
CircuitState::Closed | CircuitState::HalfOpen => {}
}
match operation().await {
Ok(result) => {
if self.state == CircuitState::HalfOpen {
debug!("Circuit breaker recovering - closing circuit");
self.state = CircuitState::Closed;
}
self.current_failures = 0;
self.last_failure_time = None;
Ok(result)
}
Err(error) => {
self.current_failures += 1;
self.last_failure_time = Some(tokio::time::Instant::now());
if self.current_failures >= self.failure_threshold {
warn!(
"Circuit breaker opening after {} failures",
self.current_failures
);
self.state = CircuitState::Open;
}
Err(error)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_exponential_backoff() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
jitter: false,
};
let mut backoff = ExponentialBackoff::new(config);
let delay1 = backoff.next_delay().unwrap();
assert_eq!(delay1, Duration::from_millis(100));
assert_eq!(backoff.attempts(), 1);
let delay2 = backoff.next_delay().unwrap();
assert_eq!(delay2, Duration::from_millis(200));
assert_eq!(backoff.attempts(), 2);
let delay3 = backoff.next_delay().unwrap();
assert_eq!(delay3, Duration::from_millis(400));
assert_eq!(backoff.attempts(), 3);
assert!(backoff.next_delay().is_none());
}
#[tokio::test]
async fn test_retry_node_success_after_failure() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let operation = move || {
let counter = counter_clone.clone();
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(BittensorError::RpcConnectionError {
message: "Connection failed".to_string(),
})
} else {
Ok("success")
}
}
};
let node = RetryNode::new();
let result: Result<&str, BittensorError> = node.execute(operation).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_node_non_retryable_error() {
let operation = || async {
Err(BittensorError::InvalidHotkey {
hotkey: "invalid".to_string(),
})
};
let node = RetryNode::new();
let result: Result<&str, BittensorError> = node.execute(operation).await;
assert!(result.is_err());
match result.unwrap_err() {
BittensorError::InvalidHotkey { .. } => {}
other => panic!("Expected InvalidHotkey, got {other:?}"),
}
}
#[tokio::test]
async fn test_circuit_breaker() {
let mut circuit_breaker = CircuitBreaker::new(2, Duration::from_millis(100));
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result: Result<(), BittensorError> = circuit_breaker
.execute(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(BittensorError::RpcConnectionError {
message: "Connection failed".to_string(),
})
}
})
.await;
assert!(result.is_err());
let counter_clone = counter.clone();
let result: Result<(), BittensorError> = circuit_breaker
.execute(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(BittensorError::RpcConnectionError {
message: "Connection failed".to_string(),
})
}
})
.await;
assert!(result.is_err());
let counter_before = counter.load(Ordering::SeqCst);
let result: Result<&str, BittensorError> = circuit_breaker
.execute(|| {
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok("should not reach here")
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), counter_before);
match result.unwrap_err() {
BittensorError::ServiceUnavailable { .. } => {}
other => panic!("Expected ServiceUnavailable, got {other:?}"),
}
}
}