use crate::{
AnthropicRequest, AnthropicResponse, AnthropicStreamEvent, AudioSpeechRequest, BoxStream,
ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest,
EmbeddingResponse, Error, GeminiRequest, GeminiResponse, ImageRequest, MultipartField,
Provider,
};
use rand::Rng;
use std::{future::Future, time::Duration};
const DEFAULT_MAX_RETRIES: u32 = 2;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_RETRY_AFTER: Duration = Duration::from_secs(60);
const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
#[derive(Debug, Clone)]
pub struct Retrying<P: Provider> {
inner: P,
max_retries: u32,
timeout: Duration,
max_retry_after: Duration,
}
impl<P: Provider> Retrying<P> {
pub fn new(inner: P) -> Self {
Self {
inner,
max_retries: DEFAULT_MAX_RETRIES,
timeout: DEFAULT_TIMEOUT,
max_retry_after: DEFAULT_MAX_RETRY_AFTER,
}
}
pub fn max_retry_after(mut self, d: Duration) -> Self {
self.max_retry_after = d;
self
}
fn should_retry(&self, e: &Error) -> bool {
if !e.is_transient() {
return false;
}
!matches!(e.retry_after(), Some(ra) if ra > self.max_retry_after)
}
async fn timed<T>(
&self,
fut: impl Future<Output = Result<T, Error>> + Send,
) -> Result<T, Error> {
if self.timeout.is_zero() {
return fut.await;
}
let Ok(result) = tokio::time::timeout(self.timeout, fut).await else {
return Err(Error::Timeout);
};
result
}
}
impl<P: Provider> Provider for Retrying<P> {
async fn chat_completion(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
match self.timed(self.inner.chat_completion(request)).await {
Ok(resp) => return Ok(resp),
Err(e) if self.should_retry(&e) => {
let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
last_err = Some(e);
tokio::time::sleep(sleep).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn chat_completion_stream(
&self,
request: &ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
match self.timed(self.inner.chat_completion_stream(request)).await {
Ok(stream) => return Ok(stream),
Err(e) if self.should_retry(&e) => {
let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
last_err = Some(e);
tokio::time::sleep(sleep).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn anthropic_messages(
&self,
request: &AnthropicRequest,
) -> Result<AnthropicResponse, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
match self.timed(self.inner.anthropic_messages(request)).await {
Ok(resp) => return Ok(resp),
Err(e) if self.should_retry(&e) => {
let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
last_err = Some(e);
tokio::time::sleep(sleep).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn anthropic_messages_stream(
&self,
request: &AnthropicRequest,
) -> Result<BoxStream<'static, Result<AnthropicStreamEvent, Error>>, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
match self
.timed(self.inner.anthropic_messages_stream(request))
.await
{
Ok(stream) => return Ok(stream),
Err(e) if self.should_retry(&e) => {
let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
last_err = Some(e);
tokio::time::sleep(sleep).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn gemini_generate_content_stream(
&self,
model: &str,
request: &GeminiRequest,
) -> Result<BoxStream<'static, Result<GeminiResponse, Error>>, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
match self
.timed(self.inner.gemini_generate_content_stream(model, request))
.await
{
Ok(stream) => return Ok(stream),
Err(e) if self.should_retry(&e) => {
let sleep = e.retry_after().unwrap_or_else(|| jittered(backoff));
last_err = Some(e);
tokio::time::sleep(sleep).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
self.inner.embedding(request).await
}
async fn image_generation(
&self,
request: &ImageRequest,
) -> Result<(bytes::Bytes, String), Error> {
self.inner.image_generation(request).await
}
async fn audio_speech(
&self,
request: &AudioSpeechRequest,
) -> Result<(bytes::Bytes, String), Error> {
self.inner.audio_speech(request).await
}
async fn audio_transcription(
&self,
model: &str,
fields: &[MultipartField],
) -> Result<(bytes::Bytes, String), Error> {
self.inner.audio_transcription(model, fields).await
}
}
fn jittered(backoff: Duration) -> Duration {
let lo = backoff.as_millis() as u64 / 2;
let hi = backoff.as_millis() as u64;
if lo >= hi {
return backoff;
}
Duration::from_millis(rand::rng().random_range(lo..=hi))
}