async_openai/middleware/retry/
openai.rs1use std::{future::Future, pin::Pin, time::Duration};
2
3use reqwest::{header::HeaderMap, Response};
4
5use crate::{
6 error::{OpenAIError, WrappedError},
7 executor::HttpRequestFactory,
8};
9
10use super::log_rate_limit_headers;
11const INSUFFICIENT_QUOTA: &str = "insufficient_quota";
12
13#[cfg(not(target_family = "wasm"))]
14type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[derive(Clone)]
25pub struct OpenAIRetryLayer {
26 max_retries: usize,
27}
28
29impl std::fmt::Debug for OpenAIRetryLayer {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("OpenAIRetryLayer")
32 .field("max_retries", &self.max_retries)
33 .finish_non_exhaustive()
34 }
35}
36
37impl OpenAIRetryLayer {
38 pub fn new(max_retries: usize) -> Self {
43 Self { max_retries }
44 }
45
46 pub fn max_retries(&self) -> usize {
48 self.max_retries
49 }
50}
51
52impl Default for OpenAIRetryLayer {
53 fn default() -> Self {
54 Self::new(3)
55 }
56}
57
58impl<S> tower::Layer<S> for OpenAIRetryLayer {
59 type Service = OpenAIRetry<S>;
60
61 fn layer(&self, inner: S) -> Self::Service {
62 OpenAIRetry {
63 inner,
64 max_retries: self.max_retries,
65 }
66 }
67}
68
69#[derive(Clone)]
71pub struct OpenAIRetry<S> {
72 inner: S,
73 max_retries: usize,
74}
75
76impl<S> std::fmt::Debug for OpenAIRetry<S>
77where
78 S: std::fmt::Debug,
79{
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("OpenAIRetry")
82 .field("inner", &self.inner)
83 .field("max_retries", &self.max_retries)
84 .finish_non_exhaustive()
85 }
86}
87
88#[cfg(not(target_family = "wasm"))]
89impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
90where
91 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
92 + Clone
93 + Send
94 + 'static,
95 S::Future: Send + 'static,
96{
97 type Response = Response;
98 type Error = OpenAIError;
99 type Future = RetryFuture;
100
101 fn poll_ready(
102 &mut self,
103 cx: &mut std::task::Context<'_>,
104 ) -> std::task::Poll<Result<(), Self::Error>> {
105 self.inner.poll_ready(cx)
106 }
107
108 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
109 let clone = self.inner.clone();
110 let mut service = std::mem::replace(&mut self.inner, clone);
111 let first_attempt = service.call(request.clone());
112 let max_retries = self.max_retries;
113
114 Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
115 }
116}
117
118#[cfg(target_family = "wasm")]
119impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
120where
121 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
122 + Clone
123 + 'static,
124 S::Future: 'static,
125{
126 type Response = Response;
127 type Error = OpenAIError;
128 type Future = RetryFuture;
129
130 fn poll_ready(
131 &mut self,
132 cx: &mut std::task::Context<'_>,
133 ) -> std::task::Poll<Result<(), Self::Error>> {
134 self.inner.poll_ready(cx)
135 }
136
137 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
138 let clone = self.inner.clone();
139 let mut service = std::mem::replace(&mut self.inner, clone);
140 let first_attempt = service.call(request.clone());
141 let max_retries = self.max_retries;
142
143 Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
144 }
145}
146
147async fn retry_request<S>(
148 mut service: S,
149 first_attempt: S::Future,
150 request: HttpRequestFactory,
151 max_retries: usize,
152) -> Result<Response, OpenAIError>
153where
154 S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>,
155{
156 use tower::ServiceExt;
157
158 let mut attempts = 0;
159 let mut backoff_attempt = 0;
160
161 let mut result = first_attempt.await;
162
163 loop {
164 let (final_result, headers, retry_after) = match result {
166 Ok(response) if response.status().is_success() => return Ok(response),
167 Ok(response) if response.status().as_u16() == 429 => {
168 let headers = response.headers().clone();
169 let retry_after = retry_after(&headers);
170 let bytes = match response.bytes().await {
171 Ok(bytes) => bytes,
172 Err(error) => return Err(OpenAIError::Reqwest(error)),
173 };
174
175 let error = match serde_json::from_slice::<WrappedError>(&bytes) {
176 Ok(wrapped_error) => {
177 if wrapped_error.error.r#type.as_deref() == Some(INSUFFICIENT_QUOTA) {
180 return Err(OpenAIError::ApiError(wrapped_error.error));
181 }
182
183 OpenAIError::ApiError(wrapped_error.error)
184 }
185 Err(error) => {
186 return Err(OpenAIError::JSONDeserialize(
187 error,
188 String::from_utf8_lossy(&bytes).into_owned(),
189 ));
190 }
191 };
192
193 (Err(error), Some(headers), retry_after)
194 }
195 Ok(response) if response.status().is_server_error() => {
196 let retry_after = retry_after(response.headers());
197 (Ok(response), None, retry_after)
198 }
199 Ok(response) => return Ok(response),
200 Err(error) if is_connection_error(&error) => (Err(error), None, None),
201 Err(error) => return Err(error),
202 };
203
204 if attempts >= max_retries {
205 return final_result;
206 }
207
208 if let Some(headers) = headers.as_ref() {
209 log_rate_limit_headers(headers);
210 }
211
212 let delay = retry_after.unwrap_or_else(|| {
213 let delay =
214 Duration::from_millis(100).saturating_mul(2_u32.saturating_pow(backoff_attempt));
215 backoff_attempt = backoff_attempt.saturating_add(1);
216 delay.min(Duration::from_secs(8))
217 });
218
219 attempts += 1;
220
221 #[cfg(not(target_family = "wasm"))]
223 tokio::time::sleep(delay).await;
224 #[cfg(target_family = "wasm")]
225 let _ = delay;
226
227 result = service.ready().await?.call(request.clone()).await;
231 }
232}
233
234fn is_connection_error(error: &OpenAIError) -> bool {
235 match error {
236 #[cfg(not(target_family = "wasm"))]
237 OpenAIError::Reqwest(error) => error.is_connect(),
238 #[cfg(target_family = "wasm")]
239 OpenAIError::Reqwest(_) => false,
240 _ => false,
241 }
242}
243
244fn retry_after(headers: &HeaderMap) -> Option<Duration> {
245 headers
246 .get(reqwest::header::RETRY_AFTER)
247 .and_then(|value| value.to_str().ok())
248 .and_then(|value| value.parse::<u64>().ok())
249 .map(Duration::from_secs)
250}