Skip to main content

enact_runner/
retry.rs

1//! Retry handler with exponential backoff and error classification
2//!
3//! Classifies errors as retryable (rate-limit, network, transient) vs
4//! fatal (auth, invalid request, tool not found) and manages backoff.
5//! Ported from patterns in zeroclaw and openclaw.
6
7use crate::config::RetryConfig;
8use std::time::Duration;
9
10/// Error classification for retry decisions.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ErrorKind {
13    /// Rate limited by provider — always retryable
14    RateLimited,
15    /// Network/connection error — retryable
16    NetworkError,
17    /// Transient server error (5xx) — retryable
18    ServerError,
19    /// Authentication failure — NOT retryable
20    AuthError,
21    /// Invalid request (bad input, schema mismatch) — NOT retryable
22    InvalidRequest,
23    /// Tool not found — NOT retryable
24    ToolNotFound,
25    /// Context window exceeded — needs compaction, not retry
26    ContextOverflow,
27    /// Unknown error — retryable with caution
28    Unknown,
29}
30
31impl ErrorKind {
32    /// Whether this error class is retryable.
33    pub fn is_retryable(&self) -> bool {
34        matches!(
35            self,
36            ErrorKind::RateLimited
37                | ErrorKind::NetworkError
38                | ErrorKind::ServerError
39                | ErrorKind::Unknown
40        )
41    }
42}
43
44/// Classifies an error string into an `ErrorKind`.
45///
46/// Uses heuristic keyword matching on error messages.
47/// This is intentionally simple — real providers should return structured errors.
48pub fn classify_error(error_msg: &str) -> ErrorKind {
49    let lower = error_msg.to_lowercase();
50
51    if lower.contains("rate limit") || lower.contains("429") || lower.contains("too many requests")
52    {
53        ErrorKind::RateLimited
54    } else if lower.contains("connection")
55        || lower.contains("timeout")
56        || lower.contains("network")
57        || lower.contains("dns")
58    {
59        ErrorKind::NetworkError
60    } else if lower.contains("500")
61        || lower.contains("502")
62        || lower.contains("503")
63        || lower.contains("internal server error")
64        || lower.contains("service unavailable")
65    {
66        ErrorKind::ServerError
67    } else if lower.contains("auth")
68        || lower.contains("unauthorized")
69        || lower.contains("401")
70        || lower.contains("403")
71        || lower.contains("forbidden")
72        || lower.contains("invalid api key")
73    {
74        ErrorKind::AuthError
75    } else if lower.contains("context")
76        && (lower.contains("length") || lower.contains("window") || lower.contains("exceeded"))
77    {
78        ErrorKind::ContextOverflow
79    } else if lower.contains("not found") && lower.contains("tool") {
80        ErrorKind::ToolNotFound
81    } else if lower.contains("invalid")
82        || lower.contains("malformed")
83        || lower.contains("400")
84        || lower.contains("bad request")
85    {
86        ErrorKind::InvalidRequest
87    } else {
88        ErrorKind::Unknown
89    }
90}
91
92/// Manages retry state and computes backoff delays.
93pub struct RetryHandler {
94    config: RetryConfig,
95    attempt: u32,
96}
97
98impl RetryHandler {
99    /// Create a new retry handler from config.
100    pub fn new(config: RetryConfig) -> Self {
101        Self { config, attempt: 0 }
102    }
103
104    /// Check if we should retry the given error.
105    /// Returns `Some(delay)` if retryable, `None` if fatal or max retries exceeded.
106    pub fn should_retry(&mut self, error_msg: &str) -> Option<Duration> {
107        let kind = classify_error(error_msg);
108
109        if !kind.is_retryable() {
110            tracing::warn!(
111                error_kind = ?kind,
112                "Non-retryable error, failing immediately"
113            );
114            return None;
115        }
116
117        if self.attempt >= self.config.max_retries {
118            tracing::warn!(
119                attempt = self.attempt,
120                max = self.config.max_retries,
121                "Max retries exceeded"
122            );
123            return None;
124        }
125
126        let delay = self.next_delay();
127        self.attempt += 1;
128
129        tracing::info!(
130            attempt = self.attempt,
131            delay_ms = delay.as_millis() as u64,
132            error_kind = ?kind,
133            "Retrying after transient error"
134        );
135
136        Some(delay)
137    }
138
139    /// Compute the next backoff delay using exponential backoff.
140    fn next_delay(&self) -> Duration {
141        let base = self.config.initial_delay.as_millis() as f64;
142        let multiplied = base * self.config.backoff_multiplier.powi(self.attempt as i32);
143        let capped = multiplied.min(self.config.max_delay.as_millis() as f64);
144        Duration::from_millis(capped as u64)
145    }
146
147    /// Reset retry state (call after a successful operation).
148    pub fn reset(&mut self) {
149        self.attempt = 0;
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_classify_rate_limit() {
159        assert_eq!(
160            classify_error("Rate limit exceeded"),
161            ErrorKind::RateLimited
162        );
163        assert_eq!(
164            classify_error("HTTP 429 Too Many Requests"),
165            ErrorKind::RateLimited
166        );
167    }
168
169    #[test]
170    fn test_classify_network() {
171        assert_eq!(
172            classify_error("Connection refused"),
173            ErrorKind::NetworkError
174        );
175        assert_eq!(
176            classify_error("Request timeout after 30s"),
177            ErrorKind::NetworkError
178        );
179    }
180
181    #[test]
182    fn test_classify_auth() {
183        assert_eq!(classify_error("401 Unauthorized"), ErrorKind::AuthError);
184        assert_eq!(
185            classify_error("Invalid API key provided"),
186            ErrorKind::AuthError
187        );
188    }
189
190    #[test]
191    fn test_classify_context_overflow() {
192        assert_eq!(
193            classify_error("Context length exceeded: 128000 tokens"),
194            ErrorKind::ContextOverflow
195        );
196    }
197
198    #[test]
199    fn test_classify_unknown() {
200        assert_eq!(
201            classify_error("Something weird happened"),
202            ErrorKind::Unknown
203        );
204    }
205
206    #[test]
207    fn test_retryable() {
208        assert!(ErrorKind::RateLimited.is_retryable());
209        assert!(ErrorKind::NetworkError.is_retryable());
210        assert!(ErrorKind::ServerError.is_retryable());
211        assert!(ErrorKind::Unknown.is_retryable());
212        assert!(!ErrorKind::AuthError.is_retryable());
213        assert!(!ErrorKind::InvalidRequest.is_retryable());
214        assert!(!ErrorKind::ContextOverflow.is_retryable());
215    }
216
217    #[test]
218    fn test_retry_handler_backoff() {
219        let config = RetryConfig {
220            max_retries: 3,
221            initial_delay: Duration::from_secs(1),
222            max_delay: Duration::from_secs(30),
223            backoff_multiplier: 2.0,
224        };
225        let mut handler = RetryHandler::new(config);
226
227        // First retry: 1s
228        let delay = handler.should_retry("rate limit exceeded").unwrap();
229        assert_eq!(delay, Duration::from_secs(1));
230
231        // Second retry: 2s
232        let delay = handler.should_retry("rate limit exceeded").unwrap();
233        assert_eq!(delay, Duration::from_secs(2));
234
235        // Third retry: 4s
236        let delay = handler.should_retry("rate limit exceeded").unwrap();
237        assert_eq!(delay, Duration::from_secs(4));
238
239        // Fourth: exceeded max_retries=3
240        assert!(handler.should_retry("rate limit exceeded").is_none());
241    }
242
243    #[test]
244    fn test_retry_handler_non_retryable() {
245        let mut handler = RetryHandler::new(RetryConfig::default());
246        assert!(handler.should_retry("401 Unauthorized").is_none());
247    }
248
249    #[test]
250    fn test_retry_handler_reset() {
251        let mut handler = RetryHandler::new(RetryConfig {
252            max_retries: 1,
253            ..Default::default()
254        });
255
256        handler.should_retry("rate limit").unwrap();
257        assert!(handler.should_retry("rate limit").is_none());
258
259        handler.reset();
260        assert!(handler.should_retry("rate limit").is_some());
261    }
262}