Skip to main content

imp_core/
retry.rs

1use std::pin::Pin;
2use std::time::Duration;
3
4use futures::StreamExt;
5use imp_llm::{provider::RetryPolicy, StreamEvent};
6
7/// Determines whether an `imp_llm::Error` is transient and worth retrying.
8///
9/// Non-retryable errors (auth failures, bad requests) should propagate
10/// immediately — retrying them wastes time and may make things worse.
11pub fn is_retryable(err: &imp_llm::Error) -> bool {
12    match err {
13        // Rate limit — always retry; the provider says to wait.
14        imp_llm::Error::RateLimited { .. } => true,
15        // HTTP transport failures: check what kind of reqwest error it is.
16        imp_llm::Error::Http(e) => e.is_connect() || e.is_timeout() || e.is_request(),
17        // Stream errors are transient (connection reset, partial read, etc.).
18        imp_llm::Error::Stream(_) => true,
19        // Provider errors may carry an HTTP status in the message. Check for 5xx.
20        imp_llm::Error::Provider(msg) => {
21            msg.contains("HTTP 500")
22                || msg.contains("HTTP 502")
23                || msg.contains("HTTP 503")
24                || msg.contains("HTTP 529")
25        }
26        // Auth errors (401, 403) and bad request (400) are permanent.
27        imp_llm::Error::Auth(_) => false,
28        // Serialization, IO, context-too-long: not transient.
29        imp_llm::Error::Serialization(_)
30        | imp_llm::Error::Io(_)
31        | imp_llm::Error::ContextTooLong { .. } => false,
32    }
33}
34
35/// Compute how long to wait before a retry attempt.
36///
37/// Uses exponential backoff with random jitter in [0, base_delay / 2).
38/// If the error carries a `Retry-After` hint that is within `max_delay`,
39/// that takes precedence.
40pub fn backoff_delay(
41    attempt: u32,
42    policy: &RetryPolicy,
43    retry_after_secs: Option<u64>,
44) -> Option<Duration> {
45    // If the provider told us exactly when to retry, respect it — unless it
46    // exceeds our maximum, in which case we give up immediately.
47    if let Some(secs) = retry_after_secs {
48        let suggested = Duration::from_secs(secs);
49        if suggested > policy.max_delay {
50            return None; // signal: abort, don't retry
51        }
52        return Some(suggested);
53    }
54
55    // Exponential backoff: base * 2^attempt, capped at max_delay.
56    let base_ms = policy.base_delay.as_millis() as u64;
57    let exp_ms = base_ms.saturating_mul(1u64 << attempt.min(10));
58    let capped_ms = exp_ms.min(policy.max_delay.as_millis() as u64);
59
60    // Jitter: add up to 50% of the capped delay to spread retries.
61    // Use timestamp + attempt for cheap pseudo-randomness without a rand dependency.
62    let seed = std::time::SystemTime::now()
63        .duration_since(std::time::UNIX_EPOCH)
64        .unwrap_or_default()
65        .as_nanos() as u64
66        ^ (attempt as u64).wrapping_mul(0x517cc1b727220a95);
67    let jitter_ms = seed % (capped_ms / 2 + 1);
68
69    Some(Duration::from_millis(capped_ms + jitter_ms))
70}
71
72/// Stream an LLM call with automatic retry on transient startup errors.
73///
74/// This preserves true streaming semantics: successful events are forwarded to
75/// the caller as soon as they arrive.
76///
77/// Retry is only transparent before the first meaningful event is emitted.
78/// Leading `MessageStart` events are buffered so we can still retry if the
79/// connection dies before the first text delta / tool call / completed message.
80/// Once any non-`MessageStart` event is forwarded, further errors are surfaced
81/// directly instead of replaying the stream.
82pub fn stream_with_retry<F, S>(
83    mut make_stream: F,
84    policy: RetryPolicy,
85) -> Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
86where
87    F: FnMut() -> S + Send + 'static,
88    S: futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Unpin + Send + 'static,
89{
90    let (tx, rx) = futures::channel::mpsc::unbounded();
91
92    tokio::spawn(async move {
93        let mut attempt = 0u32;
94
95        'attempt: loop {
96            let mut stream = make_stream();
97            let mut buffered_starts: Vec<StreamEvent> = Vec::new();
98            let mut emitted_meaningful_event = false;
99
100            while let Some(item) = stream.next().await {
101                match item {
102                    Ok(event) => {
103                        if !emitted_meaningful_event
104                            && matches!(event, StreamEvent::MessageStart { .. })
105                        {
106                            buffered_starts.push(event);
107                            continue;
108                        }
109
110                        if !emitted_meaningful_event {
111                            emitted_meaningful_event = true;
112                            for buffered in buffered_starts.drain(..) {
113                                if tx.unbounded_send(Ok(buffered)).is_err() {
114                                    return;
115                                }
116                            }
117                        }
118
119                        if tx.unbounded_send(Ok(event)).is_err() {
120                            return;
121                        }
122                    }
123                    Err(err) => {
124                        let retry_after =
125                            if let imp_llm::Error::RateLimited { retry_after_secs } = &err {
126                                *retry_after_secs
127                            } else {
128                                None
129                            };
130
131                        if !emitted_meaningful_event
132                            && is_retryable(&err)
133                            && attempt < policy.max_retries
134                        {
135                            match backoff_delay(attempt, &policy, retry_after) {
136                                None => {
137                                    let _ = tx.unbounded_send(Err(err));
138                                    return;
139                                }
140                                Some(delay) => {
141                                    tokio::time::sleep(delay).await;
142                                    attempt += 1;
143                                    continue 'attempt;
144                                }
145                            }
146                        }
147
148                        let _ = tx.unbounded_send(Err(err));
149                        return;
150                    }
151                }
152            }
153
154            for buffered in buffered_starts {
155                if tx.unbounded_send(Ok(buffered)).is_err() {
156                    return;
157                }
158            }
159
160            return;
161        }
162    });
163
164    Box::pin(rx)
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use imp_llm::provider::RetryCondition;
171
172    fn default_policy() -> RetryPolicy {
173        RetryPolicy {
174            max_retries: 3,
175            base_delay: Duration::from_millis(10), // fast for tests
176            max_delay: Duration::from_millis(100),
177            retry_on: vec![
178                RetryCondition::RateLimit,
179                RetryCondition::ServerError,
180                RetryCondition::Timeout,
181                RetryCondition::ConnectionError,
182            ],
183        }
184    }
185
186    // ── is_retryable ──────────────────────────────────────────────
187
188    #[test]
189    fn rate_limited_is_retryable() {
190        let err = imp_llm::Error::RateLimited {
191            retry_after_secs: Some(5),
192        };
193        assert!(is_retryable(&err));
194    }
195
196    #[test]
197    fn stream_error_is_retryable() {
198        let err = imp_llm::Error::Stream("connection reset".into());
199        assert!(is_retryable(&err));
200    }
201
202    #[test]
203    fn auth_error_is_not_retryable() {
204        let err = imp_llm::Error::Auth("invalid key".into());
205        assert!(!is_retryable(&err));
206    }
207
208    #[test]
209    fn provider_5xx_is_retryable() {
210        let err = imp_llm::Error::Provider("HTTP 503: overloaded".into());
211        assert!(is_retryable(&err));
212    }
213
214    #[test]
215    fn provider_4xx_is_not_retryable() {
216        let err = imp_llm::Error::Provider("HTTP 400: bad request".into());
217        assert!(!is_retryable(&err));
218    }
219
220    #[test]
221    fn provider_401_is_not_retryable() {
222        let err = imp_llm::Error::Provider("HTTP 401: unauthorized".into());
223        assert!(!is_retryable(&err));
224    }
225
226    // ── backoff_delay ─────────────────────────────────────────────
227
228    #[test]
229    fn backoff_grows_exponentially() {
230        let policy = default_policy();
231        let d0 = backoff_delay(0, &policy, None).unwrap();
232        let d1 = backoff_delay(1, &policy, None).unwrap();
233        let d2 = backoff_delay(2, &policy, None).unwrap();
234        // Each step should generally be larger (accounting for jitter).
235        // At minimum: base*2^0=10ms, base*2^1=20ms, base*2^2=40ms
236        // With jitter added, d1 >= 20ms and d2 >= 40ms (before jitter on d0).
237        // We can only assert upper bounds reliably given jitter.
238        assert!(d0 <= Duration::from_millis(200)); // 10ms base + 50% jitter, capped
239        assert!(d1 >= Duration::from_millis(20));
240        assert!(d2 >= Duration::from_millis(40));
241    }
242
243    #[test]
244    fn backoff_capped_at_max_delay() {
245        let policy = default_policy(); // max 100ms
246                                       // Attempt 10 would be base(10ms) * 2^10 = 10_240ms → capped at 100ms
247        let delay = backoff_delay(10, &policy, None).unwrap();
248        assert!(delay <= Duration::from_millis(200)); // cap + up to 50% jitter of cap
249    }
250
251    #[test]
252    fn retry_after_respected_within_limit() {
253        let policy = default_policy(); // max 100ms
254        let delay = backoff_delay(0, &policy, Some(0)).unwrap();
255        assert_eq!(delay, Duration::from_secs(0));
256    }
257
258    #[test]
259    fn retry_after_exceeds_max_delay_returns_none() {
260        let policy = default_policy(); // max 100ms
261        let result = backoff_delay(0, &policy, Some(10)); // 10s > 100ms
262        assert!(result.is_none());
263    }
264
265    // ── stream_with_retry ────────────────────────────────────────
266
267    #[tokio::test]
268    async fn retry_succeeds_after_transient_failures_before_first_meaningful_event() {
269        use std::sync::{Arc, Mutex};
270
271        let call_count = Arc::new(Mutex::new(0u32));
272
273        let policy = RetryPolicy {
274            max_retries: 3,
275            base_delay: Duration::from_millis(1),
276            max_delay: Duration::from_millis(50),
277            retry_on: vec![RetryCondition::ServerError],
278        };
279
280        let call_count_clone = call_count.clone();
281        let mut stream = stream_with_retry(
282            move || {
283                let mut count = call_count_clone.lock().unwrap();
284                *count += 1;
285                let attempt = *count;
286                drop(count);
287
288                if attempt < 3 {
289                    let events: Vec<imp_llm::Result<StreamEvent>> = vec![
290                        Ok(StreamEvent::MessageStart {
291                            model: "test".into(),
292                        }),
293                        Err(imp_llm::Error::Stream("transient".into())),
294                    ];
295                    futures::stream::iter(events)
296                } else {
297                    let events: Vec<imp_llm::Result<StreamEvent>> = vec![
298                        Ok(StreamEvent::MessageStart {
299                            model: "test".into(),
300                        }),
301                        Ok(StreamEvent::TextDelta {
302                            text: "hello".into(),
303                        }),
304                    ];
305                    futures::stream::iter(events)
306                }
307            },
308            policy,
309        );
310
311        let mut result = Vec::new();
312        while let Some(item) = stream.next().await {
313            result.push(item);
314        }
315
316        assert_eq!(*call_count.lock().unwrap(), 3);
317        assert_eq!(result.len(), 2);
318        assert!(matches!(result[0], Ok(StreamEvent::MessageStart { .. })));
319        assert!(matches!(result[1], Ok(StreamEvent::TextDelta { .. })));
320    }
321
322    #[tokio::test]
323    async fn retry_exhausts_max_retries_before_first_meaningful_event() {
324        use std::sync::{Arc, Mutex};
325
326        let call_count = Arc::new(Mutex::new(0u32));
327
328        let policy = RetryPolicy {
329            max_retries: 2,
330            base_delay: Duration::from_millis(1),
331            max_delay: Duration::from_millis(50),
332            retry_on: vec![RetryCondition::ServerError],
333        };
334
335        let call_count_clone = call_count.clone();
336        let mut stream = stream_with_retry(
337            move || {
338                *call_count_clone.lock().unwrap() += 1;
339                let events: Vec<imp_llm::Result<StreamEvent>> =
340                    vec![Err(imp_llm::Error::Stream("always fails".into()))];
341                futures::stream::iter(events)
342            },
343            policy,
344        );
345
346        let mut result = Vec::new();
347        while let Some(item) = stream.next().await {
348            result.push(item);
349        }
350
351        assert_eq!(*call_count.lock().unwrap(), 3);
352        assert_eq!(result.len(), 1);
353        assert!(matches!(result[0], Err(imp_llm::Error::Stream(_))));
354    }
355
356    #[tokio::test]
357    async fn retry_skips_non_retryable_errors() {
358        use std::sync::{Arc, Mutex};
359
360        let call_count = Arc::new(Mutex::new(0u32));
361
362        let policy = RetryPolicy {
363            max_retries: 3,
364            base_delay: Duration::from_millis(1),
365            max_delay: Duration::from_millis(50),
366            retry_on: vec![RetryCondition::ServerError],
367        };
368
369        let call_count_clone = call_count.clone();
370        let mut stream = stream_with_retry(
371            move || {
372                *call_count_clone.lock().unwrap() += 1;
373                let events: Vec<imp_llm::Result<StreamEvent>> =
374                    vec![Err(imp_llm::Error::Auth("invalid key".into()))];
375                futures::stream::iter(events)
376            },
377            policy,
378        );
379
380        let mut result = Vec::new();
381        while let Some(item) = stream.next().await {
382            result.push(item);
383        }
384
385        assert_eq!(*call_count.lock().unwrap(), 1);
386        assert_eq!(result.len(), 1);
387        assert!(matches!(result[0], Err(imp_llm::Error::Auth(_))));
388    }
389
390    #[tokio::test]
391    async fn retry_no_error_passes_through() {
392        let policy = default_policy();
393
394        let mut stream = stream_with_retry(
395            || {
396                let events: Vec<imp_llm::Result<StreamEvent>> = vec![
397                    Ok(StreamEvent::MessageStart {
398                        model: "test".into(),
399                    }),
400                    Ok(StreamEvent::TextDelta { text: "ok".into() }),
401                ];
402                futures::stream::iter(events)
403            },
404            policy,
405        );
406
407        let mut result = Vec::new();
408        while let Some(item) = stream.next().await {
409            result.push(item);
410        }
411
412        assert_eq!(result.len(), 2);
413    }
414
415    #[tokio::test]
416    async fn retry_does_not_replay_after_meaningful_event_has_streamed() {
417        use std::sync::{Arc, Mutex};
418
419        let call_count = Arc::new(Mutex::new(0u32));
420        let policy = default_policy();
421        let call_count_clone = call_count.clone();
422
423        let mut stream = stream_with_retry(
424            move || {
425                *call_count_clone.lock().unwrap() += 1;
426                let events: Vec<imp_llm::Result<StreamEvent>> = vec![
427                    Ok(StreamEvent::TextDelta {
428                        text: "partial".into(),
429                    }),
430                    Err(imp_llm::Error::Stream("boom".into())),
431                ];
432                futures::stream::iter(events)
433            },
434            policy,
435        );
436
437        let mut result = Vec::new();
438        while let Some(item) = stream.next().await {
439            result.push(item);
440        }
441
442        assert_eq!(*call_count.lock().unwrap(), 1);
443        assert_eq!(result.len(), 2);
444        assert!(matches!(result[0], Ok(StreamEvent::TextDelta { .. })));
445        assert!(matches!(result[1], Err(imp_llm::Error::Stream(_))));
446    }
447}