1use std::pin::Pin;
2use std::time::Duration;
3
4use futures::StreamExt;
5use imp_llm::{provider::RetryPolicy, StreamEvent};
6
7pub fn is_retryable(err: &imp_llm::Error) -> bool {
12 match err {
13 imp_llm::Error::RateLimited { .. } => true,
15 imp_llm::Error::Http(e) => e.is_connect() || e.is_timeout() || e.is_request(),
17 imp_llm::Error::Stream(_) => true,
19 imp_llm::Error::Provider(msg) => {
21 msg.contains("HTTP 500")
22 || msg.contains("HTTP 502")
23 || msg.contains("HTTP 503")
24 || msg.contains("HTTP 529")
25 }
26 imp_llm::Error::Auth(_) => false,
28 imp_llm::Error::Serialization(_)
30 | imp_llm::Error::Io(_)
31 | imp_llm::Error::ContextTooLong { .. } => false,
32 }
33}
34
35pub fn backoff_delay(
41 attempt: u32,
42 policy: &RetryPolicy,
43 retry_after_secs: Option<u64>,
44) -> Option<Duration> {
45 if let Some(secs) = retry_after_secs {
48 let suggested = Duration::from_secs(secs);
49 if suggested > policy.max_delay {
50 return None; }
52 return Some(suggested);
53 }
54
55 let base_ms = policy.base_delay.as_millis() as u64;
57 let exp_ms = base_ms.saturating_mul(1u64 << attempt.min(10));
58 let capped_ms = exp_ms.min(policy.max_delay.as_millis() as u64);
59
60 let seed = std::time::SystemTime::now()
63 .duration_since(std::time::UNIX_EPOCH)
64 .unwrap_or_default()
65 .as_nanos() as u64
66 ^ (attempt as u64).wrapping_mul(0x517cc1b727220a95);
67 let jitter_ms = seed % (capped_ms / 2 + 1);
68
69 Some(Duration::from_millis(capped_ms + jitter_ms))
70}
71
72pub fn stream_with_retry<F, S>(
83 mut make_stream: F,
84 policy: RetryPolicy,
85) -> Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
86where
87 F: FnMut() -> S + Send + 'static,
88 S: futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Unpin + Send + 'static,
89{
90 let (tx, rx) = futures::channel::mpsc::unbounded();
91
92 tokio::spawn(async move {
93 let mut attempt = 0u32;
94
95 'attempt: loop {
96 let mut stream = make_stream();
97 let mut buffered_starts: Vec<StreamEvent> = Vec::new();
98 let mut emitted_meaningful_event = false;
99
100 while let Some(item) = stream.next().await {
101 match item {
102 Ok(event) => {
103 if !emitted_meaningful_event
104 && matches!(event, StreamEvent::MessageStart { .. })
105 {
106 buffered_starts.push(event);
107 continue;
108 }
109
110 if !emitted_meaningful_event {
111 emitted_meaningful_event = true;
112 for buffered in buffered_starts.drain(..) {
113 if tx.unbounded_send(Ok(buffered)).is_err() {
114 return;
115 }
116 }
117 }
118
119 if tx.unbounded_send(Ok(event)).is_err() {
120 return;
121 }
122 }
123 Err(err) => {
124 let retry_after =
125 if let imp_llm::Error::RateLimited { retry_after_secs } = &err {
126 *retry_after_secs
127 } else {
128 None
129 };
130
131 if !emitted_meaningful_event
132 && is_retryable(&err)
133 && attempt < policy.max_retries
134 {
135 match backoff_delay(attempt, &policy, retry_after) {
136 None => {
137 let _ = tx.unbounded_send(Err(err));
138 return;
139 }
140 Some(delay) => {
141 tokio::time::sleep(delay).await;
142 attempt += 1;
143 continue 'attempt;
144 }
145 }
146 }
147
148 let _ = tx.unbounded_send(Err(err));
149 return;
150 }
151 }
152 }
153
154 for buffered in buffered_starts {
155 if tx.unbounded_send(Ok(buffered)).is_err() {
156 return;
157 }
158 }
159
160 return;
161 }
162 });
163
164 Box::pin(rx)
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use imp_llm::provider::RetryCondition;
171
172 fn default_policy() -> RetryPolicy {
173 RetryPolicy {
174 max_retries: 3,
175 base_delay: Duration::from_millis(10), max_delay: Duration::from_millis(100),
177 retry_on: vec![
178 RetryCondition::RateLimit,
179 RetryCondition::ServerError,
180 RetryCondition::Timeout,
181 RetryCondition::ConnectionError,
182 ],
183 }
184 }
185
186 #[test]
189 fn rate_limited_is_retryable() {
190 let err = imp_llm::Error::RateLimited {
191 retry_after_secs: Some(5),
192 };
193 assert!(is_retryable(&err));
194 }
195
196 #[test]
197 fn stream_error_is_retryable() {
198 let err = imp_llm::Error::Stream("connection reset".into());
199 assert!(is_retryable(&err));
200 }
201
202 #[test]
203 fn auth_error_is_not_retryable() {
204 let err = imp_llm::Error::Auth("invalid key".into());
205 assert!(!is_retryable(&err));
206 }
207
208 #[test]
209 fn provider_5xx_is_retryable() {
210 let err = imp_llm::Error::Provider("HTTP 503: overloaded".into());
211 assert!(is_retryable(&err));
212 }
213
214 #[test]
215 fn provider_4xx_is_not_retryable() {
216 let err = imp_llm::Error::Provider("HTTP 400: bad request".into());
217 assert!(!is_retryable(&err));
218 }
219
220 #[test]
221 fn provider_401_is_not_retryable() {
222 let err = imp_llm::Error::Provider("HTTP 401: unauthorized".into());
223 assert!(!is_retryable(&err));
224 }
225
226 #[test]
229 fn backoff_grows_exponentially() {
230 let policy = default_policy();
231 let d0 = backoff_delay(0, &policy, None).unwrap();
232 let d1 = backoff_delay(1, &policy, None).unwrap();
233 let d2 = backoff_delay(2, &policy, None).unwrap();
234 assert!(d0 <= Duration::from_millis(200)); assert!(d1 >= Duration::from_millis(20));
240 assert!(d2 >= Duration::from_millis(40));
241 }
242
243 #[test]
244 fn backoff_capped_at_max_delay() {
245 let policy = default_policy(); let delay = backoff_delay(10, &policy, None).unwrap();
248 assert!(delay <= Duration::from_millis(200)); }
250
251 #[test]
252 fn retry_after_respected_within_limit() {
253 let policy = default_policy(); let delay = backoff_delay(0, &policy, Some(0)).unwrap();
255 assert_eq!(delay, Duration::from_secs(0));
256 }
257
258 #[test]
259 fn retry_after_exceeds_max_delay_returns_none() {
260 let policy = default_policy(); let result = backoff_delay(0, &policy, Some(10)); assert!(result.is_none());
263 }
264
265 #[tokio::test]
268 async fn retry_succeeds_after_transient_failures_before_first_meaningful_event() {
269 use std::sync::{Arc, Mutex};
270
271 let call_count = Arc::new(Mutex::new(0u32));
272
273 let policy = RetryPolicy {
274 max_retries: 3,
275 base_delay: Duration::from_millis(1),
276 max_delay: Duration::from_millis(50),
277 retry_on: vec![RetryCondition::ServerError],
278 };
279
280 let call_count_clone = call_count.clone();
281 let mut stream = stream_with_retry(
282 move || {
283 let mut count = call_count_clone.lock().unwrap();
284 *count += 1;
285 let attempt = *count;
286 drop(count);
287
288 if attempt < 3 {
289 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
290 Ok(StreamEvent::MessageStart {
291 model: "test".into(),
292 }),
293 Err(imp_llm::Error::Stream("transient".into())),
294 ];
295 futures::stream::iter(events)
296 } else {
297 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
298 Ok(StreamEvent::MessageStart {
299 model: "test".into(),
300 }),
301 Ok(StreamEvent::TextDelta {
302 text: "hello".into(),
303 }),
304 ];
305 futures::stream::iter(events)
306 }
307 },
308 policy,
309 );
310
311 let mut result = Vec::new();
312 while let Some(item) = stream.next().await {
313 result.push(item);
314 }
315
316 assert_eq!(*call_count.lock().unwrap(), 3);
317 assert_eq!(result.len(), 2);
318 assert!(matches!(result[0], Ok(StreamEvent::MessageStart { .. })));
319 assert!(matches!(result[1], Ok(StreamEvent::TextDelta { .. })));
320 }
321
322 #[tokio::test]
323 async fn retry_exhausts_max_retries_before_first_meaningful_event() {
324 use std::sync::{Arc, Mutex};
325
326 let call_count = Arc::new(Mutex::new(0u32));
327
328 let policy = RetryPolicy {
329 max_retries: 2,
330 base_delay: Duration::from_millis(1),
331 max_delay: Duration::from_millis(50),
332 retry_on: vec![RetryCondition::ServerError],
333 };
334
335 let call_count_clone = call_count.clone();
336 let mut stream = stream_with_retry(
337 move || {
338 *call_count_clone.lock().unwrap() += 1;
339 let events: Vec<imp_llm::Result<StreamEvent>> =
340 vec![Err(imp_llm::Error::Stream("always fails".into()))];
341 futures::stream::iter(events)
342 },
343 policy,
344 );
345
346 let mut result = Vec::new();
347 while let Some(item) = stream.next().await {
348 result.push(item);
349 }
350
351 assert_eq!(*call_count.lock().unwrap(), 3);
352 assert_eq!(result.len(), 1);
353 assert!(matches!(result[0], Err(imp_llm::Error::Stream(_))));
354 }
355
356 #[tokio::test]
357 async fn retry_skips_non_retryable_errors() {
358 use std::sync::{Arc, Mutex};
359
360 let call_count = Arc::new(Mutex::new(0u32));
361
362 let policy = RetryPolicy {
363 max_retries: 3,
364 base_delay: Duration::from_millis(1),
365 max_delay: Duration::from_millis(50),
366 retry_on: vec![RetryCondition::ServerError],
367 };
368
369 let call_count_clone = call_count.clone();
370 let mut stream = stream_with_retry(
371 move || {
372 *call_count_clone.lock().unwrap() += 1;
373 let events: Vec<imp_llm::Result<StreamEvent>> =
374 vec![Err(imp_llm::Error::Auth("invalid key".into()))];
375 futures::stream::iter(events)
376 },
377 policy,
378 );
379
380 let mut result = Vec::new();
381 while let Some(item) = stream.next().await {
382 result.push(item);
383 }
384
385 assert_eq!(*call_count.lock().unwrap(), 1);
386 assert_eq!(result.len(), 1);
387 assert!(matches!(result[0], Err(imp_llm::Error::Auth(_))));
388 }
389
390 #[tokio::test]
391 async fn retry_no_error_passes_through() {
392 let policy = default_policy();
393
394 let mut stream = stream_with_retry(
395 || {
396 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
397 Ok(StreamEvent::MessageStart {
398 model: "test".into(),
399 }),
400 Ok(StreamEvent::TextDelta { text: "ok".into() }),
401 ];
402 futures::stream::iter(events)
403 },
404 policy,
405 );
406
407 let mut result = Vec::new();
408 while let Some(item) = stream.next().await {
409 result.push(item);
410 }
411
412 assert_eq!(result.len(), 2);
413 }
414
415 #[tokio::test]
416 async fn retry_does_not_replay_after_meaningful_event_has_streamed() {
417 use std::sync::{Arc, Mutex};
418
419 let call_count = Arc::new(Mutex::new(0u32));
420 let policy = default_policy();
421 let call_count_clone = call_count.clone();
422
423 let mut stream = stream_with_retry(
424 move || {
425 *call_count_clone.lock().unwrap() += 1;
426 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
427 Ok(StreamEvent::TextDelta {
428 text: "partial".into(),
429 }),
430 Err(imp_llm::Error::Stream("boom".into())),
431 ];
432 futures::stream::iter(events)
433 },
434 policy,
435 );
436
437 let mut result = Vec::new();
438 while let Some(item) = stream.next().await {
439 result.push(item);
440 }
441
442 assert_eq!(*call_count.lock().unwrap(), 1);
443 assert_eq!(result.len(), 2);
444 assert!(matches!(result[0], Ok(StreamEvent::TextDelta { .. })));
445 assert!(matches!(result[1], Err(imp_llm::Error::Stream(_))));
446 }
447}