Skip to main content

anvil_ssh/
retry.rs

1// SPDX-License-Identifier: GPL-3.0-or-later
2// Rust guideline compliant 2026-03-30
3//! Connection retry, backoff, and timeouts (PRD §5.8.7, M18).
4//!
5//! Three pieces:
6//!
7//! 1. [`RetryPolicy`] — caller-tunable knobs: attempt count, base /
8//!    factor / cap on exponential backoff, max wall-clock window,
9//!    per-attempt connect timeout.  Mirrors OpenSSH's
10//!    `ConnectionAttempts` + `ConnectTimeout` semantics with a
11//!    Gitway-specific cap on total elapsed time.
12//! 2. [`classify`] / [`Disposition`] — FR-82's transient-vs-fatal
13//!    error classifier.  Network noise (ECONNREFUSED, ETIMEDOUT,
14//!    EHOSTUNREACH, DNS NXDOMAIN) is `Retry`; everything else
15//!    (auth failure, host-key mismatch, protocol error, signing
16//!    error) is `Fatal`.
17//! 3. [`run`] — the loop driver.  Calls the supplied async op,
18//!    sleeps with jittered exponential backoff between attempts,
19//!    captures a [`RetryAttempt`] history for FR-83's `--test
20//!    --json` envelope, and emits a `tracing::warn!` event at
21//!    [`crate::log::CAT_RETRY`] per failed attempt.
22//!
23//! ## Trust model
24//!
25//! `run` is timeout-agnostic — its job is the loop + classifier +
26//! jitter + history.  The per-attempt `tokio::time::timeout` wrap
27//! lives at the call site (currently `session.rs::connect`) so the
28//! same loop driver can be reused for non-network operations
29//! (agent reconnects, key-load retries) without forcing every
30//! caller to think about timeouts.
31//!
32//! ## Why russh-handshake failures are NOT retried
33//!
34//! Once the TCP socket is up, any failure is either a fatal
35//! user-input error (auth rejected, host-key mismatch) or an
36//! in-flight protocol error mid-handshake.  Re-driving an in-flight
37//! handshake is unsafe (the server may have already consumed our
38//! key-exchange contribution) and the failure modes are server-side
39//! — surfacing them clearly is more useful than silently retrying.
40//! [`classify`] returns `Fatal` for every `russh::Error` variant
41//! for this reason.
42
43use std::future::Future;
44use std::time::{Duration, Instant};
45
46use rand_core::{OsRng, RngCore};
47
48use crate::error::AnvilError;
49
50// ── RetryPolicy ─────────────────────────────────────────────────────────────
51
52/// Caller-tunable retry knobs (PRD §5.8.7 FR-80, FR-81).
53///
54/// Use [`Default`] for the values PRD §5.8.7 specifies (`attempts =
55/// 3`, `base = 250 ms`, `factor = 2`, `cap = 8 s`, `max_window = 30 s`,
56/// `connect_timeout = None`).  The builder-style setters return `Self`
57/// so a CLI dispatcher can chain overrides.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct RetryPolicy {
60    /// Total number of attempts (initial + retries).  Must be ≥ 1.
61    /// `1` disables retry entirely; default `3`.
62    pub attempts: u32,
63    /// Base delay before the first retry.  Default 250 ms.
64    pub base: Duration,
65    /// Multiplier on each successive retry.  Default `2`.
66    pub factor: u32,
67    /// Cap on a single backoff interval (excluding jitter).  Default 8 s.
68    pub cap: Duration,
69    /// Hard ceiling on total elapsed wall-clock time across all
70    /// attempts.  Default 30 s.  When the cap is reached the loop
71    /// returns the most-recent error rather than starting another
72    /// attempt.
73    pub max_window: Duration,
74    /// Per-attempt TCP connect timeout.  `None` = no timeout
75    /// (matches OpenSSH's "no `ConnectTimeout`" semantics).
76    /// Default `None`.
77    pub connect_timeout: Option<Duration>,
78}
79
80impl Default for RetryPolicy {
81    fn default() -> Self {
82        Self {
83            attempts: 3,
84            base: Duration::from_millis(250),
85            factor: 2,
86            cap: Duration::from_secs(8),
87            max_window: Duration::from_secs(30),
88            connect_timeout: None,
89        }
90    }
91}
92
93impl RetryPolicy {
94    /// Builder setter for [`Self::attempts`].
95    #[must_use]
96    pub fn attempts(mut self, n: u32) -> Self {
97        self.attempts = n;
98        self
99    }
100
101    /// Builder setter for [`Self::base`].
102    #[must_use]
103    pub fn base(mut self, d: Duration) -> Self {
104        self.base = d;
105        self
106    }
107
108    /// Builder setter for [`Self::factor`].
109    #[must_use]
110    pub fn factor(mut self, f: u32) -> Self {
111        self.factor = f;
112        self
113    }
114
115    /// Builder setter for [`Self::cap`].
116    #[must_use]
117    pub fn cap(mut self, d: Duration) -> Self {
118        self.cap = d;
119        self
120    }
121
122    /// Builder setter for [`Self::max_window`].
123    #[must_use]
124    pub fn max_window(mut self, d: Duration) -> Self {
125        self.max_window = d;
126        self
127    }
128
129    /// Builder setter for [`Self::connect_timeout`].
130    #[must_use]
131    pub fn connect_timeout(mut self, d: Option<Duration>) -> Self {
132        self.connect_timeout = d;
133        self
134    }
135}
136
137// ── Classifier (FR-82) ─────────────────────────────────────────────────────
138
139/// What [`run`] should do with an [`AnvilError`] from a single attempt.
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum Disposition {
142    /// Transient error — retry after backoff.
143    Retry,
144    /// Fatal error — return immediately.
145    Fatal,
146}
147
148/// Classifies an [`AnvilError`] as transient or fatal per FR-82.
149///
150/// Transient (returns [`Disposition::Retry`]):
151///
152/// - I/O errors with `io::ErrorKind` ∈ {`ConnectionRefused`,
153///   `TimedOut`, `HostUnreachable`, `NetworkUnreachable`, `NotFound`
154///   (DNS NXDOMAIN on Linux), `AddrNotAvailable`}
155///
156/// Fatal (returns [`Disposition::Fatal`]):
157///
158/// - Authentication failure / host-key mismatch / no-key-found /
159///   invalid-config / signing / signature-invalid (user-input
160///   errors).
161/// - Russh protocol errors — re-driving an in-flight handshake is
162///   unsafe; see the module-level docs.
163/// - Other I/O kinds (e.g. `PermissionDenied`, `Interrupted`) —
164///   conservative default; these are unlikely to recover on retry.
165///
166/// **HTTP 429/503 detection** (PRD FR-82 also mentions these) is
167/// out of scope: Anvil speaks raw SSH; HTTP statuses only surface
168/// in `ProxyCommand` subprocess output, which Anvil doesn't parse.
169/// A future ProxyCommand-HTTP-CONNECT milestone may extend this
170/// classifier to handle them.
171#[must_use]
172pub fn classify(err: &AnvilError) -> Disposition {
173    if err.is_authentication_failed()
174        || err.is_host_key_mismatch()
175        || err.is_no_key_found()
176        || err.is_key_encrypted()
177    {
178        return Disposition::Fatal;
179    }
180
181    if err.is_io() {
182        if let Some(kind) = err.io_kind() {
183            return classify_io_kind(kind);
184        }
185    }
186
187    Disposition::Fatal
188}
189
190/// Inner classifier — split out so it's testable without
191/// constructing full `AnvilError`s.
192fn classify_io_kind(kind: std::io::ErrorKind) -> Disposition {
193    use std::io::ErrorKind as K;
194    match kind {
195        K::ConnectionRefused
196        | K::TimedOut
197        | K::HostUnreachable
198        | K::NetworkUnreachable
199        | K::NotFound
200        | K::AddrNotAvailable => Disposition::Retry,
201        _ => Disposition::Fatal,
202    }
203}
204
205// ── RetryAttempt history (FR-83) ───────────────────────────────────────────
206
207/// One failed attempt's record, captured during [`run`] for surfacing
208/// via [`crate::session::AnvilSession::retry_history`] and
209/// `gitway --test --json`'s `data.retry_attempts` envelope.
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct RetryAttempt {
212    /// 1-indexed attempt number.  An attempt that succeeds is **not**
213    /// recorded here; the history vector contains only failures that
214    /// triggered a retry (or the final failure when the loop bails).
215    pub attempt: u32,
216    /// Stable error code from [`AnvilError::error_code`].
217    pub reason: String,
218    /// Wall-clock elapsed since the loop started, at the moment this
219    /// attempt failed.
220    pub elapsed: Duration,
221}
222
223// ── Loop driver (FR-81 + FR-83) ────────────────────────────────────────────
224
225/// Drives the retry loop for the supplied async operation.
226///
227/// On success returns `Ok((value, history))` where `history` is the
228/// list of failed attempts (empty if the first try succeeded).  On
229/// terminal failure (fatal classification or attempt-count /
230/// max-window exhaustion) returns the most-recent error.
231///
232/// Sleep duration between attempt `n` and `n+1` is
233/// `min(base * factor^(n-1), cap) + uniform_jitter([0, base/2])`.
234/// Jitter is sourced from [`OsRng`] so concurrent processes recovering
235/// from a shared outage don't dogpile.
236///
237/// # Errors
238///
239/// Returns the underlying `AnvilError` from the last failed attempt
240/// when the loop exits without success.
241pub async fn run<F, Fut, T>(
242    policy: &RetryPolicy,
243    mut op: F,
244) -> Result<(T, Vec<RetryAttempt>), AnvilError>
245where
246    F: FnMut() -> Fut,
247    Fut: Future<Output = Result<T, AnvilError>>,
248{
249    let started_at = Instant::now();
250    let mut history: Vec<RetryAttempt> = Vec::new();
251    let attempts = policy.attempts.max(1);
252
253    for attempt in 1..=attempts {
254        if attempt > 1 {
255            // Sleep before retrying (jittered exponential backoff).
256            let delay = backoff_delay(policy, attempt - 1);
257            // Bail if max_window would be exceeded.
258            if started_at.elapsed() + delay > policy.max_window {
259                if let Some(last) = history.last() {
260                    tracing::warn!(
261                        target: crate::log::CAT_RETRY,
262                        attempt = last.attempt,
263                        reason = %last.reason,
264                        elapsed_ms = u64::try_from(last.elapsed.as_millis()).unwrap_or(u64::MAX),
265                        max_window_ms = u64::try_from(policy.max_window.as_millis()).unwrap_or(u64::MAX),
266                        "retry max_window exhausted; giving up",
267                    );
268                }
269                return Err(history_to_terminal_error(&history));
270            }
271            tokio::time::sleep(delay).await;
272        }
273
274        match op().await {
275            Ok(value) => return Ok((value, history)),
276            Err(e) => {
277                let reason = e.error_code().to_owned();
278                let elapsed = started_at.elapsed();
279                let disposition = classify(&e);
280
281                if disposition == Disposition::Fatal || attempt == attempts {
282                    // Record the terminal attempt before returning so
283                    // the caller can still see why we gave up.
284                    history.push(RetryAttempt {
285                        attempt,
286                        reason: reason.clone(),
287                        elapsed,
288                    });
289                    tracing::warn!(
290                        target: crate::log::CAT_RETRY,
291                        attempt,
292                        reason = %reason,
293                        elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
294                        disposition = if disposition == Disposition::Fatal { "fatal" } else { "exhausted" },
295                        "retry loop terminating",
296                    );
297                    return Err(e);
298                }
299
300                history.push(RetryAttempt {
301                    attempt,
302                    reason: reason.clone(),
303                    elapsed,
304                });
305                tracing::warn!(
306                    target: crate::log::CAT_RETRY,
307                    attempt,
308                    reason = %reason,
309                    elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
310                    "retrying after transient error",
311                );
312            }
313        }
314    }
315
316    // Unreachable in practice (the loop body returns on every path),
317    // but the type system can't prove it.
318    Err(history_to_terminal_error(&history))
319}
320
321/// Returns the backoff delay for the `step`-th retry (1-indexed).
322fn backoff_delay(policy: &RetryPolicy, step: u32) -> Duration {
323    let base_ms = u64::try_from(policy.base.as_millis()).unwrap_or(u64::MAX);
324    let exponent_ms = base_ms.saturating_mul(u64::from(policy.factor).saturating_pow(step - 1));
325    let cap_ms = u64::try_from(policy.cap.as_millis()).unwrap_or(u64::MAX);
326    let core_ms = exponent_ms.min(cap_ms);
327
328    // Jitter: uniform on [0, base / 2] to avoid dogpile.  Drawn from
329    // OsRng for cryptographic-grade unpredictability — overkill for
330    // backoff but cheap and consistent with the rest of the crate.
331    let jitter_max_ms = base_ms / 2;
332    let jitter_ms = if jitter_max_ms == 0 {
333        0
334    } else {
335        let mut buf = [0u8; 8];
336        OsRng.fill_bytes(&mut buf);
337        let raw = u64::from_le_bytes(buf);
338        raw % (jitter_max_ms + 1)
339    };
340
341    Duration::from_millis(core_ms.saturating_add(jitter_ms))
342}
343
344/// Synthesizes an `AnvilError` from an exhausted retry history when
345/// the loop bails on `max_window` before any op-call has the chance
346/// to fail in the current iteration.  The history's last entry is
347/// the actual cause — we surface that as an `invalid_config` since
348/// we don't have the original `AnvilError` instance to clone.
349fn history_to_terminal_error(history: &[RetryAttempt]) -> AnvilError {
350    let last = history.last().map_or("unknown", |a| a.reason.as_str());
351    AnvilError::invalid_config(format!(
352        "retry exhausted (max_window reached); last error: {last}"
353    ))
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    // ── RetryPolicy defaults ───────────────────────────────────────────────
361
362    #[test]
363    fn default_policy_matches_prd() {
364        let p = RetryPolicy::default();
365        assert_eq!(p.attempts, 3);
366        assert_eq!(p.base, Duration::from_millis(250));
367        assert_eq!(p.factor, 2);
368        assert_eq!(p.cap, Duration::from_secs(8));
369        assert_eq!(p.max_window, Duration::from_secs(30));
370        assert_eq!(p.connect_timeout, None);
371    }
372
373    #[test]
374    fn builder_setters_are_chainable() {
375        let p = RetryPolicy::default()
376            .attempts(5)
377            .base(Duration::from_millis(100))
378            .factor(3)
379            .cap(Duration::from_secs(2))
380            .max_window(Duration::from_secs(10))
381            .connect_timeout(Some(Duration::from_secs(5)));
382        assert_eq!(p.attempts, 5);
383        assert_eq!(p.base, Duration::from_millis(100));
384        assert_eq!(p.factor, 3);
385        assert_eq!(p.cap, Duration::from_secs(2));
386        assert_eq!(p.max_window, Duration::from_secs(10));
387        assert_eq!(p.connect_timeout, Some(Duration::from_secs(5)));
388    }
389
390    // ── Classifier matrix (FR-82) ─────────────────────────────────────────
391
392    #[test]
393    fn auth_failure_is_fatal() {
394        let err = AnvilError::authentication_failed();
395        assert_eq!(classify(&err), Disposition::Fatal);
396    }
397
398    #[test]
399    fn host_key_mismatch_is_fatal() {
400        let err = AnvilError::host_key_mismatch("SHA256:abc");
401        assert_eq!(classify(&err), Disposition::Fatal);
402    }
403
404    #[test]
405    fn no_key_found_is_fatal() {
406        let err = AnvilError::no_key_found();
407        assert_eq!(classify(&err), Disposition::Fatal);
408    }
409
410    #[test]
411    fn io_connection_refused_is_retry() {
412        assert_eq!(
413            classify_io_kind(std::io::ErrorKind::ConnectionRefused),
414            Disposition::Retry,
415        );
416    }
417
418    #[test]
419    fn io_timed_out_is_retry() {
420        assert_eq!(
421            classify_io_kind(std::io::ErrorKind::TimedOut),
422            Disposition::Retry,
423        );
424    }
425
426    #[test]
427    fn io_not_found_is_retry_for_dns_nxdomain() {
428        assert_eq!(
429            classify_io_kind(std::io::ErrorKind::NotFound),
430            Disposition::Retry,
431        );
432    }
433
434    #[test]
435    fn io_permission_denied_is_fatal() {
436        assert_eq!(
437            classify_io_kind(std::io::ErrorKind::PermissionDenied),
438            Disposition::Fatal,
439        );
440    }
441
442    // ── Loop driver ───────────────────────────────────────────────────────
443
444    #[tokio::test]
445    async fn run_succeeds_on_first_try_with_empty_history() {
446        let p = RetryPolicy::default().attempts(3);
447        let (value, history) = run(&p, || async { Ok::<_, AnvilError>(42_u32) })
448            .await
449            .expect("must succeed");
450        assert_eq!(value, 42);
451        assert!(history.is_empty());
452    }
453
454    #[tokio::test]
455    async fn run_bails_immediately_on_fatal() {
456        let p = RetryPolicy::default().attempts(5);
457        let (err_count, _) = run_count_calls(&p, |_n| {
458            futures::future::ready::<Result<u32, AnvilError>>(Err(
459                AnvilError::authentication_failed(),
460            ))
461        })
462        .await;
463        // Fatal error → 1 attempt, no retry.
464        assert_eq!(err_count, 1);
465    }
466
467    #[tokio::test]
468    async fn run_retries_transient_errors_and_records_history() {
469        let p = RetryPolicy::default()
470            .attempts(3)
471            .base(Duration::from_millis(1))
472            .cap(Duration::from_millis(2))
473            .max_window(Duration::from_secs(60));
474
475        let calls = std::sync::atomic::AtomicU32::new(0);
476        let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(&p, || async {
477            let n = calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
478            if n < 2 {
479                Err(AnvilError::new(crate::error::AnvilErrorKind::Io(
480                    std::io::Error::from(std::io::ErrorKind::ConnectionRefused),
481                )))
482            } else {
483                Ok::<_, AnvilError>(99)
484            }
485        })
486        .await;
487
488        let (value, history) = result.expect("third attempt must succeed");
489        assert_eq!(value, 99);
490        // Two failures recorded, third succeeded.
491        assert_eq!(history.len(), 2);
492        assert_eq!(history[0].attempt, 1);
493        assert_eq!(history[1].attempt, 2);
494        // I/O errors map to AnvilError::error_code() == "GENERAL_ERROR"
495        // per the error-code table in error.rs.
496        for entry in &history {
497            assert_eq!(
498                entry.reason, "GENERAL_ERROR",
499                "expected GENERAL_ERROR (io variant), got: {}",
500                entry.reason,
501            );
502        }
503    }
504
505    #[tokio::test]
506    async fn run_attempts_caps_after_exhausting_count() {
507        let p = RetryPolicy::default()
508            .attempts(2)
509            .base(Duration::from_millis(1))
510            .cap(Duration::from_millis(1))
511            .max_window(Duration::from_secs(60));
512
513        let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(&p, || async {
514            Err(AnvilError::new(crate::error::AnvilErrorKind::Io(
515                std::io::Error::from(std::io::ErrorKind::TimedOut),
516            )))
517        })
518        .await;
519
520        // Both attempts must run; result is the last error.
521        let err = result.expect_err("must exhaust");
522        assert!(err.is_io());
523    }
524
525    /// Helper: counts how many times `op` was called before `run`
526    /// returned, regardless of success / failure.
527    async fn run_count_calls<F, Fut>(
528        policy: &RetryPolicy,
529        mut op: F,
530    ) -> (u32, Result<u32, AnvilError>)
531    where
532        F: FnMut(u32) -> Fut,
533        Fut: Future<Output = Result<u32, AnvilError>>,
534    {
535        let calls = std::sync::atomic::AtomicU32::new(0);
536        let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(policy, || {
537            let n = calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
538            op(n)
539        })
540        .await;
541        let count = calls.load(std::sync::atomic::Ordering::SeqCst);
542        let final_result = result.map(|(v, _)| v);
543        (count, final_result)
544    }
545
546    // ── Backoff curve ──────────────────────────────────────────────────────
547
548    #[test]
549    fn backoff_delay_grows_exponentially_until_cap() {
550        let p = RetryPolicy::default()
551            .base(Duration::from_millis(10))
552            .factor(2)
553            .cap(Duration::from_millis(40));
554        // Step 1: 10ms (+ jitter ≤ 5ms)
555        // Step 2: 20ms (+ jitter ≤ 5ms)
556        // Step 3: 40ms (+ jitter ≤ 5ms) — capped
557        // Step 4: 40ms (+ jitter ≤ 5ms) — still capped
558        let d1 = backoff_delay(&p, 1);
559        let d2 = backoff_delay(&p, 2);
560        let d3 = backoff_delay(&p, 3);
561        let d4 = backoff_delay(&p, 4);
562        assert!(d1.as_millis() >= 10 && d1.as_millis() <= 15);
563        assert!(d2.as_millis() >= 20 && d2.as_millis() <= 25);
564        assert!(d3.as_millis() >= 40 && d3.as_millis() <= 45);
565        assert!(d4.as_millis() >= 40 && d4.as_millis() <= 45);
566    }
567
568    #[test]
569    fn backoff_jitter_stays_within_documented_window() {
570        // 1000 draws with base = 10ms, factor = 1, cap = 10ms ⇒ all
571        // sleeps are exactly 10ms + jitter([0, 5ms]).
572        let p = RetryPolicy::default()
573            .base(Duration::from_millis(10))
574            .factor(1)
575            .cap(Duration::from_millis(10));
576        for _ in 0..1000 {
577            let d = backoff_delay(&p, 1);
578            let ms = d.as_millis();
579            assert!(
580                (10..=15).contains(&ms),
581                "delay {ms}ms outside [10,15]ms jitter window",
582            );
583        }
584    }
585}