use std::future::Future;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use crate::error::{Error as GrpcError, TransportError};
#[derive(Clone, Default)]
pub struct ConnectOptions {
pub deadline: Option<Duration>,
pub backoff: Duration,
pub cancellation: Option<CancellationToken>,
}
impl ConnectOptions {
pub fn with_deadline(deadline: Duration) -> Self {
Self {
deadline: Some(deadline),
..Self::default()
}
}
pub fn backoff(mut self, backoff: Duration) -> Self {
self.backoff = backoff;
self
}
pub fn cancellation(mut self, token: CancellationToken) -> Self {
self.cancellation = Some(token);
self
}
}
const DEFAULT_BACKOFF: Duration = Duration::from_millis(50);
pub async fn retry_connect<F, Fut, T>(
options: &ConnectOptions,
mut attempt: F,
) -> Result<T, GrpcError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, GrpcError>>,
{
let backoff = if options.backoff.is_zero() {
DEFAULT_BACKOFF
} else {
options.backoff
};
let deadline = options
.deadline
.map(|budget| tokio::time::Instant::now() + budget);
let mut last_error: GrpcError =
TransportError::ConnectFailed("connect deadline exceeded".to_string()).into();
loop {
if let Some(token) = &options.cancellation
&& token.is_cancelled()
{
return Err(GrpcError::Cancelled("connect cancelled".to_string()));
}
let remaining = match deadline {
Some(deadline) => {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(last_error);
}
Some(remaining)
}
None => None,
};
let attempt_result = match (&options.cancellation, remaining) {
(Some(token), Some(remaining)) => tokio::select! {
_ = token.cancelled() => {
return Err(GrpcError::Cancelled("connect cancelled".to_string()));
}
result = tokio::time::timeout(remaining, attempt()) => match result {
Err(_elapsed) => return Err(last_error),
Ok(result) => result,
},
},
(Some(token), None) => tokio::select! {
_ = token.cancelled() => {
return Err(GrpcError::Cancelled("connect cancelled".to_string()));
}
result = attempt() => result,
},
(None, Some(remaining)) => match tokio::time::timeout(remaining, attempt()).await {
Err(_elapsed) => return Err(last_error),
Ok(result) => result,
},
(None, None) => attempt().await,
};
last_error = match attempt_result {
Ok(value) => return Ok(value),
Err(error) => error,
};
if let Some(deadline) = deadline
&& tokio::time::Instant::now() + backoff >= deadline
{
return Err(last_error);
}
match &options.cancellation {
Some(token) => {
tokio::select! {
_ = tokio::time::sleep(backoff) => {}
_ = token.cancelled() => {
return Err(GrpcError::Cancelled("connect cancelled".to_string()));
}
}
}
None => tokio::time::sleep(backoff).await,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[tokio::test]
async fn retries_until_success() {
let attempts = AtomicUsize::new(0);
let result = retry_connect(
&ConnectOptions::default().backoff(Duration::from_millis(1)),
|| {
let n = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err(TransportError::ConnectFailed("not yet".to_string()).into())
} else {
Ok(n)
}
}
},
)
.await
.unwrap();
assert_eq!(result, 2);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn surfaces_last_error_after_deadline() {
let result: Result<(), GrpcError> = retry_connect(
&ConnectOptions::with_deadline(Duration::from_millis(15))
.backoff(Duration::from_millis(10)),
|| async { Err(TransportError::ConnectFailed("down".to_string()).into()) },
)
.await;
let error = result.unwrap_err();
assert!(error.to_string().contains("down"), "got: {error}");
}
#[tokio::test]
async fn deadline_clips_a_hanging_attempt() {
let started = tokio::time::Instant::now();
let result: Result<(), GrpcError> = retry_connect(
&ConnectOptions::with_deadline(Duration::from_millis(50)),
std::future::pending,
)
.await;
let error = result.unwrap_err();
assert!(
matches!(
error,
GrpcError::Transport(TransportError::ConnectFailed(_))
),
"got: {error}"
);
assert!(
started.elapsed() < Duration::from_secs(2),
"retry_connect waited {:?}, past the 50ms deadline",
started.elapsed()
);
}
#[tokio::test]
async fn cancellation_aborts_retry() {
let token = CancellationToken::new();
token.cancel();
let result: Result<(), GrpcError> =
retry_connect(&ConnectOptions::default().cancellation(token), || async {
Err(TransportError::ConnectFailed("down".to_string()).into())
})
.await;
assert!(matches!(result, Err(GrpcError::Cancelled(_))));
}
#[tokio::test]
async fn cancellation_aborts_a_hanging_attempt_without_deadline() {
let token = CancellationToken::new();
let canceller = token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
canceller.cancel();
});
let started = tokio::time::Instant::now();
let result: Result<(), GrpcError> = retry_connect(
&ConnectOptions::default().cancellation(token),
std::future::pending,
)
.await;
let error = result.unwrap_err();
assert!(matches!(error, GrpcError::Cancelled(_)), "got: {error}");
assert!(
started.elapsed() < Duration::from_secs(2),
"retry_connect waited {:?}, past cancellation",
started.elapsed()
);
}
}