1use crate::{
2 AnthropicRequest, AnthropicResponse, AnthropicStreamEvent, AudioSpeechRequest, BoxStream,
3 ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest,
4 EmbeddingResponse, Error, GeminiRequest, GeminiResponse, ImageRequest, MultipartField,
5 Provider,
6};
7use rand::Rng;
8use std::{future::Future, time::Duration};
9
10const DEFAULT_MAX_RETRIES: u32 = 2;
11const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
12const DEFAULT_MAX_RETRY_AFTER: Duration = Duration::from_secs(60);
13const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
14
15#[derive(Debug, Clone)]
22pub struct Retrying<P: Provider> {
23 inner: P,
24 max_retries: u32,
25 timeout: Duration,
26 max_retry_after: Duration,
27}
28
29impl<P: Provider> Retrying<P> {
30 pub fn new(inner: P) -> Self {
33 Self {
34 inner,
35 max_retries: DEFAULT_MAX_RETRIES,
36 timeout: DEFAULT_TIMEOUT,
37 max_retry_after: DEFAULT_MAX_RETRY_AFTER,
38 }
39 }
40
41 pub fn max_retry_after(mut self, d: Duration) -> Self {
44 self.max_retry_after = d;
45 self
46 }
47
48 fn should_retry(&self, e: &Error) -> bool {
51 if !e.is_transient() {
52 return false;
53 }
54 !matches!(e.retry_after(), Some(ra) if ra > self.max_retry_after)
55 }
56
57 async fn timed<T>(
58 &self,
59 fut: impl Future<Output = Result<T, Error>> + Send,
60 ) -> Result<T, Error> {
61 if self.timeout.is_zero() {
62 return fut.await;
63 }
64 let Ok(result) = tokio::time::timeout(self.timeout, fut).await else {
65 return Err(Error::Timeout);
66 };
67 result
68 }
69}
70
71impl<P: Provider> Provider for Retrying<P> {
72 async fn chat_completion(
73 &self,
74 request: &ChatCompletionRequest,
75 ) -> Result<ChatCompletionResponse, Error> {
76 let mut backoff = INITIAL_BACKOFF;
77 let mut last_err = None;
78 for _ in 0..=self.max_retries {
79 match self.timed(self.inner.chat_completion(request)).await {
80 Ok(resp) => return Ok(resp),
81 Err(e) if self.should_retry(&e) => {
82 let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
83 last_err = Some(e);
84 tokio::time::sleep(sleep).await;
85 backoff *= 2;
86 }
87 Err(e) => return Err(e),
88 }
89 }
90 Err(last_err.expect("retry loop exited without producing an error"))
91 }
92
93 async fn chat_completion_stream(
94 &self,
95 request: &ChatCompletionRequest,
96 ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
97 let mut backoff = INITIAL_BACKOFF;
98 let mut last_err = None;
99 for _ in 0..=self.max_retries {
100 match self.timed(self.inner.chat_completion_stream(request)).await {
101 Ok(stream) => return Ok(stream),
102 Err(e) if self.should_retry(&e) => {
103 let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
104 last_err = Some(e);
105 tokio::time::sleep(sleep).await;
106 backoff *= 2;
107 }
108 Err(e) => return Err(e),
109 }
110 }
111 Err(last_err.expect("retry loop exited without producing an error"))
112 }
113
114 async fn anthropic_messages(
115 &self,
116 request: &AnthropicRequest,
117 ) -> Result<AnthropicResponse, Error> {
118 let mut backoff = INITIAL_BACKOFF;
119 let mut last_err = None;
120 for _ in 0..=self.max_retries {
121 match self.timed(self.inner.anthropic_messages(request)).await {
122 Ok(resp) => return Ok(resp),
123 Err(e) if self.should_retry(&e) => {
124 let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
125 last_err = Some(e);
126 tokio::time::sleep(sleep).await;
127 backoff *= 2;
128 }
129 Err(e) => return Err(e),
130 }
131 }
132 Err(last_err.expect("retry loop exited without producing an error"))
133 }
134
135 async fn anthropic_messages_stream(
136 &self,
137 request: &AnthropicRequest,
138 ) -> Result<BoxStream<'static, Result<AnthropicStreamEvent, Error>>, Error> {
139 let mut backoff = INITIAL_BACKOFF;
140 let mut last_err = None;
141 for _ in 0..=self.max_retries {
142 match self
143 .timed(self.inner.anthropic_messages_stream(request))
144 .await
145 {
146 Ok(stream) => return Ok(stream),
147 Err(e) if self.should_retry(&e) => {
148 let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
149 last_err = Some(e);
150 tokio::time::sleep(sleep).await;
151 backoff *= 2;
152 }
153 Err(e) => return Err(e),
154 }
155 }
156 Err(last_err.expect("retry loop exited without producing an error"))
157 }
158
159 async fn gemini_generate_content_stream(
160 &self,
161 model: &str,
162 request: &GeminiRequest,
163 ) -> Result<BoxStream<'static, Result<GeminiResponse, Error>>, Error> {
164 let mut backoff = INITIAL_BACKOFF;
165 let mut last_err = None;
166 for _ in 0..=self.max_retries {
167 match self
168 .timed(self.inner.gemini_generate_content_stream(model, request))
169 .await
170 {
171 Ok(stream) => return Ok(stream),
172 Err(e) if self.should_retry(&e) => {
173 let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
174 last_err = Some(e);
175 tokio::time::sleep(sleep).await;
176 backoff *= 2;
177 }
178 Err(e) => return Err(e),
179 }
180 }
181 Err(last_err.expect("retry loop exited without producing an error"))
182 }
183
184 async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
185 self.inner.embedding(request).await
186 }
187
188 async fn image_generation(
189 &self,
190 request: &ImageRequest,
191 ) -> Result<(bytes::Bytes, String), Error> {
192 self.inner.image_generation(request).await
193 }
194
195 async fn audio_speech(
196 &self,
197 request: &AudioSpeechRequest,
198 ) -> Result<(bytes::Bytes, String), Error> {
199 self.inner.audio_speech(request).await
200 }
201
202 async fn audio_transcription(
203 &self,
204 model: &str,
205 fields: &[MultipartField],
206 ) -> Result<(bytes::Bytes, String), Error> {
207 self.inner.audio_transcription(model, fields).await
208 }
209}
210
211fn jittered(backoff: Duration) -> Duration {
213 let lo = backoff.as_millis() as u64 / 2;
214 let hi = backoff.as_millis() as u64;
215 if lo >= hi {
216 return backoff;
217 }
218 Duration::from_millis(rand::rng().random_range(lo..=hi))
219}