use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::future::BoxFuture;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use crate::backoff::ExponentialBackoff;
use crate::error::{Error, Result};
use crate::service::ToolInvocation;
use crate::transports::{DefaultRetryClassifier, RetryClassifier};
pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct RetryToolLayer {
classifier: Arc<dyn RetryClassifier>,
max_backoff: Duration,
}
impl RetryToolLayer {
pub const NAME: &'static str = "tool_retry";
#[must_use]
pub fn new() -> Self {
Self {
classifier: Arc::new(DefaultRetryClassifier),
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
#[must_use]
pub fn with_classifier(mut self, classifier: Arc<dyn RetryClassifier>) -> Self {
self.classifier = classifier;
self
}
#[must_use]
pub const fn with_max_backoff(mut self, max: Duration) -> Self {
self.max_backoff = max;
self
}
}
impl Default for RetryToolLayer {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for RetryToolLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryToolLayer")
.field("max_backoff", &self.max_backoff)
.finish_non_exhaustive()
}
}
impl<S> Layer<S> for RetryToolLayer
where
S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Service = RetryToolService<S>;
fn layer(&self, inner: S) -> Self::Service {
RetryToolService {
inner,
classifier: Arc::clone(&self.classifier),
max_backoff: self.max_backoff,
}
}
}
impl crate::NamedLayer for RetryToolLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
#[derive(Clone)]
pub struct RetryToolService<Inner> {
inner: Inner,
classifier: Arc<dyn RetryClassifier>,
max_backoff: Duration,
}
impl<Inner> std::fmt::Debug for RetryToolService<Inner> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryToolService")
.field("max_backoff", &self.max_backoff)
.finish_non_exhaustive()
}
}
impl<Inner> Service<ToolInvocation> for RetryToolService<Inner>
where
Inner: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
Inner::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
let mut inner = self.inner.clone();
let classifier = Arc::clone(&self.classifier);
let max_backoff = self.max_backoff;
Box::pin(async move {
let hint = invocation.metadata.retry_hint;
let Some(hint) = hint else {
return inner.ready().await?.call(invocation).await;
};
let max_attempts = hint.max_attempts.max(1);
let backoff = ExponentialBackoff::new(hint.initial_backoff, max_backoff);
let mut rng = SmallRng::seed_from_u64(seed_from_time());
let mut attempt: u32 = 0;
loop {
let ctx_token = invocation.ctx.cancellation();
if ctx_token.is_cancelled() {
return Err(Error::Cancelled);
}
let cloned = invocation.clone();
let result = inner.ready().await?.call(cloned).await;
match result {
Ok(value) => return Ok(value),
Err(err) => {
attempt = attempt.saturating_add(1);
let exhausted = attempt >= max_attempts;
let decision = classifier.should_retry(&err, attempt - 1);
if exhausted || !decision.retry {
return Err(err);
}
let computed = backoff.delay_for_attempt(attempt - 1, &mut rng);
let delay = decision
.after
.map_or(computed, |hint| hint.min(max_backoff));
tokio::select! {
() = tokio::time::sleep(delay) => {}
() = ctx_token.cancelled() => return Err(Error::Cancelled),
}
}
}
}
})
}
}
fn seed_from_time() -> u64 {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |d| {
let n = d.as_nanos();
#[allow(clippy::cast_possible_truncation)]
{
n as u64
}
});
let bump = COUNTER.fetch_add(1, Ordering::Relaxed);
nanos ^ bump
}