use google_cloud_gax::error::Error as GaxError;
use http::StatusCode;
use std::error::Error;
pub use google_cloud_gax::error::CredentialsError;
pub trait SubjectTokenProviderError: Error + Send + Sync + 'static {
fn is_transient(&self) -> bool;
}
impl SubjectTokenProviderError for CredentialsError {
fn is_transient(&self) -> bool {
self.is_transient()
}
}
pub(crate) fn from_gax_error(err: GaxError, msg: &str) -> CredentialsError {
let transient = is_gax_error_retryable(&err);
CredentialsError::new(transient, msg, err)
}
pub(crate) fn from_http_error(err: reqwest::Error, msg: &str) -> CredentialsError {
let transient = self::is_retryable(&err);
CredentialsError::new(transient, msg, err)
}
pub(crate) async fn from_http_response(response: reqwest::Response, msg: &str) -> CredentialsError {
let err = response
.error_for_status_ref()
.expect_err("this function is only called on errors");
let body = response.text().await;
let transient = crate::errors::is_retryable(&err);
match body {
Err(e) => CredentialsError::new(transient, msg, e),
Ok(b) => CredentialsError::new(transient, format!("{msg}, body=<{b}>"), err),
}
}
pub(crate) fn non_retryable<T: Error + Send + Sync + 'static>(source: T) -> CredentialsError {
CredentialsError::from_source(false, source)
}
pub(crate) fn non_retryable_from_str<T: Into<String>>(message: T) -> CredentialsError {
CredentialsError::from_msg(false, message)
}
pub(crate) fn is_gax_error_retryable(err: &GaxError) -> bool {
if err
.http_status_code()
.and_then(|c| StatusCode::from_u16(c).ok())
.is_some_and(is_retryable_code)
{
return true;
}
let Some(s) = err.source() else { return false };
if let Some(cred_err) = s.downcast_ref::<CredentialsError>() {
return cred_err.is_transient();
}
if let Some(req_err) = s.downcast_ref::<reqwest::Error>() {
return is_retryable(req_err);
}
false
}
fn is_retryable(err: &reqwest::Error) -> bool {
if err.is_connect() {
return true;
}
match err.status() {
Some(code) => is_retryable_code(code),
None => false,
}
}
fn is_retryable_code(code: StatusCode) -> bool {
match code {
StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS => true,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::HeaderMap;
use std::num::ParseIntError;
use test_case::test_case;
#[test_case(StatusCode::INTERNAL_SERVER_ERROR)]
#[test_case(StatusCode::SERVICE_UNAVAILABLE)]
#[test_case(StatusCode::REQUEST_TIMEOUT)]
#[test_case(StatusCode::TOO_MANY_REQUESTS)]
fn retryable(c: StatusCode) {
assert!(is_retryable_code(c));
}
#[test_case(StatusCode::NOT_FOUND)]
#[test_case(StatusCode::UNAUTHORIZED)]
#[test_case(StatusCode::BAD_REQUEST)]
#[test_case(StatusCode::BAD_GATEWAY)]
#[test_case(StatusCode::PRECONDITION_FAILED)]
fn non_retryable(c: StatusCode) {
assert!(!is_retryable_code(c));
}
#[test]
fn helpers() {
let e = super::non_retryable_from_str("test-only-err-123");
assert!(!e.is_transient(), "{e}");
let got = format!("{e}");
assert!(got.contains("test-only-err-123"), "{got}");
let input = "NaN".parse::<u32>().unwrap_err();
let e = super::non_retryable(input.clone());
assert!(!e.is_transient(), "{e:?}");
let source = e.source().and_then(|e| e.downcast_ref::<ParseIntError>());
assert!(matches!(source, Some(ParseIntError { .. })), "{e:?}");
}
#[test_case(GaxError::http(503, HeaderMap::new(), Bytes::from("test")), true ; "retryable http status")]
#[test_case(GaxError::http(404, HeaderMap::new(), Bytes::from("test")), false ; "non-retryable http status")]
#[test_case(GaxError::authentication(CredentialsError::new(true, "msg", "NaN".parse::<u32>().unwrap_err())), true ; "transient credentials error")]
#[test_case(GaxError::authentication(CredentialsError::new(false, "msg", "NaN".parse::<u32>().unwrap_err())), false ; "permanent credentials error")]
#[test_case(GaxError::io("some io error"), false ; "io error fallback")]
#[test_case(GaxError::timeout("timeout"), false ; "timeout fallback")]
fn test_is_gax_error_retryable(gax_err: GaxError, expected: bool) {
let cred_err = from_gax_error(gax_err, "some msg");
assert_eq!(cred_err.is_transient(), expected);
}
}