use derive_more::Display;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Display, Serialize, Deserialize)]
#[display("Server returned an error {error_code} ({message})")]
pub struct ApiError {
pub error_code: String,
pub message: String,
}
pub fn is_retryable(no_retry_codes: &[u16], error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(reqwest_error) = cause.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_error.status() {
if no_retry_codes.contains(&status.as_u16()) {
return false;
}
}
}
}
true
}
#[cfg(test)]
mod test {
use super::*;
use rstest::rstest;
fn make_status_error(status: u16) -> anyhow::Error {
let response = http::Response::builder()
.status(status)
.body(bytes::Bytes::new())
.expect("BUG: Cannot build http response");
let response = reqwest::Response::from(response);
let error = response
.error_for_status()
.expect_err("BUG: Expected error for non-2xx status");
anyhow::Error::new(error)
}
#[rstest]
#[case(404, &[400, 401, 403, 404, 409, 422], false)]
#[case(400, &[400, 401, 403, 404, 409, 422], false)]
#[case(401, &[400, 401, 403, 404, 409, 422], false)]
#[case(409, &[400, 401, 403, 404, 409, 422], false)]
#[case(500, &[400, 401, 403, 404, 409, 422], true)]
#[case(503, &[400, 401, 403, 404, 409, 422], true)]
#[case(429, &[400, 401, 403, 404, 409, 422], true)]
#[case(404, &[], true)]
#[case(500, &[], true)]
fn test_is_retryable(
#[case] status: u16,
#[case] no_retry_codes: &[u16],
#[case] expected: bool,
) {
let error = make_status_error(status);
assert_eq!(is_retryable(no_retry_codes, &error), expected);
}
#[rstest]
fn test_is_retryable_connection_error() {
let error = anyhow::anyhow!("connection refused");
assert!(is_retryable(&[400, 401, 403, 404, 409, 422], &error));
}
#[rstest]
fn test_is_retryable_nested_error() {
let error = make_status_error(404);
let wrapped = error.context("some outer context");
assert!(!is_retryable(&[404], &wrapped));
}
}