Skip to main content

ftui_runtime/
retry.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Retry policies and timeout-enforced task helpers.
3//!
4//! Provides [`RetryPolicy`] for configurable retry-with-backoff and
5//! [`task_with_timeout`] / [`task_with_retry`] constructors that wrap
6//! [`Cmd::Task`](crate::Cmd) with deterministic lifecycle guarantees.
7//!
8//! # Migration rationale
9//!
10//! Source frameworks often have retry/timeout baked into effect middleware.
11//! These helpers give the migration code emitter explicit, testable primitives
12//! to target instead of ad-hoc retry loops.
13//!
14//! # Determinism
15//!
16//! Backoff delays use fixed formulas (no jitter/randomness) so that
17//! replay-based determinism tests can reproduce exact timing sequences.
18//!
19//! # Example
20//!
21//! ```
22//! use ftui_runtime::retry::{RetryPolicy, BackoffStrategy};
23//! use std::time::Duration;
24//!
25//! let policy = RetryPolicy::new(3, BackoffStrategy::Exponential {
26//!     base_ms: 100,
27//!     max_ms: 5000,
28//! });
29//!
30//! assert_eq!(policy.delay(0), Duration::from_millis(100));
31//! assert_eq!(policy.delay(1), Duration::from_millis(200));
32//! assert_eq!(policy.delay(2), Duration::from_millis(400));
33//! ```
34
35#![forbid(unsafe_code)]
36
37use crate::cancellation::{CancellationSource, CancellationToken};
38use crate::program::{Cmd, TaskSpec};
39use web_time::Duration;
40
41const TASK_THREAD_JOIN_TIMEOUT: Duration = Duration::from_millis(250);
42const TASK_THREAD_JOIN_POLL: Duration = Duration::from_millis(5);
43
44fn duration_from_millis_saturating(millis: u128) -> Duration {
45    if millis >= Duration::MAX.as_millis() {
46        Duration::MAX
47    } else {
48        let seconds = millis / 1_000;
49        let subsecond_millis = millis % 1_000;
50        let Ok(seconds) = u64::try_from(seconds) else {
51            return Duration::MAX;
52        };
53        let Ok(nanos) = u32::try_from(subsecond_millis.saturating_mul(1_000_000)) else {
54            return Duration::MAX;
55        };
56        Duration::new(seconds, nanos)
57    }
58}
59
60fn add_millis_saturating(total: &mut u128, millis: u128) {
61    *total = total.saturating_add(millis).min(Duration::MAX.as_millis());
62}
63
64fn join_task_thread(handle: std::thread::JoinHandle<()>) {
65    let _ = handle.join();
66}
67
68fn join_task_thread_bounded(handle: std::thread::JoinHandle<()>, task_name: &'static str) {
69    let start = web_time::Instant::now();
70    while !handle.is_finished() {
71        if start.elapsed() >= TASK_THREAD_JOIN_TIMEOUT {
72            tracing::warn!(
73                task = task_name,
74                timeout_ms = TASK_THREAD_JOIN_TIMEOUT.as_millis() as u64,
75                "Timed-out worker thread did not exit within the cancellation join timeout; detaching"
76            );
77            return;
78        }
79        std::thread::sleep(TASK_THREAD_JOIN_POLL);
80    }
81    join_task_thread(handle);
82}
83
84/// Backoff strategy for retry delays.
85#[derive(Debug, Clone, PartialEq)]
86#[cfg_attr(
87    feature = "state-persistence",
88    derive(serde::Serialize, serde::Deserialize)
89)]
90pub enum BackoffStrategy {
91    /// Fixed delay between retries.
92    Fixed {
93        /// Delay in milliseconds.
94        delay_ms: u64,
95    },
96    /// Exponential backoff: `base_ms * 2^attempt`, capped at `max_ms`.
97    Exponential {
98        /// Base delay in milliseconds.
99        base_ms: u64,
100        /// Maximum delay cap in milliseconds.
101        max_ms: u64,
102    },
103    /// Linear backoff: `base_ms * (attempt + 1)`, capped at `max_ms`.
104    Linear {
105        /// Base delay in milliseconds.
106        base_ms: u64,
107        /// Maximum delay cap in milliseconds.
108        max_ms: u64,
109    },
110}
111
112/// A retry policy with configurable attempts and backoff.
113#[derive(Debug, Clone, PartialEq)]
114#[cfg_attr(
115    feature = "state-persistence",
116    derive(serde::Serialize, serde::Deserialize)
117)]
118pub struct RetryPolicy {
119    /// Maximum number of retry attempts (0 = no retries, just the initial attempt).
120    pub max_retries: u32,
121    /// Backoff strategy between retries.
122    pub backoff: BackoffStrategy,
123}
124
125impl RetryPolicy {
126    /// Create a new retry policy.
127    pub fn new(max_retries: u32, backoff: BackoffStrategy) -> Self {
128        Self {
129            max_retries,
130            backoff,
131        }
132    }
133
134    /// No retries — execute once.
135    pub fn no_retry() -> Self {
136        Self {
137            max_retries: 0,
138            backoff: BackoffStrategy::Fixed { delay_ms: 0 },
139        }
140    }
141
142    /// Compute the delay before the given attempt (0-indexed).
143    pub fn delay(&self, attempt: u32) -> Duration {
144        match &self.backoff {
145            BackoffStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
146            BackoffStrategy::Exponential { base_ms, max_ms } => {
147                let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
148                let delay = base_ms.saturating_mul(multiplier);
149                Duration::from_millis(delay.min(*max_ms))
150            }
151            BackoffStrategy::Linear { base_ms, max_ms } => {
152                let delay = base_ms.saturating_mul(u64::from(attempt) + 1);
153                Duration::from_millis(delay.min(*max_ms))
154            }
155        }
156    }
157
158    /// Total maximum delay across all retries (for timeout budgeting).
159    pub fn total_max_delay(&self) -> Duration {
160        let retry_count = u128::from(self.max_retries);
161        let max_duration_millis = Duration::MAX.as_millis();
162        match &self.backoff {
163            BackoffStrategy::Fixed { delay_ms } => {
164                duration_from_millis_saturating(u128::from(*delay_ms).saturating_mul(retry_count))
165            }
166            BackoffStrategy::Linear { base_ms, max_ms } => {
167                if self.max_retries == 0 || *base_ms == 0 || *max_ms == 0 {
168                    return Duration::ZERO;
169                }
170
171                let uncapped_terms = retry_count.min(u128::from(*max_ms / *base_ms));
172                let arithmetic_sum =
173                    uncapped_terms.saturating_mul(uncapped_terms.saturating_add(1)) / 2;
174                let mut total_millis = u128::from(*base_ms)
175                    .saturating_mul(arithmetic_sum)
176                    .min(max_duration_millis);
177
178                let capped_terms = retry_count.saturating_sub(uncapped_terms);
179                add_millis_saturating(
180                    &mut total_millis,
181                    u128::from(*max_ms).saturating_mul(capped_terms),
182                );
183                duration_from_millis_saturating(total_millis)
184            }
185            BackoffStrategy::Exponential { base_ms, max_ms } => {
186                if self.max_retries == 0 || *base_ms == 0 || *max_ms == 0 {
187                    return Duration::ZERO;
188                }
189
190                let mut total_millis = 0_u128;
191                let mut attempt = 0_u32;
192                while attempt < self.max_retries {
193                    let delay_millis = match 1_u64.checked_shl(attempt) {
194                        Some(multiplier) => base_ms.saturating_mul(multiplier).min(*max_ms),
195                        None => *max_ms,
196                    };
197                    add_millis_saturating(&mut total_millis, u128::from(delay_millis));
198                    attempt = attempt.saturating_add(1);
199
200                    if delay_millis == *max_ms {
201                        let remaining = u128::from(self.max_retries.saturating_sub(attempt));
202                        add_millis_saturating(
203                            &mut total_millis,
204                            u128::from(*max_ms).saturating_mul(remaining),
205                        );
206                        break;
207                    }
208                }
209                duration_from_millis_saturating(total_millis)
210            }
211        }
212    }
213}
214
215/// Create a [`Cmd::Task`] that enforces a cooperative timeout.
216///
217/// The worker closure receives a [`CancellationToken`] and must honor it for
218/// timely timeout teardown. On timeout, the runtime requests cancellation and
219/// returns `on_timeout`; any late worker result is discarded.
220pub fn task_with_timeout<M, F>(timeout: Duration, f: F, on_timeout: M) -> Cmd<M>
221where
222    M: Send + 'static,
223    F: FnOnce(CancellationToken) -> M + Send + 'static,
224{
225    Cmd::task(move || {
226        let source = CancellationSource::new();
227        let token = source.token();
228        let (tx, rx) = std::sync::mpsc::channel();
229        let handle = std::thread::spawn(move || {
230            let result = f(token);
231            let _ = tx.send(result);
232        });
233        match rx.recv_timeout(timeout) {
234            Ok(msg) => {
235                join_task_thread(handle);
236                msg
237            }
238            Err(_) => {
239                source.cancel();
240                join_task_thread_bounded(handle, "task_with_timeout");
241                on_timeout
242            }
243        }
244    })
245}
246
247/// Create a [`Cmd::Task`] with a named spec and cooperative timeout.
248pub fn task_with_timeout_named<M, F>(
249    name: impl Into<String>,
250    timeout: Duration,
251    f: F,
252    on_timeout: M,
253) -> Cmd<M>
254where
255    M: Send + 'static,
256    F: FnOnce(CancellationToken) -> M + Send + 'static,
257{
258    Cmd::task_with_spec(TaskSpec::default().with_name(name), move || {
259        let source = CancellationSource::new();
260        let token = source.token();
261        let (tx, rx) = std::sync::mpsc::channel();
262        let handle = std::thread::spawn(move || {
263            let result = f(token);
264            let _ = tx.send(result);
265        });
266        match rx.recv_timeout(timeout) {
267            Ok(msg) => {
268                join_task_thread(handle);
269                msg
270            }
271            Err(_) => {
272                source.cancel();
273                join_task_thread_bounded(handle, "task_with_timeout_named");
274                on_timeout
275            }
276        }
277    })
278}
279
280/// Create a [`Cmd::Task`] that retries on failure with the given policy.
281///
282/// The `f` closure returns `Result<M, String>`. On `Ok`, the message is
283/// returned immediately. On `Err`, the task retries according to the policy,
284/// sleeping between attempts. After all retries are exhausted, `on_exhaust`
285/// is called with the last error to produce a fallback message.
286pub fn task_with_retry<M, F>(policy: RetryPolicy, f: F, on_exhaust: fn(String) -> M) -> Cmd<M>
287where
288    M: Send + 'static,
289    F: Fn() -> Result<M, String> + Send + 'static,
290{
291    Cmd::task(move || {
292        let mut last_err = String::new();
293        for attempt in 0..=policy.max_retries {
294            match f() {
295                Ok(msg) => return msg,
296                Err(e) => {
297                    last_err = e;
298                    if attempt < policy.max_retries {
299                        std::thread::sleep(policy.delay(attempt));
300                    }
301                }
302            }
303        }
304        on_exhaust(last_err)
305    })
306}
307
308/// Create a [`Cmd::Task`] with both retry and timeout.
309///
310/// Each individual attempt is bounded by `per_attempt_timeout`. The worker
311/// receives a [`CancellationToken`] and must honor it for timely timeout
312/// teardown. The total number of attempts is governed by the retry policy.
313pub fn task_with_retry_and_timeout<M, F>(
314    policy: RetryPolicy,
315    per_attempt_timeout: Duration,
316    f: F,
317    on_exhaust: fn(String) -> M,
318) -> Cmd<M>
319where
320    M: Send + 'static,
321    F: Fn(CancellationToken) -> Result<M, String> + Send + Sync + 'static,
322{
323    Cmd::task(move || {
324        let f = std::sync::Arc::new(f);
325        let mut last_err = String::new();
326        for attempt in 0..=policy.max_retries {
327            let source = CancellationSource::new();
328            let token = source.token();
329            let (tx, rx) = std::sync::mpsc::channel();
330            let f_clone = std::sync::Arc::clone(&f);
331            let handle = std::thread::spawn(move || {
332                let result = f_clone(token);
333                let _ = tx.send(result);
334            });
335            match rx.recv_timeout(per_attempt_timeout) {
336                Ok(Ok(msg)) => {
337                    join_task_thread(handle);
338                    return msg;
339                }
340                Ok(Err(e)) => {
341                    join_task_thread(handle);
342                    last_err = e;
343                }
344                Err(_) => {
345                    source.cancel();
346                    join_task_thread_bounded(handle, "task_with_retry_and_timeout");
347                    last_err = "timeout".into();
348                }
349            }
350            if attempt < policy.max_retries {
351                std::thread::sleep(policy.delay(attempt));
352            }
353        }
354        on_exhaust(last_err)
355    })
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn fixed_backoff_constant_delay() {
364        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
365        assert_eq!(policy.delay(0), Duration::from_millis(100));
366        assert_eq!(policy.delay(1), Duration::from_millis(100));
367        assert_eq!(policy.delay(2), Duration::from_millis(100));
368    }
369
370    #[test]
371    fn exponential_backoff_doubles() {
372        let policy = RetryPolicy::new(
373            5,
374            BackoffStrategy::Exponential {
375                base_ms: 100,
376                max_ms: 5000,
377            },
378        );
379        assert_eq!(policy.delay(0), Duration::from_millis(100));
380        assert_eq!(policy.delay(1), Duration::from_millis(200));
381        assert_eq!(policy.delay(2), Duration::from_millis(400));
382        assert_eq!(policy.delay(3), Duration::from_millis(800));
383    }
384
385    #[test]
386    fn exponential_backoff_caps_at_max() {
387        let policy = RetryPolicy::new(
388            5,
389            BackoffStrategy::Exponential {
390                base_ms: 1000,
391                max_ms: 3000,
392            },
393        );
394        assert_eq!(policy.delay(0), Duration::from_millis(1000));
395        assert_eq!(policy.delay(1), Duration::from_millis(2000));
396        assert_eq!(policy.delay(2), Duration::from_millis(3000)); // capped
397        assert_eq!(policy.delay(3), Duration::from_millis(3000)); // capped
398    }
399
400    #[test]
401    fn linear_backoff_increments() {
402        let policy = RetryPolicy::new(
403            4,
404            BackoffStrategy::Linear {
405                base_ms: 100,
406                max_ms: 500,
407            },
408        );
409        assert_eq!(policy.delay(0), Duration::from_millis(100));
410        assert_eq!(policy.delay(1), Duration::from_millis(200));
411        assert_eq!(policy.delay(2), Duration::from_millis(300));
412        assert_eq!(policy.delay(3), Duration::from_millis(400));
413        assert_eq!(policy.delay(4), Duration::from_millis(500)); // capped
414    }
415
416    #[test]
417    fn linear_backoff_caps_at_max() {
418        let policy = RetryPolicy::new(
419            4,
420            BackoffStrategy::Linear {
421                base_ms: 200,
422                max_ms: 500,
423            },
424        );
425        assert_eq!(policy.delay(2), Duration::from_millis(500)); // 200*3 = 600, capped at 500
426    }
427
428    #[test]
429    fn no_retry_policy() {
430        let policy = RetryPolicy::no_retry();
431        assert_eq!(policy.max_retries, 0);
432    }
433
434    #[test]
435    fn total_max_delay_fixed() {
436        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
437        assert_eq!(policy.total_max_delay(), Duration::from_millis(300));
438    }
439
440    #[test]
441    fn total_max_delay_exponential() {
442        let policy = RetryPolicy::new(
443            3,
444            BackoffStrategy::Exponential {
445                base_ms: 100,
446                max_ms: 10000,
447            },
448        );
449        // Delays: 100 + 200 + 400 = 700
450        assert_eq!(policy.total_max_delay(), Duration::from_millis(700));
451    }
452
453    #[test]
454    fn total_max_delay_zero_retries() {
455        let policy = RetryPolicy::no_retry();
456        assert_eq!(policy.total_max_delay(), Duration::ZERO);
457    }
458
459    #[test]
460    fn total_max_delay_fixed_saturates_without_iterating() {
461        let policy = RetryPolicy::new(u32::MAX, BackoffStrategy::Fixed { delay_ms: u64::MAX });
462        assert_eq!(policy.total_max_delay(), Duration::MAX);
463    }
464
465    #[test]
466    fn total_max_delay_linear_handles_large_retry_counts() {
467        let policy = RetryPolicy::new(
468            u32::MAX,
469            BackoffStrategy::Linear {
470                base_ms: 1,
471                max_ms: 10,
472            },
473        );
474        assert_eq!(
475            policy.total_max_delay(),
476            Duration::from_millis(10_u64.saturating_mul(u64::from(u32::MAX)) - 45)
477        );
478    }
479
480    #[test]
481    fn total_max_delay_exponential_saturates_after_cap() {
482        let policy = RetryPolicy::new(
483            6,
484            BackoffStrategy::Exponential {
485                base_ms: 10,
486                max_ms: 35,
487            },
488        );
489        assert_eq!(
490            policy.total_max_delay(),
491            Duration::from_millis(10 + 20 + 35 * 4)
492        );
493    }
494
495    #[test]
496    fn total_max_delay_matches_delay_sequence_for_representative_policies() {
497        let policies = [
498            RetryPolicy::new(5, BackoffStrategy::Fixed { delay_ms: 7 }),
499            RetryPolicy::new(
500                6,
501                BackoffStrategy::Linear {
502                    base_ms: 3,
503                    max_ms: 10,
504                },
505            ),
506            RetryPolicy::new(
507                4,
508                BackoffStrategy::Linear {
509                    base_ms: 10,
510                    max_ms: 3,
511                },
512            ),
513            RetryPolicy::new(
514                6,
515                BackoffStrategy::Exponential {
516                    base_ms: 2,
517                    max_ms: 9,
518                },
519            ),
520        ];
521
522        for policy in policies {
523            let expected_millis = (0..policy.max_retries)
524                .map(|attempt| policy.delay(attempt).as_millis())
525                .sum::<u128>();
526            assert_eq!(policy.total_max_delay().as_millis(), expected_millis);
527        }
528    }
529
530    #[test]
531    fn exponential_backoff_overflow_saturates() {
532        let policy = RetryPolicy::new(
533            1,
534            BackoffStrategy::Exponential {
535                base_ms: u64::MAX / 2,
536                max_ms: u64::MAX,
537            },
538        );
539        // Should not panic on overflow
540        let _ = policy.delay(30);
541    }
542
543    #[test]
544    fn linear_backoff_overflow_saturates() {
545        let policy = RetryPolicy::new(
546            1,
547            BackoffStrategy::Linear {
548                base_ms: u64::MAX / 2,
549                max_ms: u64::MAX,
550            },
551        );
552        let _ = policy.delay(30);
553    }
554
555    #[test]
556    fn retry_policy_clone_eq() {
557        let policy = RetryPolicy::new(
558            3,
559            BackoffStrategy::Exponential {
560                base_ms: 100,
561                max_ms: 5000,
562            },
563        );
564        let cloned = policy.clone();
565        assert_eq!(policy, cloned);
566    }
567
568    #[test]
569    fn task_with_retry_succeeds_first_try() {
570        #[derive(Debug, PartialEq)]
571        enum Msg {
572            Ok(i32),
573            Err(String),
574        }
575
576        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 1 });
577        let cmd = task_with_retry(policy, || Ok(Msg::Ok(42)), Msg::Err);
578
579        // Verify it produces a Task variant
580        assert_eq!(cmd.type_name(), "Task");
581    }
582
583    #[test]
584    fn task_with_timeout_produces_task() {
585        #[derive(Debug)]
586        #[allow(dead_code)]
587        enum Msg {
588            Result(i32),
589            Timeout,
590        }
591
592        let cmd = task_with_timeout(
593            Duration::from_secs(1),
594            |_token| Msg::Result(42),
595            Msg::Timeout,
596        );
597        assert_eq!(cmd.type_name(), "Task");
598    }
599
600    #[test]
601    fn task_with_timeout_requests_cancellation_on_timeout() {
602        use std::sync::Arc;
603        use std::sync::atomic::{AtomicBool, Ordering};
604
605        #[derive(Debug, PartialEq)]
606        enum Msg {
607            Finished,
608            Timeout,
609        }
610
611        let cancelled = Arc::new(AtomicBool::new(false));
612        let worker_exited = Arc::new(AtomicBool::new(false));
613        let cancelled_flag = Arc::clone(&cancelled);
614        let exited_flag = Arc::clone(&worker_exited);
615
616        let cmd = task_with_timeout(
617            Duration::from_millis(10),
618            move |token| {
619                cancelled_flag.store(token.wait_timeout(Duration::from_secs(1)), Ordering::SeqCst);
620                exited_flag.store(true, Ordering::SeqCst);
621                Msg::Finished
622            },
623            Msg::Timeout,
624        );
625
626        let result = match cmd {
627            Cmd::Task(_, task) => task(),
628            other => panic!("expected Task, got {other:?}"),
629        };
630
631        assert_eq!(result, Msg::Timeout);
632        std::thread::sleep(Duration::from_millis(50));
633        assert!(cancelled.load(Ordering::SeqCst));
634        assert!(worker_exited.load(Ordering::SeqCst));
635    }
636
637    #[test]
638    fn task_with_retry_and_timeout_cancels_each_timed_out_attempt() {
639        use std::sync::Arc;
640        use std::sync::atomic::{AtomicUsize, Ordering};
641
642        #[derive(Debug, PartialEq)]
643        enum Msg {
644            Exhausted(String),
645        }
646
647        fn on_exhaust(err: String) -> Msg {
648            Msg::Exhausted(err)
649        }
650
651        let attempts = Arc::new(AtomicUsize::new(0));
652        let cancelled = Arc::new(AtomicUsize::new(0));
653        let attempts_flag = Arc::clone(&attempts);
654        let cancelled_flag = Arc::clone(&cancelled);
655        let policy = RetryPolicy::new(1, BackoffStrategy::Fixed { delay_ms: 0 });
656
657        let cmd = task_with_retry_and_timeout(
658            policy,
659            Duration::from_millis(10),
660            move |token| {
661                attempts_flag.fetch_add(1, Ordering::SeqCst);
662                if token.wait_timeout(Duration::from_secs(1)) {
663                    cancelled_flag.fetch_add(1, Ordering::SeqCst);
664                }
665                Err("cancelled".to_owned())
666            },
667            on_exhaust,
668        );
669
670        let result = match cmd {
671            Cmd::Task(_, task) => task(),
672            other => panic!("expected Task, got {other:?}"),
673        };
674
675        assert_eq!(result, Msg::Exhausted("timeout".to_owned()));
676        std::thread::sleep(Duration::from_millis(50));
677        assert_eq!(attempts.load(Ordering::SeqCst), 2);
678        assert_eq!(cancelled.load(Ordering::SeqCst), 2);
679    }
680
681    #[test]
682    fn backoff_strategy_variants_debug() {
683        let fixed = BackoffStrategy::Fixed { delay_ms: 100 };
684        let exp = BackoffStrategy::Exponential {
685            base_ms: 100,
686            max_ms: 5000,
687        };
688        let linear = BackoffStrategy::Linear {
689            base_ms: 100,
690            max_ms: 500,
691        };
692        // Just verify Debug doesn't panic
693        let _ = format!("{fixed:?}");
694        let _ = format!("{exp:?}");
695        let _ = format!("{linear:?}");
696    }
697}