async_openai/middleware/retry/
mod.rs1mod openai;
2
3pub use openai::{OpenAIRetry, OpenAIRetryLayer};
4
5use std::{future::Future, pin::Pin};
6
7use reqwest::{header::HeaderMap, Response};
8
9use crate::{error::OpenAIError, executor::HttpRequestFactory};
10
11pub const X_RATELIMIT_LIMIT_REQUESTS: &str = "x-ratelimit-limit-requests";
13pub const X_RATELIMIT_LIMIT_TOKENS: &str = "x-ratelimit-limit-tokens";
15pub const X_RATELIMIT_REMAINING_REQUESTS: &str = "x-ratelimit-remaining-requests";
17pub const X_RATELIMIT_REMAINING_TOKENS: &str = "x-ratelimit-remaining-tokens";
19pub const X_RATELIMIT_RESET_REQUESTS: &str = "x-ratelimit-reset-requests";
21pub const X_RATELIMIT_RESET_TOKENS: &str = "x-ratelimit-reset-tokens";
23
24const RATE_LIMIT_HEADERS: [&str; 6] = [
25 X_RATELIMIT_LIMIT_REQUESTS,
26 X_RATELIMIT_LIMIT_TOKENS,
27 X_RATELIMIT_REMAINING_REQUESTS,
28 X_RATELIMIT_REMAINING_TOKENS,
29 X_RATELIMIT_RESET_REQUESTS,
30 X_RATELIMIT_RESET_TOKENS,
31];
32
33fn log_rate_limit_headers(headers: &HeaderMap) {
34 for header in RATE_LIMIT_HEADERS {
35 if let Some(value) = headers.get(header).and_then(|value| value.to_str().ok()) {
36 tracing::warn!("rate-limit: {header} = {value}");
37 }
38 }
39 if let Some(value) = headers
41 .get(reqwest::header::RETRY_AFTER)
42 .and_then(|value| value.to_str().ok())
43 {
44 tracing::warn!("retry-after={value}");
45 }
46}
47
48#[allow(unused_variables)]
58pub fn should_retry(result: &Result<Response, OpenAIError>) -> bool {
59 match result {
60 Ok(response) => response.status().as_u16() == 429 || response.status().is_server_error(),
61 #[cfg(not(target_family = "wasm"))]
62 Err(OpenAIError::Reqwest(error)) => error.is_connect(),
63 #[cfg(target_family = "wasm")]
64 Err(OpenAIError::Reqwest(_)) => false,
65 _ => false,
66 }
67}
68
69#[derive(Clone, Debug)]
78pub struct SimpleRetryPolicy {
79 max_retries: usize,
80 attempts: usize,
81 backoff_attempt: u32,
82}
83
84impl SimpleRetryPolicy {
85 pub fn new(max_retries: usize) -> Self {
90 Self {
91 max_retries,
92 attempts: 0,
93 backoff_attempt: 0,
94 }
95 }
96
97 pub fn max_retries(&self) -> usize {
99 self.max_retries
100 }
101
102 pub fn attempts(&self) -> usize {
104 self.attempts
105 }
106}
107
108impl Default for SimpleRetryPolicy {
109 fn default() -> Self {
110 Self::new(3)
111 }
112}
113
114impl tower::retry::Policy<HttpRequestFactory, Response, OpenAIError> for SimpleRetryPolicy {
115 #[cfg(not(target_family = "wasm"))]
116 type Future = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
117 #[cfg(target_family = "wasm")]
118 type Future = Pin<Box<dyn Future<Output = ()> + 'static>>;
119
120 fn retry(
121 &mut self,
122 _req: &mut HttpRequestFactory,
123 result: &mut Result<Response, OpenAIError>,
124 ) -> Option<Self::Future> {
125 if self.attempts >= self.max_retries || !should_retry(result) {
126 return None;
127 }
128
129 if let Ok(response) = result.as_ref() {
130 log_rate_limit_headers(response.headers());
131 }
132
133 let retry_after = result
134 .as_ref()
135 .ok()
136 .and_then(|response| response.headers().get(reqwest::header::RETRY_AFTER))
137 .and_then(|value| value.to_str().ok())
138 .and_then(|value| value.parse::<u64>().ok())
139 .map(std::time::Duration::from_secs);
140
141 let delay = retry_after.unwrap_or_else(|| {
142 let delay = std::time::Duration::from_millis(100)
143 .saturating_mul(2_u32.saturating_pow(self.backoff_attempt));
144 self.backoff_attempt = self.backoff_attempt.saturating_add(1);
145 delay.min(std::time::Duration::from_secs(8))
146 });
147
148 self.attempts += 1;
149
150 #[cfg(target_family = "wasm")]
151 {
152 let _ = delay;
153 return Some(Box::pin(std::future::ready(())));
154 }
155
156 #[cfg(not(target_family = "wasm"))]
157 Some(Box::pin(tokio::time::sleep(delay)))
158 }
159
160 fn clone_request(&mut self, req: &HttpRequestFactory) -> Option<HttpRequestFactory> {
161 Some(req.clone())
162 }
163}