use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use bytes::Bytes;
use reqwest::Method;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::Serialize;
use serde::de::DeserializeOwned;
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("outbound HTTP request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("no mock registered for {0} {1}")]
NoMock(String, String),
#[error("outbound circuit breaker is open")]
CircuitBreakerOpen,
}
pub struct Response {
status: reqwest::StatusCode,
headers: HeaderMap,
body: Bytes,
url: Option<reqwest::Url>,
}
impl Response {
pub const fn status(&self) -> reqwest::StatusCode {
self.status
}
pub const fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn is_success(&self) -> bool {
self.status.is_success()
}
pub const fn url(&self) -> Option<&reqwest::Url> {
self.url.as_ref()
}
pub fn json<T: DeserializeOwned>(self) -> Result<T, ClientError> {
serde_json::from_slice(&self.body).map_err(ClientError::Json)
}
pub fn text(self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn bytes(self) -> Bytes {
self.body
}
}
#[derive(Clone, Debug)]
pub struct RetryPolicy {
pub max_retries: u32,
pub retry_idempotent_only: bool,
pub max_retry_after: Duration,
pub request_timeout: Option<Duration>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
retry_idempotent_only: true,
max_retry_after: Duration::from_secs(10),
request_timeout: Some(Duration::from_secs(30)),
}
}
}
pub(crate) struct MockEntry {
pub(crate) method: Option<Method>,
pub(crate) path: String,
pub(crate) alias: Option<String>,
pub(crate) status: u16,
pub(crate) body: Option<serde_json::Value>,
pub(crate) call_count: Arc<AtomicUsize>,
}
pub(crate) struct MockResponse {
pub(crate) status: u16,
pub(crate) body: Option<serde_json::Value>,
}
pub struct MockRegistry {
entries: Mutex<Vec<MockEntry>>,
}
impl MockRegistry {
#[must_use]
pub const fn new() -> Self {
Self {
entries: Mutex::new(Vec::new()),
}
}
pub(crate) fn register(&self, entry: MockEntry) {
self.entries
.lock()
.expect("mock registry lock poisoned")
.push(entry);
}
pub(crate) fn find_match(
&self,
method: &Method,
url: &str,
alias: Option<&str>,
) -> Option<MockResponse> {
let url_path_owned: String = reqwest::Url::parse(url).map_or_else(
|_| {
let s = url.split_once('?').map_or(url, |(p, _)| p);
s.split_once('#').map_or(s, |(p, _)| p).to_owned()
},
|parsed| parsed.path().to_owned(),
);
let url_path = url_path_owned.as_str();
let found = {
let entries = self.entries.lock().expect("mock registry lock poisoned");
entries.iter().find_map(|entry| {
let method_ok = entry.method.as_ref().is_none_or(|m| m == method);
let path_ok = url_path == entry.path.as_str()
|| url_path
.strip_suffix(entry.path.as_str())
.is_some_and(|prefix| {
prefix.is_empty()
|| prefix.ends_with('/')
|| entry.path.starts_with('/')
});
let alias_ok = entry
.alias
.as_deref()
.is_none_or(|a| alias.is_some_and(|b| a == b));
if method_ok && path_ok && alias_ok {
Some((entry.call_count.clone(), entry.status, entry.body.clone()))
} else {
None
}
})
};
found.map(|(call_count, status, body)| {
call_count.fetch_add(1, Ordering::SeqCst);
MockResponse { status, body }
})
}
}
impl Default for MockRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct HttpMockRegistryExt(pub Arc<MockRegistry>);
pub struct MockHandle {
alias: String,
method: String,
path: String,
call_count: Arc<AtomicUsize>,
}
impl MockHandle {
pub fn expect_called(&self, expected: usize) {
let actual = self.call_count.load(Ordering::SeqCst);
assert_eq!(
actual, expected,
"http mock for {} {} {} expected {} call(s) but got {}",
self.alias, self.method, self.path, expected, actual,
);
}
#[must_use]
pub fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
pub struct MockSetupBuilder {
pub(crate) registry: Arc<MockRegistry>,
pub(crate) alias: String,
pub(crate) method: Option<Method>,
pub(crate) path: Option<String>,
}
impl MockSetupBuilder {
#[must_use]
pub fn get(mut self, path: &str) -> Self {
self.method = Some(Method::GET);
self.path = Some(path.to_owned());
self
}
#[must_use]
pub fn post(mut self, path: &str) -> Self {
self.method = Some(Method::POST);
self.path = Some(path.to_owned());
self
}
#[must_use]
pub fn put(mut self, path: &str) -> Self {
self.method = Some(Method::PUT);
self.path = Some(path.to_owned());
self
}
#[must_use]
pub fn patch(mut self, path: &str) -> Self {
self.method = Some(Method::PATCH);
self.path = Some(path.to_owned());
self
}
#[must_use]
pub fn delete(mut self, path: &str) -> Self {
self.method = Some(Method::DELETE);
self.path = Some(path.to_owned());
self
}
#[must_use]
pub fn respond_with(self, status: u16, body: serde_json::Value) -> MockHandle {
let path = self.path.clone().unwrap_or_default();
let method_str = self
.method
.as_ref()
.map_or_else(|| "*".to_owned(), ToString::to_string);
let call_count = Arc::new(AtomicUsize::new(0));
self.registry.register(MockEntry {
method: self.method,
path: path.clone(),
alias: Some(self.alias.clone()),
status,
body: Some(body),
call_count: call_count.clone(),
});
MockHandle {
alias: self.alias,
method: method_str,
path,
call_count,
}
}
#[must_use]
pub fn respond_with_status(self, status: u16) -> MockHandle {
let path = self.path.clone().unwrap_or_default();
let method_str = self
.method
.as_ref()
.map_or_else(|| "*".to_owned(), ToString::to_string);
let call_count = Arc::new(AtomicUsize::new(0));
self.registry.register(MockEntry {
method: self.method,
path: path.clone(),
alias: Some(self.alias.clone()),
status,
body: None,
call_count: call_count.clone(),
});
MockHandle {
alias: self.alias,
method: method_str,
path,
call_count,
}
}
}
#[derive(Clone)]
pub struct Client {
inner: reqwest::Client,
alias: Option<String>,
base_url: Option<String>,
base_urls: HashMap<String, String>,
retry_policy: RetryPolicy,
mock: Option<Arc<MockRegistry>>,
resilience_config: Option<Arc<crate::config::ResilienceConfig>>,
}
impl Client {
#[must_use]
pub fn new() -> Self {
Self::with_timeout(Duration::from_secs(30))
}
#[must_use]
pub fn with_timeout(timeout: Duration) -> Self {
let inner = reqwest::ClientBuilder::new()
.timeout(timeout)
.build()
.expect("failed to build reqwest client");
Self {
inner,
alias: None,
base_url: None,
base_urls: HashMap::new(),
retry_policy: RetryPolicy {
max_retries: 3,
retry_idempotent_only: true,
max_retry_after: Duration::from_secs(10),
request_timeout: Some(timeout),
},
mock: None,
resilience_config: None,
}
}
#[must_use]
pub fn from_config(config: &crate::config::HttpClientConfig) -> Self {
let timeout = Duration::from_secs(config.timeout_secs);
let inner = reqwest::ClientBuilder::new()
.timeout(timeout)
.build()
.expect("failed to build reqwest client");
Self {
inner,
alias: None,
base_url: None,
base_urls: config.base_urls.clone(),
retry_policy: RetryPolicy {
max_retries: config.max_retries,
retry_idempotent_only: true,
max_retry_after: Duration::from_secs(config.max_retry_after_secs),
request_timeout: Some(timeout),
},
mock: None,
resilience_config: None,
}
}
pub(crate) fn with_mock(mut self, registry: Arc<MockRegistry>) -> Self {
self.mock = Some(registry);
self
}
pub(crate) fn from_state(state: &crate::AppState) -> Self {
let config = state.extension::<crate::config::HttpConfig>().or_else(|| {
state
.extension::<crate::config::AutumnConfig>()
.map(|c| Arc::new(c.http.clone()))
});
let mut client = config.map_or_else(Self::new, |cfg| Self::from_config(&cfg.client));
let resilience = state
.extension::<crate::config::AutumnConfig>()
.map(|c| Arc::new(c.resilience.clone()));
client.resilience_config = resilience;
if let Some(ext) = state.extension::<HttpMockRegistryExt>() {
client = client.with_mock(ext.0.clone());
}
client
}
#[must_use]
pub fn named(&self, alias: &str) -> Self {
let base_url = self
.base_urls
.get(alias)
.cloned()
.or_else(|| self.base_url.clone());
Self {
inner: self.inner.clone(),
alias: Some(alias.to_owned()),
base_url,
base_urls: self.base_urls.clone(),
retry_policy: self.retry_policy.clone(),
mock: self.mock.clone(),
resilience_config: self.resilience_config.clone(),
}
}
#[must_use]
pub fn with_base_url(&self, base_url: impl Into<String>) -> Self {
Self {
inner: self.inner.clone(),
alias: self.alias.clone(),
base_url: Some(base_url.into()),
base_urls: self.base_urls.clone(),
retry_policy: self.retry_policy.clone(),
mock: self.mock.clone(),
resilience_config: self.resilience_config.clone(),
}
}
fn build_request(&self, method: Method, url: impl AsRef<str>) -> RequestBuilder {
let url_str = url.as_ref();
let full_url = if url_str.starts_with("http://") || url_str.starts_with("https://") {
url_str.to_owned()
} else if let Some(base) = &self.base_url {
format!(
"{}/{}",
base.trim_end_matches('/'),
url_str.trim_start_matches('/')
)
} else {
url_str.to_owned()
};
RequestBuilder {
client: self.inner.clone(),
method,
url: full_url,
extra_headers: HeaderMap::new(),
body: None,
retry_policy: self.retry_policy.clone(),
mock: self.mock.clone(),
alias: self.alias.clone(),
pending_error: None,
resilience_config: self.resilience_config.clone(),
}
}
#[must_use]
pub fn get(&self, url: impl AsRef<str>) -> RequestBuilder {
self.build_request(Method::GET, url)
}
#[must_use]
pub fn post(&self, url: impl AsRef<str>) -> RequestBuilder {
self.build_request(Method::POST, url)
}
#[must_use]
pub fn put(&self, url: impl AsRef<str>) -> RequestBuilder {
self.build_request(Method::PUT, url)
}
#[must_use]
pub fn patch(&self, url: impl AsRef<str>) -> RequestBuilder {
self.build_request(Method::PATCH, url)
}
#[must_use]
pub fn delete(&self, url: impl AsRef<str>) -> RequestBuilder {
self.build_request(Method::DELETE, url)
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
impl axum::extract::FromRequestParts<crate::AppState> for Client {
type Rejection = std::convert::Infallible;
async fn from_request_parts(
_parts: &mut http::request::Parts,
state: &crate::AppState,
) -> Result<Self, std::convert::Infallible> {
Ok(Self::from_state(state))
}
}
pub struct RequestBuilder {
client: reqwest::Client,
method: Method,
url: String,
extra_headers: HeaderMap,
body: Option<Bytes>,
retry_policy: RetryPolicy,
mock: Option<Arc<MockRegistry>>,
alias: Option<String>,
pending_error: Option<ClientError>,
resilience_config: Option<Arc<crate::config::ResilienceConfig>>,
}
impl RequestBuilder {
#[must_use]
pub fn header(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
let name_str = name.as_ref();
let value_str = value.as_ref();
match (
HeaderName::from_bytes(name_str.as_bytes()),
HeaderValue::from_str(value_str),
) {
(Ok(n), Ok(v)) => {
self.extra_headers.insert(n, v);
}
(Err(e), _) => {
tracing::warn!(header.name = name_str, error = %e, "invalid header name — header skipped");
}
(_, Err(e)) => {
tracing::warn!(header.name = name_str, error = %e, "invalid header value — header skipped");
}
}
self
}
#[must_use]
pub fn json<T: Serialize>(mut self, body: &T) -> Self {
match serde_json::to_vec(body) {
Ok(bytes) => {
self.body = Some(Bytes::from(bytes));
self = self.header("content-type", "application/json");
}
Err(e) => {
self.pending_error = Some(ClientError::Json(e));
}
}
self
}
#[must_use]
pub fn text_body(mut self, body: impl Into<String>) -> Self {
self.body = Some(Bytes::from(body.into().into_bytes()));
self
}
#[must_use]
pub const fn retries(mut self, max: u32) -> Self {
self.retry_policy.max_retries = max;
self.retry_policy.retry_idempotent_only = false;
self
}
#[must_use]
pub const fn max_retry_after(mut self, max: Duration) -> Self {
self.retry_policy.max_retry_after = max;
self
}
#[must_use]
pub const fn no_retry(mut self) -> Self {
self.retry_policy.max_retries = 0;
self
}
pub async fn send(self) -> Result<Response, ClientError> {
if let Some(err) = self.pending_error {
return Err(err);
}
if self.mock.is_some() {
return self.send_inner(false).await;
}
let host = url::Url::parse(&self.url).ok().map_or_else(
|| "unknown".to_owned(),
|u| {
let h = u.host_str().unwrap_or("unknown");
u.port()
.map_or_else(|| h.to_owned(), |port| format!("{h}:{port}"))
},
);
let breaker = self.resilience_config.as_ref().map_or_else(
|| {
crate::circuit_breaker::global_registry().get_or_create(
&host,
crate::circuit_breaker::CircuitBreakerPolicy::default(),
)
},
|rc| {
let policy = crate::circuit_breaker::CircuitBreakerPolicy::from_config(rc, &host);
crate::circuit_breaker::global_registry().get_or_create_with_config(&host, policy)
},
);
if breaker.before_call().is_err() {
return Err(ClientError::CircuitBreakerOpen);
}
let guard = crate::circuit_breaker::CircuitBreakerGuard::new(breaker.clone());
let is_half_open = breaker.state() == crate::circuit_breaker::CircuitState::HalfOpen;
let res = self.send_inner(is_half_open).await;
match &res {
Ok(resp) => {
let success = resp.status().as_u16() < 500;
if success {
guard.success();
} else {
guard.failure();
}
}
Err(_) => {
guard.failure();
}
}
res
}
async fn send_inner(self, suppress_retries: bool) -> Result<Response, ClientError> {
if let Some(ref mock) = self.mock {
match mock.find_match(&self.method, &self.url, self.alias.as_deref()) {
Some(mock_resp) => {
let status = reqwest::StatusCode::from_u16(mock_resp.status)
.unwrap_or(reqwest::StatusCode::OK);
let body_bytes = mock_resp
.body
.as_ref()
.map(|v| serde_json::to_vec(v).unwrap_or_default())
.unwrap_or_default();
tracing::info!(
http.method = %self.method,
http.url = %self.url,
http.status = mock_resp.status,
"[mock] outbound request intercepted"
);
return Ok(Response {
status,
headers: HeaderMap::new(),
body: Bytes::from(body_bytes),
url: None,
});
}
None => {
return Err(ClientError::NoMock(
self.method.to_string(),
self.url.clone(),
));
}
}
}
let start = Instant::now();
let max_attempts = if suppress_retries {
1
} else if is_idempotent_method(&self.method) || !self.retry_policy.retry_idempotent_only {
self.retry_policy.max_retries.saturating_add(1)
} else {
1
};
for attempt in 0..max_attempts {
if attempt > 0 {
let exp = (attempt - 1).min(10);
let delay = Duration::from_millis(100 * (1_u64 << exp));
tokio::time::sleep(delay).await;
}
let mut req = self.client.request(self.method.clone(), &self.url);
req = inject_trace_context(req);
for (name, value) in &self.extra_headers {
req = req.header(name.clone(), value.clone());
}
if let Some(body) = &self.body {
req = req.body(body.clone());
}
match req.send().await {
Ok(resp) => {
let status = resp.status();
let headers = resp.headers().clone();
let url_used = resp.url().clone();
if status.as_u16() == 429 && attempt + 1 < max_attempts {
let mut sleep_delay =
parse_retry_after(&headers).unwrap_or(Duration::from_secs(1));
sleep_delay = sleep_delay.min(self.retry_policy.max_retry_after);
if let Some(req_timeout) = self.retry_policy.request_timeout {
sleep_delay = sleep_delay.min(req_timeout);
}
tokio::time::sleep(sleep_delay).await;
continue;
}
if is_retryable_status(status.as_u16()) && attempt + 1 < max_attempts {
continue;
}
let body = resp
.bytes()
.await
.map_err(|e| ClientError::Request(e.without_url()))?;
let elapsed = start.elapsed();
log_request(
self.method.as_str(),
&url_used,
status.as_u16(),
elapsed,
&self.extra_headers,
);
return Ok(Response {
status,
headers,
body,
url: Some(url_used),
});
}
Err(e) if (e.is_connect() || e.is_timeout()) && attempt + 1 < max_attempts => {}
Err(e) => return Err(ClientError::Request(e.without_url())),
}
}
unreachable!("retry loop exited without returning a result — this is a bug")
}
}
const fn is_idempotent_method(method: &Method) -> bool {
matches!(
*method,
Method::GET | Method::HEAD | Method::PUT | Method::DELETE | Method::OPTIONS | Method::TRACE
)
}
const fn is_retryable_status(status: u16) -> bool {
matches!(status, 502..=504)
}
fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
let value = headers.get("retry-after")?.to_str().ok()?;
if let Ok(secs) = value.parse::<u64>() {
return Some(Duration::from_secs(secs));
}
let dt = chrono::DateTime::parse_from_rfc2822(value).ok()?;
let now = chrono::Utc::now();
let future = dt.with_timezone(&chrono::Utc);
let secs = u64::try_from((future - now).num_seconds().max(0)).unwrap_or(0);
Some(Duration::from_secs(secs))
}
const REDACTED_HEADERS: &[&str] = &["authorization", "cookie", "set-cookie"];
fn is_sensitive_header(name: &str) -> bool {
REDACTED_HEADERS
.iter()
.any(|h| h.eq_ignore_ascii_case(name))
}
fn log_request(
method: &str,
url: &reqwest::Url,
status: u16,
elapsed: Duration,
headers: &HeaderMap,
) {
let host = url.host_str().unwrap_or("unknown");
let path = url.path();
let sent_headers: Vec<&str> = headers
.keys()
.map(HeaderName::as_str)
.filter(|k| !is_sensitive_header(k))
.collect();
tracing::info!(
http.method = method,
http.host = host,
http.path = path,
http.status = status,
http.elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
http.sent_headers = ?sent_headers,
"outbound request"
);
}
#[allow(clippy::missing_const_for_fn)]
fn inject_trace_context(builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
#[cfg(not(feature = "telemetry-otlp"))]
{
builder
}
#[cfg(feature = "telemetry-otlp")]
{
use std::collections::HashMap;
use tracing_opentelemetry::OpenTelemetrySpanExt as _;
let cx = tracing::Span::current().context();
let mut map = HashMap::<String, String>::new();
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut TraceHeaderInjector(&mut map));
});
let mut builder = builder;
for (k, v) in map {
if let Ok(value) = HeaderValue::from_str(&v) {
builder = builder.header(k, value);
}
}
builder
}
}
#[cfg(feature = "telemetry-otlp")]
struct TraceHeaderInjector<'a>(&'a mut std::collections::HashMap<String, String>);
#[cfg(feature = "telemetry-otlp")]
impl opentelemetry::propagation::Injector for TraceHeaderInjector<'_> {
fn set(&mut self, key: &str, value: String) {
self.0.insert(key.to_owned(), value);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::HttpClientConfig;
#[test]
fn client_constructs_with_defaults() {
let client = Client::new();
assert!(client.alias.is_none());
assert!(client.base_url.is_none());
assert_eq!(client.retry_policy.max_retries, 3);
}
#[test]
fn request_builder_fluent_api_compiles() {
let client = Client::new();
let _builder = client
.post("https://example.com/api")
.header("x-api-key", "secret")
.json(&serde_json::json!({"key": "value"}))
.retries(2);
}
#[test]
fn response_accessors_work() {
let payload = serde_json::json!({"id": 42, "name": "Alice"});
let body = serde_json::to_vec(&payload).unwrap();
let resp = Response {
status: reqwest::StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from(body),
url: None,
};
assert_eq!(resp.status().as_u16(), 200);
assert!(resp.is_success());
}
#[test]
fn response_json_deserialises() {
#[derive(serde::Deserialize, PartialEq, Debug)]
struct User {
id: i32,
name: String,
}
let payload = serde_json::json!({"id": 1, "name": "Bob"});
let resp = Response {
status: reqwest::StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from(serde_json::to_vec(&payload).unwrap()),
url: None,
};
let user: User = resp.json().unwrap();
assert_eq!(user.id, 1);
assert_eq!(user.name, "Bob");
}
#[test]
fn response_text_returns_string() {
let resp = Response {
status: reqwest::StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"hello world"),
url: None,
};
assert_eq!(resp.text(), "hello world");
}
#[test]
fn response_bytes_returns_raw() {
let resp = Response {
status: reqwest::StatusCode::CREATED,
headers: HeaderMap::new(),
body: Bytes::from_static(b"\x00\x01\x02"),
url: None,
};
assert_eq!(resp.bytes(), Bytes::from_static(b"\x00\x01\x02"));
}
#[test]
fn config_deserialises_from_toml() {
let toml = r#"
[client]
timeout_secs = 60
max_retries = 5
[client.base_urls]
stripe = "https://api.stripe.com"
sendgrid = "https://api.sendgrid.com"
"#;
let http_cfg: crate::config::HttpConfig = toml::from_str(toml).unwrap();
let config = &http_cfg.client;
assert_eq!(config.timeout_secs, 60);
assert_eq!(config.max_retries, 5);
assert_eq!(
config.base_urls.get("stripe").map(String::as_str),
Some("https://api.stripe.com")
);
assert_eq!(
config.base_urls.get("sendgrid").map(String::as_str),
Some("https://api.sendgrid.com")
);
}
#[test]
fn config_has_correct_defaults() {
let config = HttpClientConfig::default();
assert_eq!(config.timeout_secs, 30);
assert_eq!(config.max_retries, 3);
assert!(config.base_urls.is_empty());
}
#[test]
fn idempotent_method_classification() {
assert!(is_idempotent_method(&Method::GET));
assert!(is_idempotent_method(&Method::HEAD));
assert!(is_idempotent_method(&Method::PUT));
assert!(is_idempotent_method(&Method::DELETE));
assert!(is_idempotent_method(&Method::OPTIONS));
assert!(is_idempotent_method(&Method::TRACE));
assert!(!is_idempotent_method(&Method::POST));
assert!(!is_idempotent_method(&Method::PATCH));
}
#[test]
fn retryable_status_classification() {
assert!(is_retryable_status(502));
assert!(is_retryable_status(503));
assert!(is_retryable_status(504));
assert!(!is_retryable_status(200));
assert!(!is_retryable_status(400));
assert!(!is_retryable_status(404));
assert!(!is_retryable_status(500));
assert!(!is_retryable_status(429));
}
#[test]
fn retry_after_header_parsing() {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::HeaderName::from_static("retry-after"),
HeaderValue::from_static("5"),
);
assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(5)));
let empty = HeaderMap::new();
assert_eq!(parse_retry_after(&empty), None);
}
#[test]
fn sensitive_header_detection() {
assert!(is_sensitive_header("authorization"));
assert!(is_sensitive_header("Authorization"));
assert!(is_sensitive_header("AUTHORIZATION"));
assert!(is_sensitive_header("cookie"));
assert!(is_sensitive_header("set-cookie"));
assert!(!is_sensitive_header("content-type"));
assert!(!is_sensitive_header("x-api-key"));
}
#[tokio::test]
async fn mock_registry_captures_calls() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::POST),
path: "/charges".to_owned(),
alias: Some("stripe".to_owned()),
status: 200,
body: Some(serde_json::json!({"id": "ch_123"})),
call_count: call_count.clone(),
});
let client = Client::new().with_mock(registry).named("stripe");
let resp = client
.post("https://api.stripe.com/charges")
.json(&serde_json::json!({"amount": 1000}))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
let body: serde_json::Value = resp.json().unwrap();
assert_eq!(body["id"], "ch_123");
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn mock_handle_expect_called_passes() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::GET),
path: "/users/1".to_owned(),
alias: None,
status: 200,
body: Some(serde_json::json!({"name": "Alice"})),
call_count: call_count.clone(),
});
let handle = MockHandle {
alias: "test".to_owned(),
method: "GET".to_owned(),
path: "/users/1".to_owned(),
call_count: call_count.clone(),
};
let client = Client::new().with_mock(registry);
client
.get("https://api.example.com/users/1")
.send()
.await
.unwrap();
handle.expect_called(1);
assert_eq!(handle.call_count(), 1);
}
#[tokio::test]
async fn mock_matches_by_path_suffix() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::POST),
path: "/v1/charges".to_owned(),
alias: None,
status: 201,
body: Some(serde_json::json!({"created": true})),
call_count: call_count.clone(),
});
let client = Client::new().with_mock(registry);
let resp = client
.post("https://api.stripe.com/v1/charges")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 201);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn no_mock_error_when_unmatched() {
let registry = Arc::new(MockRegistry::new());
let client = Client::new().with_mock(registry);
let result = client.post("https://api.example.com/unknown").send().await;
assert!(matches!(result, Err(ClientError::NoMock(_, _))));
}
#[tokio::test]
async fn mock_setup_builder_registers_entry() {
let registry = Arc::new(MockRegistry::new());
let builder = MockSetupBuilder {
registry: registry.clone(),
alias: "myservice".to_owned(),
method: None,
path: None,
};
let handle = builder
.post("/api/resource")
.respond_with(201, serde_json::json!({"ok": true}));
let client = Client::new().with_mock(registry).named("myservice");
client
.post("https://myservice.example.com/api/resource")
.send()
.await
.unwrap();
handle.expect_called(1);
}
#[test]
fn client_from_config() {
let config = HttpClientConfig {
timeout_secs: 10,
max_retries: 1,
max_retry_after_secs: 10,
base_urls: std::collections::HashMap::new(),
};
let client = Client::from_config(&config);
assert_eq!(client.retry_policy.max_retries, 1);
}
#[test]
fn named_client_preserves_mock_registry() {
let registry = Arc::new(MockRegistry::new());
let client = Client::new().with_mock(registry);
let named = client.named("stripe");
assert!(named.mock.is_some());
assert_eq!(named.alias.as_deref(), Some("stripe"));
}
#[test]
fn base_url_prepended_to_relative_path() {
let client = Client::new();
let client = client.with_base_url("https://api.stripe.com");
let builder = client.post("/v1/charges");
assert_eq!(builder.url, "https://api.stripe.com/v1/charges");
}
#[test]
fn absolute_url_bypasses_base_url() {
let client = Client::new().with_base_url("https://ignored.example.com");
let builder = client.get("https://actual.example.com/path");
assert_eq!(builder.url, "https://actual.example.com/path");
}
#[test]
fn retry_override_per_request() {
let client = Client::new(); let builder = client.get("https://example.com").retries(0);
assert_eq!(builder.retry_policy.max_retries, 0);
let no_retry = client.get("https://example.com").no_retry();
assert_eq!(no_retry.retry_policy.max_retries, 0);
}
#[tokio::test]
async fn client_extracts_from_state() {
use axum::extract::FromRequestParts;
let state = crate::AppState::for_test();
let mut parts = axum::http::Request::new(axum::body::Body::empty())
.into_parts()
.0;
let client = Client::from_request_parts(&mut parts, &state)
.await
.unwrap();
assert!(client.mock.is_none());
assert!(client.alias.is_none());
}
#[test]
fn mock_registry_ext_round_trips_through_state() {
let registry = Arc::new(MockRegistry::new());
let ext = HttpMockRegistryExt(registry);
let state = crate::AppState::for_test();
state.insert_extension(ext);
let retrieved = state.extension::<HttpMockRegistryExt>();
assert!(retrieved.is_some());
}
#[test]
fn named_client_resolves_base_url_from_config() {
let mut base_urls = std::collections::HashMap::new();
base_urls.insert("stripe".to_owned(), "https://api.stripe.com".to_owned());
let config = HttpClientConfig {
timeout_secs: 30,
max_retries: 3,
max_retry_after_secs: 10,
base_urls,
};
let client = Client::from_config(&config);
let stripe = client.named("stripe");
assert_eq!(stripe.base_url.as_deref(), Some("https://api.stripe.com"));
assert_eq!(stripe.alias.as_deref(), Some("stripe"));
let other = client.named("sendgrid");
assert!(other.base_url.is_none());
}
#[tokio::test]
async fn client_extracts_from_autumn_config_in_state() {
use axum::extract::FromRequestParts;
let mut cfg = crate::config::AutumnConfig::default();
cfg.http.client.max_retries = 7;
let state = crate::AppState::for_test();
state.insert_extension(cfg);
let mut parts = axum::http::Request::new(axum::body::Body::empty())
.into_parts()
.0;
let client = Client::from_request_parts(&mut parts, &state)
.await
.unwrap();
assert_eq!(client.retry_policy.max_retries, 7);
}
#[tokio::test]
async fn respond_with_status_produces_empty_body() {
let registry = Arc::new(MockRegistry::new());
let builder = MockSetupBuilder {
registry: registry.clone(),
alias: "svc".to_owned(),
method: None,
path: None,
};
let _handle = builder.delete("/items/1").respond_with_status(204);
let client = Client::new().with_mock(registry).named("svc");
let resp = client
.delete("https://svc.example.com/items/1")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 204);
assert_eq!(
resp.bytes(),
bytes::Bytes::new(),
"body must be empty, not \"null\""
);
}
#[test]
fn retry_after_http_date_parsing() {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::HeaderName::from_static("retry-after"),
HeaderValue::from_static("Tue, 01 Jan 2030 00:00:00 GMT"),
);
let duration = parse_retry_after(&headers);
assert!(duration.is_some(), "should parse HTTP-date Retry-After");
assert!(
duration.unwrap().as_secs() > 0,
"future date should yield positive delay"
);
}
#[tokio::test]
async fn non_idempotent_post_no_retry() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::POST),
path: "/endpoint".to_owned(),
alias: None,
status: 503,
body: None,
call_count: call_count.clone(),
});
let client = Client::new().with_mock(registry);
let resp = client
.post("https://example.com/endpoint")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 503);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn mock_strips_query_from_url_before_matching() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::GET),
path: "/v1/charges".to_owned(),
alias: None,
status: 200,
body: Some(serde_json::json!({"ok": true})),
call_count: call_count.clone(),
});
let client = Client::new().with_mock(registry);
let resp = client
.get("https://api.stripe.com/v1/charges?expand[]=balance_transaction")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn mock_suffix_match_with_leading_slash_path() {
let registry = Arc::new(MockRegistry::new());
let call_count = Arc::new(AtomicUsize::new(0));
registry.register(MockEntry {
method: Some(Method::POST),
path: "/charges".to_owned(),
alias: None,
status: 201,
body: Some(serde_json::json!({"matched": true})),
call_count: call_count.clone(),
});
let client = Client::new().with_mock(registry);
let resp = client
.post("https://api.stripe.com/v1/charges")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 201);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn retries_clears_idempotent_only_flag() {
let client = Client::new();
let builder = client.post("https://example.com").retries(2);
assert_eq!(builder.retry_policy.max_retries, 2);
assert!(
!builder.retry_policy.retry_idempotent_only,
"explicit retries() call must allow non-idempotent methods to retry"
);
}
#[test]
fn log_request_completes_with_sensitive_headers() {
let url = reqwest::Url::parse("https://api.example.com/v1/resource?q=1").unwrap();
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"),
);
headers.insert(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer sk_test_xxx"),
);
log_request("POST", &url, 201, Duration::from_millis(12), &headers);
}
#[test]
fn inject_trace_context_passthrough_without_telemetry() {
let inner = reqwest::Client::new();
let builder = inner.get("https://example.com");
let _b = inject_trace_context(builder);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn real_get_request_covers_network_path() {
use axum::{Router, routing::get};
let _lock = crate::circuit_breaker::TEST_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
crate::circuit_breaker::global_registry().clear();
let app = Router::new().route("/ping", get(|| async { "pong" }));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let client = Client::new();
let resp = client
.get(format!("http://127.0.0.1:{}/ping", addr.port()))
.header("x-request-id", "test-35")
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
assert!(resp.url().is_some());
assert_eq!(resp.text(), "pong");
crate::circuit_breaker::global_registry().clear();
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn real_post_with_json_body_covers_body_path() {
use axum::{Json, Router, routing::post};
use serde_json::Value;
let _lock = crate::circuit_breaker::TEST_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
crate::circuit_breaker::global_registry().clear();
let app = Router::new().route(
"/echo",
post(|Json(body): Json<Value>| async move { Json(body) }),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let client = Client::new();
let resp = client
.post(format!("http://127.0.0.1:{}/echo", addr.port()))
.json(&serde_json::json!({"hello": "world"}))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
let body: Value = resp.json().unwrap();
assert_eq!(body["hello"], "world");
crate::circuit_breaker::global_registry().clear();
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn real_get_retries_on_503_then_succeeds() {
use axum::{Router, routing::get};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering as SeqOrdering};
let _lock = crate::circuit_breaker::TEST_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
crate::circuit_breaker::global_registry().clear();
let hit = Arc::new(AtomicU32::new(0));
let hit2 = hit.clone();
let app = Router::new().route(
"/flaky",
get(move || {
let c = hit2.clone();
async move {
if c.fetch_add(1, SeqOrdering::SeqCst) == 0 {
axum::http::StatusCode::SERVICE_UNAVAILABLE
} else {
axum::http::StatusCode::OK
}
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let resp = Client::new()
.get(format!("http://127.0.0.1:{}/flaky", addr.port()))
.retries(1)
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
assert_eq!(hit.load(SeqOrdering::SeqCst), 2);
crate::circuit_breaker::global_registry().clear();
}
#[test]
fn text_body_sets_body() {
let client = Client::new();
let builder = client.post("https://example.com").text_body("hello");
assert_eq!(builder.body, Some(bytes::Bytes::from_static(b"hello")));
}
#[test]
fn client_error_display() {
let err = ClientError::NoMock("GET".to_owned(), "/path".to_owned());
assert!(err.to_string().contains("GET"));
assert!(err.to_string().contains("/path"));
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_http_client_circuit_breaker_integration() {
use axum::{Router, routing::get};
use std::sync::atomic::{AtomicU32, Ordering as SeqOrdering};
let _lock = crate::circuit_breaker::TEST_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
crate::circuit_breaker::global_registry().clear();
let hit = Arc::new(AtomicU32::new(0));
let hit2 = hit.clone();
let app = Router::new().route(
"/flaky",
get(move || {
let c = hit2.clone();
async move {
c.fetch_add(1, SeqOrdering::SeqCst);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let mut rc = crate::config::ResilienceConfig::default();
rc.circuit_breaker.defaults.failure_ratio_threshold = Some(0.5);
rc.circuit_breaker.defaults.minimum_sample_count = Some(3);
rc.circuit_breaker.defaults.open_duration_secs = Some(10);
let client = Client::new();
let client = Client {
resilience_config: Some(Arc::new(rc)),
..client
};
let url = format!("http://127.0.0.1:{}/flaky", addr.port());
for _ in 0..3 {
let res = client.get(&url).send().await;
let res = res.unwrap();
assert_eq!(res.status().as_u16(), 500);
}
let res = client.get(&url).send().await;
assert!(matches!(res, Err(ClientError::CircuitBreakerOpen)));
assert_eq!(hit.load(SeqOrdering::SeqCst), 3);
crate::circuit_breaker::global_registry().clear();
}
}