use std::io;
use http::StatusCode;
use crate::{ErrorKind, Status};
pub(crate) trait RetryExt {
fn should_retry(&self) -> bool;
}
impl RetryExt for reqwest::StatusCode {
fn should_retry(&self) -> bool {
self.is_server_error()
|| self == &StatusCode::REQUEST_TIMEOUT
|| self == &StatusCode::TOO_MANY_REQUESTS
}
}
impl RetryExt for reqwest::Error {
#[allow(clippy::if_same_then_else)]
fn should_retry(&self) -> bool {
if self.is_timeout() {
true
} else if self.is_connect() {
false
} else if self.is_body() || self.is_decode() || self.is_builder() || self.is_redirect() {
false
} else if self.is_request() {
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&self) {
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
true
} else if let Some(io_error) = get_source_error_type::<io::Error>(hyper_error) {
should_retry_io(io_error)
} else {
false
}
} else {
false
}
} else if let Some(status) = self.status() {
status.should_retry()
} else {
false
}
}
}
impl RetryExt for http::Error {
fn should_retry(&self) -> bool {
let inner = self.get_ref();
inner
.source()
.and_then(<dyn std::error::Error + 'static>::downcast_ref)
.is_some_and(should_retry_io)
}
}
impl RetryExt for ErrorKind {
fn should_retry(&self) -> bool {
if let Some(r) = self.reqwest_error() {
r.should_retry()
} else if let Some(octocrab::Error::Http {
source,
backtrace: _,
}) = self.github_error()
{
source.should_retry()
} else {
matches!(
self,
Self::RejectedStatusCode(StatusCode::TOO_MANY_REQUESTS)
)
}
}
}
impl RetryExt for Status {
fn should_retry(&self) -> bool {
match self {
Status::Timeout(_) => true,
Status::Error(err) => err.should_retry(),
Status::Ok(_)
| Status::RequestError(_)
| Status::Redirected(_, _)
| Status::UnknownStatusCode(_)
| Status::UnknownMailStatus(_)
| Status::Excluded
| Status::Unsupported(_)
| Status::Cached(_) => false,
}
}
}
fn should_retry_io(error: &io::Error) -> bool {
matches!(
error.kind(),
io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted | io::ErrorKind::TimedOut
)
}
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(hyper_err) = err.downcast_ref::<T>() {
return Some(hyper_err);
}
source = err.source();
}
None
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::RetryExt;
#[test]
fn test_should_retry() {
assert!(StatusCode::REQUEST_TIMEOUT.should_retry());
assert!(StatusCode::TOO_MANY_REQUESTS.should_retry());
assert!(!StatusCode::FORBIDDEN.should_retry());
assert!(StatusCode::INTERNAL_SERVER_ERROR.should_retry());
}
}