Skip to main content

async_openai/middleware/retry/
openai.rs

1use std::{future::Future, pin::Pin, time::Duration};
2
3use reqwest::{header::HeaderMap, Response};
4
5use crate::{
6    error::{OpenAIError, WrappedError},
7    executor::HttpRequestFactory,
8};
9
10use super::log_rate_limit_headers;
11const INSUFFICIENT_QUOTA: &str = "insufficient_quota";
12
13#[cfg(not(target_family = "wasm"))]
14type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18/// Retries `429`, `5xx`, and native reqwest connect errors with exponential backoff.
19///
20/// This layer consumes response body to check if 429 is from rate limit (retryable) or insufficient quota (permanent)
21/// that's why any layer above it could receive OpenAIError as a result of parsing response body.
22///
23/// This is why [`crate::middleware::retry::SimpleRetryPolicy`] is also available which doesn't consume response body.
24#[derive(Clone)]
25pub struct OpenAIRetryLayer {
26    max_retries: usize,
27}
28
29impl std::fmt::Debug for OpenAIRetryLayer {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("OpenAIRetryLayer")
32            .field("max_retries", &self.max_retries)
33            .finish_non_exhaustive()
34    }
35}
36
37impl OpenAIRetryLayer {
38    /// Create a retry layer that allows at most `max_retries` retry attempts.
39    ///
40    /// This value is the number of additional attempts after the initial
41    /// request, not the total number of requests.
42    pub fn new(max_retries: usize) -> Self {
43        Self { max_retries }
44    }
45
46    /// Number of retry attempts configured for this layer.
47    pub fn max_retries(&self) -> usize {
48        self.max_retries
49    }
50}
51
52impl Default for OpenAIRetryLayer {
53    fn default() -> Self {
54        Self::new(3)
55    }
56}
57
58impl<S> tower::Layer<S> for OpenAIRetryLayer {
59    type Service = OpenAIRetry<S>;
60
61    fn layer(&self, inner: S) -> Self::Service {
62        OpenAIRetry {
63            inner,
64            max_retries: self.max_retries,
65        }
66    }
67}
68
69/// Tower service produced by [`OpenAIRetryLayer`].
70#[derive(Clone)]
71pub struct OpenAIRetry<S> {
72    inner: S,
73    max_retries: usize,
74}
75
76impl<S> std::fmt::Debug for OpenAIRetry<S>
77where
78    S: std::fmt::Debug,
79{
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("OpenAIRetry")
82            .field("inner", &self.inner)
83            .field("max_retries", &self.max_retries)
84            .finish_non_exhaustive()
85    }
86}
87
88#[cfg(not(target_family = "wasm"))]
89impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
90where
91    S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
92        + Clone
93        + Send
94        + 'static,
95    S::Future: Send + 'static,
96{
97    type Response = Response;
98    type Error = OpenAIError;
99    type Future = RetryFuture;
100
101    fn poll_ready(
102        &mut self,
103        cx: &mut std::task::Context<'_>,
104    ) -> std::task::Poll<Result<(), Self::Error>> {
105        self.inner.poll_ready(cx)
106    }
107
108    fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
109        let clone = self.inner.clone();
110        let mut service = std::mem::replace(&mut self.inner, clone);
111        let first_attempt = service.call(request.clone());
112        let max_retries = self.max_retries;
113
114        Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
115    }
116}
117
118#[cfg(target_family = "wasm")]
119impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
120where
121    S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
122        + Clone
123        + 'static,
124    S::Future: 'static,
125{
126    type Response = Response;
127    type Error = OpenAIError;
128    type Future = RetryFuture;
129
130    fn poll_ready(
131        &mut self,
132        cx: &mut std::task::Context<'_>,
133    ) -> std::task::Poll<Result<(), Self::Error>> {
134        self.inner.poll_ready(cx)
135    }
136
137    fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
138        let clone = self.inner.clone();
139        let mut service = std::mem::replace(&mut self.inner, clone);
140        let first_attempt = service.call(request.clone());
141        let max_retries = self.max_retries;
142
143        Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
144    }
145}
146
147async fn retry_request<S>(
148    mut service: S,
149    first_attempt: S::Future,
150    request: HttpRequestFactory,
151    max_retries: usize,
152) -> Result<Response, OpenAIError>
153where
154    S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>,
155{
156    use tower::ServiceExt;
157
158    let mut attempts = 0;
159    let mut backoff_attempt = 0;
160
161    let mut result = first_attempt.await;
162
163    loop {
164        // In this match satatement return early if the error is not retryable.
165        let (final_result, headers, retry_after) = match result {
166            Ok(response) if response.status().is_success() => return Ok(response),
167            Ok(response) if response.status().as_u16() == 429 => {
168                let headers = response.headers().clone();
169                let retry_after = retry_after(&headers);
170                let bytes = match response.bytes().await {
171                    Ok(bytes) => bytes,
172                    Err(error) => return Err(OpenAIError::Reqwest(error)),
173                };
174
175                let error = match serde_json::from_slice::<WrappedError>(&bytes) {
176                    Ok(wrapped_error) => {
177                        // 429 and insufficient_quota are treated as permanent error.
178                        // https://developers.openai.com/api/docs/guides/error-codes
179                        if wrapped_error.error.r#type.as_deref() == Some(INSUFFICIENT_QUOTA) {
180                            return Err(OpenAIError::ApiError(wrapped_error.error));
181                        }
182
183                        OpenAIError::ApiError(wrapped_error.error)
184                    }
185                    Err(error) => {
186                        return Err(OpenAIError::JSONDeserialize(
187                            error,
188                            String::from_utf8_lossy(&bytes).into_owned(),
189                        ));
190                    }
191                };
192
193                (Err(error), Some(headers), retry_after)
194            }
195            Ok(response) if response.status().is_server_error() => {
196                let retry_after = retry_after(response.headers());
197                (Ok(response), None, retry_after)
198            }
199            Ok(response) => return Ok(response),
200            Err(error) if is_connection_error(&error) => (Err(error), None, None),
201            Err(error) => return Err(error),
202        };
203
204        if attempts >= max_retries {
205            return final_result;
206        }
207
208        if let Some(headers) = headers.as_ref() {
209            log_rate_limit_headers(headers);
210        }
211
212        let delay = retry_after.unwrap_or_else(|| {
213            let delay =
214                Duration::from_millis(100).saturating_mul(2_u32.saturating_pow(backoff_attempt));
215            backoff_attempt = backoff_attempt.saturating_add(1);
216            delay.min(Duration::from_secs(8))
217        });
218
219        attempts += 1;
220
221        // on wasm there is no standard sleep so we retry immediately
222        #[cfg(not(target_family = "wasm"))]
223        tokio::time::sleep(delay).await;
224        #[cfg(target_family = "wasm")]
225        let _ = delay;
226
227        // The service moved into this future was already made ready before the
228        // first call. For retries we must poll readiness again before each
229        // additional call, matching tower::retry's service contract.
230        result = service.ready().await?.call(request.clone()).await;
231    }
232}
233
234fn is_connection_error(error: &OpenAIError) -> bool {
235    match error {
236        #[cfg(not(target_family = "wasm"))]
237        OpenAIError::Reqwest(error) => error.is_connect(),
238        #[cfg(target_family = "wasm")]
239        OpenAIError::Reqwest(_) => false,
240        _ => false,
241    }
242}
243
244fn retry_after(headers: &HeaderMap) -> Option<Duration> {
245    headers
246        .get(reqwest::header::RETRY_AFTER)
247        .and_then(|value| value.to_str().ok())
248        .and_then(|value| value.parse::<u64>().ok())
249        .map(Duration::from_secs)
250}