use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::future::BoxFuture;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use tower::{Layer, Service, ServiceExt};
use crate::backoff::{DEFAULT_MAX_ATTEMPTS, ExponentialBackoff};
use crate::error::{Error, Result};
#[must_use]
pub fn parse_retry_after(header: Option<&http::HeaderValue>) -> Option<Duration> {
let header = header?.to_str().ok()?;
let secs: u64 = header.trim().parse().ok()?;
if secs == 0 {
return None;
}
Some(Duration::from_secs(secs))
}
pub trait RetryClassifier: Send + Sync + std::fmt::Debug {
fn should_retry(&self, error: &Error, attempt: u32) -> RetryDecision;
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RetryDecision {
pub retry: bool,
pub after: Option<Duration>,
}
impl RetryDecision {
pub const STOP: Self = Self {
retry: false,
after: None,
};
pub const RETRY: Self = Self {
retry: true,
after: None,
};
#[must_use]
pub const fn retry_after(after: Duration) -> Self {
Self {
retry: true,
after: Some(after),
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct DefaultRetryClassifier;
impl RetryClassifier for DefaultRetryClassifier {
#[allow(clippy::match_same_arms)]
fn should_retry(&self, error: &Error, _attempt: u32) -> RetryDecision {
match error {
Error::Provider {
kind, retry_after, ..
} if is_transient_kind(*kind) => match retry_after {
Some(after) => RetryDecision::retry_after(*after),
None => RetryDecision::RETRY,
},
Error::UsageLimitExceeded(_) => RetryDecision::STOP,
_ => RetryDecision::STOP,
}
}
}
const fn is_transient_kind(kind: crate::error::ProviderErrorKind) -> bool {
use crate::error::ProviderErrorKind;
match kind {
ProviderErrorKind::Network | ProviderErrorKind::Tls | ProviderErrorKind::Dns => true,
ProviderErrorKind::Http(status) => matches!(status, 408 | 425 | 429 | 500..=599),
}
}
#[derive(Clone, Debug)]
pub struct RetryPolicy {
max_attempts: u32,
backoff: ExponentialBackoff,
classifier: Arc<dyn RetryClassifier>,
}
impl RetryPolicy {
#[must_use]
pub fn new(
max_attempts: u32,
backoff: ExponentialBackoff,
classifier: Arc<dyn RetryClassifier>,
) -> Self {
Self {
max_attempts,
backoff,
classifier,
}
}
#[must_use]
pub fn standard() -> Self {
Self::new(
DEFAULT_MAX_ATTEMPTS,
ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(5)),
Arc::new(DefaultRetryClassifier),
)
}
#[must_use]
pub const fn with_max_attempts(mut self, n: u32) -> Self {
self.max_attempts = n;
self
}
#[must_use]
pub const fn with_backoff(mut self, backoff: ExponentialBackoff) -> Self {
self.backoff = backoff;
self
}
#[must_use]
pub fn with_classifier(mut self, classifier: Arc<dyn RetryClassifier>) -> Self {
self.classifier = classifier;
self
}
#[must_use]
pub const fn max_attempts(&self) -> u32 {
self.max_attempts
}
#[must_use]
pub const fn backoff(&self) -> ExponentialBackoff {
self.backoff
}
#[must_use]
pub fn classifier(&self) -> &Arc<dyn RetryClassifier> {
&self.classifier
}
}
#[derive(Clone, Debug)]
pub struct RetryLayer {
policy: RetryPolicy,
}
impl RetryLayer {
pub const NAME: &'static str = "retry";
#[must_use]
pub const fn new(policy: RetryPolicy) -> Self {
Self { policy }
}
}
impl<S> Layer<S> for RetryLayer {
type Service = RetryService<S>;
fn layer(&self, inner: S) -> Self::Service {
RetryService {
inner,
policy: self.policy.clone(),
}
}
}
impl crate::NamedLayer for RetryLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
#[derive(Clone, Debug)]
pub struct RetryService<S> {
inner: S,
policy: RetryPolicy,
}
impl<S, Req, Resp> Service<Req> for RetryService<S>
where
S: Service<Req, Response = Resp, Error = Error> + Clone + Send + 'static,
S::Future: Send + 'static,
Req: Retryable + Send + 'static,
Resp: Send + 'static,
{
type Response = Resp;
type Error = Error;
type Future = BoxFuture<'static, Result<Resp>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Req) -> Self::Future {
let inner = self.inner.clone();
let policy = self.policy.clone();
Box::pin(async move { run_with_retry(inner, request, policy).await })
}
}
pub trait Retryable: Clone {
fn ctx(&self) -> &crate::context::ExecutionContext;
fn ctx_mut(&mut self) -> &mut crate::context::ExecutionContext;
}
impl Retryable for crate::service::ModelInvocation {
fn ctx(&self) -> &crate::context::ExecutionContext {
&self.ctx
}
fn ctx_mut(&mut self) -> &mut crate::context::ExecutionContext {
&mut self.ctx
}
}
impl Retryable for crate::service::ToolInvocation {
fn ctx(&self) -> &crate::context::ExecutionContext {
&self.ctx
}
fn ctx_mut(&mut self) -> &mut crate::context::ExecutionContext {
&mut self.ctx
}
}
impl Retryable for crate::service::StreamingModelInvocation {
fn ctx(&self) -> &crate::context::ExecutionContext {
&self.inner.ctx
}
fn ctx_mut(&mut self) -> &mut crate::context::ExecutionContext {
&mut self.inner.ctx
}
}
async fn run_with_retry<S, Req, Resp>(
mut inner: S,
mut request: Req,
policy: RetryPolicy,
) -> Result<Resp>
where
S: Service<Req, Response = Resp, Error = Error> + Clone + Send,
S::Future: Send,
Req: Retryable + Send,
{
let seed = seed_from_time();
let mut rng = SmallRng::seed_from_u64(seed);
request
.ctx_mut()
.ensure_idempotency_key(|| uuid::Uuid::new_v4().to_string());
let max_attempts = policy.max_attempts.max(1);
let mut attempt: u32 = 0;
loop {
let ctx_token = request.ctx().cancellation();
if ctx_token.is_cancelled() {
return Err(Error::Cancelled);
}
if let Some(deadline) = request.ctx().deadline()
&& tokio::time::Instant::now() >= deadline
{
return Err(Error::DeadlineExceeded);
}
let cloned = request.clone();
let result = inner.ready().await?.call(cloned).await;
match result {
Ok(resp) => return Ok(resp),
Err(err) => {
attempt = attempt.saturating_add(1);
let exhausted = attempt >= max_attempts;
let decision = policy.classifier.should_retry(&err, attempt - 1);
if exhausted || !decision.retry {
return Err(err);
}
let backoff_delay = policy.backoff.delay_for_attempt(attempt - 1, &mut rng);
let delay = match decision.after {
Some(hint) => hint.min(policy.backoff.max()),
None => backoff_delay,
};
let effective_delay = if let Some(deadline) = request.ctx().deadline() {
let now = tokio::time::Instant::now();
let remaining = deadline.saturating_duration_since(now);
if remaining.is_zero() {
return Err(Error::DeadlineExceeded);
}
delay.min(remaining)
} else {
delay
};
let deadline_for_select = request.ctx().deadline();
tokio::select! {
() = tokio::time::sleep(effective_delay) => {
if let Some(deadline) = deadline_for_select
&& tokio::time::Instant::now() >= deadline
{
return Err(Error::DeadlineExceeded);
}
}
() = ctx_token.cancelled() => return Err(Error::Cancelled),
}
}
}
}
}
fn seed_from_time() -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |d| {
let n = d.as_nanos();
#[allow(clippy::cast_possible_truncation)]
{
n as u64
}
});
let bump = COUNTER.fetch_add(1, Ordering::Relaxed);
nanos ^ bump
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn default_classifier_retries_transient_http_status_codes() {
let c = DefaultRetryClassifier;
for status in [408_u16, 425, 429, 500, 502, 503, 504, 599] {
let err = Error::provider_http(status, "x");
assert!(c.should_retry(&err, 0).retry, "status {status} must retry");
}
}
#[test]
fn default_classifier_retries_transport_class_failures() {
let c = DefaultRetryClassifier;
assert!(
c.should_retry(&Error::provider_network("connect refused"), 0)
.retry
);
assert!(
c.should_retry(&Error::provider_tls("handshake failed"), 0)
.retry
);
assert!(
c.should_retry(&Error::provider_dns("no such host"), 0)
.retry
);
}
#[test]
fn default_classifier_does_not_retry_permanent_status_codes() {
let c = DefaultRetryClassifier;
for status in [400_u16, 401, 403, 404, 410, 422] {
let err = Error::provider_http(status, "x");
assert!(
!c.should_retry(&err, 0).retry,
"status {status} must NOT retry"
);
}
}
#[test]
fn default_classifier_does_not_retry_caller_intent_or_programming_errors() {
let c = DefaultRetryClassifier;
assert!(!c.should_retry(&Error::Cancelled, 0).retry);
assert!(!c.should_retry(&Error::DeadlineExceeded, 0).retry);
assert!(!c.should_retry(&Error::invalid_request("nope"), 0).retry);
assert!(!c.should_retry(&Error::config("bad"), 0).retry);
}
#[test]
fn default_classifier_does_not_retry_usage_limit_exceeded() {
use crate::run_budget::UsageLimitBreach;
let c = DefaultRetryClassifier;
let err = Error::UsageLimitExceeded(UsageLimitBreach::Requests {
limit: 5,
observed: 5,
});
let decision = c.should_retry(&err, 0);
assert!(!decision.retry);
assert_eq!(decision.after, None);
}
#[test]
fn default_classifier_propagates_vendor_retry_after_hint() {
let c = DefaultRetryClassifier;
let err =
Error::provider_http(429, "rate limited").with_retry_after(Duration::from_secs(7));
let decision = c.should_retry(&err, 0);
assert!(decision.retry);
assert_eq!(decision.after, Some(Duration::from_secs(7)));
}
#[test]
fn ensure_idempotency_key_stamps_once_and_subsequent_calls_observe_the_same_value() {
use crate::context::ExecutionContext;
let mut ctx = ExecutionContext::new();
assert!(ctx.idempotency_key().is_none());
let mut counter = 0u32;
let first = ctx
.ensure_idempotency_key(|| {
counter += 1;
"first-uuid".to_owned()
})
.to_owned();
let second = ctx
.ensure_idempotency_key(|| {
counter += 1;
"second-uuid".to_owned()
})
.to_owned();
assert_eq!(first, "first-uuid");
assert_eq!(second, "first-uuid", "stamp must be stable across calls");
assert_eq!(counter, 1, "generator must run exactly once");
let cloned = ctx.clone();
assert_eq!(cloned.idempotency_key(), Some("first-uuid"));
}
#[test]
fn default_classifier_does_not_attach_retry_after_when_vendor_does_not_supply_one() {
let c = DefaultRetryClassifier;
let err = Error::provider_http(503, "down");
let decision = c.should_retry(&err, 0);
assert!(decision.retry);
assert!(decision.after.is_none());
}
#[test]
fn retry_policy_standard_uses_default_max_attempts() {
let p = RetryPolicy::standard();
assert_eq!(p.max_attempts(), DEFAULT_MAX_ATTEMPTS);
}
#[test]
fn retry_policy_overrides_compose() {
let p = RetryPolicy::standard()
.with_max_attempts(2)
.with_backoff(ExponentialBackoff::new(
Duration::from_millis(1),
Duration::from_millis(10),
));
assert_eq!(p.max_attempts(), 2);
assert_eq!(p.backoff().base(), Duration::from_millis(1));
}
}