use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use llmkit_core::{
ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
LlmProvider, LlmResult,
};
use crate::layer::LlmLayer;
#[derive(Debug, Clone, Copy)]
pub struct RetryLayer {
attempts: u32,
base: Duration,
max_backoff: Duration,
}
impl RetryLayer {
pub fn exponential(attempts: u32, base: Duration) -> Self {
Self { attempts: attempts.max(1), base, max_backoff: Duration::from_secs(30) }
}
pub fn max_backoff(mut self, max: Duration) -> Self {
self.max_backoff = max;
self
}
}
impl LlmLayer for RetryLayer {
type Provider = Retry;
fn layer(self, inner: Arc<dyn LlmProvider>) -> Retry {
Retry { inner, cfg: self }
}
}
pub struct Retry {
inner: Arc<dyn LlmProvider>,
cfg: RetryLayer,
}
impl Retry {
fn backoff(&self, attempt: u32, err: &LlmError) -> Duration {
if let LlmError::RateLimited { retry_after: Some(d), .. } = err {
return (*d).min(self.cfg.max_backoff);
}
let mult = 2u32.saturating_pow(attempt);
self.cfg.base.saturating_mul(mult).min(self.cfg.max_backoff)
}
}
#[async_trait]
impl LlmProvider for Retry {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
let mut last = None;
for attempt in 0..self.cfg.attempts {
match self.inner.chat(req.clone()).await {
Ok(resp) => return Ok(resp),
Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
tracing::debug!(provider = self.inner.name(), attempt, error = %e, "retrying");
tokio::time::sleep(self.backoff(attempt, &e)).await;
last = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
let mut last = None;
for attempt in 0..self.cfg.attempts {
match self.inner.chat_stream(req.clone()).await {
Ok(s) => return Ok(s),
Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
tokio::time::sleep(self.backoff(attempt, &e)).await;
last = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
let mut last = None;
for attempt in 0..self.cfg.attempts {
match self.inner.embed(req.clone()).await {
Ok(r) => return Ok(r),
Err(e) if e.is_retryable() && attempt + 1 < self.cfg.attempts => {
tokio::time::sleep(self.backoff(attempt, &e)).await;
last = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last.unwrap_or_else(|| LlmError::Other("retry exhausted".into())))
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn model(&self) -> &str {
self.inner.model()
}
fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
self.inner.estimate_cost(req)
}
}