use reqwest::Response;
use reqwest_middleware::Error;
use reqwest_retry::{DefaultRetryableStrategy, Retryable, RetryableStrategy};
use crate::endpoint::{current_endpoint, EndpointKind};
pub struct KindAwareRetryStrategy;
impl RetryableStrategy for KindAwareRetryStrategy {
fn handle(&self, res: &Result<Response, Error>) -> Option<Retryable> {
if let Some(ep) = current_endpoint() {
match ep.kind {
EndpointKind::Recovery | EndpointKind::Auth => {
if let Err(e) = res {
tracing::warn!(
steam.endpoint.kind = %ep.kind,
steam.endpoint.path = %ep.path,
error = %e,
"skipping auto-retry for security-sensitive endpoint",
);
} else if let Ok(r) = res {
let s = r.status();
if s.is_server_error() || s == reqwest::StatusCode::TOO_MANY_REQUESTS {
tracing::warn!(
steam.endpoint.kind = %ep.kind,
steam.endpoint.path = %ep.path,
status = %s,
"skipping auto-retry for security-sensitive endpoint",
);
}
}
return None;
}
EndpointKind::Read | EndpointKind::Write | EndpointKind::Upload => {
}
}
}
DefaultRetryableStrategy.handle(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::endpoint::{EndpointInfo, Host, HttpMethod, CURRENT_ENDPOINT};
fn make_500_response() -> Response {
let raw = http::Response::builder().status(500).body("").expect("valid http response");
Response::from(raw)
}
fn make_200_response() -> Response {
let raw = http::Response::builder().status(200).body("").expect("valid http response");
Response::from(raw)
}
fn ep(kind: EndpointKind) -> EndpointInfo {
EndpointInfo {
name: "test", module: "test", method: HttpMethod::Get,
host: Host::Help, path: "/test", kind,
}
}
#[tokio::test]
async fn recovery_endpoint_skips_retry_on_5xx() {
static EP: std::sync::OnceLock<EndpointInfo> = std::sync::OnceLock::new();
let info = EP.get_or_init(|| ep(EndpointKind::Recovery));
let decision = CURRENT_ENDPOINT
.scope(info, async {
KindAwareRetryStrategy.handle(&Ok(make_500_response()))
})
.await;
assert!(decision.is_none(), "Recovery must not retry on 5xx");
}
#[tokio::test]
async fn auth_endpoint_skips_retry_on_5xx() {
static EP: std::sync::OnceLock<EndpointInfo> = std::sync::OnceLock::new();
let info = EP.get_or_init(|| ep(EndpointKind::Auth));
let decision = CURRENT_ENDPOINT
.scope(info, async {
KindAwareRetryStrategy.handle(&Ok(make_500_response()))
})
.await;
assert!(decision.is_none(), "Auth must not retry on 5xx");
}
#[tokio::test]
async fn read_endpoint_uses_default_retry_on_5xx() {
static EP: std::sync::OnceLock<EndpointInfo> = std::sync::OnceLock::new();
let info = EP.get_or_init(|| ep(EndpointKind::Read));
let decision = CURRENT_ENDPOINT
.scope(info, async {
KindAwareRetryStrategy.handle(&Ok(make_500_response()))
})
.await;
assert!(matches!(decision, Some(Retryable::Transient)), "Read on 5xx should retry");
}
#[tokio::test]
async fn read_endpoint_no_retry_on_200() {
static EP: std::sync::OnceLock<EndpointInfo> = std::sync::OnceLock::new();
let info = EP.get_or_init(|| ep(EndpointKind::Read));
let decision = CURRENT_ENDPOINT
.scope(info, async {
KindAwareRetryStrategy.handle(&Ok(make_200_response()))
})
.await;
assert!(decision.is_none(), "200 should not retry");
}
#[test]
fn no_endpoint_falls_back_to_default() {
let decision = KindAwareRetryStrategy.handle(&Ok(make_500_response()));
assert!(matches!(decision, Some(Retryable::Transient)));
}
}