use crabllm_core::{
AudioSpeechRequest, BoxStream, ChatCompletionChunk, ChatCompletionRequest,
ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, Error, ImageRequest,
MultipartField, Provider,
};
use rand::Rng;
use std::time::Duration;
const DEFAULT_MAX_RETRIES: u32 = 2;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
#[derive(Debug, Clone)]
pub struct Retrying<P: Provider> {
inner: P,
max_retries: u32,
timeout: Duration,
}
impl<P: Provider> Retrying<P> {
pub fn new(inner: P) -> Self {
Self {
inner,
max_retries: DEFAULT_MAX_RETRIES,
timeout: DEFAULT_TIMEOUT,
}
}
}
impl<P: Provider> Provider for Retrying<P> {
async fn chat_completion(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Error> {
let mut backoff = INITIAL_BACKOFF;
let mut last_err = None;
for _ in 0..=self.max_retries {
let result = if self.timeout.is_zero() {
self.inner.chat_completion(request).await
} else {
match tokio::time::timeout(self.timeout, self.inner.chat_completion(request)).await
{
Ok(r) => r,
Err(_) => Err(Error::Timeout),
}
};
match result {
Ok(resp) => return Ok(resp),
Err(e) if e.is_transient() => {
last_err = Some(e);
tokio::time::sleep(jittered(backoff)).await;
backoff *= 2;
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("retry loop exited without producing an error"))
}
async fn chat_completion_stream(
&self,
request: &ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
if self.timeout.is_zero() {
self.inner.chat_completion_stream(request).await
} else {
match tokio::time::timeout(self.timeout, self.inner.chat_completion_stream(request))
.await
{
Ok(r) => r,
Err(_) => Err(Error::Timeout),
}
}
}
async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
self.inner.embedding(request).await
}
async fn image_generation(
&self,
request: &ImageRequest,
) -> Result<(bytes::Bytes, String), Error> {
self.inner.image_generation(request).await
}
async fn audio_speech(
&self,
request: &AudioSpeechRequest,
) -> Result<(bytes::Bytes, String), Error> {
self.inner.audio_speech(request).await
}
async fn audio_transcription(
&self,
model: &str,
fields: &[MultipartField],
) -> Result<(bytes::Bytes, String), Error> {
self.inner.audio_transcription(model, fields).await
}
}
fn jittered(backoff: Duration) -> Duration {
let lo = backoff.as_millis() as u64 / 2;
let hi = backoff.as_millis() as u64;
if lo >= hi {
return backoff;
}
Duration::from_millis(rand::rng().random_range(lo..=hi))
}