Skip to main content

async_openai/middleware/retry/
mod.rs

1mod openai;
2
3pub use openai::{OpenAIRetry, OpenAIRetryLayer};
4
5use std::{future::Future, pin::Pin};
6
7use reqwest::{header::HeaderMap, Response};
8
9use crate::{error::OpenAIError, executor::HttpRequestFactory};
10
11/// Header containing the maximum request count permitted before rate-limit exhaustion.
12pub const X_RATELIMIT_LIMIT_REQUESTS: &str = "x-ratelimit-limit-requests";
13/// Header containing the maximum token count permitted before rate-limit exhaustion.
14pub const X_RATELIMIT_LIMIT_TOKENS: &str = "x-ratelimit-limit-tokens";
15/// Header containing the remaining request count before rate-limit exhaustion.
16pub const X_RATELIMIT_REMAINING_REQUESTS: &str = "x-ratelimit-remaining-requests";
17/// Header containing the remaining token count before rate-limit exhaustion.
18pub const X_RATELIMIT_REMAINING_TOKENS: &str = "x-ratelimit-remaining-tokens";
19/// Header containing the duration until the request-count rate limit resets.
20pub const X_RATELIMIT_RESET_REQUESTS: &str = "x-ratelimit-reset-requests";
21/// Header containing the duration until the token-count rate limit resets.
22pub const X_RATELIMIT_RESET_TOKENS: &str = "x-ratelimit-reset-tokens";
23
24const RATE_LIMIT_HEADERS: [&str; 6] = [
25    X_RATELIMIT_LIMIT_REQUESTS,
26    X_RATELIMIT_LIMIT_TOKENS,
27    X_RATELIMIT_REMAINING_REQUESTS,
28    X_RATELIMIT_REMAINING_TOKENS,
29    X_RATELIMIT_RESET_REQUESTS,
30    X_RATELIMIT_RESET_TOKENS,
31];
32
33fn log_rate_limit_headers(headers: &HeaderMap) {
34    for header in RATE_LIMIT_HEADERS {
35        if let Some(value) = headers.get(header).and_then(|value| value.to_str().ok()) {
36            tracing::warn!("rate-limit: {header} = {value}");
37        }
38    }
39    // Also log the Retry-After header if present
40    if let Some(value) = headers
41        .get(reqwest::header::RETRY_AFTER)
42        .and_then(|value| value.to_str().ok())
43    {
44        tracing::warn!("retry-after={value}");
45    }
46}
47
48/// Return whether [SimpleRetryPolicy] should retry this result.
49///
50/// It retries only:
51///
52/// - HTTP `429 Too Many Requests`, because the server explicitly rate limited
53///   the request.
54/// - HTTP `5xx` server errors, because the server did not successfully process
55///   the request.
56/// - Native reqwest connect errors
57#[allow(unused_variables)]
58pub fn should_retry(result: &Result<Response, OpenAIError>) -> bool {
59    match result {
60        Ok(response) => response.status().as_u16() == 429 || response.status().is_server_error(),
61        #[cfg(not(target_family = "wasm"))]
62        Err(OpenAIError::Reqwest(error)) => error.is_connect(),
63        #[cfg(target_family = "wasm")]
64        Err(OpenAIError::Reqwest(_)) => false,
65        _ => false,
66    }
67}
68
69/// Simple [`tower::retry::Policy`] for OpenAI compatible APIs.
70///
71/// `SimpleRetryPolicy` retries rate limits, server errors, and native connect
72/// errors. It can be used directly with [`tower::ServiceBuilder::retry`]
73/// around [`crate::middleware::ReqwestService`] or any compatible tower service
74/// whose request type is [`crate::middleware::HttpRequestFactory`].
75///
76/// The default policy allows three retry attempts.
77#[derive(Clone, Debug)]
78pub struct SimpleRetryPolicy {
79    max_retries: usize,
80    attempts: usize,
81    backoff_attempt: u32,
82}
83
84impl SimpleRetryPolicy {
85    /// Create a policy that allows at most `max_retries` retry attempts.
86    ///
87    /// This value is the number of additional attempts after the initial
88    /// request, not the total number of requests.
89    pub fn new(max_retries: usize) -> Self {
90        Self {
91            max_retries,
92            attempts: 0,
93            backoff_attempt: 0,
94        }
95    }
96
97    /// Number of retry attempts configured for this policy.
98    pub fn max_retries(&self) -> usize {
99        self.max_retries
100    }
101
102    /// Number of retry attempts already consumed by this policy instance.
103    pub fn attempts(&self) -> usize {
104        self.attempts
105    }
106}
107
108impl Default for SimpleRetryPolicy {
109    fn default() -> Self {
110        Self::new(3)
111    }
112}
113
114impl tower::retry::Policy<HttpRequestFactory, Response, OpenAIError> for SimpleRetryPolicy {
115    #[cfg(not(target_family = "wasm"))]
116    type Future = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
117    #[cfg(target_family = "wasm")]
118    type Future = Pin<Box<dyn Future<Output = ()> + 'static>>;
119
120    fn retry(
121        &mut self,
122        _req: &mut HttpRequestFactory,
123        result: &mut Result<Response, OpenAIError>,
124    ) -> Option<Self::Future> {
125        if self.attempts >= self.max_retries || !should_retry(result) {
126            return None;
127        }
128
129        if let Ok(response) = result.as_ref() {
130            log_rate_limit_headers(response.headers());
131        }
132
133        let retry_after = result
134            .as_ref()
135            .ok()
136            .and_then(|response| response.headers().get(reqwest::header::RETRY_AFTER))
137            .and_then(|value| value.to_str().ok())
138            .and_then(|value| value.parse::<u64>().ok())
139            .map(std::time::Duration::from_secs);
140
141        let delay = retry_after.unwrap_or_else(|| {
142            let delay = std::time::Duration::from_millis(100)
143                .saturating_mul(2_u32.saturating_pow(self.backoff_attempt));
144            self.backoff_attempt = self.backoff_attempt.saturating_add(1);
145            delay.min(std::time::Duration::from_secs(8))
146        });
147
148        self.attempts += 1;
149
150        #[cfg(target_family = "wasm")]
151        {
152            let _ = delay;
153            return Some(Box::pin(std::future::ready(())));
154        }
155
156        #[cfg(not(target_family = "wasm"))]
157        Some(Box::pin(tokio::time::sleep(delay)))
158    }
159
160    fn clone_request(&mut self, req: &HttpRequestFactory) -> Option<HttpRequestFactory> {
161        Some(req.clone())
162    }
163}