use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use crate::transport::Transport;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
impl RetryPolicy {
#[must_use]
pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub const fn with_initial_backoff(mut self, backoff: Duration) -> Self {
self.initial_backoff = backoff;
self
}
#[must_use]
pub const fn with_max_backoff(mut self, max: Duration) -> Self {
self.max_backoff = max;
self
}
#[must_use]
pub const fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
}
impl ClientError {
#[must_use]
pub const fn is_retryable(&self) -> bool {
match self {
Self::Http(_) | Self::HttpClient(_) | Self::Timeout(_) => true,
Self::UnexpectedStatus { status, .. } => {
matches!(status, 429 | 502 | 503 | 504)
}
Self::Serialization(_)
| Self::Protocol(_)
| Self::Transport(_)
| Self::InvalidEndpoint(_)
| Self::AuthRequired { .. }
| Self::ProtocolBindingMismatch(_) => false,
}
}
}
pub(crate) struct RetryTransport {
inner: Box<dyn Transport>,
policy: RetryPolicy,
}
impl RetryTransport {
pub(crate) fn new(inner: Box<dyn Transport>, policy: RetryPolicy) -> Self {
Self { inner, policy }
}
}
impl Transport for RetryTransport {
fn send_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
Box::pin(async move {
let mut last_err = None;
let mut backoff = self.policy.initial_backoff;
let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
for attempt in 0..=self.policy.max_retries {
if attempt > 0 {
let jittered_backoff = jittered(backoff);
trace_info!(method, attempt, ?jittered_backoff, "retrying after backoff");
tokio::time::sleep(jittered_backoff).await;
backoff = cap_backoff(
backoff,
self.policy.backoff_multiplier,
self.policy.max_backoff,
);
}
let attempt_params: serde_json::Value =
serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
match self
.inner
.send_request(method, attempt_params, extra_headers)
.await
{
Ok(result) => return Ok(result),
Err(e) if e.is_retryable() => {
trace_warn!(method, attempt, error = %e, "transient error, will retry");
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("at least one attempt was made"))
})
}
fn send_streaming_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(async move {
let mut last_err = None;
let mut backoff = self.policy.initial_backoff;
let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
for attempt in 0..=self.policy.max_retries {
if attempt > 0 {
let jittered_backoff = jittered(backoff);
trace_info!(
method,
attempt,
?jittered_backoff,
"retrying stream connect after backoff"
);
tokio::time::sleep(jittered_backoff).await;
backoff = cap_backoff(
backoff,
self.policy.backoff_multiplier,
self.policy.max_backoff,
);
}
let attempt_params: serde_json::Value =
serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
match self
.inner
.send_streaming_request(method, attempt_params, extra_headers)
.await
{
Ok(stream) => return Ok(stream),
Err(e) if e.is_retryable() => {
trace_warn!(method, attempt, error = %e, "transient error, will retry");
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("at least one attempt was made"))
})
}
}
fn cap_backoff(current: Duration, multiplier: f64, max: Duration) -> Duration {
let next_secs = current.as_secs_f64() * multiplier;
if !next_secs.is_finite() || next_secs < 0.0 {
return max;
}
let next = Duration::from_secs_f64(next_secs);
if next > max {
max
} else {
next
}
}
fn jittered(backoff: Duration) -> Duration {
use std::hash::{BuildHasher, Hasher};
let mut hasher = std::collections::hash_map::RandomState::new().build_hasher();
hasher.write_u128(backoff.as_nanos());
let random_bits = hasher.finish();
#[allow(clippy::cast_precision_loss)] let factor = (random_bits as f64 / u64::MAX as f64).mul_add(0.5, 0.5);
let jittered_secs = backoff.as_secs_f64() * factor;
if !jittered_secs.is_finite() || jittered_secs < 0.0 {
backoff
} else {
Duration::from_secs_f64(jittered_secs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn http_errors_are_retryable() {
let e = ClientError::HttpClient("connection refused".into());
assert!(e.is_retryable());
}
#[test]
fn timeout_is_retryable() {
let e = ClientError::Timeout("request timed out".into());
assert!(e.is_retryable());
}
#[test]
fn status_503_is_retryable() {
let e = ClientError::UnexpectedStatus {
status: 503,
body: "Service Unavailable".into(),
};
assert!(e.is_retryable());
}
#[test]
fn status_429_is_retryable() {
let e = ClientError::UnexpectedStatus {
status: 429,
body: "Too Many Requests".into(),
};
assert!(e.is_retryable());
}
#[test]
fn status_404_is_not_retryable() {
let e = ClientError::UnexpectedStatus {
status: 404,
body: "Not Found".into(),
};
assert!(!e.is_retryable());
}
#[test]
fn serialization_error_is_not_retryable() {
let e = ClientError::Serialization(serde_json::from_str::<String>("not json").unwrap_err());
assert!(!e.is_retryable());
}
#[test]
fn protocol_error_is_not_retryable() {
let e = ClientError::Protocol(a2a_protocol_types::A2aError::task_not_found("t1"));
assert!(!e.is_retryable());
}
#[test]
fn default_retry_policy() {
let p = RetryPolicy::default();
assert_eq!(p.max_retries, 3);
assert_eq!(p.initial_backoff, Duration::from_millis(500));
assert_eq!(p.max_backoff, Duration::from_secs(30));
assert!((p.backoff_multiplier - 2.0).abs() < f64::EPSILON);
}
#[test]
fn cap_backoff_works() {
let result = cap_backoff(Duration::from_secs(1), 2.0, Duration::from_secs(5));
assert_eq!(result, Duration::from_secs(2));
let result = cap_backoff(Duration::from_secs(4), 2.0, Duration::from_secs(5));
assert_eq!(result, Duration::from_secs(5));
}
#[test]
fn status_502_is_retryable() {
let e = ClientError::UnexpectedStatus {
status: 502,
body: "Bad Gateway".into(),
};
assert!(e.is_retryable());
}
#[test]
fn status_504_is_retryable() {
let e = ClientError::UnexpectedStatus {
status: 504,
body: "Gateway Timeout".into(),
};
assert!(e.is_retryable());
}
#[test]
fn status_boundary_not_retryable() {
for status in [428, 430, 500, 501, 505] {
let e = ClientError::UnexpectedStatus {
status,
body: String::new(),
};
assert!(!e.is_retryable(), "status {status} should not be retryable");
}
}
#[test]
fn retry_policy_builder_methods() {
let p = RetryPolicy::default()
.with_max_retries(5)
.with_initial_backoff(Duration::from_secs(1))
.with_max_backoff(Duration::from_secs(60))
.with_backoff_multiplier(3.0);
assert_eq!(p.max_retries, 5);
assert_eq!(p.initial_backoff, Duration::from_secs(1));
assert_eq!(p.max_backoff, Duration::from_secs(60));
assert!((p.backoff_multiplier - 3.0).abs() < f64::EPSILON);
}
#[test]
fn cap_backoff_exact_boundary() {
let result = cap_backoff(Duration::from_secs(5), 1.0, Duration::from_secs(5));
assert_eq!(result, Duration::from_secs(5));
let result = cap_backoff(Duration::from_millis(1), 2.0, Duration::from_secs(5));
assert_eq!(result, Duration::from_millis(2));
}
#[test]
fn cap_backoff_infinity_returns_max() {
let max = Duration::from_secs(30);
let result = cap_backoff(Duration::from_secs(u64::MAX / 2), f64::MAX, max);
assert_eq!(result, max, "infinity should clamp to max");
}
#[test]
fn jittered_backoff_in_expected_range() {
let backoff = Duration::from_secs(2);
for _ in 0..100 {
let result = jittered(backoff);
assert!(
result >= Duration::from_secs(1),
"jittered backoff should be >= backoff/2, got {result:?}"
);
assert!(
result <= backoff,
"jittered backoff should be <= backoff, got {result:?}"
);
}
}
#[test]
fn jittered_zero_backoff() {
let result = jittered(Duration::ZERO);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn cap_backoff_nan_returns_max() {
let max = Duration::from_secs(30);
let result = cap_backoff(Duration::from_secs(0), f64::NAN, max);
assert_eq!(result, max, "NaN should clamp to max");
}
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::streaming::EventStream;
struct FailNTransport {
failures_remaining: Arc<AtomicUsize>,
success_response: serde_json::Value,
call_count: Arc<AtomicUsize>,
}
impl FailNTransport {
fn new(fail_count: usize, response: serde_json::Value) -> Self {
Self {
failures_remaining: Arc::new(AtomicUsize::new(fail_count)),
success_response: response,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
}
impl crate::transport::Transport for FailNTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
let resp = self.success_response.clone();
Box::pin(async move {
if remaining > 0 {
Err(ClientError::Timeout("transient".into()))
} else {
Ok(resp)
}
})
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
Box::pin(async move {
if remaining > 0 {
Err(ClientError::Timeout("transient".into()))
} else {
Err(ClientError::Transport("streaming not mocked".into()))
}
})
}
}
struct NonRetryableErrorTransport {
call_count: Arc<AtomicUsize>,
}
impl NonRetryableErrorTransport {
fn new() -> Self {
Self {
call_count: Arc::new(AtomicUsize::new(0)),
}
}
}
impl crate::transport::Transport for NonRetryableErrorTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
}
}
#[tokio::test]
async fn retry_transport_retries_on_transient_error() {
let inner = FailNTransport::new(2, serde_json::json!({"ok": true}));
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(3),
);
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok(), "should succeed after retries");
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"should have made 3 attempts (2 failures + 1 success)"
);
}
#[tokio::test]
async fn retry_transport_gives_up_after_max_retries() {
let inner = FailNTransport::new(10, serde_json::json!({"ok": true}));
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(2),
);
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_err(), "should fail after exhausting retries");
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"should have made 3 attempts (initial + 2 retries)"
);
}
#[tokio::test]
async fn retry_transport_no_retry_on_non_retryable() {
let inner = NonRetryableErrorTransport::new();
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(3),
);
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ClientError::InvalidEndpoint(_)
));
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"non-retryable error should not be retried"
);
}
#[tokio::test]
async fn retry_transport_streaming_retries() {
let inner = FailNTransport::new(1, serde_json::json!(null));
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(2),
);
let headers = HashMap::new();
let result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_err());
assert_eq!(
call_count.load(Ordering::SeqCst),
2,
"should have retried once for streaming"
);
}
#[tokio::test]
async fn retry_transport_streaming_no_retry_on_non_retryable() {
let inner = NonRetryableErrorTransport::new();
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(3),
);
let headers = HashMap::new();
let result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(matches!(
result.unwrap_err(),
ClientError::InvalidEndpoint(_)
));
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"non-retryable streaming error should not be retried"
);
}
#[tokio::test]
async fn retry_transport_streaming_succeeds_after_retry() {
use tokio::sync::mpsc;
struct FailThenStreamTransport {
call_count: Arc<AtomicUsize>,
}
impl crate::transport::Transport for FailThenStreamTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
{
Box::pin(async move { Ok(serde_json::Value::Null) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
let attempt = self.call_count.fetch_add(1, Ordering::SeqCst);
Box::pin(async move {
if attempt == 0 {
Err(ClientError::Timeout("transient timeout".into()))
} else {
let (tx, rx) = mpsc::channel(8);
drop(tx); Ok(EventStream::new(rx))
}
})
}
}
let call_count = Arc::new(AtomicUsize::new(0));
let inner = FailThenStreamTransport {
call_count: Arc::clone(&call_count),
};
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(2),
);
let headers = HashMap::new();
let result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok(), "streaming should succeed after retry");
assert_eq!(
call_count.load(Ordering::SeqCst),
2,
"should have made 2 attempts (1 failure + 1 success)"
);
}
#[tokio::test]
async fn retry_transport_streaming_exhausts_retries() {
let inner = FailNTransport::new(10, serde_json::json!(null));
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(2),
);
let headers = HashMap::new();
let result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_err());
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"should make 3 attempts total for streaming"
);
}
#[tokio::test]
async fn retry_transport_succeeds_without_retry_on_first_attempt() {
let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
let call_count = Arc::clone(&inner.call_count);
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_millis(1))
.with_max_retries(3),
);
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok());
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"should succeed on first try"
);
}
#[tokio::test(start_paused = true)]
async fn no_backoff_before_first_attempt() {
let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_secs(100))
.with_max_retries(1),
);
let start = tokio::time::Instant::now();
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok());
assert!(
start.elapsed() < Duration::from_secs(1),
"first attempt must not sleep, elapsed: {:?}",
start.elapsed()
);
}
#[tokio::test(start_paused = true)]
async fn backoff_applied_on_retry() {
let inner = FailNTransport::new(1, serde_json::json!({"ok": true}));
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_secs(100))
.with_max_retries(2),
);
let start = tokio::time::Instant::now();
let headers = HashMap::new();
let result = transport
.send_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok());
assert!(
start.elapsed() >= Duration::from_secs(50),
"retry should sleep (jittered backoff), elapsed: {:?}",
start.elapsed()
);
}
#[tokio::test(start_paused = true)]
async fn no_backoff_before_first_streaming_attempt() {
use tokio::sync::mpsc;
struct ImmediateStreamTransport;
impl crate::transport::Transport for ImmediateStreamTransport {
fn send_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
{
Box::pin(async { Ok(serde_json::Value::Null) })
}
fn send_streaming_request<'a>(
&'a self,
_method: &'a str,
_params: serde_json::Value,
_extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(async {
let (tx, rx) = mpsc::channel(1);
drop(tx);
Ok(EventStream::new(rx))
})
}
}
let transport = RetryTransport::new(
Box::new(ImmediateStreamTransport),
RetryPolicy::default()
.with_initial_backoff(Duration::from_secs(100))
.with_max_retries(1),
);
let start = tokio::time::Instant::now();
let headers = HashMap::new();
let result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(result.is_ok());
assert!(
start.elapsed() < Duration::from_secs(1),
"first streaming attempt must not sleep, elapsed: {:?}",
start.elapsed()
);
}
#[tokio::test(start_paused = true)]
async fn backoff_applied_on_streaming_retry() {
let inner = FailNTransport::new(1, serde_json::json!(null));
let transport = RetryTransport::new(
Box::new(inner),
RetryPolicy::default()
.with_initial_backoff(Duration::from_secs(100))
.with_max_retries(2),
);
let start = tokio::time::Instant::now();
let headers = HashMap::new();
let _result = transport
.send_streaming_request("test", serde_json::Value::Null, &headers)
.await;
assert!(
start.elapsed() >= Duration::from_secs(50),
"streaming retry should sleep, elapsed: {:?}",
start.elapsed()
);
}
#[test]
fn cap_backoff_zero_multiplier_returns_zero() {
let max = Duration::from_secs(30);
let result = cap_backoff(Duration::from_secs(5), 0.0, max);
assert_eq!(
result,
Duration::ZERO,
"0 * any = 0, should not clamp to max"
);
}
}