#![forbid(unsafe_code)]
use crate::cancellation::{CancellationSource, CancellationToken};
use crate::program::{Cmd, TaskSpec};
use web_time::Duration;
const TASK_THREAD_JOIN_TIMEOUT: Duration = Duration::from_millis(250);
const TASK_THREAD_JOIN_POLL: Duration = Duration::from_millis(5);
fn join_task_thread(handle: std::thread::JoinHandle<()>) {
let _ = handle.join();
}
fn join_task_thread_bounded(handle: std::thread::JoinHandle<()>, task_name: &'static str) {
let start = web_time::Instant::now();
while !handle.is_finished() {
if start.elapsed() >= TASK_THREAD_JOIN_TIMEOUT {
tracing::warn!(
task = task_name,
timeout_ms = TASK_THREAD_JOIN_TIMEOUT.as_millis() as u64,
"Timed-out worker thread did not exit within the cancellation join timeout; detaching"
);
return;
}
std::thread::sleep(TASK_THREAD_JOIN_POLL);
}
join_task_thread(handle);
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "state-persistence",
derive(serde::Serialize, serde::Deserialize)
)]
pub enum BackoffStrategy {
Fixed {
delay_ms: u64,
},
Exponential {
base_ms: u64,
max_ms: u64,
},
Linear {
base_ms: u64,
max_ms: u64,
},
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "state-persistence",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct RetryPolicy {
pub max_retries: u32,
pub backoff: BackoffStrategy,
}
impl RetryPolicy {
pub fn new(max_retries: u32, backoff: BackoffStrategy) -> Self {
Self {
max_retries,
backoff,
}
}
pub fn no_retry() -> Self {
Self {
max_retries: 0,
backoff: BackoffStrategy::Fixed { delay_ms: 0 },
}
}
pub fn delay(&self, attempt: u32) -> Duration {
match &self.backoff {
BackoffStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
BackoffStrategy::Exponential { base_ms, max_ms } => {
let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
let delay = base_ms.saturating_mul(multiplier);
Duration::from_millis(delay.min(*max_ms))
}
BackoffStrategy::Linear { base_ms, max_ms } => {
let delay = base_ms.saturating_mul(u64::from(attempt) + 1);
Duration::from_millis(delay.min(*max_ms))
}
}
}
pub fn total_max_delay(&self) -> Duration {
let mut total = Duration::ZERO;
for i in 0..self.max_retries {
total += self.delay(i);
}
total
}
}
pub fn task_with_timeout<M, F>(timeout: Duration, f: F, on_timeout: M) -> Cmd<M>
where
M: Send + 'static,
F: FnOnce(CancellationToken) -> M + Send + 'static,
{
Cmd::task(move || {
let source = CancellationSource::new();
let token = source.token();
let (tx, rx) = std::sync::mpsc::channel();
let handle = std::thread::spawn(move || {
let result = f(token);
let _ = tx.send(result);
});
match rx.recv_timeout(timeout) {
Ok(msg) => {
join_task_thread(handle);
msg
}
Err(_) => {
source.cancel();
join_task_thread_bounded(handle, "task_with_timeout");
on_timeout
}
}
})
}
pub fn task_with_timeout_named<M, F>(
name: impl Into<String>,
timeout: Duration,
f: F,
on_timeout: M,
) -> Cmd<M>
where
M: Send + 'static,
F: FnOnce(CancellationToken) -> M + Send + 'static,
{
Cmd::task_with_spec(TaskSpec::default().with_name(name), move || {
let source = CancellationSource::new();
let token = source.token();
let (tx, rx) = std::sync::mpsc::channel();
let handle = std::thread::spawn(move || {
let result = f(token);
let _ = tx.send(result);
});
match rx.recv_timeout(timeout) {
Ok(msg) => {
join_task_thread(handle);
msg
}
Err(_) => {
source.cancel();
join_task_thread_bounded(handle, "task_with_timeout_named");
on_timeout
}
}
})
}
pub fn task_with_retry<M, F>(policy: RetryPolicy, f: F, on_exhaust: fn(String) -> M) -> Cmd<M>
where
M: Send + 'static,
F: Fn() -> Result<M, String> + Send + 'static,
{
Cmd::task(move || {
let mut last_err = String::new();
for attempt in 0..=policy.max_retries {
match f() {
Ok(msg) => return msg,
Err(e) => {
last_err = e;
if attempt < policy.max_retries {
std::thread::sleep(policy.delay(attempt));
}
}
}
}
on_exhaust(last_err)
})
}
pub fn task_with_retry_and_timeout<M, F>(
policy: RetryPolicy,
per_attempt_timeout: Duration,
f: F,
on_exhaust: fn(String) -> M,
) -> Cmd<M>
where
M: Send + 'static,
F: Fn(CancellationToken) -> Result<M, String> + Send + Sync + 'static,
{
Cmd::task(move || {
let f = std::sync::Arc::new(f);
let mut last_err = String::new();
for attempt in 0..=policy.max_retries {
let source = CancellationSource::new();
let token = source.token();
let (tx, rx) = std::sync::mpsc::channel();
let f_clone = std::sync::Arc::clone(&f);
let handle = std::thread::spawn(move || {
let result = f_clone(token);
let _ = tx.send(result);
});
match rx.recv_timeout(per_attempt_timeout) {
Ok(Ok(msg)) => {
join_task_thread(handle);
return msg;
}
Ok(Err(e)) => {
join_task_thread(handle);
last_err = e;
}
Err(_) => {
source.cancel();
join_task_thread_bounded(handle, "task_with_retry_and_timeout");
last_err = "timeout".into();
}
}
if attempt < policy.max_retries {
std::thread::sleep(policy.delay(attempt));
}
}
on_exhaust(last_err)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fixed_backoff_constant_delay() {
let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
assert_eq!(policy.delay(0), Duration::from_millis(100));
assert_eq!(policy.delay(1), Duration::from_millis(100));
assert_eq!(policy.delay(2), Duration::from_millis(100));
}
#[test]
fn exponential_backoff_doubles() {
let policy = RetryPolicy::new(
5,
BackoffStrategy::Exponential {
base_ms: 100,
max_ms: 5000,
},
);
assert_eq!(policy.delay(0), Duration::from_millis(100));
assert_eq!(policy.delay(1), Duration::from_millis(200));
assert_eq!(policy.delay(2), Duration::from_millis(400));
assert_eq!(policy.delay(3), Duration::from_millis(800));
}
#[test]
fn exponential_backoff_caps_at_max() {
let policy = RetryPolicy::new(
5,
BackoffStrategy::Exponential {
base_ms: 1000,
max_ms: 3000,
},
);
assert_eq!(policy.delay(0), Duration::from_millis(1000));
assert_eq!(policy.delay(1), Duration::from_millis(2000));
assert_eq!(policy.delay(2), Duration::from_millis(3000)); assert_eq!(policy.delay(3), Duration::from_millis(3000)); }
#[test]
fn linear_backoff_increments() {
let policy = RetryPolicy::new(
4,
BackoffStrategy::Linear {
base_ms: 100,
max_ms: 500,
},
);
assert_eq!(policy.delay(0), Duration::from_millis(100));
assert_eq!(policy.delay(1), Duration::from_millis(200));
assert_eq!(policy.delay(2), Duration::from_millis(300));
assert_eq!(policy.delay(3), Duration::from_millis(400));
assert_eq!(policy.delay(4), Duration::from_millis(500)); }
#[test]
fn linear_backoff_caps_at_max() {
let policy = RetryPolicy::new(
4,
BackoffStrategy::Linear {
base_ms: 200,
max_ms: 500,
},
);
assert_eq!(policy.delay(2), Duration::from_millis(500)); }
#[test]
fn no_retry_policy() {
let policy = RetryPolicy::no_retry();
assert_eq!(policy.max_retries, 0);
}
#[test]
fn total_max_delay_fixed() {
let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
assert_eq!(policy.total_max_delay(), Duration::from_millis(300));
}
#[test]
fn total_max_delay_exponential() {
let policy = RetryPolicy::new(
3,
BackoffStrategy::Exponential {
base_ms: 100,
max_ms: 10000,
},
);
assert_eq!(policy.total_max_delay(), Duration::from_millis(700));
}
#[test]
fn total_max_delay_zero_retries() {
let policy = RetryPolicy::no_retry();
assert_eq!(policy.total_max_delay(), Duration::ZERO);
}
#[test]
fn exponential_backoff_overflow_saturates() {
let policy = RetryPolicy::new(
1,
BackoffStrategy::Exponential {
base_ms: u64::MAX / 2,
max_ms: u64::MAX,
},
);
let _ = policy.delay(30);
}
#[test]
fn linear_backoff_overflow_saturates() {
let policy = RetryPolicy::new(
1,
BackoffStrategy::Linear {
base_ms: u64::MAX / 2,
max_ms: u64::MAX,
},
);
let _ = policy.delay(30);
}
#[test]
fn retry_policy_clone_eq() {
let policy = RetryPolicy::new(
3,
BackoffStrategy::Exponential {
base_ms: 100,
max_ms: 5000,
},
);
let cloned = policy.clone();
assert_eq!(policy, cloned);
}
#[test]
fn task_with_retry_succeeds_first_try() {
#[derive(Debug, PartialEq)]
enum Msg {
Ok(i32),
Err(String),
}
let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 1 });
let cmd = task_with_retry(policy, || Ok(Msg::Ok(42)), Msg::Err);
assert_eq!(cmd.type_name(), "Task");
}
#[test]
fn task_with_timeout_produces_task() {
#[derive(Debug)]
#[allow(dead_code)]
enum Msg {
Result(i32),
Timeout,
}
let cmd = task_with_timeout(
Duration::from_secs(1),
|_token| Msg::Result(42),
Msg::Timeout,
);
assert_eq!(cmd.type_name(), "Task");
}
#[test]
fn task_with_timeout_requests_cancellation_on_timeout() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, PartialEq)]
enum Msg {
Finished,
Timeout,
}
let cancelled = Arc::new(AtomicBool::new(false));
let worker_exited = Arc::new(AtomicBool::new(false));
let cancelled_flag = Arc::clone(&cancelled);
let exited_flag = Arc::clone(&worker_exited);
let cmd = task_with_timeout(
Duration::from_millis(10),
move |token| {
cancelled_flag.store(token.wait_timeout(Duration::from_secs(1)), Ordering::SeqCst);
exited_flag.store(true, Ordering::SeqCst);
Msg::Finished
},
Msg::Timeout,
);
let result = match cmd {
Cmd::Task(_, task) => task(),
other => panic!("expected Task, got {other:?}"),
};
assert_eq!(result, Msg::Timeout);
std::thread::sleep(Duration::from_millis(50));
assert!(cancelled.load(Ordering::SeqCst));
assert!(worker_exited.load(Ordering::SeqCst));
}
#[test]
fn task_with_retry_and_timeout_cancels_each_timed_out_attempt() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, PartialEq)]
enum Msg {
Exhausted(String),
}
fn on_exhaust(err: String) -> Msg {
Msg::Exhausted(err)
}
let attempts = Arc::new(AtomicUsize::new(0));
let cancelled = Arc::new(AtomicUsize::new(0));
let attempts_flag = Arc::clone(&attempts);
let cancelled_flag = Arc::clone(&cancelled);
let policy = RetryPolicy::new(1, BackoffStrategy::Fixed { delay_ms: 0 });
let cmd = task_with_retry_and_timeout(
policy,
Duration::from_millis(10),
move |token| {
attempts_flag.fetch_add(1, Ordering::SeqCst);
if token.wait_timeout(Duration::from_secs(1)) {
cancelled_flag.fetch_add(1, Ordering::SeqCst);
}
Err("cancelled".to_owned())
},
on_exhaust,
);
let result = match cmd {
Cmd::Task(_, task) => task(),
other => panic!("expected Task, got {other:?}"),
};
assert_eq!(result, Msg::Exhausted("timeout".to_owned()));
std::thread::sleep(Duration::from_millis(50));
assert_eq!(attempts.load(Ordering::SeqCst), 2);
assert_eq!(cancelled.load(Ordering::SeqCst), 2);
}
#[test]
fn backoff_strategy_variants_debug() {
let fixed = BackoffStrategy::Fixed { delay_ms: 100 };
let exp = BackoffStrategy::Exponential {
base_ms: 100,
max_ms: 5000,
};
let linear = BackoffStrategy::Linear {
base_ms: 100,
max_ms: 500,
};
let _ = format!("{fixed:?}");
let _ = format!("{exp:?}");
let _ = format!("{linear:?}");
}
}