Skip to main content

imp_core/
retry.rs

1use std::pin::Pin;
2use std::time::Duration;
3
4use futures::StreamExt;
5use imp_llm::{provider::RetryPolicy, StreamEvent};
6
7/// Determines whether an `imp_llm::Error` is transient and worth retrying.
8///
9/// Non-retryable errors (auth failures, bad requests) should propagate
10/// immediately — retrying them wastes time and may make things worse.
11pub fn is_retryable(err: &imp_llm::Error) -> bool {
12    match err {
13        // Rate limit — always retry; the provider says to wait.
14        imp_llm::Error::RateLimited { .. } => true,
15        // HTTP transport/body failures: check what kind of reqwest error it is.
16        // `bytes_stream()` surfaces truncated or malformed compressed response
17        // bodies as decode errors (for example: "error decoding response body").
18        // Those are usually transient provider/proxy failures and are safe to
19        // retry before any meaningful stream event has been emitted.
20        imp_llm::Error::Http(e) => {
21            e.is_connect() || e.is_timeout() || e.is_request() || e.is_decode() || e.is_body()
22        }
23        // Stream errors are transient (connection reset, partial read, etc.).
24        imp_llm::Error::Stream(_) => true,
25        // Provider errors may carry an HTTP status in the message. Check for 5xx.
26        imp_llm::Error::Provider(msg) => {
27            msg.contains("HTTP 500")
28                || msg.contains("HTTP 502")
29                || msg.contains("HTTP 503")
30                || msg.contains("HTTP 529")
31        }
32        // Auth errors (401, 403) and bad request (400) are permanent.
33        imp_llm::Error::Auth(_) => false,
34        // Serialization, IO, context-too-long: not transient.
35        imp_llm::Error::Serialization(_)
36        | imp_llm::Error::Io(_)
37        | imp_llm::Error::ContextTooLong { .. } => false,
38    }
39}
40
41/// Compute how long to wait before a retry attempt.
42///
43/// Uses exponential backoff with random jitter in [0, base_delay / 2).
44/// If the error carries a `Retry-After` hint that is within `max_delay`,
45/// that takes precedence.
46pub fn backoff_delay(
47    attempt: u32,
48    policy: &RetryPolicy,
49    retry_after_secs: Option<u64>,
50) -> Option<Duration> {
51    // If the provider told us exactly when to retry, respect it — unless it
52    // exceeds our maximum, in which case we give up immediately.
53    if let Some(secs) = retry_after_secs {
54        let suggested = Duration::from_secs(secs);
55        if suggested > policy.max_delay {
56            return None; // signal: abort, don't retry
57        }
58        return Some(suggested);
59    }
60
61    // Exponential backoff: base * 2^attempt, capped at max_delay.
62    let base_ms = policy.base_delay.as_millis() as u64;
63    let exp_ms = base_ms.saturating_mul(1u64 << attempt.min(10));
64    let capped_ms = exp_ms.min(policy.max_delay.as_millis() as u64);
65
66    // Jitter: add up to 50% of the capped delay to spread retries.
67    // Use timestamp + attempt for cheap pseudo-randomness without a rand dependency.
68    let seed = std::time::SystemTime::now()
69        .duration_since(std::time::UNIX_EPOCH)
70        .unwrap_or_default()
71        .as_nanos() as u64
72        ^ (attempt as u64).wrapping_mul(0x517cc1b727220a95);
73    let jitter_ms = seed % (capped_ms / 2 + 1);
74
75    Some(Duration::from_millis(capped_ms + jitter_ms))
76}
77
78/// Stream an LLM call with automatic retry on transient startup errors.
79///
80/// This preserves true streaming semantics: successful events are forwarded to
81/// the caller as soon as they arrive.
82///
83/// Retry is only transparent before the first meaningful event is emitted.
84/// Leading `MessageStart` events are buffered so we can still retry if the
85/// connection dies before the first text delta / tool call / completed message.
86/// Once any non-`MessageStart` event is forwarded, further errors are surfaced
87/// directly instead of replaying the stream.
88pub fn stream_with_retry<F, S>(
89    mut make_stream: F,
90    policy: RetryPolicy,
91) -> Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
92where
93    F: FnMut() -> S + Send + 'static,
94    S: futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Unpin + Send + 'static,
95{
96    let (tx, rx) = futures::channel::mpsc::unbounded();
97
98    tokio::spawn(async move {
99        let mut attempt = 0u32;
100
101        'attempt: loop {
102            let mut stream = make_stream();
103            let mut buffered_starts: Vec<StreamEvent> = Vec::new();
104            let mut emitted_meaningful_event = false;
105
106            while let Some(item) = stream.next().await {
107                match item {
108                    Ok(event) => {
109                        if !emitted_meaningful_event
110                            && matches!(event, StreamEvent::MessageStart { .. })
111                        {
112                            buffered_starts.push(event);
113                            continue;
114                        }
115
116                        if !emitted_meaningful_event {
117                            emitted_meaningful_event = true;
118                            for buffered in buffered_starts.drain(..) {
119                                if tx.unbounded_send(Ok(buffered)).is_err() {
120                                    return;
121                                }
122                            }
123                        }
124
125                        if tx.unbounded_send(Ok(event)).is_err() {
126                            return;
127                        }
128                    }
129                    Err(err) => {
130                        let retry_after =
131                            if let imp_llm::Error::RateLimited { retry_after_secs } = &err {
132                                *retry_after_secs
133                            } else {
134                                None
135                            };
136
137                        if !emitted_meaningful_event
138                            && is_retryable(&err)
139                            && attempt < policy.max_retries
140                        {
141                            match backoff_delay(attempt, &policy, retry_after) {
142                                None => {
143                                    let _ = tx.unbounded_send(Err(err));
144                                    return;
145                                }
146                                Some(delay) => {
147                                    tokio::time::sleep(delay).await;
148                                    attempt += 1;
149                                    continue 'attempt;
150                                }
151                            }
152                        }
153
154                        let _ = tx.unbounded_send(Err(err));
155                        return;
156                    }
157                }
158            }
159
160            for buffered in buffered_starts {
161                if tx.unbounded_send(Ok(buffered)).is_err() {
162                    return;
163                }
164            }
165
166            return;
167        }
168    });
169
170    Box::pin(rx)
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use imp_llm::provider::RetryCondition;
177
178    fn default_policy() -> RetryPolicy {
179        RetryPolicy {
180            max_retries: 3,
181            base_delay: Duration::from_millis(10), // fast for tests
182            max_delay: Duration::from_millis(100),
183            retry_on: vec![
184                RetryCondition::RateLimit,
185                RetryCondition::ServerError,
186                RetryCondition::Timeout,
187                RetryCondition::ConnectionError,
188            ],
189        }
190    }
191
192    // ── is_retryable ──────────────────────────────────────────────
193
194    #[test]
195    fn rate_limited_is_retryable() {
196        let err = imp_llm::Error::RateLimited {
197            retry_after_secs: Some(5),
198        };
199        assert!(is_retryable(&err));
200    }
201
202    #[test]
203    fn stream_error_is_retryable() {
204        let err = imp_llm::Error::Stream("connection reset".into());
205        assert!(is_retryable(&err));
206    }
207
208    #[tokio::test]
209    async fn http_decode_error_is_retryable() {
210        use tokio::io::{AsyncReadExt, AsyncWriteExt};
211
212        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
213        let addr = listener.local_addr().unwrap();
214
215        tokio::spawn(async move {
216            let (mut socket, _) = listener.accept().await.unwrap();
217            let mut request_buf = [0u8; 1024];
218            let _ = socket.read(&mut request_buf).await;
219            socket
220                .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 999\r\n\r\nnot-g")
221                .await
222                .unwrap();
223        });
224
225        let err = reqwest::get(format!("http://{addr}"))
226            .await
227            .unwrap()
228            .bytes()
229            .await
230            .unwrap_err();
231
232        assert!(err.is_decode() || err.is_body());
233        assert!(is_retryable(&imp_llm::Error::Http(err)));
234    }
235
236    #[test]
237    fn auth_error_is_not_retryable() {
238        let err = imp_llm::Error::Auth("invalid key".into());
239        assert!(!is_retryable(&err));
240    }
241
242    #[test]
243    fn provider_5xx_is_retryable() {
244        let err = imp_llm::Error::Provider("HTTP 503: overloaded".into());
245        assert!(is_retryable(&err));
246    }
247
248    #[test]
249    fn provider_4xx_is_not_retryable() {
250        let err = imp_llm::Error::Provider("HTTP 400: bad request".into());
251        assert!(!is_retryable(&err));
252    }
253
254    #[test]
255    fn provider_401_is_not_retryable() {
256        let err = imp_llm::Error::Provider("HTTP 401: unauthorized".into());
257        assert!(!is_retryable(&err));
258    }
259
260    // ── backoff_delay ─────────────────────────────────────────────
261
262    #[test]
263    fn backoff_grows_exponentially() {
264        let policy = default_policy();
265        let d0 = backoff_delay(0, &policy, None).unwrap();
266        let d1 = backoff_delay(1, &policy, None).unwrap();
267        let d2 = backoff_delay(2, &policy, None).unwrap();
268        // Each step should generally be larger (accounting for jitter).
269        // At minimum: base*2^0=10ms, base*2^1=20ms, base*2^2=40ms
270        // With jitter added, d1 >= 20ms and d2 >= 40ms (before jitter on d0).
271        // We can only assert upper bounds reliably given jitter.
272        assert!(d0 <= Duration::from_millis(200)); // 10ms base + 50% jitter, capped
273        assert!(d1 >= Duration::from_millis(20));
274        assert!(d2 >= Duration::from_millis(40));
275    }
276
277    #[test]
278    fn backoff_capped_at_max_delay() {
279        let policy = default_policy(); // max 100ms
280                                       // Attempt 10 would be base(10ms) * 2^10 = 10_240ms → capped at 100ms
281        let delay = backoff_delay(10, &policy, None).unwrap();
282        assert!(delay <= Duration::from_millis(200)); // cap + up to 50% jitter of cap
283    }
284
285    #[test]
286    fn retry_after_respected_within_limit() {
287        let policy = default_policy(); // max 100ms
288        let delay = backoff_delay(0, &policy, Some(0)).unwrap();
289        assert_eq!(delay, Duration::from_secs(0));
290    }
291
292    #[test]
293    fn retry_after_exceeds_max_delay_returns_none() {
294        let policy = default_policy(); // max 100ms
295        let result = backoff_delay(0, &policy, Some(10)); // 10s > 100ms
296        assert!(result.is_none());
297    }
298
299    // ── stream_with_retry ────────────────────────────────────────
300
301    #[tokio::test]
302    async fn retry_succeeds_after_transient_failures_before_first_meaningful_event() {
303        use std::sync::{Arc, Mutex};
304
305        let call_count = Arc::new(Mutex::new(0u32));
306
307        let policy = RetryPolicy {
308            max_retries: 3,
309            base_delay: Duration::from_millis(1),
310            max_delay: Duration::from_millis(50),
311            retry_on: vec![RetryCondition::ServerError],
312        };
313
314        let call_count_clone = call_count.clone();
315        let mut stream = stream_with_retry(
316            move || {
317                let mut count = call_count_clone.lock().unwrap();
318                *count += 1;
319                let attempt = *count;
320                drop(count);
321
322                if attempt < 3 {
323                    let events: Vec<imp_llm::Result<StreamEvent>> = vec![
324                        Ok(StreamEvent::MessageStart {
325                            model: "test".into(),
326                        }),
327                        Err(imp_llm::Error::Stream("transient".into())),
328                    ];
329                    futures::stream::iter(events)
330                } else {
331                    let events: Vec<imp_llm::Result<StreamEvent>> = vec![
332                        Ok(StreamEvent::MessageStart {
333                            model: "test".into(),
334                        }),
335                        Ok(StreamEvent::TextDelta {
336                            text: "hello".into(),
337                        }),
338                    ];
339                    futures::stream::iter(events)
340                }
341            },
342            policy,
343        );
344
345        let mut result = Vec::new();
346        while let Some(item) = stream.next().await {
347            result.push(item);
348        }
349
350        assert_eq!(*call_count.lock().unwrap(), 3);
351        assert_eq!(result.len(), 2);
352        assert!(matches!(result[0], Ok(StreamEvent::MessageStart { .. })));
353        assert!(matches!(result[1], Ok(StreamEvent::TextDelta { .. })));
354    }
355
356    #[tokio::test]
357    async fn retry_exhausts_max_retries_before_first_meaningful_event() {
358        use std::sync::{Arc, Mutex};
359
360        let call_count = Arc::new(Mutex::new(0u32));
361
362        let policy = RetryPolicy {
363            max_retries: 2,
364            base_delay: Duration::from_millis(1),
365            max_delay: Duration::from_millis(50),
366            retry_on: vec![RetryCondition::ServerError],
367        };
368
369        let call_count_clone = call_count.clone();
370        let mut stream = stream_with_retry(
371            move || {
372                *call_count_clone.lock().unwrap() += 1;
373                let events: Vec<imp_llm::Result<StreamEvent>> =
374                    vec![Err(imp_llm::Error::Stream("always fails".into()))];
375                futures::stream::iter(events)
376            },
377            policy,
378        );
379
380        let mut result = Vec::new();
381        while let Some(item) = stream.next().await {
382            result.push(item);
383        }
384
385        assert_eq!(*call_count.lock().unwrap(), 3);
386        assert_eq!(result.len(), 1);
387        assert!(matches!(result[0], Err(imp_llm::Error::Stream(_))));
388    }
389
390    #[tokio::test]
391    async fn retry_skips_non_retryable_errors() {
392        use std::sync::{Arc, Mutex};
393
394        let call_count = Arc::new(Mutex::new(0u32));
395
396        let policy = RetryPolicy {
397            max_retries: 3,
398            base_delay: Duration::from_millis(1),
399            max_delay: Duration::from_millis(50),
400            retry_on: vec![RetryCondition::ServerError],
401        };
402
403        let call_count_clone = call_count.clone();
404        let mut stream = stream_with_retry(
405            move || {
406                *call_count_clone.lock().unwrap() += 1;
407                let events: Vec<imp_llm::Result<StreamEvent>> =
408                    vec![Err(imp_llm::Error::Auth("invalid key".into()))];
409                futures::stream::iter(events)
410            },
411            policy,
412        );
413
414        let mut result = Vec::new();
415        while let Some(item) = stream.next().await {
416            result.push(item);
417        }
418
419        assert_eq!(*call_count.lock().unwrap(), 1);
420        assert_eq!(result.len(), 1);
421        assert!(matches!(result[0], Err(imp_llm::Error::Auth(_))));
422    }
423
424    #[tokio::test]
425    async fn retry_no_error_passes_through() {
426        let policy = default_policy();
427
428        let mut stream = stream_with_retry(
429            || {
430                let events: Vec<imp_llm::Result<StreamEvent>> = vec![
431                    Ok(StreamEvent::MessageStart {
432                        model: "test".into(),
433                    }),
434                    Ok(StreamEvent::TextDelta { text: "ok".into() }),
435                ];
436                futures::stream::iter(events)
437            },
438            policy,
439        );
440
441        let mut result = Vec::new();
442        while let Some(item) = stream.next().await {
443            result.push(item);
444        }
445
446        assert_eq!(result.len(), 2);
447    }
448
449    #[tokio::test]
450    async fn retry_does_not_replay_after_meaningful_event_has_streamed() {
451        use std::sync::{Arc, Mutex};
452
453        let call_count = Arc::new(Mutex::new(0u32));
454        let policy = default_policy();
455        let call_count_clone = call_count.clone();
456
457        let mut stream = stream_with_retry(
458            move || {
459                *call_count_clone.lock().unwrap() += 1;
460                let events: Vec<imp_llm::Result<StreamEvent>> = vec![
461                    Ok(StreamEvent::TextDelta {
462                        text: "partial".into(),
463                    }),
464                    Err(imp_llm::Error::Stream("boom".into())),
465                ];
466                futures::stream::iter(events)
467            },
468            policy,
469        );
470
471        let mut result = Vec::new();
472        while let Some(item) = stream.next().await {
473            result.push(item);
474        }
475
476        assert_eq!(*call_count.lock().unwrap(), 1);
477        assert_eq!(result.len(), 2);
478        assert!(matches!(result[0], Ok(StreamEvent::TextDelta { .. })));
479        assert!(matches!(result[1], Err(imp_llm::Error::Stream(_))));
480    }
481}