use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use rand::Rng;
use brainwires_core::message::{ChatResponse, Message, StreamChunk};
use brainwires_core::provider::{ChatOptions, Provider};
use brainwires_core::tool::Tool;
use crate::classify::{classify_error, parse_retry_after};
use crate::error::ResilienceError;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base: Duration,
pub max: Duration,
pub jitter: f64,
pub honor_retry_after: bool,
pub overall_deadline: Option<Duration>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 4,
base: Duration::from_millis(500),
max: Duration::from_secs(30),
jitter: 0.2,
honor_retry_after: true,
overall_deadline: Some(Duration::from_secs(60)),
}
}
}
impl RetryPolicy {
pub fn none() -> Self {
Self {
max_attempts: 1,
..Self::default()
}
}
fn backoff_for(&self, attempt: u32) -> Duration {
let shift = attempt.saturating_sub(1).min(16);
let nominal = self.base.saturating_mul(1u32 << shift);
let capped = nominal.min(self.max);
apply_jitter(capped, self.jitter)
}
}
fn apply_jitter(base: Duration, factor: f64) -> Duration {
if factor <= 0.0 {
return base;
}
let ms = base.as_millis() as f64;
let spread = ms * factor.clamp(0.0, 1.0);
let delta: f64 = rand::thread_rng().gen_range(-spread..=spread);
let jittered = (ms + delta).max(0.0);
Duration::from_millis(jittered as u64)
}
pub struct RetryProvider<P: Provider + ?Sized> {
inner: Arc<P>,
policy: RetryPolicy,
}
impl<P: Provider + ?Sized> RetryProvider<P> {
pub fn new(inner: Arc<P>, policy: RetryPolicy) -> Self {
Self { inner, policy }
}
pub fn inner(&self) -> &Arc<P> {
&self.inner
}
}
#[async_trait]
impl<P: Provider + ?Sized + 'static> Provider for RetryProvider<P> {
fn name(&self) -> &str {
self.inner.name()
}
fn max_output_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens()
}
async fn chat(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
options: &ChatOptions,
) -> Result<ChatResponse> {
let mut last_err: Option<anyhow::Error> = None;
let started = Instant::now();
for attempt in 1..=self.policy.max_attempts {
match self.inner.chat(messages, tools, options).await {
Ok(resp) => return Ok(resp),
Err(e) => {
let class = classify_error(&e);
if !class.is_retryable() || attempt == self.policy.max_attempts {
if attempt > 1 {
return Err(ResilienceError::RetriesExhausted {
attempts: attempt,
source: e,
}
.into());
}
return Err(e);
}
let mut delay = if self.policy.honor_retry_after {
parse_retry_after(&e).unwrap_or_else(|| self.policy.backoff_for(attempt))
} else {
self.policy.backoff_for(attempt)
};
if let Some(deadline) = self.policy.overall_deadline {
let elapsed = started.elapsed();
if elapsed >= deadline || elapsed.saturating_add(delay) >= deadline {
tracing::warn!(
provider = self.inner.name(),
attempt,
elapsed_ms = elapsed.as_millis() as u64,
deadline_ms = deadline.as_millis() as u64,
"retry deadline reached, giving up"
);
return Err(ResilienceError::DeadlineExceeded {
attempts: attempt,
elapsed_ms: elapsed.as_millis() as u64,
source: e,
}
.into());
}
let remaining = deadline.saturating_sub(elapsed);
if delay > remaining {
delay = remaining;
}
}
tracing::warn!(
provider = self.inner.name(),
attempt,
delay_ms = delay.as_millis() as u64,
class = ?class,
"retrying provider call after transient error"
);
last_err = Some(e);
tokio::time::sleep(delay).await;
}
}
}
Err(ResilienceError::RetriesExhausted {
attempts: self.policy.max_attempts,
source: last_err.unwrap_or_else(|| anyhow::anyhow!("unknown retry failure")),
}
.into())
}
fn stream_chat<'a>(
&'a self,
messages: &'a [Message],
tools: Option<&'a [Tool]>,
options: &'a ChatOptions,
) -> BoxStream<'a, Result<StreamChunk>> {
self.inner.stream_chat(messages, tools, options)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_is_sensible() {
let p = RetryPolicy::default();
assert_eq!(p.max_attempts, 4);
assert!(p.honor_retry_after);
}
#[test]
fn backoff_doubles_then_caps() {
let p = RetryPolicy {
max_attempts: 10,
base: Duration::from_millis(100),
max: Duration::from_millis(800),
jitter: 0.0,
honor_retry_after: false,
overall_deadline: None,
};
assert_eq!(p.backoff_for(1), Duration::from_millis(100));
assert_eq!(p.backoff_for(2), Duration::from_millis(200));
assert_eq!(p.backoff_for(3), Duration::from_millis(400));
assert_eq!(p.backoff_for(4), Duration::from_millis(800));
assert_eq!(p.backoff_for(5), Duration::from_millis(800));
}
#[test]
fn none_policy_disables_retry() {
assert_eq!(RetryPolicy::none().max_attempts, 1);
}
#[test]
fn jitter_stays_within_bounds() {
let base = Duration::from_millis(1000);
for _ in 0..50 {
let j = apply_jitter(base, 0.2);
assert!(j >= Duration::from_millis(800) && j <= Duration::from_millis(1200));
}
}
}