1use std::pin::Pin;
2use std::time::Duration;
3
4use futures::StreamExt;
5use imp_llm::{provider::RetryPolicy, StreamEvent};
6
7pub fn is_retryable(err: &imp_llm::Error) -> bool {
12 match err {
13 imp_llm::Error::RateLimited { .. } => true,
15 imp_llm::Error::Http(e) => {
21 e.is_connect() || e.is_timeout() || e.is_request() || e.is_decode() || e.is_body()
22 }
23 imp_llm::Error::Stream(_) => true,
25 imp_llm::Error::Provider(msg) => {
27 msg.contains("HTTP 500")
28 || msg.contains("HTTP 502")
29 || msg.contains("HTTP 503")
30 || msg.contains("HTTP 529")
31 }
32 imp_llm::Error::Auth(_) => false,
34 imp_llm::Error::Serialization(_)
36 | imp_llm::Error::Io(_)
37 | imp_llm::Error::ContextTooLong { .. } => false,
38 }
39}
40
41pub fn backoff_delay(
47 attempt: u32,
48 policy: &RetryPolicy,
49 retry_after_secs: Option<u64>,
50) -> Option<Duration> {
51 if let Some(secs) = retry_after_secs {
54 let suggested = Duration::from_secs(secs);
55 if suggested > policy.max_delay {
56 return None; }
58 return Some(suggested);
59 }
60
61 let base_ms = policy.base_delay.as_millis() as u64;
63 let exp_ms = base_ms.saturating_mul(1u64 << attempt.min(10));
64 let capped_ms = exp_ms.min(policy.max_delay.as_millis() as u64);
65
66 let seed = std::time::SystemTime::now()
69 .duration_since(std::time::UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_nanos() as u64
72 ^ (attempt as u64).wrapping_mul(0x517cc1b727220a95);
73 let jitter_ms = seed % (capped_ms / 2 + 1);
74
75 Some(Duration::from_millis(capped_ms + jitter_ms))
76}
77
78pub fn stream_with_retry<F, S>(
89 mut make_stream: F,
90 policy: RetryPolicy,
91) -> Pin<Box<dyn futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Send>>
92where
93 F: FnMut() -> S + Send + 'static,
94 S: futures_core::Stream<Item = imp_llm::Result<StreamEvent>> + Unpin + Send + 'static,
95{
96 let (tx, rx) = futures::channel::mpsc::unbounded();
97
98 tokio::spawn(async move {
99 let mut attempt = 0u32;
100
101 'attempt: loop {
102 let mut stream = make_stream();
103 let mut buffered_starts: Vec<StreamEvent> = Vec::new();
104 let mut emitted_meaningful_event = false;
105
106 while let Some(item) = stream.next().await {
107 match item {
108 Ok(event) => {
109 if !emitted_meaningful_event
110 && matches!(event, StreamEvent::MessageStart { .. })
111 {
112 buffered_starts.push(event);
113 continue;
114 }
115
116 if !emitted_meaningful_event {
117 emitted_meaningful_event = true;
118 for buffered in buffered_starts.drain(..) {
119 if tx.unbounded_send(Ok(buffered)).is_err() {
120 return;
121 }
122 }
123 }
124
125 if tx.unbounded_send(Ok(event)).is_err() {
126 return;
127 }
128 }
129 Err(err) => {
130 let retry_after =
131 if let imp_llm::Error::RateLimited { retry_after_secs } = &err {
132 *retry_after_secs
133 } else {
134 None
135 };
136
137 if !emitted_meaningful_event
138 && is_retryable(&err)
139 && attempt < policy.max_retries
140 {
141 match backoff_delay(attempt, &policy, retry_after) {
142 None => {
143 let _ = tx.unbounded_send(Err(err));
144 return;
145 }
146 Some(delay) => {
147 tokio::time::sleep(delay).await;
148 attempt += 1;
149 continue 'attempt;
150 }
151 }
152 }
153
154 let _ = tx.unbounded_send(Err(err));
155 return;
156 }
157 }
158 }
159
160 for buffered in buffered_starts {
161 if tx.unbounded_send(Ok(buffered)).is_err() {
162 return;
163 }
164 }
165
166 return;
167 }
168 });
169
170 Box::pin(rx)
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use imp_llm::provider::RetryCondition;
177
178 fn default_policy() -> RetryPolicy {
179 RetryPolicy {
180 max_retries: 3,
181 base_delay: Duration::from_millis(10), max_delay: Duration::from_millis(100),
183 retry_on: vec![
184 RetryCondition::RateLimit,
185 RetryCondition::ServerError,
186 RetryCondition::Timeout,
187 RetryCondition::ConnectionError,
188 ],
189 }
190 }
191
192 #[test]
195 fn rate_limited_is_retryable() {
196 let err = imp_llm::Error::RateLimited {
197 retry_after_secs: Some(5),
198 };
199 assert!(is_retryable(&err));
200 }
201
202 #[test]
203 fn stream_error_is_retryable() {
204 let err = imp_llm::Error::Stream("connection reset".into());
205 assert!(is_retryable(&err));
206 }
207
208 #[tokio::test]
209 async fn http_decode_error_is_retryable() {
210 use tokio::io::{AsyncReadExt, AsyncWriteExt};
211
212 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
213 let addr = listener.local_addr().unwrap();
214
215 tokio::spawn(async move {
216 let (mut socket, _) = listener.accept().await.unwrap();
217 let mut request_buf = [0u8; 1024];
218 let _ = socket.read(&mut request_buf).await;
219 socket
220 .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 999\r\n\r\nnot-g")
221 .await
222 .unwrap();
223 });
224
225 let err = reqwest::get(format!("http://{addr}"))
226 .await
227 .unwrap()
228 .bytes()
229 .await
230 .unwrap_err();
231
232 assert!(err.is_decode() || err.is_body());
233 assert!(is_retryable(&imp_llm::Error::Http(err)));
234 }
235
236 #[test]
237 fn auth_error_is_not_retryable() {
238 let err = imp_llm::Error::Auth("invalid key".into());
239 assert!(!is_retryable(&err));
240 }
241
242 #[test]
243 fn provider_5xx_is_retryable() {
244 let err = imp_llm::Error::Provider("HTTP 503: overloaded".into());
245 assert!(is_retryable(&err));
246 }
247
248 #[test]
249 fn provider_4xx_is_not_retryable() {
250 let err = imp_llm::Error::Provider("HTTP 400: bad request".into());
251 assert!(!is_retryable(&err));
252 }
253
254 #[test]
255 fn provider_401_is_not_retryable() {
256 let err = imp_llm::Error::Provider("HTTP 401: unauthorized".into());
257 assert!(!is_retryable(&err));
258 }
259
260 #[test]
263 fn backoff_grows_exponentially() {
264 let policy = default_policy();
265 let d0 = backoff_delay(0, &policy, None).unwrap();
266 let d1 = backoff_delay(1, &policy, None).unwrap();
267 let d2 = backoff_delay(2, &policy, None).unwrap();
268 assert!(d0 <= Duration::from_millis(200)); assert!(d1 >= Duration::from_millis(20));
274 assert!(d2 >= Duration::from_millis(40));
275 }
276
277 #[test]
278 fn backoff_capped_at_max_delay() {
279 let policy = default_policy(); let delay = backoff_delay(10, &policy, None).unwrap();
282 assert!(delay <= Duration::from_millis(200)); }
284
285 #[test]
286 fn retry_after_respected_within_limit() {
287 let policy = default_policy(); let delay = backoff_delay(0, &policy, Some(0)).unwrap();
289 assert_eq!(delay, Duration::from_secs(0));
290 }
291
292 #[test]
293 fn retry_after_exceeds_max_delay_returns_none() {
294 let policy = default_policy(); let result = backoff_delay(0, &policy, Some(10)); assert!(result.is_none());
297 }
298
299 #[tokio::test]
302 async fn retry_succeeds_after_transient_failures_before_first_meaningful_event() {
303 use std::sync::{Arc, Mutex};
304
305 let call_count = Arc::new(Mutex::new(0u32));
306
307 let policy = RetryPolicy {
308 max_retries: 3,
309 base_delay: Duration::from_millis(1),
310 max_delay: Duration::from_millis(50),
311 retry_on: vec![RetryCondition::ServerError],
312 };
313
314 let call_count_clone = call_count.clone();
315 let mut stream = stream_with_retry(
316 move || {
317 let mut count = call_count_clone.lock().unwrap();
318 *count += 1;
319 let attempt = *count;
320 drop(count);
321
322 if attempt < 3 {
323 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
324 Ok(StreamEvent::MessageStart {
325 model: "test".into(),
326 }),
327 Err(imp_llm::Error::Stream("transient".into())),
328 ];
329 futures::stream::iter(events)
330 } else {
331 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
332 Ok(StreamEvent::MessageStart {
333 model: "test".into(),
334 }),
335 Ok(StreamEvent::TextDelta {
336 text: "hello".into(),
337 }),
338 ];
339 futures::stream::iter(events)
340 }
341 },
342 policy,
343 );
344
345 let mut result = Vec::new();
346 while let Some(item) = stream.next().await {
347 result.push(item);
348 }
349
350 assert_eq!(*call_count.lock().unwrap(), 3);
351 assert_eq!(result.len(), 2);
352 assert!(matches!(result[0], Ok(StreamEvent::MessageStart { .. })));
353 assert!(matches!(result[1], Ok(StreamEvent::TextDelta { .. })));
354 }
355
356 #[tokio::test]
357 async fn retry_exhausts_max_retries_before_first_meaningful_event() {
358 use std::sync::{Arc, Mutex};
359
360 let call_count = Arc::new(Mutex::new(0u32));
361
362 let policy = RetryPolicy {
363 max_retries: 2,
364 base_delay: Duration::from_millis(1),
365 max_delay: Duration::from_millis(50),
366 retry_on: vec![RetryCondition::ServerError],
367 };
368
369 let call_count_clone = call_count.clone();
370 let mut stream = stream_with_retry(
371 move || {
372 *call_count_clone.lock().unwrap() += 1;
373 let events: Vec<imp_llm::Result<StreamEvent>> =
374 vec![Err(imp_llm::Error::Stream("always fails".into()))];
375 futures::stream::iter(events)
376 },
377 policy,
378 );
379
380 let mut result = Vec::new();
381 while let Some(item) = stream.next().await {
382 result.push(item);
383 }
384
385 assert_eq!(*call_count.lock().unwrap(), 3);
386 assert_eq!(result.len(), 1);
387 assert!(matches!(result[0], Err(imp_llm::Error::Stream(_))));
388 }
389
390 #[tokio::test]
391 async fn retry_skips_non_retryable_errors() {
392 use std::sync::{Arc, Mutex};
393
394 let call_count = Arc::new(Mutex::new(0u32));
395
396 let policy = RetryPolicy {
397 max_retries: 3,
398 base_delay: Duration::from_millis(1),
399 max_delay: Duration::from_millis(50),
400 retry_on: vec![RetryCondition::ServerError],
401 };
402
403 let call_count_clone = call_count.clone();
404 let mut stream = stream_with_retry(
405 move || {
406 *call_count_clone.lock().unwrap() += 1;
407 let events: Vec<imp_llm::Result<StreamEvent>> =
408 vec![Err(imp_llm::Error::Auth("invalid key".into()))];
409 futures::stream::iter(events)
410 },
411 policy,
412 );
413
414 let mut result = Vec::new();
415 while let Some(item) = stream.next().await {
416 result.push(item);
417 }
418
419 assert_eq!(*call_count.lock().unwrap(), 1);
420 assert_eq!(result.len(), 1);
421 assert!(matches!(result[0], Err(imp_llm::Error::Auth(_))));
422 }
423
424 #[tokio::test]
425 async fn retry_no_error_passes_through() {
426 let policy = default_policy();
427
428 let mut stream = stream_with_retry(
429 || {
430 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
431 Ok(StreamEvent::MessageStart {
432 model: "test".into(),
433 }),
434 Ok(StreamEvent::TextDelta { text: "ok".into() }),
435 ];
436 futures::stream::iter(events)
437 },
438 policy,
439 );
440
441 let mut result = Vec::new();
442 while let Some(item) = stream.next().await {
443 result.push(item);
444 }
445
446 assert_eq!(result.len(), 2);
447 }
448
449 #[tokio::test]
450 async fn retry_does_not_replay_after_meaningful_event_has_streamed() {
451 use std::sync::{Arc, Mutex};
452
453 let call_count = Arc::new(Mutex::new(0u32));
454 let policy = default_policy();
455 let call_count_clone = call_count.clone();
456
457 let mut stream = stream_with_retry(
458 move || {
459 *call_count_clone.lock().unwrap() += 1;
460 let events: Vec<imp_llm::Result<StreamEvent>> = vec![
461 Ok(StreamEvent::TextDelta {
462 text: "partial".into(),
463 }),
464 Err(imp_llm::Error::Stream("boom".into())),
465 ];
466 futures::stream::iter(events)
467 },
468 policy,
469 );
470
471 let mut result = Vec::new();
472 while let Some(item) = stream.next().await {
473 result.push(item);
474 }
475
476 assert_eq!(*call_count.lock().unwrap(), 1);
477 assert_eq!(result.len(), 2);
478 assert!(matches!(result[0], Ok(StreamEvent::TextDelta { .. })));
479 assert!(matches!(result[1], Err(imp_llm::Error::Stream(_))));
480 }
481}