use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use snapdir_core::StoreError;
use snapdir_stores::{
parse_retry_after, retry_network, AsyncSleeper, Attempt, Jitter, RateLimiter, RetryPolicy,
};
#[derive(Default)]
struct RecordingSleeper {
delays: Mutex<Vec<Duration>>,
}
impl RecordingSleeper {
fn recorded(&self) -> Vec<Duration> {
self.delays.lock().unwrap().clone()
}
}
impl AsyncSleeper for RecordingSleeper {
fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send {
self.delays.lock().unwrap().push(dur);
std::future::ready(())
}
}
struct FixedJitter(f64);
impl Jitter for FixedJitter {
fn jitter01(&self) -> f64 {
self.0
}
}
fn runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.expect("build tokio runtime")
}
fn boom() -> StoreError {
StoreError::Backend {
message: "boom".into(),
source: None,
}
}
fn transient(retry_after: Option<Duration>) -> Attempt {
Attempt {
err: boom(),
transient: true,
retry_after,
}
}
fn hard() -> Attempt {
Attempt {
err: boom(),
transient: false,
retry_after: None,
}
}
fn small_policy() -> RetryPolicy {
RetryPolicy {
max_attempts: 5,
base: Duration::from_millis(100),
cap: Duration::from_secs(10),
}
}
#[test]
fn backoff_wire_success_after_k_transient_retries() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let limiter = RateLimiter::new(None); let sleeper = RecordingSleeper::default();
let jitter = FixedJitter(0.5);
let calls = AtomicUsize::new(0);
let k = 3usize;
let result: Result<u32, StoreError> =
retry_network(&policy, &limiter, &sleeper, &jitter, || {
let prev = calls.fetch_add(1, Ordering::SeqCst);
async move {
if prev < k {
Err(transient(None))
} else {
Ok(7)
}
}
})
.await;
assert_eq!(result.unwrap(), 7, "succeeds on the (k+1)-th attempt");
assert_eq!(calls.load(Ordering::SeqCst), k + 1, "op invoked k+1 times");
let recorded = sleeper.recorded();
assert_eq!(recorded.len(), k, "one sleep before each of the k retries");
for d in &recorded {
assert!(*d <= policy.cap, "sleep {d:?} must respect the cap");
}
assert!(recorded[1] >= recorded[0], "backoff should grow");
assert!(recorded[2] >= recorded[1], "backoff should grow");
});
}
#[test]
fn backoff_wire_hard_error_surfaces_without_retry() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let limiter = RateLimiter::new(None);
let sleeper = RecordingSleeper::default();
let jitter = FixedJitter(0.5);
let calls = AtomicUsize::new(0);
let result: Result<(), StoreError> =
retry_network(&policy, &limiter, &sleeper, &jitter, || {
calls.fetch_add(1, Ordering::SeqCst);
async { Err(hard()) }
})
.await;
let err = result.expect_err("hard error surfaces");
assert!(
matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
"the StoreError must surface: {err:?}"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"non-transient => op called exactly once"
);
assert!(
sleeper.recorded().is_empty(),
"no sleeps for a non-transient error"
);
});
}
#[test]
fn backoff_wire_retry_after_hint_is_a_floor() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let limiter = RateLimiter::new(None);
let sleeper = RecordingSleeper::default();
let jitter = FixedJitter(0.0);
let hint = Duration::from_secs(5);
let _r: Result<(), StoreError> =
retry_network(&policy, &limiter, &sleeper, &jitter, || async move {
Err(transient(Some(hint)))
})
.await;
let recorded = sleeper.recorded();
assert!(!recorded.is_empty(), "at least one retry happened");
for d in &recorded {
assert!(
*d >= hint,
"recorded delay {d:?} must be >= the server hint {hint:?}"
);
}
});
}
#[test]
fn backoff_wire_persistent_transient_exhausts_budget() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let limiter = RateLimiter::new(None);
let sleeper = RecordingSleeper::default();
let jitter = FixedJitter(0.5);
let calls = AtomicUsize::new(0);
let result: Result<(), StoreError> =
retry_network(&policy, &limiter, &sleeper, &jitter, || {
calls.fetch_add(1, Ordering::SeqCst);
async { Err(transient(None)) }
})
.await;
assert!(result.is_err(), "persistent transient surfaces the err");
assert_eq!(
calls.load(Ordering::SeqCst),
policy.max_attempts as usize,
"op invoked exactly max_attempts times"
);
assert_eq!(
sleeper.recorded().len(),
(policy.max_attempts - 1) as usize,
"max_attempts-1 sleeps (none after the final attempt)"
);
});
}
#[test]
fn backoff_wire_parse_retry_after_delta_seconds() {
assert_eq!(parse_retry_after("125"), Some(Duration::from_secs(125)));
assert_eq!(parse_retry_after(" 7 "), Some(Duration::from_secs(7)));
assert_eq!(parse_retry_after("0"), Some(Duration::ZERO));
assert_eq!(parse_retry_after("Wed, 21 Oct 2015 07:28:00 GMT"), None);
assert_eq!(parse_retry_after("not-a-number"), None);
assert_eq!(parse_retry_after(""), None);
}
#[test]
fn backoff_wire_request_limiter_paces_each_call() {
let rt = runtime();
rt.block_on(async {
let policy = RetryPolicy {
max_attempts: 5,
base: Duration::from_millis(1),
cap: Duration::from_millis(1),
};
let limiter = RateLimiter::new(Some(2)); let sleeper = RecordingSleeper::default();
let jitter = FixedJitter(0.0);
let calls = AtomicUsize::new(0);
let k = 3usize;
let start = tokio::time::Instant::now();
let result: Result<u32, StoreError> =
retry_network(&policy, &limiter, &sleeper, &jitter, || {
let prev = calls.fetch_add(1, Ordering::SeqCst);
async move {
if prev < k {
Err(transient(None))
} else {
Ok(1)
}
}
})
.await;
let elapsed = start.elapsed();
assert_eq!(result.unwrap(), 1);
assert_eq!(
calls.load(Ordering::SeqCst),
k + 1,
"four total attempts (= four request-token acquisitions)"
);
assert!(
elapsed >= Duration::from_millis(900),
"the per-call request limiter must pace each attempt, took {elapsed:?}"
);
});
}