async_openai/middleware/retry/
openai.rs1use std::{future::Future, pin::Pin, time::Duration};
2
3use reqwest::{header::HeaderMap, Response};
4
5use crate::{
6 error::{ApiErrorResponse, OpenAIError, WrappedError},
7 executor::HttpRequestFactory,
8};
9
10use super::log_rate_limit_headers;
11const INSUFFICIENT_QUOTA: &str = "insufficient_quota";
12
13#[cfg(not(target_family = "wasm"))]
14type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[derive(Clone)]
25pub struct OpenAIRetryLayer {
26 max_retries: usize,
27}
28
29impl std::fmt::Debug for OpenAIRetryLayer {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("OpenAIRetryLayer")
32 .field("max_retries", &self.max_retries)
33 .finish_non_exhaustive()
34 }
35}
36
37impl OpenAIRetryLayer {
38 pub fn new(max_retries: usize) -> Self {
43 Self { max_retries }
44 }
45
46 pub fn max_retries(&self) -> usize {
48 self.max_retries
49 }
50}
51
52impl Default for OpenAIRetryLayer {
53 fn default() -> Self {
54 Self::new(3)
55 }
56}
57
58impl<S> tower::Layer<S> for OpenAIRetryLayer {
59 type Service = OpenAIRetry<S>;
60
61 fn layer(&self, inner: S) -> Self::Service {
62 OpenAIRetry {
63 inner,
64 max_retries: self.max_retries,
65 }
66 }
67}
68
69#[derive(Clone)]
71pub struct OpenAIRetry<S> {
72 inner: S,
73 max_retries: usize,
74}
75
76impl<S> std::fmt::Debug for OpenAIRetry<S>
77where
78 S: std::fmt::Debug,
79{
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("OpenAIRetry")
82 .field("inner", &self.inner)
83 .field("max_retries", &self.max_retries)
84 .finish_non_exhaustive()
85 }
86}
87
88#[cfg(not(target_family = "wasm"))]
89impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
90where
91 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
92 + Clone
93 + Send
94 + 'static,
95 S::Future: Send + 'static,
96{
97 type Response = Response;
98 type Error = OpenAIError;
99 type Future = RetryFuture;
100
101 fn poll_ready(
102 &mut self,
103 cx: &mut std::task::Context<'_>,
104 ) -> std::task::Poll<Result<(), Self::Error>> {
105 self.inner.poll_ready(cx)
106 }
107
108 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
109 let clone = self.inner.clone();
110 let mut service = std::mem::replace(&mut self.inner, clone);
111 let first_attempt = service.call(request.clone());
112 let max_retries = self.max_retries;
113
114 Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
115 }
116}
117
118#[cfg(target_family = "wasm")]
119impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
120where
121 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
122 + Clone
123 + 'static,
124 S::Future: 'static,
125{
126 type Response = Response;
127 type Error = OpenAIError;
128 type Future = RetryFuture;
129
130 fn poll_ready(
131 &mut self,
132 cx: &mut std::task::Context<'_>,
133 ) -> std::task::Poll<Result<(), Self::Error>> {
134 self.inner.poll_ready(cx)
135 }
136
137 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
138 let clone = self.inner.clone();
139 let mut service = std::mem::replace(&mut self.inner, clone);
140 let first_attempt = service.call(request.clone());
141 let max_retries = self.max_retries;
142
143 Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
144 }
145}
146
147async fn retry_request<S>(
148 mut service: S,
149 first_attempt: S::Future,
150 request: HttpRequestFactory,
151 max_retries: usize,
152) -> Result<Response, OpenAIError>
153where
154 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>,
155{
156 use tower::ServiceExt;
157
158 let mut attempts = 0;
159 let mut backoff_attempt = 0;
160
161 let mut result = first_attempt.await;
162
163 loop {
164 let (final_result, headers, retry_after) = match result {
166 Ok(response) if response.status().is_success() => return Ok(response),
167 Ok(response) if response.status().as_u16() == 429 => {
168 let status_code = response.status();
169 let headers = response.headers().clone();
170 let retry_after = retry_after(&headers);
171 let bytes = match response.bytes().await {
172 Ok(bytes) => bytes,
173 Err(error) => return Err(OpenAIError::Reqwest(error)),
174 };
175
176 let error = match serde_json::from_slice::<WrappedError>(&bytes) {
177 Ok(wrapped_error) => {
178 if wrapped_error.error.r#type.as_deref() == Some(INSUFFICIENT_QUOTA) {
181 return Err(OpenAIError::ApiError(ApiErrorResponse {
182 status_code,
183 api_error: wrapped_error.error,
184 }));
185 }
186
187 OpenAIError::ApiError(ApiErrorResponse {
188 status_code,
189 api_error: wrapped_error.error,
190 })
191 }
192 Err(error) => {
193 return Err(OpenAIError::JSONDeserialize(
194 error,
195 String::from_utf8_lossy(&bytes).into_owned(),
196 ));
197 }
198 };
199
200 (Err(error), Some(headers), retry_after)
201 }
202 Ok(response) if response.status().is_server_error() => {
203 let retry_after = retry_after(response.headers());
204 (Ok(response), None, retry_after)
205 }
206 Ok(response) => return Ok(response),
207 Err(error) if is_connection_error(&error) => (Err(error), None, None),
208 Err(error) => return Err(error),
209 };
210
211 if attempts >= max_retries {
212 return final_result;
213 }
214
215 if let Some(headers) = headers.as_ref() {
216 log_rate_limit_headers(headers);
217 }
218
219 let delay = retry_after.unwrap_or_else(|| {
220 let delay =
221 Duration::from_millis(100).saturating_mul(2_u32.saturating_pow(backoff_attempt));
222 backoff_attempt = backoff_attempt.saturating_add(1);
223 delay.min(Duration::from_secs(8))
224 });
225
226 attempts += 1;
227
228 #[cfg(not(target_family = "wasm"))]
230 tokio::time::sleep(delay).await;
231 #[cfg(target_family = "wasm")]
232 let _ = delay;
233
234 result = service.ready().await?.call(request.clone()).await;
238 }
239}
240
241fn is_connection_error(error: &OpenAIError) -> bool {
242 match error {
243 #[cfg(not(target_family = "wasm"))]
244 OpenAIError::Reqwest(error) => error.is_connect(),
245 #[cfg(target_family = "wasm")]
246 OpenAIError::Reqwest(_) => false,
247 _ => false,
248 }
249}
250
251fn retry_after(headers: &HeaderMap) -> Option<Duration> {
252 headers
253 .get(reqwest::header::RETRY_AFTER)
254 .and_then(|value| value.to_str().ok())
255 .and_then(|value| value.parse::<u64>().ok())
256 .map(Duration::from_secs)
257}