use std::time::Duration;
use thiserror::Error;
use tokio::sync::Semaphore;
#[derive(Debug, Error)]
pub enum AsyncControlError {
#[error("Semaphore closed")]
SemaphoreClosed,
#[error("Timed out after {0}s")]
Timeout(u64),
#[error("{0}")]
Operation(String),
}
pub async fn with_permit_and_timeout<F, Fut, T, E>(
semaphore: &Semaphore,
timeout_dur: Duration,
op: F,
) -> Result<T, AsyncControlError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let _permit = semaphore
.acquire()
.await
.map_err(|_| AsyncControlError::SemaphoreClosed)?;
match tokio::time::timeout(timeout_dur, op()).await {
Ok(Ok(v)) => Ok(v),
Ok(Err(e)) => Err(AsyncControlError::Operation(e.to_string())),
Err(_) => Err(AsyncControlError::Timeout(timeout_dur.as_secs())),
}
}
pub async fn retry_fixed_delays<F, Fut, SleepFn, SleepFut, T, E>(
delays: &[Duration],
mut sleep_fn: SleepFn,
mut op: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
SleepFn: FnMut(Duration) -> SleepFut,
SleepFut: std::future::Future<Output = ()>,
E: std::fmt::Debug,
{
let mut last_err = None;
for d in delays {
sleep_fn(*d).await;
match op().await {
Ok(v) => return Ok(v),
Err(e) => {
last_err = Some(e);
}
}
}
#[expect(clippy::expect_used)]
Err(last_err.expect("retry_fixed_delays called with empty delays"))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[tokio::test]
async fn semaphore_limits_concurrency() {
let semaphore = Semaphore::new(2);
let in_flight = Arc::new(AtomicUsize::new(0));
let max_observed = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..4 {
let sem = &semaphore;
let in_flight = Arc::clone(&in_flight);
let max_observed = Arc::clone(&max_observed);
handles.push(async move {
let result: Result<(), AsyncControlError> =
with_permit_and_timeout(sem, Duration::from_secs(10), || async {
let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
max_observed.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok::<_, TestError>(())
})
.await;
result
});
}
futures::future::join_all(handles).await;
assert_eq!(max_observed.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn timeout_returns_error_when_exceeded() {
let semaphore = Semaphore::new(1);
let result: Result<(), AsyncControlError> =
with_permit_and_timeout(&semaphore, Duration::from_millis(10), || async {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok::<_, TestError>(())
})
.await;
assert!(result.is_err());
match result.unwrap_err() {
AsyncControlError::Timeout(_) => {}
other => panic!("Expected Timeout error, got: {other:?}"),
}
}
#[tokio::test]
async fn timeout_returns_success_when_op_completes_in_time() {
let semaphore = Semaphore::new(1);
let result: Result<i32, AsyncControlError> =
with_permit_and_timeout(&semaphore, Duration::from_secs(10), || async {
Ok::<_, TestError>(42)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retry_succeeds_on_third_attempt() {
let attempt_count = Arc::new(AtomicUsize::new(0));
let delays_observed = Arc::new(std::sync::Mutex::new(Vec::new()));
let delays = [
Duration::from_millis(0),
Duration::from_millis(10),
Duration::from_millis(20),
];
let result: Result<&str, TestError> = retry_fixed_delays(
&delays,
|d| {
let delays_observed = Arc::clone(&delays_observed);
async move {
delays_observed.lock().unwrap().push(d);
}
},
|| {
let attempt_count = Arc::clone(&attempt_count);
async move {
let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if attempt < 3 {
Err(TestError(format!("attempt {attempt} failed")))
} else {
Ok("success")
}
}
},
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn retry_returns_last_error_when_all_fail() {
let delays = [Duration::from_millis(0), Duration::from_millis(0)];
let result: Result<(), TestError> = retry_fixed_delays(
&delays,
|_| async {},
|| async { Err(TestError("always fails".into())) },
)
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().0, "always fails");
}
#[tokio::test]
async fn retry_succeeds_on_first_attempt() {
let delays = [Duration::from_millis(0)];
let result: Result<i32, TestError> =
retry_fixed_delays(&delays, |_| async {}, || async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
}