use std::future::Future;
use std::time::Duration;
use rand::Rng;
use tokio::time::sleep;
use tracing::{debug, warn};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn none() -> Self {
Self {
max_retries: 0,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
max_retries: 5,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(10),
backoff_multiplier: 1.5,
jitter: true,
}
}
pub fn for_scan() -> Self {
Self {
max_retries: 5,
initial_delay: Duration::from_millis(200),
max_delay: Duration::from_secs(2),
backoff_multiplier: 1.5,
jitter: true,
}
}
pub fn for_connect() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
}
}
pub fn for_read() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(2),
backoff_multiplier: 2.0,
jitter: true,
}
}
pub fn for_write() -> Self {
Self {
max_retries: 2,
initial_delay: Duration::from_millis(200),
max_delay: Duration::from_secs(3),
backoff_multiplier: 2.0,
jitter: true,
}
}
pub fn for_history() -> Self {
Self {
max_retries: 5,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(15),
backoff_multiplier: 2.0,
jitter: true,
}
}
pub fn for_reconnect() -> Self {
Self {
max_retries: 5,
initial_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
pub fn quick() -> Self {
Self {
max_retries: 2,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_millis(500),
backoff_multiplier: 2.0,
jitter: false,
}
}
#[must_use]
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
#[must_use]
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
#[must_use]
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
#[must_use]
pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
#[must_use]
pub fn jitter(mut self, enabled: bool) -> Self {
self.jitter = enabled;
self
}
fn delay_for_attempt(&self, attempt: u32) -> Duration {
let base_delay =
self.initial_delay.as_secs_f64() * self.backoff_multiplier.powi(attempt as i32);
let capped_delay = base_delay.min(self.max_delay.as_secs_f64());
let final_delay = if self.jitter {
let jitter_factor = 1.0 + (rand::rng().random::<f64>() * 0.25);
capped_delay * jitter_factor
} else {
capped_delay
};
Duration::from_secs_f64(final_delay)
}
}
pub async fn with_retry<F, Fut, T>(
config: &RetryConfig,
operation_name: &str,
operation: F,
) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..=config.max_retries {
match operation().await {
Ok(result) => {
if attempt > 0 {
debug!("{} succeeded after {} retries", operation_name, attempt);
}
return Ok(result);
}
Err(e) => {
if !is_retryable(&e) {
return Err(e);
}
last_error = Some(e);
if attempt < config.max_retries {
let delay = config.delay_for_attempt(attempt);
warn!(
"{} failed (attempt {}/{}), retrying in {:?}",
operation_name,
attempt + 1,
config.max_retries + 1,
delay
);
sleep(delay).await;
}
}
}
}
Err(last_error
.unwrap_or_else(|| Error::InvalidData("Operation failed with no error".to_string())))
}
fn is_retryable(error: &Error) -> bool {
use crate::error::ConnectionFailureReason;
match error {
Error::Timeout { .. } => true,
Error::Bluetooth(_) => true,
Error::ConnectionFailed { reason, .. } => {
matches!(
reason,
ConnectionFailureReason::OutOfRange
| ConnectionFailureReason::Timeout
| ConnectionFailureReason::BleError(_)
| ConnectionFailureReason::Other(_)
)
}
Error::NotConnected => true,
Error::WriteFailed { .. } => true,
Error::InvalidData(_) => false,
Error::InvalidHistoryData { .. } => false,
Error::InvalidReadingFormat { .. } => false,
Error::DeviceNotFound(_) => false,
Error::CharacteristicNotFound { .. } => false,
Error::Cancelled => false,
Error::Io(_) => true,
Error::InvalidConfig(_) => false,
Error::Unsupported(_) => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{ConnectionFailureReason, DeviceNotFoundReason};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert!(config.jitter);
}
#[test]
fn test_retry_config_none() {
let config = RetryConfig::none();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_delay_calculation() {
let config = RetryConfig {
initial_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_delay: Duration::from_secs(10),
jitter: false,
max_retries: 5,
};
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
}
#[test]
fn test_is_retryable() {
assert!(is_retryable(&Error::Timeout {
operation: "test".to_string(),
duration: Duration::from_secs(1),
}));
assert!(is_retryable(&Error::ConnectionFailed {
device_id: None,
reason: ConnectionFailureReason::Other("test".to_string()),
}));
assert!(is_retryable(&Error::NotConnected));
assert!(!is_retryable(&Error::InvalidData("test".to_string())));
assert!(!is_retryable(&Error::DeviceNotFound(
DeviceNotFoundReason::NotFound {
identifier: "test".to_string()
}
)));
}
#[tokio::test]
async fn test_with_retry_immediate_success() {
let config = RetryConfig::new(3);
let result = with_retry(&config, "test", || async { Ok::<_, Error>(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_with_retry_eventual_success() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(1),
jitter: false,
..Default::default()
};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32> = with_retry(&config, "test", || {
let attempts = Arc::clone(&attempts_clone);
async move {
let count = attempts.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(Error::ConnectionFailed {
device_id: None,
reason: ConnectionFailureReason::Other("transient error".to_string()),
})
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_with_retry_all_fail() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(1),
jitter: false,
..Default::default()
};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32> = with_retry(&config, "test", || {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(Error::ConnectionFailed {
device_id: None,
reason: ConnectionFailureReason::Other("persistent error".to_string()),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_with_retry_non_retryable_error() {
let config = RetryConfig::new(3);
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32> = with_retry(&config, "test", || {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(Error::InvalidData("not retryable".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1); }
}