use crate::{
ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, MESSAGE_TOO_LARGE_KEY,
grpc::IsUserLongPoll,
request_extensions::{IsWorkerTaskLongPoll, NoRetryOnMatching, RetryConfigForCall},
};
use backoff::{
Clock, SystemClock,
backoff::Backoff,
exponential::{self, ExponentialBackoff},
};
use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
use std::{
error::Error,
fmt::Debug,
future::Future,
time::{Duration, Instant},
};
use tonic::Code;
#[doc(hidden)]
pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
Code::DataLoss,
Code::Internal,
Code::Unknown,
Code::ResourceExhausted,
Code::Aborted,
Code::OutOfRange,
Code::Unavailable,
];
const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
#[derive(Clone, Debug, PartialEq)]
pub struct RetryOptions {
pub initial_interval: Duration,
pub randomization_factor: f64,
pub multiplier: f64,
pub max_interval: Duration,
pub max_elapsed_time: Option<Duration>,
pub max_retries: usize,
}
impl Default for RetryOptions {
fn default() -> Self {
Self {
initial_interval: Duration::from_millis(100), randomization_factor: 0.2, multiplier: 1.7, max_interval: Duration::from_secs(5), max_elapsed_time: Some(Duration::from_secs(10)), max_retries: 10,
}
}
}
impl RetryOptions {
pub(crate) const fn task_poll_retry_policy() -> Self {
Self {
initial_interval: Duration::from_millis(200),
randomization_factor: 0.2,
multiplier: 2.0,
max_interval: Duration::from_secs(10),
max_elapsed_time: None,
max_retries: 0,
}
}
pub(crate) const fn throttle_retry_policy() -> Self {
Self {
initial_interval: Duration::from_secs(1),
randomization_factor: 0.2,
multiplier: 2.0,
max_interval: Duration::from_secs(10),
max_elapsed_time: None,
max_retries: 0,
}
}
pub const fn no_retries() -> Self {
Self {
initial_interval: Duration::from_secs(0),
randomization_factor: 0.0,
multiplier: 1.0,
max_interval: Duration::from_secs(0),
max_elapsed_time: None,
max_retries: 1,
}
}
pub(crate) fn get_call_info<R>(
&self,
call_name: &'static str,
request: Option<&tonic::Request<R>>,
) -> CallInfo {
let mut call_type = CallType::Normal;
let mut retry_short_circuit = None;
let mut retry_cfg_override = None;
if let Some(r) = request.as_ref() {
let ext = r.extensions();
if ext.get::<IsUserLongPoll>().is_some() {
call_type = CallType::UserLongPoll;
} else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
call_type = CallType::TaskLongPoll;
}
retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
}
let retry_cfg = if let Some(ovr) = retry_cfg_override {
ovr.0
} else if call_type == CallType::TaskLongPoll {
RetryOptions::task_poll_retry_policy()
} else {
self.clone()
};
CallInfo {
call_type,
call_name,
retry_cfg,
retry_short_circuit,
}
}
pub(crate) fn into_exp_backoff<C>(self, clock: C) -> exponential::ExponentialBackoff<C> {
exponential::ExponentialBackoff {
current_interval: self.initial_interval,
initial_interval: self.initial_interval,
randomization_factor: self.randomization_factor,
multiplier: self.multiplier,
max_interval: self.max_interval,
max_elapsed_time: self.max_elapsed_time,
clock,
start_time: Instant::now(),
}
}
}
impl From<RetryOptions> for backoff::ExponentialBackoff {
fn from(c: RetryOptions) -> Self {
c.into_exp_backoff(SystemClock::default())
}
}
pub(crate) fn make_future_retry<R, F, Fut>(
info: CallInfo,
factory: F,
) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
where
F: FnMut() -> Fut + Unpin,
Fut: Future<Output = Result<R, tonic::Status>>,
{
FutureRetry::new(
factory,
TonicErrorHandler::new(info, RetryOptions::throttle_retry_policy()),
)
}
#[derive(Debug)]
pub(crate) struct TonicErrorHandler<C: Clock> {
backoff: ExponentialBackoff<C>,
throttle_backoff: ExponentialBackoff<C>,
max_retries: usize,
call_type: CallType,
call_name: &'static str,
have_retried_goaway_cancel: bool,
retry_short_circuit: Option<NoRetryOnMatching>,
}
impl TonicErrorHandler<SystemClock> {
fn new(call_info: CallInfo, throttle_cfg: RetryOptions) -> Self {
Self::new_with_clock(
call_info,
throttle_cfg,
SystemClock::default(),
SystemClock::default(),
)
}
}
impl<C> TonicErrorHandler<C>
where
C: Clock,
{
fn new_with_clock(
call_info: CallInfo,
throttle_cfg: RetryOptions,
clock: C,
throttle_clock: C,
) -> Self {
Self {
call_type: call_info.call_type,
call_name: call_info.call_name,
max_retries: call_info.retry_cfg.max_retries,
backoff: call_info.retry_cfg.into_exp_backoff(clock),
throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
have_retried_goaway_cancel: false,
retry_short_circuit: call_info.retry_short_circuit,
}
}
fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
let mut do_log = false;
if self.max_retries == 0 && cur_attempt > 5 {
do_log = true;
}
if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
do_log = true;
}
if do_log {
if self.max_retries == 0 && cur_attempt > 15 {
error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
} else {
warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
}
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CallInfo {
pub call_type: CallType,
call_name: &'static str,
retry_cfg: RetryOptions,
retry_short_circuit: Option<NoRetryOnMatching>,
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum CallType {
Normal,
UserLongPoll,
TaskLongPoll,
}
impl CallType {
pub(crate) fn is_long(&self) -> bool {
matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
}
}
impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
where
C: Clock,
{
type OutError = tonic::Status;
fn handle(
&mut self,
current_attempt: usize,
mut e: tonic::Status,
) -> RetryPolicy<tonic::Status> {
if self.max_retries > 0 && current_attempt >= self.max_retries {
return RetryPolicy::ForwardError(e);
}
if let Some(sc) = self.retry_short_circuit.as_ref()
&& (sc.predicate)(&e)
{
e.metadata_mut().insert(
ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
tonic::metadata::MetadataValue::from(0),
);
return RetryPolicy::ForwardError(e);
}
if e.code() == Code::ResourceExhausted
&& (e
.message()
.starts_with("grpc: received message larger than max")
|| e.message()
.starts_with("grpc: message after decompression larger than max")
|| e.message()
.starts_with("grpc: received message after decompression larger than max"))
{
e.metadata_mut().insert(
MESSAGE_TOO_LARGE_KEY,
tonic::metadata::MetadataValue::from(0),
);
return RetryPolicy::ForwardError(e);
}
let long_poll_allowed = self.call_type == CallType::TaskLongPoll
&& [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
let mut goaway_retry_allowed = false;
if !self.have_retried_goaway_cancel
&& e.code() == Code::Cancelled
&& let Some(e) = e
.source()
.and_then(|e| e.downcast_ref::<tonic::transport::Error>())
.and_then(|te| te.source())
.and_then(|tec| tec.downcast_ref::<hyper::Error>())
&& format!("{e:?}").contains("connection closed")
{
goaway_retry_allowed = true;
self.have_retried_goaway_cancel = true;
}
if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
if current_attempt == 1 {
debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
} else {
self.maybe_log_retry(current_attempt, &e);
}
match self.backoff.next_backoff() {
None => RetryPolicy::ForwardError(e), Some(backoff) => {
if e.code() == Code::ResourceExhausted {
let extended_backoff =
backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
RetryPolicy::WaitRetry(extended_backoff)
} else {
RetryPolicy::WaitRetry(backoff)
}
}
}
} else if self.call_type == CallType::TaskLongPoll
&& self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
{
RetryPolicy::WaitRetry(self.backoff.max_interval)
} else {
RetryPolicy::ForwardError(e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use backoff::Clock;
use std::{ops::Add, time::Instant};
use temporalio_common::protos::temporal::api::workflowservice::v1::{
PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
};
use tonic::{IntoRequest, Status};
const TEST_RETRY_CONFIG: RetryOptions = RetryOptions {
initial_interval: Duration::from_millis(1),
randomization_factor: 0.0,
multiplier: 1.1,
max_interval: Duration::from_millis(2),
max_elapsed_time: None,
max_retries: 10,
};
const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
struct FixedClock(Instant);
impl Clock for FixedClock {
fn now(&self) -> Instant {
self.0
}
}
#[tokio::test]
async fn long_poll_non_retryable_errors() {
for code in [
Code::InvalidArgument,
Code::NotFound,
Code::AlreadyExists,
Code::PermissionDenied,
Code::FailedPrecondition,
Code::Unauthenticated,
Code::Unimplemented,
] {
for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
let mut err_handler = TonicErrorHandler::new_with_clock(
CallInfo {
call_type: CallType::TaskLongPoll,
call_name,
retry_cfg: TEST_RETRY_CONFIG,
retry_short_circuit: None,
},
TEST_RETRY_CONFIG,
FixedClock(Instant::now()),
FixedClock(Instant::now()),
);
let result = err_handler.handle(1, Status::new(code, "Ahh"));
assert_matches!(result, RetryPolicy::WaitRetry(_));
err_handler.backoff.clock.0 = err_handler
.backoff
.clock
.0
.add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
let result = err_handler.handle(2, Status::new(code, "Ahh"));
assert_matches!(result, RetryPolicy::ForwardError(_));
}
}
}
#[tokio::test]
async fn long_poll_retryable_errors_never_fatal() {
for code in RETRYABLE_ERROR_CODES {
for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
let mut err_handler = TonicErrorHandler::new_with_clock(
CallInfo {
call_type: CallType::TaskLongPoll,
call_name,
retry_cfg: TEST_RETRY_CONFIG,
retry_short_circuit: None,
},
TEST_RETRY_CONFIG,
FixedClock(Instant::now()),
FixedClock(Instant::now()),
);
let result = err_handler.handle(1, Status::new(code, "Ahh"));
assert_matches!(result, RetryPolicy::WaitRetry(_));
err_handler.backoff.clock.0 = err_handler
.backoff
.clock
.0
.add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
let result = err_handler.handle(2, Status::new(code, "Ahh"));
assert_matches!(result, RetryPolicy::WaitRetry(_));
}
}
}
#[tokio::test]
async fn retry_resource_exhausted() {
let mut err_handler = TonicErrorHandler::new_with_clock(
CallInfo {
call_type: CallType::TaskLongPoll,
call_name: POLL_WORKFLOW_METH_NAME,
retry_cfg: TEST_RETRY_CONFIG,
retry_short_circuit: None,
},
RetryOptions {
initial_interval: Duration::from_millis(2),
randomization_factor: 0.0,
multiplier: 4.0,
max_interval: Duration::from_millis(10),
max_elapsed_time: None,
max_retries: 10,
},
FixedClock(Instant::now()),
FixedClock(Instant::now()),
);
let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
match result {
RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
_ => panic!(),
}
err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
err_handler.throttle_backoff.clock.0 = err_handler
.throttle_backoff
.clock
.0
.add(Duration::from_millis(10));
let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
match result {
RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
_ => panic!(),
}
}
#[tokio::test]
async fn retry_short_circuit() {
let mut err_handler = TonicErrorHandler::new_with_clock(
CallInfo {
call_type: CallType::TaskLongPoll,
call_name: POLL_WORKFLOW_METH_NAME,
retry_cfg: TEST_RETRY_CONFIG,
retry_short_circuit: Some(NoRetryOnMatching {
predicate: |s: &Status| s.code() == Code::ResourceExhausted,
}),
},
TEST_RETRY_CONFIG,
FixedClock(Instant::now()),
FixedClock(Instant::now()),
);
let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
assert!(
e.metadata()
.get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
.is_some()
);
}
#[tokio::test]
async fn message_too_large_not_retried() {
let mut err_handler = TonicErrorHandler::new_with_clock(
CallInfo {
call_type: CallType::TaskLongPoll,
call_name: POLL_WORKFLOW_METH_NAME,
retry_cfg: TEST_RETRY_CONFIG,
retry_short_circuit: None,
},
TEST_RETRY_CONFIG,
FixedClock(Instant::now()),
FixedClock(Instant::now()),
);
let result = err_handler.handle(
1,
Status::new(
Code::ResourceExhausted,
"grpc: received message larger than max",
),
);
assert_matches!(result, RetryPolicy::ForwardError(_));
let result = err_handler.handle(
1,
Status::new(
Code::ResourceExhausted,
"grpc: message after decompression larger than max",
),
);
assert_matches!(result, RetryPolicy::ForwardError(_));
let result = err_handler.handle(
1,
Status::new(
Code::ResourceExhausted,
"grpc: received message after decompression larger than max",
),
);
assert_matches!(result, RetryPolicy::ForwardError(_));
}
#[rstest::rstest]
#[tokio::test]
async fn task_poll_retries_forever<R>(
#[values(
(
POLL_WORKFLOW_METH_NAME,
PollWorkflowTaskQueueRequest::default(),
),
(
POLL_ACTIVITY_METH_NAME,
PollActivityTaskQueueRequest::default(),
),
(
POLL_NEXUS_METH_NAME,
PollNexusTaskQueueRequest::default(),
),
)]
(call_name, req): (&'static str, R),
) {
let mut req = req.into_request();
req.extensions_mut().insert(IsWorkerTaskLongPoll);
for i in 1..=50 {
let mut err_handler = TonicErrorHandler::new(
TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
RetryOptions::throttle_retry_policy(),
);
let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
assert_matches!(result, RetryPolicy::WaitRetry(_));
}
}
#[rstest::rstest]
#[tokio::test]
async fn task_poll_retries_deadline_exceeded<R>(
#[values(
(
POLL_WORKFLOW_METH_NAME,
PollWorkflowTaskQueueRequest::default(),
),
(
POLL_ACTIVITY_METH_NAME,
PollActivityTaskQueueRequest::default(),
),
(
POLL_NEXUS_METH_NAME,
PollNexusTaskQueueRequest::default(),
),
)]
(call_name, req): (&'static str, R),
) {
let mut req = req.into_request();
req.extensions_mut().insert(IsWorkerTaskLongPoll);
for code in [Code::Cancelled, Code::DeadlineExceeded] {
let mut err_handler = TonicErrorHandler::new(
TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
RetryOptions::throttle_retry_policy(),
);
for i in 1..=5 {
let result = err_handler.handle(i, Status::new(code, "retryable failure"));
assert_matches!(result, RetryPolicy::WaitRetry(_));
}
}
}
}