Skip to main content

ftui_runtime/
retry.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Retry policies and timeout-enforced task helpers.
3//!
4//! Provides [`RetryPolicy`] for configurable retry-with-backoff and
5//! [`task_with_timeout`] / [`task_with_retry`] constructors that wrap
6//! [`Cmd::Task`](crate::Cmd) with deterministic lifecycle guarantees.
7//!
8//! # Migration rationale
9//!
10//! Source frameworks often have retry/timeout baked into effect middleware.
11//! These helpers give the migration code emitter explicit, testable primitives
12//! to target instead of ad-hoc retry loops.
13//!
14//! # Determinism
15//!
16//! Backoff delays use fixed formulas (no jitter/randomness) so that
17//! replay-based determinism tests can reproduce exact timing sequences.
18//!
19//! # Example
20//!
21//! ```
22//! use ftui_runtime::retry::{RetryPolicy, BackoffStrategy};
23//! use std::time::Duration;
24//!
25//! let policy = RetryPolicy::new(3, BackoffStrategy::Exponential {
26//!     base_ms: 100,
27//!     max_ms: 5000,
28//! });
29//!
30//! assert_eq!(policy.delay(0), Duration::from_millis(100));
31//! assert_eq!(policy.delay(1), Duration::from_millis(200));
32//! assert_eq!(policy.delay(2), Duration::from_millis(400));
33//! ```
34
35#![forbid(unsafe_code)]
36
37use crate::cancellation::{CancellationSource, CancellationToken};
38use crate::program::{Cmd, TaskSpec};
39use web_time::Duration;
40
41const TASK_THREAD_JOIN_TIMEOUT: Duration = Duration::from_millis(250);
42const TASK_THREAD_JOIN_POLL: Duration = Duration::from_millis(5);
43
44fn join_task_thread(handle: std::thread::JoinHandle<()>) {
45    let _ = handle.join();
46}
47
48fn join_task_thread_bounded(handle: std::thread::JoinHandle<()>, task_name: &'static str) {
49    let start = web_time::Instant::now();
50    while !handle.is_finished() {
51        if start.elapsed() >= TASK_THREAD_JOIN_TIMEOUT {
52            tracing::warn!(
53                task = task_name,
54                timeout_ms = TASK_THREAD_JOIN_TIMEOUT.as_millis() as u64,
55                "Timed-out worker thread did not exit within the cancellation join timeout; detaching"
56            );
57            return;
58        }
59        std::thread::sleep(TASK_THREAD_JOIN_POLL);
60    }
61    join_task_thread(handle);
62}
63
64/// Backoff strategy for retry delays.
65#[derive(Debug, Clone, PartialEq)]
66#[cfg_attr(
67    feature = "state-persistence",
68    derive(serde::Serialize, serde::Deserialize)
69)]
70pub enum BackoffStrategy {
71    /// Fixed delay between retries.
72    Fixed {
73        /// Delay in milliseconds.
74        delay_ms: u64,
75    },
76    /// Exponential backoff: `base_ms * 2^attempt`, capped at `max_ms`.
77    Exponential {
78        /// Base delay in milliseconds.
79        base_ms: u64,
80        /// Maximum delay cap in milliseconds.
81        max_ms: u64,
82    },
83    /// Linear backoff: `base_ms * (attempt + 1)`, capped at `max_ms`.
84    Linear {
85        /// Base delay in milliseconds.
86        base_ms: u64,
87        /// Maximum delay cap in milliseconds.
88        max_ms: u64,
89    },
90}
91
92/// A retry policy with configurable attempts and backoff.
93#[derive(Debug, Clone, PartialEq)]
94#[cfg_attr(
95    feature = "state-persistence",
96    derive(serde::Serialize, serde::Deserialize)
97)]
98pub struct RetryPolicy {
99    /// Maximum number of retry attempts (0 = no retries, just the initial attempt).
100    pub max_retries: u32,
101    /// Backoff strategy between retries.
102    pub backoff: BackoffStrategy,
103}
104
105impl RetryPolicy {
106    /// Create a new retry policy.
107    pub fn new(max_retries: u32, backoff: BackoffStrategy) -> Self {
108        Self {
109            max_retries,
110            backoff,
111        }
112    }
113
114    /// No retries — execute once.
115    pub fn no_retry() -> Self {
116        Self {
117            max_retries: 0,
118            backoff: BackoffStrategy::Fixed { delay_ms: 0 },
119        }
120    }
121
122    /// Compute the delay before the given attempt (0-indexed).
123    pub fn delay(&self, attempt: u32) -> Duration {
124        match &self.backoff {
125            BackoffStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
126            BackoffStrategy::Exponential { base_ms, max_ms } => {
127                let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
128                let delay = base_ms.saturating_mul(multiplier);
129                Duration::from_millis(delay.min(*max_ms))
130            }
131            BackoffStrategy::Linear { base_ms, max_ms } => {
132                let delay = base_ms.saturating_mul(u64::from(attempt) + 1);
133                Duration::from_millis(delay.min(*max_ms))
134            }
135        }
136    }
137
138    /// Total maximum delay across all retries (for timeout budgeting).
139    pub fn total_max_delay(&self) -> Duration {
140        let mut total = Duration::ZERO;
141        for i in 0..self.max_retries {
142            total += self.delay(i);
143        }
144        total
145    }
146}
147
148/// Create a [`Cmd::Task`] that enforces a cooperative timeout.
149///
150/// The worker closure receives a [`CancellationToken`] and must honor it for
151/// timely timeout teardown. On timeout, the runtime requests cancellation and
152/// returns `on_timeout`; any late worker result is discarded.
153pub fn task_with_timeout<M, F>(timeout: Duration, f: F, on_timeout: M) -> Cmd<M>
154where
155    M: Send + 'static,
156    F: FnOnce(CancellationToken) -> M + Send + 'static,
157{
158    Cmd::task(move || {
159        let source = CancellationSource::new();
160        let token = source.token();
161        let (tx, rx) = std::sync::mpsc::channel();
162        let handle = std::thread::spawn(move || {
163            let result = f(token);
164            let _ = tx.send(result);
165        });
166        match rx.recv_timeout(timeout) {
167            Ok(msg) => {
168                join_task_thread(handle);
169                msg
170            }
171            Err(_) => {
172                source.cancel();
173                join_task_thread_bounded(handle, "task_with_timeout");
174                on_timeout
175            }
176        }
177    })
178}
179
180/// Create a [`Cmd::Task`] with a named spec and cooperative timeout.
181pub fn task_with_timeout_named<M, F>(
182    name: impl Into<String>,
183    timeout: Duration,
184    f: F,
185    on_timeout: M,
186) -> Cmd<M>
187where
188    M: Send + 'static,
189    F: FnOnce(CancellationToken) -> M + Send + 'static,
190{
191    Cmd::task_with_spec(TaskSpec::default().with_name(name), move || {
192        let source = CancellationSource::new();
193        let token = source.token();
194        let (tx, rx) = std::sync::mpsc::channel();
195        let handle = std::thread::spawn(move || {
196            let result = f(token);
197            let _ = tx.send(result);
198        });
199        match rx.recv_timeout(timeout) {
200            Ok(msg) => {
201                join_task_thread(handle);
202                msg
203            }
204            Err(_) => {
205                source.cancel();
206                join_task_thread_bounded(handle, "task_with_timeout_named");
207                on_timeout
208            }
209        }
210    })
211}
212
213/// Create a [`Cmd::Task`] that retries on failure with the given policy.
214///
215/// The `f` closure returns `Result<M, String>`. On `Ok`, the message is
216/// returned immediately. On `Err`, the task retries according to the policy,
217/// sleeping between attempts. After all retries are exhausted, `on_exhaust`
218/// is called with the last error to produce a fallback message.
219pub fn task_with_retry<M, F>(policy: RetryPolicy, f: F, on_exhaust: fn(String) -> M) -> Cmd<M>
220where
221    M: Send + 'static,
222    F: Fn() -> Result<M, String> + Send + 'static,
223{
224    Cmd::task(move || {
225        let mut last_err = String::new();
226        for attempt in 0..=policy.max_retries {
227            match f() {
228                Ok(msg) => return msg,
229                Err(e) => {
230                    last_err = e;
231                    if attempt < policy.max_retries {
232                        std::thread::sleep(policy.delay(attempt));
233                    }
234                }
235            }
236        }
237        on_exhaust(last_err)
238    })
239}
240
241/// Create a [`Cmd::Task`] with both retry and timeout.
242///
243/// Each individual attempt is bounded by `per_attempt_timeout`. The worker
244/// receives a [`CancellationToken`] and must honor it for timely timeout
245/// teardown. The total number of attempts is governed by the retry policy.
246pub fn task_with_retry_and_timeout<M, F>(
247    policy: RetryPolicy,
248    per_attempt_timeout: Duration,
249    f: F,
250    on_exhaust: fn(String) -> M,
251) -> Cmd<M>
252where
253    M: Send + 'static,
254    F: Fn(CancellationToken) -> Result<M, String> + Send + Sync + 'static,
255{
256    Cmd::task(move || {
257        let f = std::sync::Arc::new(f);
258        let mut last_err = String::new();
259        for attempt in 0..=policy.max_retries {
260            let source = CancellationSource::new();
261            let token = source.token();
262            let (tx, rx) = std::sync::mpsc::channel();
263            let f_clone = std::sync::Arc::clone(&f);
264            let handle = std::thread::spawn(move || {
265                let result = f_clone(token);
266                let _ = tx.send(result);
267            });
268            match rx.recv_timeout(per_attempt_timeout) {
269                Ok(Ok(msg)) => {
270                    join_task_thread(handle);
271                    return msg;
272                }
273                Ok(Err(e)) => {
274                    join_task_thread(handle);
275                    last_err = e;
276                }
277                Err(_) => {
278                    source.cancel();
279                    join_task_thread_bounded(handle, "task_with_retry_and_timeout");
280                    last_err = "timeout".into();
281                }
282            }
283            if attempt < policy.max_retries {
284                std::thread::sleep(policy.delay(attempt));
285            }
286        }
287        on_exhaust(last_err)
288    })
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn fixed_backoff_constant_delay() {
297        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
298        assert_eq!(policy.delay(0), Duration::from_millis(100));
299        assert_eq!(policy.delay(1), Duration::from_millis(100));
300        assert_eq!(policy.delay(2), Duration::from_millis(100));
301    }
302
303    #[test]
304    fn exponential_backoff_doubles() {
305        let policy = RetryPolicy::new(
306            5,
307            BackoffStrategy::Exponential {
308                base_ms: 100,
309                max_ms: 5000,
310            },
311        );
312        assert_eq!(policy.delay(0), Duration::from_millis(100));
313        assert_eq!(policy.delay(1), Duration::from_millis(200));
314        assert_eq!(policy.delay(2), Duration::from_millis(400));
315        assert_eq!(policy.delay(3), Duration::from_millis(800));
316    }
317
318    #[test]
319    fn exponential_backoff_caps_at_max() {
320        let policy = RetryPolicy::new(
321            5,
322            BackoffStrategy::Exponential {
323                base_ms: 1000,
324                max_ms: 3000,
325            },
326        );
327        assert_eq!(policy.delay(0), Duration::from_millis(1000));
328        assert_eq!(policy.delay(1), Duration::from_millis(2000));
329        assert_eq!(policy.delay(2), Duration::from_millis(3000)); // capped
330        assert_eq!(policy.delay(3), Duration::from_millis(3000)); // capped
331    }
332
333    #[test]
334    fn linear_backoff_increments() {
335        let policy = RetryPolicy::new(
336            4,
337            BackoffStrategy::Linear {
338                base_ms: 100,
339                max_ms: 500,
340            },
341        );
342        assert_eq!(policy.delay(0), Duration::from_millis(100));
343        assert_eq!(policy.delay(1), Duration::from_millis(200));
344        assert_eq!(policy.delay(2), Duration::from_millis(300));
345        assert_eq!(policy.delay(3), Duration::from_millis(400));
346        assert_eq!(policy.delay(4), Duration::from_millis(500)); // capped
347    }
348
349    #[test]
350    fn linear_backoff_caps_at_max() {
351        let policy = RetryPolicy::new(
352            4,
353            BackoffStrategy::Linear {
354                base_ms: 200,
355                max_ms: 500,
356            },
357        );
358        assert_eq!(policy.delay(2), Duration::from_millis(500)); // 200*3 = 600, capped at 500
359    }
360
361    #[test]
362    fn no_retry_policy() {
363        let policy = RetryPolicy::no_retry();
364        assert_eq!(policy.max_retries, 0);
365    }
366
367    #[test]
368    fn total_max_delay_fixed() {
369        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 100 });
370        assert_eq!(policy.total_max_delay(), Duration::from_millis(300));
371    }
372
373    #[test]
374    fn total_max_delay_exponential() {
375        let policy = RetryPolicy::new(
376            3,
377            BackoffStrategy::Exponential {
378                base_ms: 100,
379                max_ms: 10000,
380            },
381        );
382        // Delays: 100 + 200 + 400 = 700
383        assert_eq!(policy.total_max_delay(), Duration::from_millis(700));
384    }
385
386    #[test]
387    fn total_max_delay_zero_retries() {
388        let policy = RetryPolicy::no_retry();
389        assert_eq!(policy.total_max_delay(), Duration::ZERO);
390    }
391
392    #[test]
393    fn exponential_backoff_overflow_saturates() {
394        let policy = RetryPolicy::new(
395            1,
396            BackoffStrategy::Exponential {
397                base_ms: u64::MAX / 2,
398                max_ms: u64::MAX,
399            },
400        );
401        // Should not panic on overflow
402        let _ = policy.delay(30);
403    }
404
405    #[test]
406    fn linear_backoff_overflow_saturates() {
407        let policy = RetryPolicy::new(
408            1,
409            BackoffStrategy::Linear {
410                base_ms: u64::MAX / 2,
411                max_ms: u64::MAX,
412            },
413        );
414        let _ = policy.delay(30);
415    }
416
417    #[test]
418    fn retry_policy_clone_eq() {
419        let policy = RetryPolicy::new(
420            3,
421            BackoffStrategy::Exponential {
422                base_ms: 100,
423                max_ms: 5000,
424            },
425        );
426        let cloned = policy.clone();
427        assert_eq!(policy, cloned);
428    }
429
430    #[test]
431    fn task_with_retry_succeeds_first_try() {
432        #[derive(Debug, PartialEq)]
433        enum Msg {
434            Ok(i32),
435            Err(String),
436        }
437
438        let policy = RetryPolicy::new(3, BackoffStrategy::Fixed { delay_ms: 1 });
439        let cmd = task_with_retry(policy, || Ok(Msg::Ok(42)), Msg::Err);
440
441        // Verify it produces a Task variant
442        assert_eq!(cmd.type_name(), "Task");
443    }
444
445    #[test]
446    fn task_with_timeout_produces_task() {
447        #[derive(Debug)]
448        #[allow(dead_code)]
449        enum Msg {
450            Result(i32),
451            Timeout,
452        }
453
454        let cmd = task_with_timeout(
455            Duration::from_secs(1),
456            |_token| Msg::Result(42),
457            Msg::Timeout,
458        );
459        assert_eq!(cmd.type_name(), "Task");
460    }
461
462    #[test]
463    fn task_with_timeout_requests_cancellation_on_timeout() {
464        use std::sync::Arc;
465        use std::sync::atomic::{AtomicBool, Ordering};
466
467        #[derive(Debug, PartialEq)]
468        enum Msg {
469            Finished,
470            Timeout,
471        }
472
473        let cancelled = Arc::new(AtomicBool::new(false));
474        let worker_exited = Arc::new(AtomicBool::new(false));
475        let cancelled_flag = Arc::clone(&cancelled);
476        let exited_flag = Arc::clone(&worker_exited);
477
478        let cmd = task_with_timeout(
479            Duration::from_millis(10),
480            move |token| {
481                cancelled_flag.store(token.wait_timeout(Duration::from_secs(1)), Ordering::SeqCst);
482                exited_flag.store(true, Ordering::SeqCst);
483                Msg::Finished
484            },
485            Msg::Timeout,
486        );
487
488        let result = match cmd {
489            Cmd::Task(_, task) => task(),
490            other => panic!("expected Task, got {other:?}"),
491        };
492
493        assert_eq!(result, Msg::Timeout);
494        std::thread::sleep(Duration::from_millis(50));
495        assert!(cancelled.load(Ordering::SeqCst));
496        assert!(worker_exited.load(Ordering::SeqCst));
497    }
498
499    #[test]
500    fn task_with_retry_and_timeout_cancels_each_timed_out_attempt() {
501        use std::sync::Arc;
502        use std::sync::atomic::{AtomicUsize, Ordering};
503
504        #[derive(Debug, PartialEq)]
505        enum Msg {
506            Exhausted(String),
507        }
508
509        fn on_exhaust(err: String) -> Msg {
510            Msg::Exhausted(err)
511        }
512
513        let attempts = Arc::new(AtomicUsize::new(0));
514        let cancelled = Arc::new(AtomicUsize::new(0));
515        let attempts_flag = Arc::clone(&attempts);
516        let cancelled_flag = Arc::clone(&cancelled);
517        let policy = RetryPolicy::new(1, BackoffStrategy::Fixed { delay_ms: 0 });
518
519        let cmd = task_with_retry_and_timeout(
520            policy,
521            Duration::from_millis(10),
522            move |token| {
523                attempts_flag.fetch_add(1, Ordering::SeqCst);
524                if token.wait_timeout(Duration::from_secs(1)) {
525                    cancelled_flag.fetch_add(1, Ordering::SeqCst);
526                }
527                Err("cancelled".to_owned())
528            },
529            on_exhaust,
530        );
531
532        let result = match cmd {
533            Cmd::Task(_, task) => task(),
534            other => panic!("expected Task, got {other:?}"),
535        };
536
537        assert_eq!(result, Msg::Exhausted("timeout".to_owned()));
538        std::thread::sleep(Duration::from_millis(50));
539        assert_eq!(attempts.load(Ordering::SeqCst), 2);
540        assert_eq!(cancelled.load(Ordering::SeqCst), 2);
541    }
542
543    #[test]
544    fn backoff_strategy_variants_debug() {
545        let fixed = BackoffStrategy::Fixed { delay_ms: 100 };
546        let exp = BackoffStrategy::Exponential {
547            base_ms: 100,
548            max_ms: 5000,
549        };
550        let linear = BackoffStrategy::Linear {
551            base_ms: 100,
552            max_ms: 500,
553        };
554        // Just verify Debug doesn't panic
555        let _ = format!("{fixed:?}");
556        let _ = format!("{exp:?}");
557        let _ = format!("{linear:?}");
558    }
559}