use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CircuitState {
pub state: CircuitStatus,
pub failure_count: u32,
pub success_count: u32,
pub opened_at: Option<Instant>,
pub current_backoff: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CircuitStatus {
Closed,
Open,
HalfOpen,
}
impl Default for CircuitState {
fn default() -> Self {
Self {
state: CircuitStatus::Closed,
failure_count: 0,
success_count: 0,
opened_at: None,
current_backoff: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub base_timeout: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub enabled: bool,
pub allow_private: bool,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
base_timeout: Duration::from_secs(30),
max_backoff: Duration::from_secs(600),
backoff_multiplier: 1.5,
enabled: true,
allow_private: false,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerOpen {
pub host: String,
pub retry_after: Duration,
}
impl std::fmt::Display for CircuitBreakerOpen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Circuit breaker open for {}: retry after {:?}",
self.host, self.retry_after
)
}
}
impl std::error::Error for CircuitBreakerOpen {}
pub fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_private_v4(v4),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
return is_private_v4(v4);
}
let seg0 = v6.segments().first().copied().unwrap_or(0);
v6.is_loopback()
|| v6.is_unspecified()
|| (seg0 & 0xffc0) == 0xfe80 || (seg0 & 0xfe00) == 0xfc00 }
}
}
fn is_private_v4(v4: std::net::Ipv4Addr) -> bool {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| v4.is_documentation()
}
struct SsrfSafeResolver;
impl reqwest::dns::Resolve for SsrfSafeResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
Box::pin(async move {
let host = name.as_str().to_string();
let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:0"))
.await?
.collect();
let safe: Vec<SocketAddr> = addrs
.into_iter()
.filter(|addr| !is_private_ip(addr.ip()))
.collect();
if safe.is_empty() {
return Err(format!("DNS resolution for {host} returned only private IPs").into());
}
let addrs: reqwest::dns::Addrs = Box::new(safe.into_iter());
Ok(addrs)
})
}
}
pub fn build_ssrf_safe_client() -> reqwest::Client {
reqwest::Client::builder()
.dns_resolver(std::sync::Arc::new(SsrfSafeResolver))
.build()
.unwrap_or_else(|e| {
tracing::error!("Failed to build SSRF-safe HTTP client: {e}");
unreachable!("TLS backend required for HTTP client")
})
}
#[derive(Clone)]
pub struct CircuitBreakerClient {
inner: reqwest::Client,
states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
config: CircuitBreakerConfig,
}
impl CircuitBreakerClient {
pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
Self {
inner: client,
states: std::sync::Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn with_defaults(client: reqwest::Client) -> Self {
Self::new(client, CircuitBreakerConfig::default())
}
pub fn with_ssrf_protection() -> Self {
Self::new(build_ssrf_safe_client(), CircuitBreakerConfig::default())
}
pub fn inner(&self) -> &reqwest::Client {
&self.inner
}
pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
HttpClient::new(self.clone(), timeout)
}
fn extract_host(url: &reqwest::Url) -> String {
format!(
"{}://{}{}",
url.scheme(),
url.host_str().unwrap_or("unknown"),
url.port().map(|p| format!(":{}", p)).unwrap_or_default()
)
}
fn url_targets_private_ip(url: &reqwest::Url) -> bool {
let Some(host) = url.host_str() else {
return false;
};
let trimmed = host.trim_start_matches('[').trim_end_matches(']');
let Ok(ip) = trimmed.parse::<IpAddr>() else {
return false;
};
is_private_ip(ip)
}
pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
if !self.config.enabled {
return Ok(());
}
let states = self.states.read().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = match states.get(host) {
Some(s) => s,
None => return Ok(()), };
match state.state {
CircuitStatus::Closed => Ok(()),
CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
let opened_at = state.opened_at.unwrap_or_else(Instant::now);
let elapsed = opened_at.elapsed();
if elapsed >= state.current_backoff {
Ok(())
} else {
Err(CircuitBreakerOpen {
host: host.to_string(),
retry_after: state.current_backoff - elapsed,
})
}
}
}
}
pub fn record_success(&self, host: &str) {
if !self.config.enabled {
return;
}
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = states.entry(host.to_string()).or_default();
match state.state {
CircuitStatus::Closed => {
state.failure_count = 0;
}
CircuitStatus::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
tracing::info!(host = %host, "Circuit breaker closed, service recovered");
state.state = CircuitStatus::Closed;
state.failure_count = 0;
state.success_count = 0;
state.opened_at = None;
state.current_backoff = self.config.base_timeout;
}
}
CircuitStatus::Open => {
tracing::info!(host = %host, "Circuit breaker half-open, testing service");
state.state = CircuitStatus::HalfOpen;
state.success_count = 1;
}
}
}
pub fn record_failure(&self, host: &str) {
if !self.config.enabled {
return;
}
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = states.entry(host.to_string()).or_default();
match state.state {
CircuitStatus::Closed => {
state.failure_count += 1;
if state.failure_count >= self.config.failure_threshold {
tracing::warn!(
host = %host,
failures = state.failure_count,
"Circuit breaker opened, service unhealthy"
);
state.state = CircuitStatus::Open;
state.opened_at = Some(Instant::now());
}
}
CircuitStatus::HalfOpen => {
let new_backoff = Duration::from_secs_f64(
(state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
.min(self.config.max_backoff.as_secs_f64()),
);
tracing::warn!(
host = %host,
backoff_secs = new_backoff.as_secs(),
"Circuit breaker reopened, service still unhealthy"
);
state.state = CircuitStatus::Open;
state.opened_at = Some(Instant::now());
state.current_backoff = new_backoff;
state.success_count = 0;
}
CircuitStatus::Open => {
state.opened_at = Some(Instant::now());
}
}
}
pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
if !self.config.allow_private && Self::url_targets_private_ip(request.url()) {
return Err(CircuitBreakerError::PrivateHostBlocked(
request.url().host_str().unwrap_or("unknown").to_string(),
));
}
let host = Self::extract_host(request.url());
self.should_allow(&host)
.map_err(CircuitBreakerError::CircuitOpen)?;
{
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
if let Some(state) = states.get_mut(&host)
&& state.state == CircuitStatus::Open
&& let Some(opened_at) = state.opened_at
&& opened_at.elapsed() >= state.current_backoff
{
tracing::info!(host = %host, "Circuit breaker half-open, testing service");
state.state = CircuitStatus::HalfOpen;
state.success_count = 0;
}
}
match self.inner.execute(request).await {
Ok(response) => {
if response.status().is_server_error() {
self.record_failure(&host);
} else {
self.record_success(&host);
}
Ok(response)
}
Err(e) => {
self.record_failure(&host);
Err(CircuitBreakerError::Request(e))
}
}
}
pub fn get_state(&self, host: &str) -> Option<CircuitState> {
self.states
.read()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.get(host)
.cloned()
}
pub fn reset(&self, host: &str) {
self.states
.write()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.remove(host);
}
pub fn reset_all(&self) {
self.states
.write()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.clear();
}
}
#[derive(Debug)]
pub enum CircuitBreakerError {
CircuitOpen(CircuitBreakerOpen),
PrivateHostBlocked(String),
Request(reqwest::Error),
}
impl std::fmt::Display for CircuitBreakerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
CircuitBreakerError::PrivateHostBlocked(_host) => write!(
f,
"Outbound request blocked: target resolves to a private IP"
),
CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
}
}
}
impl std::error::Error for CircuitBreakerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CircuitBreakerError::CircuitOpen(e) => Some(e),
CircuitBreakerError::PrivateHostBlocked(_) => None,
CircuitBreakerError::Request(e) => Some(e),
}
}
}
impl From<reqwest::Error> for CircuitBreakerError {
fn from(e: reqwest::Error) -> Self {
CircuitBreakerError::Request(e)
}
}
#[derive(Clone)]
pub struct HttpClient {
circuit_breaker: CircuitBreakerClient,
default_timeout: Option<Duration>,
}
impl HttpClient {
pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
Self {
circuit_breaker,
default_timeout,
}
}
pub fn inner(&self) -> &reqwest::Client {
self.circuit_breaker.inner()
}
pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
&self.circuit_breaker
}
pub fn default_timeout(&self) -> Option<Duration> {
self.default_timeout
}
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
}
pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::GET, url)
}
pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::POST, url)
}
pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::PUT, url)
}
pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::PATCH, url)
}
pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::DELETE, url)
}
pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::HEAD, url)
}
pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
self.apply_default_timeout(&mut request);
self.circuit_breaker
.execute(request)
.await
.map_err(Into::into)
}
fn apply_default_timeout(&self, request: &mut Request) {
if request.timeout().is_none()
&& let Some(timeout) = self.default_timeout
{
*request.timeout_mut() = Some(timeout);
}
}
}
pub struct HttpRequestBuilder {
client: HttpClient,
request: RequestBuilder,
}
impl HttpRequestBuilder {
fn new(client: HttpClient, request: RequestBuilder) -> Self {
Self { client, request }
}
pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
Self {
request: self.request.header(key.as_ref(), value.as_ref()),
..self
}
}
pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
Self {
request: self.request.headers(headers),
..self
}
}
pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
Self {
request: self.request.bearer_auth(token),
..self
}
}
pub fn basic_auth(
self,
username: impl std::fmt::Display,
password: Option<impl std::fmt::Display>,
) -> Self {
Self {
request: self.request.basic_auth(username, password),
..self
}
}
pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
Self {
request: self.request.body(body),
..self
}
}
pub fn json(self, json: &impl serde::Serialize) -> Self {
Self {
request: self.request.json(json),
..self
}
}
pub fn form(self, form: &impl serde::Serialize) -> Self {
Self {
request: self.request.form(form),
..self
}
}
pub fn query(self, query: &impl serde::Serialize) -> Self {
Self {
request: self.request.query(query),
..self
}
}
pub fn timeout(self, timeout: Duration) -> Self {
Self {
request: self.request.timeout(timeout),
..self
}
}
pub fn version(self, version: reqwest::Version) -> Self {
Self {
request: self.request.version(version),
..self
}
}
pub fn try_clone(&self) -> Option<Self> {
self.request.try_clone().map(|request| Self {
client: self.client.clone(),
request,
})
}
pub fn build(self) -> crate::Result<Request> {
self.request
.build()
.map_err(|e| crate::ForgeError::internal_with("Failed to build HTTP request", e))
}
pub async fn send(self) -> crate::Result<Response> {
let client = self.client.clone();
let request = self.build()?;
client.execute(request).await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_defaults() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 2);
assert!(config.enabled);
}
#[test]
fn test_circuit_state_transitions() {
let client = reqwest::Client::new();
let breaker = CircuitBreakerClient::with_defaults(client);
let host = "https://api.example.com";
assert!(breaker.should_allow(host).is_ok());
for _ in 0..5 {
breaker.record_failure(host);
}
let state = breaker.get_state(host).unwrap();
assert_eq!(state.state, CircuitStatus::Open);
assert!(breaker.should_allow(host).is_err());
breaker.reset(host);
assert!(breaker.should_allow(host).is_ok());
}
#[test]
fn test_extract_host() {
let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
assert_eq!(
CircuitBreakerClient::extract_host(&url),
"https://api.example.com:8080"
);
let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
assert_eq!(
CircuitBreakerClient::extract_host(&url2),
"http://localhost"
);
}
#[test]
fn test_http_client_applies_default_timeout_when_missing() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(Some(Duration::from_secs(5)));
let mut request = reqwest::Request::new(
Method::GET,
reqwest::Url::parse("https://example.com").unwrap(),
);
client.apply_default_timeout(&mut request);
assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
}
#[test]
fn test_http_client_preserves_explicit_timeout() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(Some(Duration::from_secs(5)));
let mut request = reqwest::Request::new(
Method::GET,
reqwest::Url::parse("https://example.com").unwrap(),
);
*request.timeout_mut() = Some(Duration::from_secs(1));
client.apply_default_timeout(&mut request);
assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
}
fn url(s: &str) -> reqwest::Url {
reqwest::Url::parse(s).expect("valid url")
}
fn breaker_with(config: CircuitBreakerConfig) -> CircuitBreakerClient {
CircuitBreakerClient::new(reqwest::Client::new(), config)
}
#[test]
fn private_ip_guard_blocks_ipv4_loopback_and_metadata_endpoint() {
assert!(CircuitBreakerClient::url_targets_private_ip(&url(
"http://127.0.0.1/"
)));
assert!(CircuitBreakerClient::url_targets_private_ip(&url(
"http://169.254.169.254/latest/meta-data/"
)));
}
#[test]
fn private_ip_guard_blocks_all_ipv4_classes_doc_says_it_blocks() {
let blocked = [
"http://10.0.0.1/", "http://172.16.0.1/", "http://192.168.1.1/", "http://169.254.1.1/", "http://0.0.0.0/", "http://255.255.255.255/", "http://192.0.2.1/", "http://198.51.100.1/", "http://203.0.113.1/", ];
for u in blocked {
assert!(
CircuitBreakerClient::url_targets_private_ip(&url(u)),
"should block {u}"
);
}
}
#[test]
fn private_ip_guard_blocks_ipv6_loopback_link_local_and_ula() {
let blocked = [
"http://[::1]/", "http://[::]/", "http://[fe80::1]/", "http://[febf::1]/", "http://[fc00::1]/", "http://[fd00::1]/", ];
for u in blocked {
assert!(
CircuitBreakerClient::url_targets_private_ip(&url(u)),
"should block {u}"
);
}
}
#[test]
fn private_ip_guard_allows_public_ips_and_dns_hostnames() {
let allowed = [
"http://1.1.1.1/",
"http://8.8.8.8/",
"http://[2001:4860:4860::8888]/", "http://api.example.com/",
"http://localhost/",
];
for u in allowed {
assert!(
!CircuitBreakerClient::url_targets_private_ip(&url(u)),
"should NOT block {u}"
);
}
}
#[tokio::test]
async fn execute_returns_private_host_blocked_error_when_guard_trips() {
let breaker = breaker_with(CircuitBreakerConfig {
allow_private: false,
..Default::default()
});
let req = reqwest::Request::new(Method::GET, url("http://127.0.0.1/"));
let err = breaker.execute(req).await.expect_err("loopback blocked");
match err {
CircuitBreakerError::PrivateHostBlocked(host) => {
assert_eq!(host, "127.0.0.1");
}
other => panic!("expected PrivateHostBlocked, got {other:?}"),
}
}
#[test]
fn is_private_ip_blocks_all_private_ranges() {
let blocked: Vec<IpAddr> = vec![
"127.0.0.1".parse().unwrap(),
"10.0.0.1".parse().unwrap(),
"172.16.0.1".parse().unwrap(),
"192.168.1.1".parse().unwrap(),
"169.254.169.254".parse().unwrap(),
"0.0.0.0".parse().unwrap(),
"255.255.255.255".parse().unwrap(),
"::1".parse().unwrap(),
"::".parse().unwrap(),
"fe80::1".parse().unwrap(),
"fc00::1".parse().unwrap(),
"fd00::1".parse().unwrap(),
];
for ip in blocked {
assert!(is_private_ip(ip), "should block {ip}");
}
}
#[test]
fn is_private_ip_blocks_ipv4_mapped_ipv6() {
let mapped: Vec<IpAddr> = vec![
"::ffff:127.0.0.1".parse().unwrap(),
"::ffff:10.0.0.1".parse().unwrap(),
"::ffff:169.254.169.254".parse().unwrap(),
"::ffff:192.168.1.1".parse().unwrap(),
];
for ip in mapped {
assert!(is_private_ip(ip), "should block IPv4-mapped {ip}");
}
}
#[test]
fn is_private_ip_allows_public_addresses() {
let allowed: Vec<IpAddr> = vec![
"1.1.1.1".parse().unwrap(),
"8.8.8.8".parse().unwrap(),
"93.184.216.34".parse().unwrap(),
"2001:4860:4860::8888".parse().unwrap(),
];
for ip in allowed {
assert!(!is_private_ip(ip), "should allow {ip}");
}
}
#[test]
fn success_in_half_open_below_threshold_keeps_circuit_half_open() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let host = "https://flaky.example.com";
for _ in 0..5 {
breaker.record_failure(host);
}
assert_eq!(breaker.get_state(host).unwrap().state, CircuitStatus::Open);
breaker.record_success(host);
let s = breaker.get_state(host).unwrap();
assert_eq!(s.state, CircuitStatus::HalfOpen);
assert_eq!(s.success_count, 1);
}
#[test]
fn second_success_in_half_open_closes_circuit_and_resets_counters() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let host = "https://flaky2.example.com";
for _ in 0..5 {
breaker.record_failure(host);
}
breaker.record_success(host); breaker.record_success(host);
let s = breaker.get_state(host).unwrap();
assert_eq!(s.state, CircuitStatus::Closed);
assert_eq!(s.failure_count, 0);
assert_eq!(s.success_count, 0);
assert!(
s.opened_at.is_none(),
"opened_at must clear on full recovery"
);
}
#[test]
fn failure_in_half_open_reopens_with_exponential_backoff() {
let breaker = breaker_with(CircuitBreakerConfig {
failure_threshold: 3,
success_threshold: 2,
base_timeout: Duration::from_secs(10),
max_backoff: Duration::from_secs(600),
backoff_multiplier: 2.0,
enabled: true,
allow_private: true,
});
let host = "https://still-down.example.com";
for _ in 0..3 {
breaker.record_failure(host);
}
let initial_backoff = breaker.get_state(host).unwrap().current_backoff;
breaker.record_success(host); breaker.record_failure(host);
let s = breaker.get_state(host).unwrap();
assert_eq!(s.state, CircuitStatus::Open);
assert_eq!(s.success_count, 0, "success_count must reset on reopen");
let expected = Duration::from_secs_f64(initial_backoff.as_secs_f64() * 2.0);
assert_eq!(
s.current_backoff, expected,
"backoff must scale by multiplier on reopen"
);
}
#[test]
fn failure_in_half_open_caps_backoff_at_max() {
let breaker = breaker_with(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 1,
base_timeout: Duration::from_secs(30),
max_backoff: Duration::from_secs(45),
backoff_multiplier: 10.0,
enabled: true,
allow_private: true,
});
let host = "https://capped.example.com";
breaker.record_failure(host); breaker.record_success(host); breaker.record_failure(host);
let s = breaker.get_state(host).unwrap();
assert_eq!(s.current_backoff, Duration::from_secs(45));
}
#[test]
fn record_failure_while_open_just_refreshes_opened_at_without_changing_state() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let host = "https://still-open.example.com";
for _ in 0..5 {
breaker.record_failure(host);
}
let before = breaker.get_state(host).unwrap();
assert_eq!(before.state, CircuitStatus::Open);
std::thread::sleep(Duration::from_millis(2));
breaker.record_failure(host);
let after = breaker.get_state(host).unwrap();
assert_eq!(after.state, CircuitStatus::Open);
assert!(
after.opened_at.unwrap() >= before.opened_at.unwrap(),
"opened_at should be refreshed or unchanged, not regressed"
);
assert_eq!(after.current_backoff, before.current_backoff);
}
#[test]
fn disabled_breaker_never_blocks_and_never_records_state() {
let breaker = breaker_with(CircuitBreakerConfig {
enabled: false,
..Default::default()
});
let host = "https://noop.example.com";
for _ in 0..100 {
breaker.record_failure(host);
}
assert!(breaker.get_state(host).is_none());
assert!(breaker.should_allow(host).is_ok());
breaker.record_success(host);
assert!(breaker.get_state(host).is_none());
}
#[test]
fn reset_all_clears_state_for_every_host() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
breaker.record_failure("https://a.example.com");
breaker.record_failure("https://b.example.com");
breaker.record_failure("https://c.example.com");
assert!(breaker.get_state("https://a.example.com").is_some());
breaker.reset_all();
assert!(breaker.get_state("https://a.example.com").is_none());
assert!(breaker.get_state("https://b.example.com").is_none());
assert!(breaker.get_state("https://c.example.com").is_none());
}
#[test]
fn should_allow_returns_ok_when_open_timeout_has_elapsed() {
let breaker = breaker_with(CircuitBreakerConfig {
failure_threshold: 1,
base_timeout: Duration::from_millis(10),
..Default::default()
});
let host = "https://ready.example.com";
breaker.record_failure(host);
{
let mut states = breaker.states.write().unwrap();
let s = states.get_mut(host).unwrap();
s.opened_at = Some(Instant::now() - Duration::from_secs(3600));
s.current_backoff = Duration::from_millis(10);
}
assert!(
breaker.should_allow(host).is_ok(),
"expired open circuit must allow the next request through"
);
}
#[test]
fn should_allow_reports_retry_after_when_open_and_within_backoff() {
let breaker = breaker_with(CircuitBreakerConfig {
failure_threshold: 1,
base_timeout: Duration::from_secs(60),
..Default::default()
});
let host = "https://hot.example.com";
breaker.record_failure(host);
let err = breaker.should_allow(host).expect_err("still open");
assert_eq!(err.host, host);
let backoff = breaker.get_state(host).unwrap().current_backoff;
assert!(err.retry_after > Duration::ZERO);
assert!(err.retry_after <= backoff);
}
#[test]
fn extract_host_handles_default_ports_and_no_port() {
assert_eq!(
CircuitBreakerClient::extract_host(&url("https://api.example.com/")),
"https://api.example.com"
);
assert_eq!(
CircuitBreakerClient::extract_host(&url("http://api.example.com/")),
"http://api.example.com"
);
assert_eq!(
CircuitBreakerClient::extract_host(&url("https://api.example.com:8443/")),
"https://api.example.com:8443"
);
}
#[test]
fn extract_host_includes_ipv6_brackets() {
let h = CircuitBreakerClient::extract_host(&url("http://[::1]:8080/"));
assert!(h.contains("::1"), "got: {h}");
assert!(h.ends_with(":8080"), "got: {h}");
}
#[test]
fn circuit_breaker_open_display_mentions_host_and_retry_after() {
let err = CircuitBreakerOpen {
host: "https://flaky.example.com".to_string(),
retry_after: Duration::from_secs(42),
};
let s = err.to_string();
assert!(s.contains("https://flaky.example.com"));
assert!(s.contains("42"));
}
#[test]
fn private_host_blocked_display_redacts_host() {
let err = CircuitBreakerError::PrivateHostBlocked("127.0.0.1".to_string());
let s = err.to_string();
assert!(
!s.contains("127.0.0.1"),
"host must not leak through Display"
);
assert!(s.contains("private IP"));
}
#[test]
fn circuit_breaker_error_source_chains_through_inner_variants() {
let inner = CircuitBreakerOpen {
host: "h".to_string(),
retry_after: Duration::from_secs(1),
};
let err = CircuitBreakerError::CircuitOpen(inner);
assert!(
std::error::Error::source(&err).is_some(),
"CircuitOpen should expose its wrapped error as source"
);
let err = CircuitBreakerError::PrivateHostBlocked("h".to_string());
assert!(
std::error::Error::source(&err).is_none(),
"PrivateHostBlocked has no source"
);
}
#[test]
fn http_client_apply_default_timeout_is_noop_when_default_unset() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(None);
let mut req = reqwest::Request::new(Method::GET, url("https://example.com/"));
client.apply_default_timeout(&mut req);
assert_eq!(req.timeout(), None);
}
#[test]
fn http_client_accessors_expose_underlying_pieces() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(Some(Duration::from_secs(7)));
assert_eq!(client.default_timeout(), Some(Duration::from_secs(7)));
let _ = client.inner();
let _ = client.circuit_breaker();
}
}