1#![forbid(unsafe_code)]
36
37use crate::cancellation::{CancellationSource, CancellationToken};
38use crate::program::{Cmd, TaskSpec};
39use web_time::Duration;
40
41const TASK_THREAD_JOIN_TIMEOUT: Duration = Duration::from_millis(250);
42const TASK_THREAD_JOIN_POLL: Duration = Duration::from_millis(5);
43
44fn duration_from_millis_saturating(millis: u128) -> Duration {
45 if millis >= Duration::MAX.as_millis() {
46 Duration::MAX
47 } else {
48 let seconds = millis / 1_000;
49 let subsecond_millis = millis % 1_000;
50 let Ok(seconds) = u64::try_from(seconds) else {
51 return Duration::MAX;
52 };
53 let Ok(nanos) = u32::try_from(subsecond_millis.saturating_mul(1_000_000)) else {
54 return Duration::MAX;
55 };
56 Duration::new(seconds, nanos)
57 }
58}
59
60fn add_millis_saturating(total: &mut u128, millis: u128) {
61 *total = total.saturating_add(millis).min(Duration::MAX.as_millis());
62}
63
64fn join_task_thread(handle: std::thread::JoinHandle<()>) {
65 let _ = handle.join();
66}
67
68fn join_task_thread_bounded(handle: std::thread::JoinHandle<()>, task_name: &'static str) {
69 let start = web_time::Instant::now();
70 while !handle.is_finished() {
71 if start.elapsed() >= TASK_THREAD_JOIN_TIMEOUT {
72 tracing::warn!(
73 task = task_name,
74 timeout_ms = TASK_THREAD_JOIN_TIMEOUT.as_millis() as u64,
75 "Timed-out worker thread did not exit within the cancellation join timeout; detaching"
76 );
77 return;
78 }
79 std::thread::sleep(TASK_THREAD_JOIN_POLL);
80 }
81 join_task_thread(handle);
82}
83
84#[derive(Debug, Clone, PartialEq)]
86#[cfg_attr(
87 feature = "state-persistence",
88 derive(serde::Serialize, serde::Deserialize)
89)]
90pub enum BackoffStrategy {
91 Fixed {
93 delay_ms: u64,
95 },
96 Exponential {
98 base_ms: u64,
100 max_ms: u64,
102 },
103 Linear {
105 base_ms: u64,
107 max_ms: u64,
109 },
110}
111
112#[derive(Debug, Clone, PartialEq)]
114#[cfg_attr(
115 feature = "state-persistence",
116 derive(serde::Serialize, serde::Deserialize)
117)]
118pub struct RetryPolicy {
119 pub max_retries: u32,
121 pub backoff: BackoffStrategy,
123}
124
125impl RetryPolicy {
126 pub fn new(max_retries: u32, backoff: BackoffStrategy) -> Self {
128 Self {
129 max_retries,
130 backoff,
131 }
132 }
133
134 pub fn no_retry() -> Self {
136 Self {
137 max_retries: 0,
138 backoff: BackoffStrategy::Fixed { delay_ms: 0 },
139 }
140 }
141
142 pub fn delay(&self, attempt: u32) -> Duration {
144 match &self.backoff {
145 BackoffStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
146 BackoffStrategy::Exponential { base_ms, max_ms } => {
147 let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
148 let delay = base_ms.saturating_mul(multiplier);
149 Duration::from_millis(delay.min(*max_ms))
150 }
151 BackoffStrategy::Linear { base_ms, max_ms } => {
152 let delay = base_ms.saturating_mul(u64::from(attempt) + 1);
153 Duration::from_millis(delay.min(*max_ms))
154 }
155 }
156 }
157
158 pub fn total_max_delay(&self) -> Duration {
160 let retry_count = u128::from(self.max_retries);
161 let max_duration_millis = Duration::MAX.as_millis();
162 match &self.backoff {
163 BackoffStrategy::Fixed { delay_ms } => {
164 duration_from_millis_saturating(u128::from(*delay_ms).saturating_mul(retry_count))
165 }
166 BackoffStrategy::Linear { base_ms, max_ms } => {
167 if self.max_retries == 0 || *base_ms == 0 || *max_ms == 0 {
168 return Duration::ZERO;
169 }
170
171 let uncapped_terms = retry_count.min(u128::from(*max_ms / *base_ms));
172 let arithmetic_sum =
173 uncapped_terms.saturating_mul(uncapped_terms.saturating_add(1)) / 2;
174 let mut total_millis = u128::from(*base_ms)
175 .saturating_mul(arithmetic_sum)
176 .min(max_duration_millis);
177
178 let capped_terms = retry_count.saturating_sub(uncapped_terms);
179 add_millis_saturating(
180 &mut total_millis,
181 u128::from(*max_ms).saturating_mul(capped_terms),
182 );
183 duration_from_millis_saturating(total_millis)
184 }
185 BackoffStrategy::Exponential { base_ms, max_ms } => {
186 if self.max_retries == 0 || *base_ms == 0 || *max_ms == 0 {
187 return Duration::ZERO;
188 }
189
190 let mut total_millis = 0_u128;
191 let mut attempt = 0_u32;
192 while attempt < self.max_retries {
193 let delay_millis = match 1_u64.checked_shl(attempt) {
194 Some(multiplier) => base_ms.saturating_mul(multiplier).min(*max_ms),
195 None => *max_ms,
196 };
197 add_millis_saturating(&mut total_millis, u128::from(delay_millis));
198 attempt = attempt.saturating_add(1);
199
200 if delay_millis == *max_ms {
201 let remaining = u128::from(self.max_retries.saturating_sub(attempt));
202 add_millis_saturating(
203 &mut total_millis,
204 u128::from(*max_ms).saturating_mul(remaining),
205 );
206 break;
207 }
208 }
209 duration_from_millis_saturating(total_millis)
210 }
211 }
212 }
213}
214
215pub fn task_with_timeout<M, F>(timeout: Duration, f: F, on_timeout: M) -> Cmd<M>
221where
222 M: Send + 'static,
223 F: FnOnce(CancellationToken) -> M + Send + 'static,
224{
225 Cmd::task(move || {
226 let source = CancellationSource::new();
227 let token = source.token();
228 let (tx, rx) = std::sync::mpsc::channel();
229 let handle = std::thread::spawn(move || {
230 let result = f(token);
231 let _ = tx.send(result);
232 });
233 match rx.recv_timeout(timeout) {
234 Ok(msg) => {
235 join_task_thread(handle);
236 msg
237 }
238 Err(_) => {
239 source.cancel();
240 join_task_thread_bounded(handle, "task_with_timeout");
241 on_timeout
242 }
243 }
244 })
245}
246
247pub fn task_with_timeout_named<M, F>(
249 name: impl Into<String>,
250 timeout: Duration,
251 f: F,
252 on_timeout: M,
253) -> Cmd<M>
254where
255 M: Send + 'static,
256 F: FnOnce(CancellationToken) -> M + Send + 'static,
257{
258 Cmd::task_with_spec(TaskSpec::default().with_name(name), move || {
259 let source = CancellationSource::new();
260 let token = source.token();
261 let (tx, rx) = std::sync::mpsc::channel();
262 let handle = std::thread::spawn(move || {
263 let result = f(token);
264 let _ = tx.send(result);
265 });
266 match rx.recv_timeout(timeout) {
267 Ok(msg) => {
268 join_task_thread(handle);
269 msg
270 }
271 Err(_) => {
272 source.cancel();
273 join_task_thread_bounded(handle, "task_with_timeout_named");
274 on_timeout
275 }
276 }
277 })
278}
279
280pub fn task_with_retry<M, F>(policy: RetryPolicy, f: F, on_exhaust: fn(String) -> M) -> Cmd<M>
287where
288 M: Send + 'static,
289 F: Fn() -> Result<M, String> + Send + 'static,
290{
291 Cmd::task(move || {
292 let mut last_err = String::new();
293 for attempt in 0..=policy.max_retries {
294 match f() {
295 Ok(msg) => return msg,
296 Err(e) => {
297 last_err = e;
298 if attempt < policy.max_retries {
299 std::thread::sleep(policy.delay(attempt));
300 }
301 }
302 }
303 }
304 on_exhaust(last_err)
305 })
306}
307
308pub fn task_with_retry_and_timeout<M, F>(
314 policy: RetryPolicy,
315 per_attempt_timeout: Duration,
316 f: F,
317 on_exhaust: fn(String) -> M,
318) -> Cmd<M>
319where
320 M: Send + 'static,
321 F: Fn(CancellationToken) -> Result<M, String> + Send + Sync + 'static,
322{
323 Cmd::task(move || {
324 let f = std::sync::Arc::new(f);
325 let mut last_err = String::new();
326 for attempt in 0..=policy.max_retries {
327 let source = CancellationSource::new();
328 let token = source.token();
329 let (tx, rx) = std::sync::mpsc::channel();
330 let f_clone = std::sync::Arc::clone(&f);
331 let handle = std::thread::spawn(move || {
332 let result = f_clone(token);
333 let _ = tx.send(result);
334 });
335 match rx.recv_timeout(per_attempt_timeout) {
336 Ok(Ok(msg)) => {
337 join_task_thread(handle);
338 return msg;
339 }
340 Ok(Err(e)) => {
341 join_task_thread(handle);
342 last_err = e;
343 }
344 Err(_) => {
345 source.cancel();
346 join_task_thread_bounded(handle, "task_with_retry_and_timeout");
347 last_err = "timeout".into();
348 }
349 }
350 if attempt < policy.max_retries {
351 std::thread::sleep(policy.delay(attempt));
352 }
353 }
354 on_exhaust(last_err)
355 })
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn fixed_backoff_constant_delay() {
364 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
365 assert_eq!(policy.delay(0), Duration::from_millis(100));
366 assert_eq!(policy.delay(1), Duration::from_millis(100));
367 assert_eq!(policy.delay(2), Duration::from_millis(100));
368 }
369
370 #[test]
371 fn exponential_backoff_doubles() {
372 let policy = RetryPolicy::new(
373 5,
374 BackoffStrategy::Exponential {
375 base_ms: 100,
376 max_ms: 5000,
377 },
378 );
379 assert_eq!(policy.delay(0), Duration::from_millis(100));
380 assert_eq!(policy.delay(1), Duration::from_millis(200));
381 assert_eq!(policy.delay(2), Duration::from_millis(400));
382 assert_eq!(policy.delay(3), Duration::from_millis(800));
383 }
384
385 #[test]
386 fn exponential_backoff_caps_at_max() {
387 let policy = RetryPolicy::new(
388 5,
389 BackoffStrategy::Exponential {
390 base_ms: 1000,
391 max_ms: 3000,
392 },
393 );
394 assert_eq!(policy.delay(0), Duration::from_millis(1000));
395 assert_eq!(policy.delay(1), Duration::from_millis(2000));
396 assert_eq!(policy.delay(2), Duration::from_millis(3000)); assert_eq!(policy.delay(3), Duration::from_millis(3000)); }
399
400 #[test]
401 fn linear_backoff_increments() {
402 let policy = RetryPolicy::new(
403 4,
404 BackoffStrategy::Linear {
405 base_ms: 100,
406 max_ms: 500,
407 },
408 );
409 assert_eq!(policy.delay(0), Duration::from_millis(100));
410 assert_eq!(policy.delay(1), Duration::from_millis(200));
411 assert_eq!(policy.delay(2), Duration::from_millis(300));
412 assert_eq!(policy.delay(3), Duration::from_millis(400));
413 assert_eq!(policy.delay(4), Duration::from_millis(500)); }
415
416 #[test]
417 fn linear_backoff_caps_at_max() {
418 let policy = RetryPolicy::new(
419 4,
420 BackoffStrategy::Linear {
421 base_ms: 200,
422 max_ms: 500,
423 },
424 );
425 assert_eq!(policy.delay(2), Duration::from_millis(500)); }
427
428 #[test]
429 fn no_retry_policy() {
430 let policy = RetryPolicy::no_retry();
431 assert_eq!(policy.max_retries, 0);
432 }
433
434 #[test]
435 fn total_max_delay_fixed() {
436 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
437 assert_eq!(policy.total_max_delay(), Duration::from_millis(300));
438 }
439
440 #[test]
441 fn total_max_delay_exponential() {
442 let policy = RetryPolicy::new(
443 3,
444 BackoffStrategy::Exponential {
445 base_ms: 100,
446 max_ms: 10000,
447 },
448 );
449 assert_eq!(policy.total_max_delay(), Duration::from_millis(700));
451 }
452
453 #[test]
454 fn total_max_delay_zero_retries() {
455 let policy = RetryPolicy::no_retry();
456 assert_eq!(policy.total_max_delay(), Duration::ZERO);
457 }
458
459 #[test]
460 fn total_max_delay_fixed_saturates_without_iterating() {
461 let policy = RetryPolicy::new(u32::MAX, BackoffStrategy::Fixed { delay_ms: u64::MAX });
462 assert_eq!(policy.total_max_delay(), Duration::MAX);
463 }
464
465 #[test]
466 fn total_max_delay_linear_handles_large_retry_counts() {
467 let policy = RetryPolicy::new(
468 u32::MAX,
469 BackoffStrategy::Linear {
470 base_ms: 1,
471 max_ms: 10,
472 },
473 );
474 assert_eq!(
475 policy.total_max_delay(),
476 Duration::from_millis(10_u64.saturating_mul(u64::from(u32::MAX)) - 45)
477 );
478 }
479
480 #[test]
481 fn total_max_delay_exponential_saturates_after_cap() {
482 let policy = RetryPolicy::new(
483 6,
484 BackoffStrategy::Exponential {
485 base_ms: 10,
486 max_ms: 35,
487 },
488 );
489 assert_eq!(
490 policy.total_max_delay(),
491 Duration::from_millis(10 + 20 + 35 * 4)
492 );
493 }
494
495 #[test]
496 fn total_max_delay_matches_delay_sequence_for_representative_policies() {
497 let policies = [
498 RetryPolicy::new(5, BackoffStrategy::Fixed { delay_ms: 7 }),
499 RetryPolicy::new(
500 6,
501 BackoffStrategy::Linear {
502 base_ms: 3,
503 max_ms: 10,
504 },
505 ),
506 RetryPolicy::new(
507 4,
508 BackoffStrategy::Linear {
509 base_ms: 10,
510 max_ms: 3,
511 },
512 ),
513 RetryPolicy::new(
514 6,
515 BackoffStrategy::Exponential {
516 base_ms: 2,
517 max_ms: 9,
518 },
519 ),
520 ];
521
522 for policy in policies {
523 let expected_millis = (0..policy.max_retries)
524 .map(|attempt| policy.delay(attempt).as_millis())
525 .sum::<u128>();
526 assert_eq!(policy.total_max_delay().as_millis(), expected_millis);
527 }
528 }
529
530 #[test]
531 fn exponential_backoff_overflow_saturates() {
532 let policy = RetryPolicy::new(
533 1,
534 BackoffStrategy::Exponential {
535 base_ms: u64::MAX / 2,
536 max_ms: u64::MAX,
537 },
538 );
539 let _ = policy.delay(30);
541 }
542
543 #[test]
544 fn linear_backoff_overflow_saturates() {
545 let policy = RetryPolicy::new(
546 1,
547 BackoffStrategy::Linear {
548 base_ms: u64::MAX / 2,
549 max_ms: u64::MAX,
550 },
551 );
552 let _ = policy.delay(30);
553 }
554
555 #[test]
556 fn retry_policy_clone_eq() {
557 let policy = RetryPolicy::new(
558 3,
559 BackoffStrategy::Exponential {
560 base_ms: 100,
561 max_ms: 5000,
562 },
563 );
564 let cloned = policy.clone();
565 assert_eq!(policy, cloned);
566 }
567
568 #[test]
569 fn task_with_retry_succeeds_first_try() {
570 #[derive(Debug, PartialEq)]
571 enum Msg {
572 Ok(i32),
573 Err(String),
574 }
575
576 let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 1 });
577 let cmd = task_with_retry(policy, || Ok(Msg::Ok(42)), Msg::Err);
578
579 assert_eq!(cmd.type_name(), "Task");
581 }
582
583 #[test]
584 fn task_with_timeout_produces_task() {
585 #[derive(Debug)]
586 #[allow(dead_code)]
587 enum Msg {
588 Result(i32),
589 Timeout,
590 }
591
592 let cmd = task_with_timeout(
593 Duration::from_secs(1),
594 |_token| Msg::Result(42),
595 Msg::Timeout,
596 );
597 assert_eq!(cmd.type_name(), "Task");
598 }
599
600 #[test]
601 fn task_with_timeout_requests_cancellation_on_timeout() {
602 use std::sync::Arc;
603 use std::sync::atomic::{AtomicBool, Ordering};
604
605 #[derive(Debug, PartialEq)]
606 enum Msg {
607 Finished,
608 Timeout,
609 }
610
611 let cancelled = Arc::new(AtomicBool::new(false));
612 let worker_exited = Arc::new(AtomicBool::new(false));
613 let cancelled_flag = Arc::clone(&cancelled);
614 let exited_flag = Arc::clone(&worker_exited);
615
616 let cmd = task_with_timeout(
617 Duration::from_millis(10),
618 move |token| {
619 cancelled_flag.store(token.wait_timeout(Duration::from_secs(1)), Ordering::SeqCst);
620 exited_flag.store(true, Ordering::SeqCst);
621 Msg::Finished
622 },
623 Msg::Timeout,
624 );
625
626 let result = match cmd {
627 Cmd::Task(_, task) => task(),
628 other => panic!("expected Task, got {other:?}"),
629 };
630
631 assert_eq!(result, Msg::Timeout);
632 std::thread::sleep(Duration::from_millis(50));
633 assert!(cancelled.load(Ordering::SeqCst));
634 assert!(worker_exited.load(Ordering::SeqCst));
635 }
636
637 #[test]
638 fn task_with_retry_and_timeout_cancels_each_timed_out_attempt() {
639 use std::sync::Arc;
640 use std::sync::atomic::{AtomicUsize, Ordering};
641
642 #[derive(Debug, PartialEq)]
643 enum Msg {
644 Exhausted(String),
645 }
646
647 fn on_exhaust(err: String) -> Msg {
648 Msg::Exhausted(err)
649 }
650
651 let attempts = Arc::new(AtomicUsize::new(0));
652 let cancelled = Arc::new(AtomicUsize::new(0));
653 let attempts_flag = Arc::clone(&attempts);
654 let cancelled_flag = Arc::clone(&cancelled);
655 let policy = RetryPolicy::new(1, BackoffStrategy::Fixed { delay_ms: 0 });
656
657 let cmd = task_with_retry_and_timeout(
658 policy,
659 Duration::from_millis(10),
660 move |token| {
661 attempts_flag.fetch_add(1, Ordering::SeqCst);
662 if token.wait_timeout(Duration::from_secs(1)) {
663 cancelled_flag.fetch_add(1, Ordering::SeqCst);
664 }
665 Err("cancelled".to_owned())
666 },
667 on_exhaust,
668 );
669
670 let result = match cmd {
671 Cmd::Task(_, task) => task(),
672 other => panic!("expected Task, got {other:?}"),
673 };
674
675 assert_eq!(result, Msg::Exhausted("timeout".to_owned()));
676 std::thread::sleep(Duration::from_millis(50));
677 assert_eq!(attempts.load(Ordering::SeqCst), 2);
678 assert_eq!(cancelled.load(Ordering::SeqCst), 2);
679 }
680
681 #[test]
682 fn backoff_strategy_variants_debug() {
683 let fixed = BackoffStrategy::Fixed { delay_ms: 100 };
684 let exp = BackoffStrategy::Exponential {
685 base_ms: 100,
686 max_ms: 5000,
687 };
688 let linear = BackoffStrategy::Linear {
689 base_ms: 100,
690 max_ms: 500,
691 };
692 let _ = format!("{fixed:?}");
694 let _ = format!("{exp:?}");
695 let _ = format!("{linear:?}");
696 }
697}