Skip to main content

sgr_agent/
retry.rs

1//! RetryClient — wraps LlmClient with exponential backoff for transient errors.
2//!
3//! Retries on: rate limits (429), server errors (5xx), empty responses, network errors.
4//! Honors `retry_after_secs` from rate limit headers when available.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12/// Retry configuration.
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Max retry attempts (0 = no retries).
16    pub max_retries: usize,
17    /// Base delay in milliseconds.
18    pub base_delay_ms: u64,
19    /// Max delay cap in milliseconds.
20    pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            base_delay_ms: 500,
28            max_delay_ms: 30_000,
29        }
30    }
31}
32
33/// Determine if an error is retryable (transient: rate limit, timeout, server errors).
34pub fn is_retryable(err: &SgrError) -> bool {
35    match err {
36        SgrError::RateLimit { .. } => true,
37        SgrError::EmptyResponse => true,
38        // reqwest::Error — retryable if timeout or connect error
39        SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40        SgrError::Api { status, .. } => {
41            *status == 0 || *status >= 500 || *status == 408 || *status == 429
42        }
43        // Empty response wrapped as Schema error — transient model behavior
44        SgrError::Schema(msg) => msg.contains("Empty response"),
45        // MaxOutputTokens and PromptTooLong are NOT retryable at this level —
46        // they are handled by the agent loop with special recovery logic
47        SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48        _ => false,
49    }
50}
51
52/// Calculate delay for attempt N, honoring rate limit headers.
53pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
54    // Honor retry-after header from rate limit
55    if let Some(info) = err.rate_limit_info()
56        && let Some(secs) = info.retry_after_secs
57    {
58        return Duration::from_secs(secs + 1); // +1s safety margin
59    }
60
61    // Exponential backoff: base * 2^attempt, capped at max
62    let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
63    // Add jitter ±10%
64    let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
65    Duration::from_millis(delay_ms.saturating_add(jitter))
66}
67
68/// LLM client wrapper with automatic retry on transient errors.
69pub struct RetryClient<C: LlmClient> {
70    inner: C,
71    config: RetryConfig,
72}
73
74impl<C: LlmClient> RetryClient<C> {
75    pub fn new(inner: C) -> Self {
76        Self {
77            inner,
78            config: RetryConfig::default(),
79        }
80    }
81
82    pub fn with_config(mut self, config: RetryConfig) -> Self {
83        self.config = config;
84        self
85    }
86}
87
88#[async_trait::async_trait]
89impl<C: LlmClient> LlmClient for RetryClient<C> {
90    async fn structured_call(
91        &self,
92        messages: &[Message],
93        schema: &Value,
94    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
95        let mut last_err = None;
96        for attempt in 0..=self.config.max_retries {
97            match self.inner.structured_call(messages, schema).await {
98                Ok(result) => return Ok(result),
99                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
100                    let delay = delay_for_attempt(attempt, &self.config, &e);
101                    tracing::warn!(
102                        attempt = attempt + 1,
103                        max = self.config.max_retries,
104                        delay_ms = delay.as_millis() as u64,
105                        "Retrying structured_call: {}",
106                        e
107                    );
108                    tokio::time::sleep(delay).await;
109                    last_err = Some(e);
110                }
111                Err(e) => return Err(e),
112            }
113        }
114        Err(last_err.unwrap())
115    }
116
117    async fn tools_call(
118        &self,
119        messages: &[Message],
120        tools: &[ToolDef],
121    ) -> Result<Vec<ToolCall>, SgrError> {
122        let mut last_err = None;
123        for attempt in 0..=self.config.max_retries {
124            match self.inner.tools_call(messages, tools).await {
125                Ok(result) => return Ok(result),
126                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
127                    let delay = delay_for_attempt(attempt, &self.config, &e);
128                    tracing::warn!(
129                        attempt = attempt + 1,
130                        max = self.config.max_retries,
131                        delay_ms = delay.as_millis() as u64,
132                        "Retrying tools_call: {}",
133                        e
134                    );
135                    tokio::time::sleep(delay).await;
136                    last_err = Some(e);
137                }
138                Err(e) => return Err(e),
139            }
140        }
141        Err(last_err.unwrap())
142    }
143
144    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
145        let mut last_err = None;
146        for attempt in 0..=self.config.max_retries {
147            match self.inner.complete(messages).await {
148                Ok(result) => return Ok(result),
149                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
150                    let delay = delay_for_attempt(attempt, &self.config, &e);
151                    tracing::warn!(
152                        attempt = attempt + 1,
153                        max = self.config.max_retries,
154                        delay_ms = delay.as_millis() as u64,
155                        "Retrying complete: {}",
156                        e
157                    );
158                    tokio::time::sleep(delay).await;
159                    last_err = Some(e);
160                }
161                Err(e) => return Err(e),
162            }
163        }
164        Err(last_err.unwrap())
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::sync::Arc;
172    use std::sync::atomic::{AtomicUsize, Ordering};
173
174    struct FailingClient {
175        fail_count: usize,
176        call_count: Arc<AtomicUsize>,
177    }
178
179    #[async_trait::async_trait]
180    impl LlmClient for FailingClient {
181        async fn structured_call(
182            &self,
183            _: &[Message],
184            _: &Value,
185        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
186            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
187            if n < self.fail_count {
188                Err(SgrError::EmptyResponse)
189            } else {
190                Ok((None, vec![], "ok".into()))
191            }
192        }
193        async fn tools_call(
194            &self,
195            _: &[Message],
196            _: &[ToolDef],
197        ) -> Result<Vec<ToolCall>, SgrError> {
198            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
199            if n < self.fail_count {
200                Err(SgrError::Api {
201                    status: 500,
202                    body: "internal error".into(),
203                })
204            } else {
205                Ok(vec![])
206            }
207        }
208        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
209            Ok("ok".into())
210        }
211    }
212
213    #[tokio::test]
214    async fn retries_on_empty_response() {
215        let count = Arc::new(AtomicUsize::new(0));
216        let client = RetryClient::new(FailingClient {
217            fail_count: 2,
218            call_count: count.clone(),
219        })
220        .with_config(RetryConfig {
221            max_retries: 3,
222            base_delay_ms: 1,
223            max_delay_ms: 10,
224        });
225
226        let result = client
227            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
228            .await;
229        assert!(result.is_ok());
230        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
231    }
232
233    #[tokio::test]
234    async fn retries_on_server_error() {
235        let count = Arc::new(AtomicUsize::new(0));
236        let client = RetryClient::new(FailingClient {
237            fail_count: 1,
238            call_count: count.clone(),
239        })
240        .with_config(RetryConfig {
241            max_retries: 2,
242            base_delay_ms: 1,
243            max_delay_ms: 10,
244        });
245
246        let result = client.tools_call(&[Message::user("hi")], &[]).await;
247        assert!(result.is_ok());
248        assert_eq!(count.load(Ordering::SeqCst), 2);
249    }
250
251    #[tokio::test]
252    async fn fails_after_max_retries() {
253        let count = Arc::new(AtomicUsize::new(0));
254        let client = RetryClient::new(FailingClient {
255            fail_count: 10,
256            call_count: count.clone(),
257        })
258        .with_config(RetryConfig {
259            max_retries: 2,
260            base_delay_ms: 1,
261            max_delay_ms: 10,
262        });
263
264        let result = client
265            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
266            .await;
267        assert!(result.is_err());
268        assert_eq!(count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
269    }
270
271    #[test]
272    fn non_retryable_errors() {
273        assert!(!is_retryable(&SgrError::Api {
274            status: 400,
275            body: "bad request".into()
276        }));
277        assert!(!is_retryable(&SgrError::Schema("parse".into())));
278        assert!(is_retryable(&SgrError::Schema(
279            "Empty response from model (parts: text)".into()
280        )));
281        assert!(is_retryable(&SgrError::EmptyResponse));
282        assert!(is_retryable(&SgrError::Api {
283            status: 503,
284            body: "server error".into()
285        }));
286        assert!(is_retryable(&SgrError::Api {
287            status: 429,
288            body: "rate limit".into()
289        }));
290    }
291
292    #[test]
293    fn delay_exponential_backoff() {
294        let config = RetryConfig {
295            max_retries: 5,
296            base_delay_ms: 100,
297            max_delay_ms: 5000,
298        };
299        let err = SgrError::EmptyResponse;
300
301        let d0 = delay_for_attempt(0, &config, &err);
302        let d1 = delay_for_attempt(1, &config, &err);
303        let d2 = delay_for_attempt(2, &config, &err);
304
305        // Roughly 100ms, 200ms, 400ms (with jitter)
306        assert!(d0.as_millis() <= 150);
307        assert!(d1.as_millis() <= 250);
308        assert!(d2.as_millis() <= 500);
309    }
310
311    #[test]
312    fn delay_capped_at_max() {
313        let config = RetryConfig {
314            max_retries: 10,
315            base_delay_ms: 1000,
316            max_delay_ms: 5000,
317        };
318        let err = SgrError::EmptyResponse;
319
320        let d10 = delay_for_attempt(10, &config, &err);
321        assert!(d10.as_millis() <= 5500); // max + jitter
322    }
323}