Skip to main content

vtcode_commons/
retry.rs

1//! Canonical retry policy shared across the workspace.
2//!
3//! This module owns the retry *policy math*: attempt budgets, exponential
4//! backoff with an optional deterministic jitter, and category-based retry
5//! decisions built on [`ErrorCategory::is_retryable`]. Domain-specific
6//! adapters (typed error downcasts, tool-aware timeout rules, LLM
7//! `Retry-After` extraction) live in `vtcode-core::retry` as an extension
8//! trait over this policy.
9//!
10//! Wire-level HTTP clients that only need "should I retry this call?" use
11//! [`RetryPolicy::classify_anyhow`] / [`RetryPolicy::classify_status`];
12//! richer loops use [`RetryPolicy::decision_for_category`].
13
14use std::time::Duration;
15
16use crate::error_category::{ErrorCategory, classify_anyhow_error};
17
18/// Typed retry policy shared across runtime layers.
19#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
20pub struct RetryPolicy {
21    /// Maximum number of total attempts, including the initial call.
22    pub max_attempts: u32,
23    pub initial_delay: Duration,
24    pub max_delay: Duration,
25    pub multiplier: f64,
26    pub jitter: f64,
27}
28
29impl RetryPolicy {
30    pub fn new(
31        max_attempts: u32,
32        initial_delay: Duration,
33        max_delay: Duration,
34        multiplier: f64,
35    ) -> Self {
36        Self {
37            max_attempts: max_attempts.max(1),
38            initial_delay,
39            max_delay,
40            multiplier: multiplier.max(1.0),
41            jitter: 0.0,
42        }
43    }
44
45    pub fn from_retries(
46        max_retries: u32,
47        initial_delay: Duration,
48        max_delay: Duration,
49        multiplier: f64,
50    ) -> Self {
51        Self::new(
52            max_retries.saturating_add(1),
53            initial_delay,
54            max_delay,
55            multiplier,
56        )
57    }
58
59    /// Millisecond-based constructor for wire clients.
60    ///
61    /// Uses a 2.0 multiplier and no jitter, so
62    /// [`Self::delay_for_attempt`] reproduces the classic
63    /// `base_ms << attempt` doubling curve capped at `max_delay_ms`.
64    pub fn simple(max_retries: u32, base_delay_ms: u64, max_delay_ms: u64) -> Self {
65        Self::from_retries(
66            max_retries,
67            Duration::from_millis(base_delay_ms),
68            Duration::from_millis(max_delay_ms),
69            2.0,
70        )
71    }
72
73    pub fn delay_for_attempt(&self, attempt_index: u32) -> Duration {
74        let multiplier = self.multiplier.powi(attempt_index as i32);
75        let base_delay = Duration::try_from_secs_f64(self.initial_delay.as_secs_f64() * multiplier)
76            .unwrap_or(self.max_delay)
77            .min(self.max_delay);
78
79        if !self.jitter.is_finite() || self.jitter <= 0.0 {
80            return base_delay;
81        }
82
83        #[allow(clippy::cast_sign_loss)]
84        let max_jitter_ms = (base_delay.as_millis() as f64 * self.jitter)
85            .round()
86            .clamp(0.0, u64::MAX as f64) as u64;
87        if max_jitter_ms == 0 {
88            return base_delay;
89        }
90
91        let offset = (u64::from(attempt_index) * 31) % max_jitter_ms.saturating_add(1);
92        base_delay.saturating_add(Duration::from_millis(offset))
93    }
94
95    pub fn decision_for_category(
96        &self,
97        category: ErrorCategory,
98        attempt_index: u32,
99        retry_after: Option<Duration>,
100    ) -> RetryDecision {
101        let has_remaining_attempts = attempt_index.saturating_add(1) < self.max_attempts;
102        if !category.is_retryable() || !has_remaining_attempts {
103            return RetryDecision {
104                category,
105                retryable: false,
106                delay: None,
107                retry_after,
108            };
109        }
110
111        let delay = retry_after.unwrap_or_else(|| self.delay_for_attempt(attempt_index));
112        RetryDecision {
113            category,
114            retryable: true,
115            delay: Some(delay),
116            retry_after,
117        }
118    }
119
120    /// Classify an `anyhow::Error` for retry eligibility.
121    ///
122    /// Attempt-agnostic: `retryable` reflects only the error category, not
123    /// the remaining attempt budget. Wire clients that manage their own
124    /// attempt counting use this; loops that want budget-aware decisions
125    /// use [`Self::decision_for_category`].
126    pub fn classify_anyhow(&self, error: &anyhow::Error) -> RetryDecision {
127        let category = classify_anyhow_error(error);
128        RetryDecision {
129            category,
130            retryable: category.is_retryable(),
131            delay: None,
132            retry_after: None,
133        }
134    }
135
136    /// Classify an HTTP status code for retry eligibility.
137    ///
138    /// Attempt-agnostic, like [`Self::classify_anyhow`].
139    pub fn classify_status(&self, status: u16) -> RetryDecision {
140        let category = match status {
141            429 => ErrorCategory::RateLimit,
142            500 | 502 | 504 => ErrorCategory::Network,
143            503 => ErrorCategory::ServiceUnavailable,
144            401 | 403 => ErrorCategory::Authentication,
145            _ => ErrorCategory::ExecutionError,
146        };
147        RetryDecision {
148            category,
149            retryable: category.is_retryable(),
150            delay: None,
151            retry_after: None,
152        }
153    }
154}
155
156impl Default for RetryPolicy {
157    fn default() -> Self {
158        Self::from_retries(2, Duration::from_secs(1), Duration::from_secs(60), 2.0)
159    }
160}
161
162/// Result of classifying a failure for retry handling.
163#[derive(Debug, Clone, PartialEq, Eq)]
164pub struct RetryDecision {
165    pub category: ErrorCategory,
166    pub retryable: bool,
167    pub delay: Option<Duration>,
168    pub retry_after: Option<Duration>,
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn default_policy_allows_two_retries() {
177        let policy = RetryPolicy::default();
178        assert_eq!(policy.max_attempts, 3);
179        assert_eq!(policy.initial_delay, Duration::from_secs(1));
180        assert_eq!(policy.max_delay, Duration::from_secs(60));
181    }
182
183    #[test]
184    fn classify_status_rate_limit() {
185        let policy = RetryPolicy::default();
186        let decision = policy.classify_status(429);
187        assert!(decision.retryable);
188        assert_eq!(decision.category, ErrorCategory::RateLimit);
189    }
190
191    #[test]
192    fn classify_status_server_error() {
193        let policy = RetryPolicy::default();
194        let decision = policy.classify_status(503);
195        assert!(decision.retryable);
196        assert_eq!(decision.category, ErrorCategory::ServiceUnavailable);
197    }
198
199    #[test]
200    fn classify_status_auth_not_retryable() {
201        let policy = RetryPolicy::default();
202        let decision = policy.classify_status(401);
203        assert!(!decision.retryable);
204        assert_eq!(decision.category, ErrorCategory::Authentication);
205    }
206
207    #[test]
208    fn classify_anyhow_network_error() {
209        let policy = RetryPolicy::default();
210        let err = anyhow::anyhow!("connection refused");
211        let decision = policy.classify_anyhow(&err);
212        assert!(decision.retryable);
213    }
214
215    #[test]
216    fn simple_policy_matches_bit_shift_doubling() {
217        // Parity with the historical `base_ms << attempt` curve used by
218        // wire clients before consolidation.
219        let policy = RetryPolicy::simple(10, 1000, 5000);
220        let legacy =
221            |attempt: u32| -> u64 { 1000u64.saturating_mul(1u64 << attempt.min(16)).min(5000) };
222        for attempt in 0..6 {
223            assert_eq!(
224                policy.delay_for_attempt(attempt),
225                Duration::from_millis(legacy(attempt)),
226                "delay mismatch at attempt {attempt}"
227            );
228        }
229    }
230
231    #[test]
232    fn delay_for_attempt_clamps_overflowing_backoff_to_max_delay() {
233        let policy =
234            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), f64::MAX);
235
236        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(8));
237    }
238
239    #[test]
240    fn delay_for_attempt_ignores_non_finite_jitter() {
241        let mut policy =
242            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
243        policy.jitter = f64::INFINITY;
244
245        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
246    }
247
248    #[test]
249    fn delay_for_attempt_handles_huge_finite_jitter() {
250        let mut policy =
251            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
252        policy.jitter = f64::MAX;
253
254        assert!(policy.delay_for_attempt(1) >= Duration::from_secs(2));
255    }
256
257    #[test]
258    fn decision_for_category_respects_attempt_budget() {
259        let policy =
260            RetryPolicy::from_retries(1, Duration::from_secs(1), Duration::from_secs(8), 2.0);
261
262        let first = policy.decision_for_category(ErrorCategory::Network, 0, None);
263        assert!(first.retryable);
264        assert_eq!(first.delay, Some(Duration::from_secs(1)));
265
266        let exhausted = policy.decision_for_category(ErrorCategory::Network, 1, None);
267        assert!(!exhausted.retryable);
268        assert!(exhausted.delay.is_none());
269    }
270
271    #[test]
272    fn decision_for_category_prefers_retry_after() {
273        let policy =
274            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
275
276        let decision =
277            policy.decision_for_category(ErrorCategory::RateLimit, 0, Some(Duration::from_secs(7)));
278        assert!(decision.retryable);
279        assert_eq!(decision.delay, Some(Duration::from_secs(7)));
280        assert_eq!(decision.retry_after, Some(Duration::from_secs(7)));
281    }
282}