1#![forbid(unsafe_code)]
36
37use crate::cancellation::{CancellationSource, CancellationToken};
38use crate::program::{Cmd, TaskSpec};
39use web_time::Duration;
40
41const TASK_THREAD_JOIN_TIMEOUT: Duration = Duration::from_millis(250);
42const TASK_THREAD_JOIN_POLL: Duration = Duration::from_millis(5);
43
44fn join_task_thread(handle: std::thread::JoinHandle<()>) {
45 let _ = handle.join();
46}
47
48fn join_task_thread_bounded(handle: std::thread::JoinHandle<()>, task_name: &'static str) {
49 let start = web_time::Instant::now();
50 while !handle.is_finished() {
51 if start.elapsed() >= TASK_THREAD_JOIN_TIMEOUT {
52 tracing::warn!(
53 task = task_name,
54 timeout_ms = TASK_THREAD_JOIN_TIMEOUT.as_millis() as u64,
55 "Timed-out worker thread did not exit within the cancellation join timeout; detaching"
56 );
57 return;
58 }
59 std::thread::sleep(TASK_THREAD_JOIN_POLL);
60 }
61 join_task_thread(handle);
62}
63
64#[derive(Debug, Clone, PartialEq)]
66#[cfg_attr(
67 feature = "state-persistence",
68 derive(serde::Serialize, serde::Deserialize)
69)]
70pub enum BackoffStrategy {
71 Fixed {
73 delay_ms: u64,
75 },
76 Exponential {
78 base_ms: u64,
80 max_ms: u64,
82 },
83 Linear {
85 base_ms: u64,
87 max_ms: u64,
89 },
90}
91
92#[derive(Debug, Clone, PartialEq)]
94#[cfg_attr(
95 feature = "state-persistence",
96 derive(serde::Serialize, serde::Deserialize)
97)]
98pub struct RetryPolicy {
99 pub max_retries: u32,
101 pub backoff: BackoffStrategy,
103}
104
105impl RetryPolicy {
106 pub fn new(max_retries: u32, backoff: BackoffStrategy) -> Self {
108 Self {
109 max_retries,
110 backoff,
111 }
112 }
113
114 pub fn no_retry() -> Self {
116 Self {
117 max_retries: 0,
118 backoff: BackoffStrategy::Fixed { delay_ms: 0 },
119 }
120 }
121
122 pub fn delay(&self, attempt: u32) -> Duration {
124 match &self.backoff {
125 BackoffStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
126 BackoffStrategy::Exponential { base_ms, max_ms } => {
127 let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
128 let delay = base_ms.saturating_mul(multiplier);
129 Duration::from_millis(delay.min(*max_ms))
130 }
131 BackoffStrategy::Linear { base_ms, max_ms } => {
132 let delay = base_ms.saturating_mul(u64::from(attempt) + 1);
133 Duration::from_millis(delay.min(*max_ms))
134 }
135 }
136 }
137
138 pub fn total_max_delay(&self) -> Duration {
140 let mut total = Duration::ZERO;
141 for i in 0..self.max_retries {
142 total += self.delay(i);
143 }
144 total
145 }
146}
147
148pub fn task_with_timeout<M, F>(timeout: Duration, f: F, on_timeout: M) -> Cmd<M>
154where
155 M: Send + 'static,
156 F: FnOnce(CancellationToken) -> M + Send + 'static,
157{
158 Cmd::task(move || {
159 let source = CancellationSource::new();
160 let token = source.token();
161 let (tx, rx) = std::sync::mpsc::channel();
162 let handle = std::thread::spawn(move || {
163 let result = f(token);
164 let _ = tx.send(result);
165 });
166 match rx.recv_timeout(timeout) {
167 Ok(msg) => {
168 join_task_thread(handle);
169 msg
170 }
171 Err(_) => {
172 source.cancel();
173 join_task_thread_bounded(handle, "task_with_timeout");
174 on_timeout
175 }
176 }
177 })
178}
179
180pub fn task_with_timeout_named<M, F>(
182 name: impl Into<String>,
183 timeout: Duration,
184 f: F,
185 on_timeout: M,
186) -> Cmd<M>
187where
188 M: Send + 'static,
189 F: FnOnce(CancellationToken) -> M + Send + 'static,
190{
191 Cmd::task_with_spec(TaskSpec::default().with_name(name), move || {
192 let source = CancellationSource::new();
193 let token = source.token();
194 let (tx, rx) = std::sync::mpsc::channel();
195 let handle = std::thread::spawn(move || {
196 let result = f(token);
197 let _ = tx.send(result);
198 });
199 match rx.recv_timeout(timeout) {
200 Ok(msg) => {
201 join_task_thread(handle);
202 msg
203 }
204 Err(_) => {
205 source.cancel();
206 join_task_thread_bounded(handle, "task_with_timeout_named");
207 on_timeout
208 }
209 }
210 })
211}
212
213pub fn task_with_retry<M, F>(policy: RetryPolicy, f: F, on_exhaust: fn(String) -> M) -> Cmd<M>
220where
221 M: Send + 'static,
222 F: Fn() -> Result<M, String> + Send + 'static,
223{
224 Cmd::task(move || {
225 let mut last_err = String::new();
226 for attempt in 0..=policy.max_retries {
227 match f() {
228 Ok(msg) => return msg,
229 Err(e) => {
230 last_err = e;
231 if attempt < policy.max_retries {
232 std::thread::sleep(policy.delay(attempt));
233 }
234 }
235 }
236 }
237 on_exhaust(last_err)
238 })
239}
240
241pub fn task_with_retry_and_timeout<M, F>(
247 policy: RetryPolicy,
248 per_attempt_timeout: Duration,
249 f: F,
250 on_exhaust: fn(String) -> M,
251) -> Cmd<M>
252where
253 M: Send + 'static,
254 F: Fn(CancellationToken) -> Result<M, String> + Send + Sync + 'static,
255{
256 Cmd::task(move || {
257 let f = std::sync::Arc::new(f);
258 let mut last_err = String::new();
259 for attempt in 0..=policy.max_retries {
260 let source = CancellationSource::new();
261 let token = source.token();
262 let (tx, rx) = std::sync::mpsc::channel();
263 let f_clone = std::sync::Arc::clone(&f);
264 let handle = std::thread::spawn(move || {
265 let result = f_clone(token);
266 let _ = tx.send(result);
267 });
268 match rx.recv_timeout(per_attempt_timeout) {
269 Ok(Ok(msg)) => {
270 join_task_thread(handle);
271 return msg;
272 }
273 Ok(Err(e)) => {
274 join_task_thread(handle);
275 last_err = e;
276 }
277 Err(_) => {
278 source.cancel();
279 join_task_thread_bounded(handle, "task_with_retry_and_timeout");
280 last_err = "timeout".into();
281 }
282 }
283 if attempt < policy.max_retries {
284 std::thread::sleep(policy.delay(attempt));
285 }
286 }
287 on_exhaust(last_err)
288 })
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn fixed_backoff_constant_delay() {
297 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
298 assert_eq!(policy.delay(0), Duration::from_millis(100));
299 assert_eq!(policy.delay(1), Duration::from_millis(100));
300 assert_eq!(policy.delay(2), Duration::from_millis(100));
301 }
302
303 #[test]
304 fn exponential_backoff_doubles() {
305 let policy = RetryPolicy::new(
306 5,
307 BackoffStrategy::Exponential {
308 base_ms: 100,
309 max_ms: 5000,
310 },
311 );
312 assert_eq!(policy.delay(0), Duration::from_millis(100));
313 assert_eq!(policy.delay(1), Duration::from_millis(200));
314 assert_eq!(policy.delay(2), Duration::from_millis(400));
315 assert_eq!(policy.delay(3), Duration::from_millis(800));
316 }
317
318 #[test]
319 fn exponential_backoff_caps_at_max() {
320 let policy = RetryPolicy::new(
321 5,
322 BackoffStrategy::Exponential {
323 base_ms: 1000,
324 max_ms: 3000,
325 },
326 );
327 assert_eq!(policy.delay(0), Duration::from_millis(1000));
328 assert_eq!(policy.delay(1), Duration::from_millis(2000));
329 assert_eq!(policy.delay(2), Duration::from_millis(3000)); assert_eq!(policy.delay(3), Duration::from_millis(3000)); }
332
333 #[test]
334 fn linear_backoff_increments() {
335 let policy = RetryPolicy::new(
336 4,
337 BackoffStrategy::Linear {
338 base_ms: 100,
339 max_ms: 500,
340 },
341 );
342 assert_eq!(policy.delay(0), Duration::from_millis(100));
343 assert_eq!(policy.delay(1), Duration::from_millis(200));
344 assert_eq!(policy.delay(2), Duration::from_millis(300));
345 assert_eq!(policy.delay(3), Duration::from_millis(400));
346 assert_eq!(policy.delay(4), Duration::from_millis(500)); }
348
349 #[test]
350 fn linear_backoff_caps_at_max() {
351 let policy = RetryPolicy::new(
352 4,
353 BackoffStrategy::Linear {
354 base_ms: 200,
355 max_ms: 500,
356 },
357 );
358 assert_eq!(policy.delay(2), Duration::from_millis(500)); }
360
361 #[test]
362 fn no_retry_policy() {
363 let policy = RetryPolicy::no_retry();
364 assert_eq!(policy.max_retries, 0);
365 }
366
367 #[test]
368 fn total_max_delay_fixed() {
369 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
370 assert_eq!(policy.total_max_delay(), Duration::from_millis(300));
371 }
372
373 #[test]
374 fn total_max_delay_exponential() {
375 let policy = RetryPolicy::new(
376 3,
377 BackoffStrategy::Exponential {
378 base_ms: 100,
379 max_ms: 10000,
380 },
381 );
382 assert_eq!(policy.total_max_delay(), Duration::from_millis(700));
384 }
385
386 #[test]
387 fn total_max_delay_zero_retries() {
388 let policy = RetryPolicy::no_retry();
389 assert_eq!(policy.total_max_delay(), Duration::ZERO);
390 }
391
392 #[test]
393 fn exponential_backoff_overflow_saturates() {
394 let policy = RetryPolicy::new(
395 1,
396 BackoffStrategy::Exponential {
397 base_ms: u64::MAX / 2,
398 max_ms: u64::MAX,
399 },
400 );
401 let _ = policy.delay(30);
403 }
404
405 #[test]
406 fn linear_backoff_overflow_saturates() {
407 let policy = RetryPolicy::new(
408 1,
409 BackoffStrategy::Linear {
410 base_ms: u64::MAX / 2,
411 max_ms: u64::MAX,
412 },
413 );
414 let _ = policy.delay(30);
415 }
416
417 #[test]
418 fn retry_policy_clone_eq() {
419 let policy = RetryPolicy::new(
420 3,
421 BackoffStrategy::Exponential {
422 base_ms: 100,
423 max_ms: 5000,
424 },
425 );
426 let cloned = policy.clone();
427 assert_eq!(policy, cloned);
428 }
429
430 #[test]
431 fn task_with_retry_succeeds_first_try() {
432 #[derive(Debug, PartialEq)]
433 enum Msg {
434 Ok(i32),
435 Err(String),
436 }
437
438 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 1 });
439 let cmd = task_with_retry(policy, || Ok(Msg::Ok(42)), Msg::Err);
440
441 assert_eq!(cmd.type_name(), "Task");
443 }
444
445 #[test]
446 fn task_with_timeout_produces_task() {
447 #[derive(Debug)]
448 #[allow(dead_code)]
449 enum Msg {
450 Result(i32),
451 Timeout,
452 }
453
454 let cmd = task_with_timeout(
455 Duration::from_secs(1),
456 |_token| Msg::Result(42),
457 Msg::Timeout,
458 );
459 assert_eq!(cmd.type_name(), "Task");
460 }
461
462 #[test]
463 fn task_with_timeout_requests_cancellation_on_timeout() {
464 use std::sync::Arc;
465 use std::sync::atomic::{AtomicBool, Ordering};
466
467 #[derive(Debug, PartialEq)]
468 enum Msg {
469 Finished,
470 Timeout,
471 }
472
473 let cancelled = Arc::new(AtomicBool::new(false));
474 let worker_exited = Arc::new(AtomicBool::new(false));
475 let cancelled_flag = Arc::clone(&cancelled);
476 let exited_flag = Arc::clone(&worker_exited);
477
478 let cmd = task_with_timeout(
479 Duration::from_millis(10),
480 move |token| {
481 cancelled_flag.store(token.wait_timeout(Duration::from_secs(1)), Ordering::SeqCst);
482 exited_flag.store(true, Ordering::SeqCst);
483 Msg::Finished
484 },
485 Msg::Timeout,
486 );
487
488 let result = match cmd {
489 Cmd::Task(_, task) => task(),
490 other => panic!("expected Task, got {other:?}"),
491 };
492
493 assert_eq!(result, Msg::Timeout);
494 std::thread::sleep(Duration::from_millis(50));
495 assert!(cancelled.load(Ordering::SeqCst));
496 assert!(worker_exited.load(Ordering::SeqCst));
497 }
498
499 #[test]
500 fn task_with_retry_and_timeout_cancels_each_timed_out_attempt() {
501 use std::sync::Arc;
502 use std::sync::atomic::{AtomicUsize, Ordering};
503
504 #[derive(Debug, PartialEq)]
505 enum Msg {
506 Exhausted(String),
507 }
508
509 fn on_exhaust(err: String) -> Msg {
510 Msg::Exhausted(err)
511 }
512
513 let attempts = Arc::new(AtomicUsize::new(0));
514 let cancelled = Arc::new(AtomicUsize::new(0));
515 let attempts_flag = Arc::clone(&attempts);
516 let cancelled_flag = Arc::clone(&cancelled);
517 let policy = RetryPolicy::new(1, BackoffStrategy::Fixed { delay_ms: 0 });
518
519 let cmd = task_with_retry_and_timeout(
520 policy,
521 Duration::from_millis(10),
522 move |token| {
523 attempts_flag.fetch_add(1, Ordering::SeqCst);
524 if token.wait_timeout(Duration::from_secs(1)) {
525 cancelled_flag.fetch_add(1, Ordering::SeqCst);
526 }
527 Err("cancelled".to_owned())
528 },
529 on_exhaust,
530 );
531
532 let result = match cmd {
533 Cmd::Task(_, task) => task(),
534 other => panic!("expected Task, got {other:?}"),
535 };
536
537 assert_eq!(result, Msg::Exhausted("timeout".to_owned()));
538 std::thread::sleep(Duration::from_millis(50));
539 assert_eq!(attempts.load(Ordering::SeqCst), 2);
540 assert_eq!(cancelled.load(Ordering::SeqCst), 2);
541 }
542
543 #[test]
544 fn backoff_strategy_variants_debug() {
545 let fixed = BackoffStrategy::Fixed { delay_ms: 100 };
546 let exp = BackoffStrategy::Exponential {
547 base_ms: 100,
548 max_ms: 5000,
549 };
550 let linear = BackoffStrategy::Linear {
551 base_ms: 100,
552 max_ms: 500,
553 };
554 let _ = format!("{fixed:?}");
556 let _ = format!("{exp:?}");
557 let _ = format!("{linear:?}");
558 }
559}