mod openai;
pub use openai::{OpenAIRetry, OpenAIRetryLayer};
use std::{future::Future, pin::Pin};
use reqwest::{header::HeaderMap, Response};
use crate::{error::OpenAIError, executor::HttpRequestFactory};
pub const X_RATELIMIT_LIMIT_REQUESTS: &str = "x-ratelimit-limit-requests";
pub const X_RATELIMIT_LIMIT_TOKENS: &str = "x-ratelimit-limit-tokens";
pub const X_RATELIMIT_REMAINING_REQUESTS: &str = "x-ratelimit-remaining-requests";
pub const X_RATELIMIT_REMAINING_TOKENS: &str = "x-ratelimit-remaining-tokens";
pub const X_RATELIMIT_RESET_REQUESTS: &str = "x-ratelimit-reset-requests";
pub const X_RATELIMIT_RESET_TOKENS: &str = "x-ratelimit-reset-tokens";
const RATE_LIMIT_HEADERS: [&str; 6] = [
X_RATELIMIT_LIMIT_REQUESTS,
X_RATELIMIT_LIMIT_TOKENS,
X_RATELIMIT_REMAINING_REQUESTS,
X_RATELIMIT_REMAINING_TOKENS,
X_RATELIMIT_RESET_REQUESTS,
X_RATELIMIT_RESET_TOKENS,
];
fn log_rate_limit_headers(headers: &HeaderMap) {
for header in RATE_LIMIT_HEADERS {
if let Some(value) = headers.get(header).and_then(|value| value.to_str().ok()) {
tracing::warn!("rate-limit: {header} = {value}");
}
}
if let Some(value) = headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|value| value.to_str().ok())
{
tracing::warn!("retry-after={value}");
}
}
#[allow(unused_variables)]
pub fn should_retry(result: &Result<Response, OpenAIError>) -> bool {
match result {
Ok(response) => response.status().as_u16() == 429 || response.status().is_server_error(),
#[cfg(not(target_family = "wasm"))]
Err(OpenAIError::Reqwest(error)) => error.is_connect(),
#[cfg(target_family = "wasm")]
Err(OpenAIError::Reqwest(_)) => false,
_ => false,
}
}
#[derive(Clone, Debug)]
pub struct SimpleRetryPolicy {
max_retries: usize,
attempts: usize,
backoff_attempt: u32,
}
impl SimpleRetryPolicy {
pub fn new(max_retries: usize) -> Self {
Self {
max_retries,
attempts: 0,
backoff_attempt: 0,
}
}
pub fn max_retries(&self) -> usize {
self.max_retries
}
pub fn attempts(&self) -> usize {
self.attempts
}
}
impl Default for SimpleRetryPolicy {
fn default() -> Self {
Self::new(3)
}
}
impl tower::retry::Policy<HttpRequestFactory, Response, OpenAIError> for SimpleRetryPolicy {
#[cfg(not(target_family = "wasm"))]
type Future = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[cfg(target_family = "wasm")]
type Future = Pin<Box<dyn Future<Output = ()> + 'static>>;
fn retry(
&mut self,
_req: &mut HttpRequestFactory,
result: &mut Result<Response, OpenAIError>,
) -> Option<Self::Future> {
if self.attempts >= self.max_retries || !should_retry(result) {
return None;
}
if let Ok(response) = result.as_ref() {
log_rate_limit_headers(response.headers());
}
let retry_after = result
.as_ref()
.ok()
.and_then(|response| response.headers().get(reqwest::header::RETRY_AFTER))
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.map(std::time::Duration::from_secs);
let delay = retry_after.unwrap_or_else(|| {
let delay = std::time::Duration::from_millis(100)
.saturating_mul(2_u32.saturating_pow(self.backoff_attempt));
self.backoff_attempt = self.backoff_attempt.saturating_add(1);
delay.min(std::time::Duration::from_secs(8))
});
self.attempts += 1;
#[cfg(target_family = "wasm")]
{
let _ = delay;
return Some(Box::pin(std::future::ready(())));
}
#[cfg(not(target_family = "wasm"))]
Some(Box::pin(tokio::time::sleep(delay)))
}
fn clone_request(&mut self, req: &HttpRequestFactory) -> Option<HttpRequestFactory> {
Some(req.clone())
}
}