use crate::error::{Error, LlmError};
use crate::llm::{ChatRequest, ChatResponse, ChunkStream, LlmClient};
use std::future::Future;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::debug;
pub(crate) const MAX_LLM_RETRIES: u32 = 3;
pub(crate) const LLM_RETRY_BASE_DELAY: Duration = Duration::from_millis(100);
pub(crate) async fn complete_with_retry(
llm: &dyn LlmClient,
cancel: &CancellationToken,
req: ChatRequest,
) -> Result<ChatResponse, Error> {
retry_call(cancel, || llm.complete(req.clone())).await
}
pub(crate) async fn stream_with_retry(
llm: &dyn LlmClient,
cancel: &CancellationToken,
req: ChatRequest,
) -> Result<ChunkStream, Error> {
retry_call(cancel, || llm.stream(req.clone())).await
}
pub(crate) async fn retry_call<T, F, Fut>(cancel: &CancellationToken, mut op: F) -> Result<T, Error>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>> + Send,
{
let max_attempts = MAX_LLM_RETRIES + 1;
let mut attempt: u32 = 0;
loop {
attempt += 1;
if cancel.is_cancelled() {
return Err(Error::Cancelled);
}
match op().await {
Ok(v) => return Ok(v),
Err(e) => {
let last_attempt = attempt >= max_attempts;
let err: Error = e.into();
if last_attempt || !err.retryable() {
return Err(err);
}
let delay = match &err {
Error::Llm(LlmError::RateLimit { retry_after_secs }) => {
Duration::from_secs(u64::from(*retry_after_secs))
}
_ => LLM_RETRY_BASE_DELAY * 2u32.pow(attempt - 1),
};
debug!(
attempt,
delay_ms = delay.as_millis() as u64,
reason = %err,
"llm call failed; retrying after backoff"
);
tokio::select! {
_ = cancel.cancelled() => return Err(Error::Cancelled),
_ = tokio::time::sleep(delay) => {}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::{
Capabilities, ChatRequest, ChatResponse, ChunkStream, Embedding, FinishReason, LlmClient,
Message, Role,
};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct FailingThenSucceedingLlm {
fail_for: u32,
calls: AtomicU32,
caps: Capabilities,
make_err: Box<dyn Fn() -> LlmError + Send + Sync>,
}
impl FailingThenSucceedingLlm {
fn new(fail_for: u32, make_err: impl Fn() -> LlmError + Send + Sync + 'static) -> Self {
Self {
fail_for,
calls: AtomicU32::new(0),
caps: Capabilities::default(),
make_err: Box::new(make_err),
}
}
fn call_count(&self) -> u32 {
self.calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl LlmClient for FailingThenSucceedingLlm {
fn name(&self) -> &str {
"failing-then-succeeding"
}
fn capabilities(&self) -> &Capabilities {
&self.caps
}
async fn complete(&self, _req: ChatRequest) -> Result<ChatResponse, LlmError> {
let n = self.calls.fetch_add(1, Ordering::SeqCst);
if n < self.fail_for {
return Err((self.make_err)());
}
Ok(ChatResponse {
message: Message {
role: Role::Assistant,
content: "ok".into(),
tool_calls: vec![],
tool_call_id: None,
},
usage: Default::default(),
finish_reason: FinishReason::Stop,
})
}
async fn stream(&self, _req: ChatRequest) -> Result<ChunkStream, LlmError> {
Err(LlmError::Unsupported("streaming".into()))
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Embedding>, LlmError> {
Err(LlmError::Unsupported("embeddings".into()))
}
}
fn req() -> ChatRequest {
ChatRequest::new(vec![])
}
#[tokio::test(start_paused = true)]
async fn retries_server_5xx_then_succeeds() {
let llm = FailingThenSucceedingLlm::new(2, || LlmError::Server("503 unavailable".into()));
let cancel = CancellationToken::new();
let resp = complete_with_retry(&llm, &cancel, req())
.await
.expect("should succeed after retries");
assert_eq!(resp.message.content, "ok");
assert_eq!(
llm.call_count(),
3,
"two failures plus one success = three attempts"
);
}
#[tokio::test(start_paused = true)]
async fn unauthorized_does_not_retry() {
let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Unauthorized);
let cancel = CancellationToken::new();
let err = complete_with_retry(&llm, &cancel, req())
.await
.expect_err("must error");
assert!(
matches!(err, Error::Llm(LlmError::Unauthorized)),
"expected Llm(Unauthorized), got {err:?}"
);
assert_eq!(llm.call_count(), 1, "non-retryable: exactly one attempt");
}
#[tokio::test(start_paused = true)]
async fn cancellation_during_backoff_aborts() {
let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Timeout);
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
let canceller = tokio::spawn(async move {
tokio::task::yield_now().await;
cancel_clone.cancel();
});
let err = complete_with_retry(&llm, &cancel, req())
.await
.expect_err("must error");
canceller.await.unwrap();
assert!(
matches!(err, Error::Cancelled),
"expected Cancelled, got {err:?}"
);
assert_eq!(llm.call_count(), 1);
}
#[tokio::test(start_paused = true)]
async fn exhausts_attempts_and_propagates_last_error() {
let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Server("boom".into()));
let cancel = CancellationToken::new();
let err = complete_with_retry(&llm, &cancel, req())
.await
.expect_err("must error");
assert!(
matches!(err, Error::Llm(LlmError::Server(ref m)) if m == "boom"),
"expected Llm(Server(\"boom\")), got {err:?}"
);
assert_eq!(llm.call_count(), MAX_LLM_RETRIES + 1);
}
#[tokio::test(start_paused = true)]
async fn rate_limit_uses_retry_after_then_succeeds() {
let llm = FailingThenSucceedingLlm::new(1, || LlmError::RateLimit {
retry_after_secs: 1,
});
let cancel = CancellationToken::new();
let resp = complete_with_retry(&llm, &cancel, req())
.await
.expect("should succeed after rate-limit retry");
assert_eq!(resp.message.content, "ok");
assert_eq!(llm.call_count(), 2);
}
}