use std::ops::ControlFlow;
use std::sync::Arc;
use std::time::Duration;
use layered::Service;
use tick::Clock;
use super::*;
use crate::utils::EnableIf;
use crate::{NotSet, RecoveryInfo, RecoveryKind};
#[derive(Debug)]
pub struct Retry<In, Out, S> {
pub(super) shared: Arc<RetryShared<In, Out>>,
pub(super) inner: S,
}
#[derive(Debug)]
pub(crate) struct RetryShared<In, Out> {
pub(crate) clock: Clock,
pub(crate) max_attempts: u32,
pub(crate) backoff: DelayBackoff,
pub(crate) clone_input: CloneInput<In>,
pub(crate) should_recover: ShouldRecover<Out>,
pub(crate) on_retry: Option<OnRetry<Out>>,
pub(crate) enable_if: EnableIf<In>,
#[cfg(any(feature = "logs", feature = "metrics", test))]
pub(crate) telemetry: crate::utils::TelemetryHelper,
pub(crate) restore_input: Option<RestoreInput<In, Out>>,
pub(crate) handle_unavailable: bool,
}
impl<In, Out, S: Clone> Clone for Retry<In, Out, S> {
fn clone(&self) -> Self {
Self {
shared: Arc::clone(&self.shared),
inner: self.inner.clone(),
}
}
}
impl<In, Out> Retry<In, Out, ()> {
pub fn layer(
name: impl Into<std::borrow::Cow<'static, str>>,
context: &crate::ResilienceContext<In, Out>,
) -> crate::retry::RetryLayer<In, Out, NotSet, NotSet> {
crate::retry::RetryLayer::new(name.into(), context)
}
}
impl<In, Out: Send, S> Service<In> for Retry<In, Out, S>
where
In: Send,
S: Service<In, Out = Out>,
{
type Out = Out;
#[cfg_attr(test, mutants::skip)] async fn execute(&self, mut input: In) -> Self::Out {
if !self.shared.enable_if.call(&input) {
return self.inner.execute(input).await;
}
let mut attempt = Attempt::first(self.shared.max_attempts);
let mut delays = self.shared.backoff.delays();
let mut previous_recovery = None;
loop {
let (original_input, attempt_input) = self.shared.clone_input(input, attempt, previous_recovery.clone());
let out = self.inner.execute(attempt_input).await;
match self.shared.evaluate_attempt(original_input, out, attempt, &mut delays) {
ControlFlow::Continue(state) => {
self.shared.clock.delay(state.delay).await;
input = state.input;
attempt = state.attempt;
previous_recovery = Some(state.recovery);
}
ControlFlow::Break(out) => return out,
}
}
}
}
impl<In, Out> RetryShared<In, Out> {
fn clone_input(&self, mut input: In, attempt: Attempt, previous_recovery: Option<RecoveryInfo>) -> (Option<In>, In) {
let args = CloneArgs {
attempt,
previous_recovery,
};
match self.clone_input.call(&mut input, args) {
Some(cloned) => (Some(input), cloned),
None => (None, input),
}
}
fn evaluate_attempt(
&self,
mut original_input: Option<In>,
mut out: Out,
attempt: Attempt,
delays: &mut impl Iterator<Item = Duration>,
) -> ControlFlow<Out, ContinueRetry<In>> {
let recovery = self.should_recover.call(
&out,
RecoveryArgs {
attempt,
clock: &self.clock,
},
);
if !self.is_recoverable(&recovery) {
return ControlFlow::Break(out);
}
let Some(next_attempt) = attempt.increment(self.max_attempts) else {
self.emit_telemetry(attempt, Duration::ZERO);
return ControlFlow::Break(out);
};
let retry_delay = compute_retry_delay(&recovery, delays);
self.emit_telemetry(attempt, retry_delay);
if let Some(input) = self.try_restore_input(original_input.as_ref(), &mut out, attempt, &recovery) {
original_input = Some(input);
}
match original_input {
Some(input) => {
self.invoke_on_retry(&out, attempt, retry_delay, &recovery);
ControlFlow::Continue(ContinueRetry {
input,
attempt: next_attempt,
recovery,
delay: retry_delay,
})
}
None => ControlFlow::Break(out),
}
}
#[cfg_attr(test, mutants::skip)] fn is_recoverable(&self, recovery: &RecoveryInfo) -> bool {
match recovery.kind() {
RecoveryKind::Unavailable => self.handle_unavailable,
RecoveryKind::Retry => true,
RecoveryKind::Never | RecoveryKind::Unknown | _ => false,
}
}
fn try_restore_input(&self, original_input: Option<&In>, out: &mut Out, attempt: Attempt, recovery: &RecoveryInfo) -> Option<In> {
if original_input.is_some() {
return None;
}
match &self.restore_input {
Some(restore) => restore.call(
out,
RestoreInputArgs {
attempt,
recovery: recovery.clone(),
},
),
None => None,
}
}
fn invoke_on_retry(&self, out: &Out, attempt: Attempt, retry_delay: Duration, recovery: &RecoveryInfo) {
if let Some(on_retry) = &self.on_retry {
on_retry.call(
out,
OnRetryArgs {
attempt,
retry_delay,
recovery: recovery.clone(),
},
);
}
}
#[cfg_attr(
not(any(feature = "logs", test)),
expect(unused_variables, reason = "unused when logs feature not used")
)]
fn emit_telemetry(&self, attempt: Attempt, retry_delay: Duration) {
#[cfg(any(feature = "logs", test))]
if self.telemetry.logs_enabled {
tracing::event!(
name: "seatbelt.retry",
tracing::Level::WARN,
pipeline.name = %self.telemetry.pipeline_name,
strategy.name = %self.telemetry.strategy_name,
resilience.attempt.index = attempt.index(),
resilience.attempt.is_last = attempt.is_last(),
resilience.retry.delay = retry_delay.as_secs_f32(),
);
}
#[cfg(any(feature = "metrics", test))]
if self.telemetry.metrics_enabled() {
use super::telemetry::{ATTEMPT_INDEX, ATTEMPT_NUMBER_IS_LAST, RETRY_EVENT};
use crate::utils::{EVENT_NAME, PIPELINE_NAME, STRATEGY_NAME};
self.telemetry.report_metrics(&[
opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()),
opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()),
opentelemetry::KeyValue::new(EVENT_NAME, RETRY_EVENT),
opentelemetry::KeyValue::new(ATTEMPT_INDEX, i64::from(attempt.index())),
opentelemetry::KeyValue::new(ATTEMPT_NUMBER_IS_LAST, attempt.is_last()),
]);
}
}
}
fn compute_retry_delay(recovery: &RecoveryInfo, delays: &mut impl Iterator<Item = Duration>) -> Duration {
let backoff_delay = delays.next().unwrap_or(Duration::ZERO);
recovery.get_delay().unwrap_or(backoff_delay)
}
struct ContinueRetry<In> {
input: In,
attempt: Attempt,
recovery: RecoveryInfo,
delay: Duration,
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(not(miri))] #[cfg(test)]
mod tests {
use layered::Execute;
use opentelemetry::KeyValue;
use tick::ClockControl;
use super::*;
use crate::testing::MetricTester;
use crate::{ResilienceContext, Set};
use layered::Layer;
#[test]
fn layer_ensure_defaults() {
let context = ResilienceContext::<String, String>::new(Clock::new_frozen()).name("test_pipeline");
let layer: RetryLayer<String, String, NotSet, NotSet> = Retry::layer("test_retry", &context);
let layer = layer.recovery_with(|_, _| RecoveryInfo::never()).clone_input();
let retry = layer.layer(Execute::new(|v: String| async move { v }));
assert_eq!(retry.shared.telemetry.pipeline_name.to_string(), "test_pipeline");
assert_eq!(retry.shared.telemetry.strategy_name.to_string(), "test_retry");
assert_eq!(retry.shared.max_attempts, 4);
assert_eq!(retry.shared.backoff.0.base_delay, Duration::from_millis(10));
assert_eq!(retry.shared.backoff.0.backoff_type, Backoff::Exponential);
assert!(retry.shared.backoff.0.use_jitter);
assert!(retry.shared.on_retry.is_none());
assert!(retry.shared.enable_if.call(&"str".to_string()));
}
#[tokio::test]
async fn retries_exhausted_ensure_telemetry_reported() {
let tester = MetricTester::new();
let context = ResilienceContext::<String, String>::new(ClockControl::default().auto_advance_timers(true).to_clock())
.name("test_pipeline")
.use_metrics(tester.meter_provider());
let service = create_ready_retry_layer_core(RecoveryInfo::retry(), &context)
.clone_input_with(move |input, _args| Some(input.clone()))
.max_retry_attempts(2)
.recovery_with(move |_input, _args| RecoveryInfo::retry())
.layer(Execute::new(move |v: String| async move { v }));
let _result = service.execute("test".to_string()).await;
tester.assert_attributes(
&[
KeyValue::new("resilience.attempt.index", 0),
KeyValue::new("resilience.attempt.index", 1),
KeyValue::new("resilience.attempt.is_last", false),
KeyValue::new("resilience.attempt.is_last", true),
KeyValue::new("resilience.pipeline.name", "test_pipeline"),
KeyValue::new("resilience.strategy.name", "test_retry"),
KeyValue::new("resilience.event.name", "retry"),
],
Some(15),
);
}
#[tokio::test]
async fn retry_emits_log() {
use tracing_subscriber::util::SubscriberInitExt;
use crate::testing::LogCapture;
let log_capture = LogCapture::new();
let _guard = log_capture.subscriber().set_default();
let clock = ClockControl::default().auto_advance_timers(true).to_clock();
let context = ResilienceContext::<String, String>::new(clock).name("log_test_pipeline").use_logs();
let service = Retry::layer("log_test_retry", &context)
.clone_input()
.recovery_with(|_, _| RecoveryInfo::retry())
.max_retry_attempts(2)
.layer(Execute::new(|v: String| async move { v }));
let _ = service.execute("test".to_string()).await;
log_capture.assert_contains("seatbelt::retry");
log_capture.assert_contains("log_test_pipeline");
log_capture.assert_contains("log_test_retry");
log_capture.assert_contains("resilience.attempt.index");
log_capture.assert_contains("resilience.retry.delay");
}
fn create_ready_retry_layer_core(
recover: RecoveryInfo,
context: &ResilienceContext<String, String>,
) -> RetryLayer<String, String, Set, Set> {
Retry::layer("test_retry", context)
.recovery_with(move |_, _| recover.clone())
.clone_input()
.max_delay(Duration::from_secs(9999)) }
}