use std::{future::Future, pin::Pin, time::Duration};
use reqwest::{header::HeaderMap, Response};
use crate::{
error::{OpenAIError, WrappedError},
executor::HttpRequestFactory,
};
use super::log_rate_limit_headers;
const INSUFFICIENT_QUOTA: &str = "insufficient_quota";
#[cfg(not(target_family = "wasm"))]
type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
#[cfg(target_family = "wasm")]
type RetryFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
#[derive(Clone)]
pub struct OpenAIRetryLayer {
max_retries: usize,
}
impl std::fmt::Debug for OpenAIRetryLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIRetryLayer")
.field("max_retries", &self.max_retries)
.finish_non_exhaustive()
}
}
impl OpenAIRetryLayer {
pub fn new(max_retries: usize) -> Self {
Self { max_retries }
}
pub fn max_retries(&self) -> usize {
self.max_retries
}
}
impl Default for OpenAIRetryLayer {
fn default() -> Self {
Self::new(3)
}
}
impl<S> tower::Layer<S> for OpenAIRetryLayer {
type Service = OpenAIRetry<S>;
fn layer(&self, inner: S) -> Self::Service {
OpenAIRetry {
inner,
max_retries: self.max_retries,
}
}
}
#[derive(Clone)]
pub struct OpenAIRetry<S> {
inner: S,
max_retries: usize,
}
impl<S> std::fmt::Debug for OpenAIRetry<S>
where
S: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIRetry")
.field("inner", &self.inner)
.field("max_retries", &self.max_retries)
.finish_non_exhaustive()
}
}
#[cfg(not(target_family = "wasm"))]
impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
where
S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = OpenAIError;
type Future = RetryFuture;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
let clone = self.inner.clone();
let mut service = std::mem::replace(&mut self.inner, clone);
let first_attempt = service.call(request.clone());
let max_retries = self.max_retries;
Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
}
}
#[cfg(target_family = "wasm")]
impl<S> tower::Service<HttpRequestFactory> for OpenAIRetry<S>
where
S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>
+ Clone
+ 'static,
S::Future: 'static,
{
type Response = Response;
type Error = OpenAIError;
type Future = RetryFuture;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
let clone = self.inner.clone();
let mut service = std::mem::replace(&mut self.inner, clone);
let first_attempt = service.call(request.clone());
let max_retries = self.max_retries;
Box::pin(async move { retry_request(service, first_attempt, request, max_retries).await })
}
}
async fn retry_request<S>(
mut service: S,
first_attempt: S::Future,
request: HttpRequestFactory,
max_retries: usize,
) -> Result<Response, OpenAIError>
where
S: tower::Service<HttpRequestFactory, Response = Response, Error = OpenAIError>,
{
use tower::ServiceExt;
let mut attempts = 0;
let mut backoff_attempt = 0;
let mut result = first_attempt.await;
loop {
let (final_result, headers, retry_after) = match result {
Ok(response) if response.status().is_success() => return Ok(response),
Ok(response) if response.status().as_u16() == 429 => {
let headers = response.headers().clone();
let retry_after = retry_after(&headers);
let bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(error) => return Err(OpenAIError::Reqwest(error)),
};
let error = match serde_json::from_slice::<WrappedError>(&bytes) {
Ok(wrapped_error) => {
if wrapped_error.error.r#type.as_deref() == Some(INSUFFICIENT_QUOTA) {
return Err(OpenAIError::ApiError(wrapped_error.error));
}
OpenAIError::ApiError(wrapped_error.error)
}
Err(error) => {
return Err(OpenAIError::JSONDeserialize(
error,
String::from_utf8_lossy(&bytes).into_owned(),
));
}
};
(Err(error), Some(headers), retry_after)
}
Ok(response) if response.status().is_server_error() => {
let retry_after = retry_after(response.headers());
(Ok(response), None, retry_after)
}
Ok(response) => return Ok(response),
Err(error) if is_connection_error(&error) => (Err(error), None, None),
Err(error) => return Err(error),
};
if attempts >= max_retries {
return final_result;
}
if let Some(headers) = headers.as_ref() {
log_rate_limit_headers(headers);
}
let delay = retry_after.unwrap_or_else(|| {
let delay =
Duration::from_millis(100).saturating_mul(2_u32.saturating_pow(backoff_attempt));
backoff_attempt = backoff_attempt.saturating_add(1);
delay.min(Duration::from_secs(8))
});
attempts += 1;
#[cfg(not(target_family = "wasm"))]
tokio::time::sleep(delay).await;
#[cfg(target_family = "wasm")]
let _ = delay;
result = service.ready().await?.call(request.clone()).await;
}
}
fn is_connection_error(error: &OpenAIError) -> bool {
match error {
#[cfg(not(target_family = "wasm"))]
OpenAIError::Reqwest(error) => error.is_connect(),
#[cfg(target_family = "wasm")]
OpenAIError::Reqwest(_) => false,
_ => false,
}
}
fn retry_after(headers: &HeaderMap) -> Option<Duration> {
headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.map(Duration::from_secs)
}