Skip to main content

sgr_agent/
retry.rs

1//! RetryClient — wraps LlmClient with exponential backoff for transient errors.
2//!
3//! Retries on: rate limits (429), server errors (5xx), empty responses, network errors.
4//! Honors `retry_after_secs` from rate limit headers when available.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12/// Retry configuration.
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Max retry attempts (0 = no retries).
16    pub max_retries: usize,
17    /// Base delay in milliseconds.
18    pub base_delay_ms: u64,
19    /// Max delay cap in milliseconds.
20    pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            base_delay_ms: 500,
28            max_delay_ms: 30_000,
29        }
30    }
31}
32
33/// Determine if an error is retryable (transient: rate limit, timeout, server errors).
34pub fn is_retryable(err: &SgrError) -> bool {
35    match err {
36        SgrError::RateLimit { .. } => true,
37        SgrError::EmptyResponse => true,
38        // reqwest::Error — retryable if timeout or connect error
39        SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40        SgrError::Api { status, body } => {
41            *status >= 500
42                || *status == 408
43                || *status == 429
44                // Cloudflare Workers AI transient errors (bad JSON parse on their side)
45                || (*status == 400 && body.contains("AiError"))
46        }
47        // Empty response wrapped as Schema error — transient model behavior
48        SgrError::Schema(msg) => msg.contains("Empty response"),
49        _ => false,
50    }
51}
52
53/// Calculate delay for attempt N, honoring rate limit headers.
54pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
55    // Honor retry-after header from rate limit
56    if let Some(info) = err.rate_limit_info()
57        && let Some(secs) = info.retry_after_secs
58    {
59        return Duration::from_secs(secs + 1); // +1s safety margin
60    }
61
62    // Exponential backoff: base * 2^attempt, capped at max
63    let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
64    // Add jitter ±10%
65    let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
66    Duration::from_millis(delay_ms.saturating_add(jitter))
67}
68
69/// LLM client wrapper with automatic retry on transient errors.
70pub struct RetryClient<C: LlmClient> {
71    inner: C,
72    config: RetryConfig,
73}
74
75impl<C: LlmClient> RetryClient<C> {
76    pub fn new(inner: C) -> Self {
77        Self {
78            inner,
79            config: RetryConfig::default(),
80        }
81    }
82
83    pub fn with_config(mut self, config: RetryConfig) -> Self {
84        self.config = config;
85        self
86    }
87}
88
89#[async_trait::async_trait]
90impl<C: LlmClient> LlmClient for RetryClient<C> {
91    async fn structured_call(
92        &self,
93        messages: &[Message],
94        schema: &Value,
95    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
96        let mut last_err = None;
97        for attempt in 0..=self.config.max_retries {
98            match self.inner.structured_call(messages, schema).await {
99                Ok(result) => return Ok(result),
100                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
101                    let delay = delay_for_attempt(attempt, &self.config, &e);
102                    tracing::warn!(
103                        attempt = attempt + 1,
104                        max = self.config.max_retries,
105                        delay_ms = delay.as_millis() as u64,
106                        "Retrying structured_call: {}",
107                        e
108                    );
109                    tokio::time::sleep(delay).await;
110                    last_err = Some(e);
111                }
112                Err(e) => return Err(e),
113            }
114        }
115        Err(last_err.unwrap())
116    }
117
118    async fn tools_call(
119        &self,
120        messages: &[Message],
121        tools: &[ToolDef],
122    ) -> Result<Vec<ToolCall>, SgrError> {
123        let mut last_err = None;
124        for attempt in 0..=self.config.max_retries {
125            match self.inner.tools_call(messages, tools).await {
126                Ok(result) => return Ok(result),
127                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
128                    let delay = delay_for_attempt(attempt, &self.config, &e);
129                    tracing::warn!(
130                        attempt = attempt + 1,
131                        max = self.config.max_retries,
132                        delay_ms = delay.as_millis() as u64,
133                        "Retrying tools_call: {}",
134                        e
135                    );
136                    tokio::time::sleep(delay).await;
137                    last_err = Some(e);
138                }
139                Err(e) => return Err(e),
140            }
141        }
142        Err(last_err.unwrap())
143    }
144
145    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
146        let mut last_err = None;
147        for attempt in 0..=self.config.max_retries {
148            match self.inner.complete(messages).await {
149                Ok(result) => return Ok(result),
150                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
151                    let delay = delay_for_attempt(attempt, &self.config, &e);
152                    tracing::warn!(
153                        attempt = attempt + 1,
154                        max = self.config.max_retries,
155                        delay_ms = delay.as_millis() as u64,
156                        "Retrying complete: {}",
157                        e
158                    );
159                    tokio::time::sleep(delay).await;
160                    last_err = Some(e);
161                }
162                Err(e) => return Err(e),
163            }
164        }
165        Err(last_err.unwrap())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use std::sync::Arc;
173    use std::sync::atomic::{AtomicUsize, Ordering};
174
175    struct FailingClient {
176        fail_count: usize,
177        call_count: Arc<AtomicUsize>,
178    }
179
180    #[async_trait::async_trait]
181    impl LlmClient for FailingClient {
182        async fn structured_call(
183            &self,
184            _: &[Message],
185            _: &Value,
186        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
187            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
188            if n < self.fail_count {
189                Err(SgrError::EmptyResponse)
190            } else {
191                Ok((None, vec![], "ok".into()))
192            }
193        }
194        async fn tools_call(
195            &self,
196            _: &[Message],
197            _: &[ToolDef],
198        ) -> Result<Vec<ToolCall>, SgrError> {
199            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
200            if n < self.fail_count {
201                Err(SgrError::Api {
202                    status: 500,
203                    body: "internal error".into(),
204                })
205            } else {
206                Ok(vec![])
207            }
208        }
209        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
210            Ok("ok".into())
211        }
212    }
213
214    #[tokio::test]
215    async fn retries_on_empty_response() {
216        let count = Arc::new(AtomicUsize::new(0));
217        let client = RetryClient::new(FailingClient {
218            fail_count: 2,
219            call_count: count.clone(),
220        })
221        .with_config(RetryConfig {
222            max_retries: 3,
223            base_delay_ms: 1,
224            max_delay_ms: 10,
225        });
226
227        let result = client
228            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
229            .await;
230        assert!(result.is_ok());
231        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
232    }
233
234    #[tokio::test]
235    async fn retries_on_server_error() {
236        let count = Arc::new(AtomicUsize::new(0));
237        let client = RetryClient::new(FailingClient {
238            fail_count: 1,
239            call_count: count.clone(),
240        })
241        .with_config(RetryConfig {
242            max_retries: 2,
243            base_delay_ms: 1,
244            max_delay_ms: 10,
245        });
246
247        let result = client.tools_call(&[Message::user("hi")], &[]).await;
248        assert!(result.is_ok());
249        assert_eq!(count.load(Ordering::SeqCst), 2);
250    }
251
252    #[tokio::test]
253    async fn fails_after_max_retries() {
254        let count = Arc::new(AtomicUsize::new(0));
255        let client = RetryClient::new(FailingClient {
256            fail_count: 10,
257            call_count: count.clone(),
258        })
259        .with_config(RetryConfig {
260            max_retries: 2,
261            base_delay_ms: 1,
262            max_delay_ms: 10,
263        });
264
265        let result = client
266            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
267            .await;
268        assert!(result.is_err());
269        assert_eq!(count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
270    }
271
272    #[test]
273    fn non_retryable_errors() {
274        assert!(!is_retryable(&SgrError::Api {
275            status: 400,
276            body: "bad request".into()
277        }));
278        assert!(!is_retryable(&SgrError::Schema("parse".into())));
279        assert!(is_retryable(&SgrError::Schema(
280            "Empty response from model (parts: text)".into()
281        )));
282        assert!(is_retryable(&SgrError::EmptyResponse));
283        assert!(is_retryable(&SgrError::Api {
284            status: 503,
285            body: "server error".into()
286        }));
287        assert!(is_retryable(&SgrError::Api {
288            status: 429,
289            body: "rate limit".into()
290        }));
291    }
292
293    #[test]
294    fn delay_exponential_backoff() {
295        let config = RetryConfig {
296            max_retries: 5,
297            base_delay_ms: 100,
298            max_delay_ms: 5000,
299        };
300        let err = SgrError::EmptyResponse;
301
302        let d0 = delay_for_attempt(0, &config, &err);
303        let d1 = delay_for_attempt(1, &config, &err);
304        let d2 = delay_for_attempt(2, &config, &err);
305
306        // Roughly 100ms, 200ms, 400ms (with jitter)
307        assert!(d0.as_millis() <= 150);
308        assert!(d1.as_millis() <= 250);
309        assert!(d2.as_millis() <= 500);
310    }
311
312    #[test]
313    fn delay_capped_at_max() {
314        let config = RetryConfig {
315            max_retries: 10,
316            base_delay_ms: 1000,
317            max_delay_ms: 5000,
318        };
319        let err = SgrError::EmptyResponse;
320
321        let d10 = delay_for_attempt(10, &config, &err);
322        assert!(d10.as_millis() <= 5500); // max + jitter
323    }
324}