1use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
11use std::net::{IpAddr, SocketAddr};
12
13#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub struct CircuitState {
17 pub state: CircuitStatus,
18 pub failure_count: u32,
19 pub success_count: u32,
21 pub opened_at: Option<Instant>,
22 pub current_backoff: Duration,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum CircuitStatus {
29 Closed,
31 Open,
33 HalfOpen,
35}
36
37impl Default for CircuitState {
38 fn default() -> Self {
39 Self {
40 state: CircuitStatus::Closed,
41 failure_count: 0,
42 success_count: 0,
43 opened_at: None,
44 current_backoff: Duration::from_secs(30),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct CircuitBreakerConfig {
52 pub failure_threshold: u32,
53 pub success_threshold: u32,
54 pub base_timeout: Duration,
55 pub max_backoff: Duration,
56 pub backoff_multiplier: f64,
57 pub enabled: bool,
58 pub allow_private: bool,
62}
63
64impl Default for CircuitBreakerConfig {
65 fn default() -> Self {
66 Self {
67 failure_threshold: 5,
68 success_threshold: 2,
69 base_timeout: Duration::from_secs(30),
70 max_backoff: Duration::from_secs(600),
71 backoff_multiplier: 1.5,
72 enabled: true,
73 allow_private: false,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct CircuitBreakerOpen {
81 pub host: String,
82 pub retry_after: Duration,
83}
84
85impl std::fmt::Display for CircuitBreakerOpen {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(
88 f,
89 "Circuit breaker open for {}: retry after {:?}",
90 self.host, self.retry_after
91 )
92 }
93}
94
95impl std::error::Error for CircuitBreakerOpen {}
96
97pub fn is_private_ip(ip: IpAddr) -> bool {
101 match ip {
102 IpAddr::V4(v4) => is_private_v4(v4),
103 IpAddr::V6(v6) => {
104 if let Some(v4) = v6.to_ipv4_mapped() {
106 return is_private_v4(v4);
107 }
108 let seg0 = v6.segments().first().copied().unwrap_or(0);
109 v6.is_loopback()
110 || v6.is_unspecified()
111 || (seg0 & 0xffc0) == 0xfe80 || (seg0 & 0xfe00) == 0xfc00 }
114 }
115}
116
117fn is_private_v4(v4: std::net::Ipv4Addr) -> bool {
118 v4.is_loopback()
119 || v4.is_private()
120 || v4.is_link_local()
121 || v4.is_broadcast()
122 || v4.is_unspecified()
123 || v4.is_documentation()
124}
125
126struct SsrfSafeResolver;
130
131impl reqwest::dns::Resolve for SsrfSafeResolver {
132 fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
133 Box::pin(async move {
134 let host = name.as_str().to_string();
135 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:0"))
136 .await?
137 .collect();
138 let safe: Vec<SocketAddr> = addrs
139 .into_iter()
140 .filter(|addr| !is_private_ip(addr.ip()))
141 .collect();
142 if safe.is_empty() {
143 return Err(format!("DNS resolution for {host} returned only private IPs").into());
144 }
145 let addrs: reqwest::dns::Addrs = Box::new(safe.into_iter());
146 Ok(addrs)
147 })
148 }
149}
150
151pub fn build_ssrf_safe_client() -> reqwest::Client {
158 reqwest::Client::builder()
159 .dns_resolver(std::sync::Arc::new(SsrfSafeResolver))
160 .build()
161 .unwrap_or_else(|e| {
162 tracing::error!("Failed to build SSRF-safe HTTP client: {e}");
163 unreachable!("TLS backend required for HTTP client")
167 })
168}
169
170#[derive(Clone)]
174pub struct CircuitBreakerClient {
175 inner: reqwest::Client,
176 states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
177 config: CircuitBreakerConfig,
178}
179
180impl CircuitBreakerClient {
181 pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
183 Self {
184 inner: client,
185 states: std::sync::Arc::new(RwLock::new(HashMap::new())),
186 config,
187 }
188 }
189
190 pub fn with_defaults(client: reqwest::Client) -> Self {
192 Self::new(client, CircuitBreakerConfig::default())
193 }
194
195 pub fn with_ssrf_protection() -> Self {
198 Self::new(build_ssrf_safe_client(), CircuitBreakerConfig::default())
199 }
200
201 pub fn inner(&self) -> &reqwest::Client {
203 &self.inner
204 }
205
206 pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
208 HttpClient::new(self.clone(), timeout)
209 }
210
211 fn extract_host(url: &reqwest::Url) -> String {
213 format!(
214 "{}://{}{}",
215 url.scheme(),
216 url.host_str().unwrap_or("unknown"),
217 url.port().map(|p| format!(":{}", p)).unwrap_or_default()
218 )
219 }
220
221 fn url_targets_private_ip(url: &reqwest::Url) -> bool {
225 let Some(host) = url.host_str() else {
226 return false;
227 };
228 let trimmed = host.trim_start_matches('[').trim_end_matches(']');
229 let Ok(ip) = trimmed.parse::<IpAddr>() else {
230 return false;
231 };
232 is_private_ip(ip)
233 }
234
235 pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
237 if !self.config.enabled {
238 return Ok(());
239 }
240
241 let states = self.states.read().unwrap_or_else(|e| {
242 tracing::error!("Circuit breaker lock was poisoned, recovering");
243 e.into_inner()
244 });
245 let state = match states.get(host) {
246 Some(s) => s,
247 None => return Ok(()), };
249
250 match state.state {
251 CircuitStatus::Closed => Ok(()),
252 CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
254 let opened_at = state.opened_at.unwrap_or_else(Instant::now);
255 let elapsed = opened_at.elapsed();
256
257 if elapsed >= state.current_backoff {
258 Ok(())
260 } else {
261 Err(CircuitBreakerOpen {
262 host: host.to_string(),
263 retry_after: state.current_backoff - elapsed,
264 })
265 }
266 }
267 }
268 }
269
270 pub fn record_success(&self, host: &str) {
272 if !self.config.enabled {
273 return;
274 }
275
276 let mut states = self.states.write().unwrap_or_else(|e| {
277 tracing::error!("Circuit breaker lock was poisoned, recovering");
278 e.into_inner()
279 });
280 let state = states.entry(host.to_string()).or_default();
281
282 match state.state {
283 CircuitStatus::Closed => {
284 state.failure_count = 0;
286 }
287 CircuitStatus::HalfOpen => {
288 state.success_count += 1;
289 if state.success_count >= self.config.success_threshold {
290 tracing::info!(host = %host, "Circuit breaker closed, service recovered");
292 state.state = CircuitStatus::Closed;
293 state.failure_count = 0;
294 state.success_count = 0;
295 state.opened_at = None;
296 state.current_backoff = self.config.base_timeout;
297 }
298 }
299 CircuitStatus::Open => {
300 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
302 state.state = CircuitStatus::HalfOpen;
303 state.success_count = 1;
304 }
305 }
306 }
307
308 pub fn record_failure(&self, host: &str) {
310 if !self.config.enabled {
311 return;
312 }
313
314 let mut states = self.states.write().unwrap_or_else(|e| {
315 tracing::error!("Circuit breaker lock was poisoned, recovering");
316 e.into_inner()
317 });
318 let state = states.entry(host.to_string()).or_default();
319
320 match state.state {
321 CircuitStatus::Closed => {
322 state.failure_count += 1;
323 if state.failure_count >= self.config.failure_threshold {
324 tracing::warn!(
326 host = %host,
327 failures = state.failure_count,
328 "Circuit breaker opened, service unhealthy"
329 );
330 state.state = CircuitStatus::Open;
331 state.opened_at = Some(Instant::now());
332 }
333 }
334 CircuitStatus::HalfOpen => {
335 let new_backoff = Duration::from_secs_f64(
337 (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
338 .min(self.config.max_backoff.as_secs_f64()),
339 );
340 tracing::warn!(
341 host = %host,
342 backoff_secs = new_backoff.as_secs(),
343 "Circuit breaker reopened, service still unhealthy"
344 );
345 state.state = CircuitStatus::Open;
346 state.opened_at = Some(Instant::now());
347 state.current_backoff = new_backoff;
348 state.success_count = 0;
349 }
350 CircuitStatus::Open => {
351 state.opened_at = Some(Instant::now());
353 }
354 }
355 }
356
357 pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
359 if !self.config.allow_private && Self::url_targets_private_ip(request.url()) {
362 return Err(CircuitBreakerError::PrivateHostBlocked(
363 request.url().host_str().unwrap_or("unknown").to_string(),
364 ));
365 }
366
367 let host = Self::extract_host(request.url());
368
369 self.should_allow(&host)
371 .map_err(CircuitBreakerError::CircuitOpen)?;
372
373 {
375 let mut states = self.states.write().unwrap_or_else(|e| {
376 tracing::error!("Circuit breaker lock was poisoned, recovering");
377 e.into_inner()
378 });
379 if let Some(state) = states.get_mut(&host)
380 && state.state == CircuitStatus::Open
381 && let Some(opened_at) = state.opened_at
382 && opened_at.elapsed() >= state.current_backoff
383 {
384 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
385 state.state = CircuitStatus::HalfOpen;
386 state.success_count = 0;
387 }
388 }
389
390 match self.inner.execute(request).await {
392 Ok(response) => {
393 if response.status().is_server_error() {
395 self.record_failure(&host);
396 } else {
397 self.record_success(&host);
398 }
399 Ok(response)
400 }
401 Err(e) => {
402 self.record_failure(&host);
403 Err(CircuitBreakerError::Request(e))
404 }
405 }
406 }
407
408 pub fn get_state(&self, host: &str) -> Option<CircuitState> {
410 self.states
411 .read()
412 .unwrap_or_else(|e| {
413 tracing::error!("Circuit breaker lock was poisoned, recovering");
414 e.into_inner()
415 })
416 .get(host)
417 .cloned()
418 }
419
420 pub fn reset(&self, host: &str) {
422 self.states
423 .write()
424 .unwrap_or_else(|e| {
425 tracing::error!("Circuit breaker lock was poisoned, recovering");
426 e.into_inner()
427 })
428 .remove(host);
429 }
430
431 pub fn reset_all(&self) {
433 self.states
434 .write()
435 .unwrap_or_else(|e| {
436 tracing::error!("Circuit breaker lock was poisoned, recovering");
437 e.into_inner()
438 })
439 .clear();
440 }
441}
442
443#[derive(Debug)]
445pub enum CircuitBreakerError {
446 CircuitOpen(CircuitBreakerOpen),
448 PrivateHostBlocked(String),
451 Request(reqwest::Error),
453}
454
455impl std::fmt::Display for CircuitBreakerError {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 match self {
458 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
459 CircuitBreakerError::PrivateHostBlocked(_host) => write!(
460 f,
461 "Outbound request blocked: target resolves to a private IP"
462 ),
463 CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
464 }
465 }
466}
467
468impl std::error::Error for CircuitBreakerError {
469 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
470 match self {
471 CircuitBreakerError::CircuitOpen(e) => Some(e),
472 CircuitBreakerError::PrivateHostBlocked(_) => None,
473 CircuitBreakerError::Request(e) => Some(e),
474 }
475 }
476}
477
478impl From<reqwest::Error> for CircuitBreakerError {
479 fn from(e: reqwest::Error) -> Self {
480 CircuitBreakerError::Request(e)
481 }
482}
483
484#[derive(Clone)]
487pub struct HttpClient {
488 circuit_breaker: CircuitBreakerClient,
489 default_timeout: Option<Duration>,
490}
491
492impl HttpClient {
493 pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
495 Self {
496 circuit_breaker,
497 default_timeout,
498 }
499 }
500
501 pub fn inner(&self) -> &reqwest::Client {
503 self.circuit_breaker.inner()
504 }
505
506 pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
508 &self.circuit_breaker
509 }
510
511 pub fn default_timeout(&self) -> Option<Duration> {
513 self.default_timeout
514 }
515
516 pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
518 HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
519 }
520
521 pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
522 self.request(Method::GET, url)
523 }
524
525 pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
526 self.request(Method::POST, url)
527 }
528
529 pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
530 self.request(Method::PUT, url)
531 }
532
533 pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
534 self.request(Method::PATCH, url)
535 }
536
537 pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
538 self.request(Method::DELETE, url)
539 }
540
541 pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
542 self.request(Method::HEAD, url)
543 }
544
545 pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
547 self.apply_default_timeout(&mut request);
548 self.circuit_breaker
549 .execute(request)
550 .await
551 .map_err(Into::into)
552 }
553
554 fn apply_default_timeout(&self, request: &mut Request) {
555 if request.timeout().is_none()
556 && let Some(timeout) = self.default_timeout
557 {
558 *request.timeout_mut() = Some(timeout);
559 }
560 }
561}
562
563pub struct HttpRequestBuilder {
565 client: HttpClient,
566 request: RequestBuilder,
567}
568
569impl HttpRequestBuilder {
570 fn new(client: HttpClient, request: RequestBuilder) -> Self {
571 Self { client, request }
572 }
573
574 pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
575 Self {
576 request: self.request.header(key.as_ref(), value.as_ref()),
577 ..self
578 }
579 }
580
581 pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
582 Self {
583 request: self.request.headers(headers),
584 ..self
585 }
586 }
587
588 pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
589 Self {
590 request: self.request.bearer_auth(token),
591 ..self
592 }
593 }
594
595 pub fn basic_auth(
596 self,
597 username: impl std::fmt::Display,
598 password: Option<impl std::fmt::Display>,
599 ) -> Self {
600 Self {
601 request: self.request.basic_auth(username, password),
602 ..self
603 }
604 }
605
606 pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
607 Self {
608 request: self.request.body(body),
609 ..self
610 }
611 }
612
613 pub fn json(self, json: &impl serde::Serialize) -> Self {
614 Self {
615 request: self.request.json(json),
616 ..self
617 }
618 }
619
620 pub fn form(self, form: &impl serde::Serialize) -> Self {
621 Self {
622 request: self.request.form(form),
623 ..self
624 }
625 }
626
627 pub fn query(self, query: &impl serde::Serialize) -> Self {
628 Self {
629 request: self.request.query(query),
630 ..self
631 }
632 }
633
634 pub fn timeout(self, timeout: Duration) -> Self {
635 Self {
636 request: self.request.timeout(timeout),
637 ..self
638 }
639 }
640
641 pub fn version(self, version: reqwest::Version) -> Self {
642 Self {
643 request: self.request.version(version),
644 ..self
645 }
646 }
647
648 pub fn try_clone(&self) -> Option<Self> {
649 self.request.try_clone().map(|request| Self {
650 client: self.client.clone(),
651 request,
652 })
653 }
654
655 pub fn build(self) -> crate::Result<Request> {
656 self.request
657 .build()
658 .map_err(|e| crate::ForgeError::internal_with("Failed to build HTTP request", e))
659 }
660
661 pub async fn send(self) -> crate::Result<Response> {
662 let client = self.client.clone();
663 let request = self.build()?;
664 client.execute(request).await
665 }
666}
667
668#[cfg(test)]
669#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_circuit_breaker_defaults() {
675 let config = CircuitBreakerConfig::default();
676 assert_eq!(config.failure_threshold, 5);
677 assert_eq!(config.success_threshold, 2);
678 assert!(config.enabled);
679 }
680
681 #[test]
682 fn test_circuit_state_transitions() {
683 let client = reqwest::Client::new();
684 let breaker = CircuitBreakerClient::with_defaults(client);
685 let host = "https://api.example.com";
686
687 assert!(breaker.should_allow(host).is_ok());
689
690 for _ in 0..5 {
692 breaker.record_failure(host);
693 }
694
695 let state = breaker.get_state(host).unwrap();
697 assert_eq!(state.state, CircuitStatus::Open);
698
699 assert!(breaker.should_allow(host).is_err());
701
702 breaker.reset(host);
704 assert!(breaker.should_allow(host).is_ok());
705 }
706
707 #[test]
708 fn test_extract_host() {
709 let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
710 assert_eq!(
711 CircuitBreakerClient::extract_host(&url),
712 "https://api.example.com:8080"
713 );
714
715 let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
716 assert_eq!(
717 CircuitBreakerClient::extract_host(&url2),
718 "http://localhost"
719 );
720 }
721
722 #[test]
723 fn test_http_client_applies_default_timeout_when_missing() {
724 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
725 let client = breaker.with_timeout(Some(Duration::from_secs(5)));
726 let mut request = reqwest::Request::new(
727 Method::GET,
728 reqwest::Url::parse("https://example.com").unwrap(),
729 );
730
731 client.apply_default_timeout(&mut request);
732
733 assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
734 }
735
736 #[test]
737 fn test_http_client_preserves_explicit_timeout() {
738 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
739 let client = breaker.with_timeout(Some(Duration::from_secs(5)));
740 let mut request = reqwest::Request::new(
741 Method::GET,
742 reqwest::Url::parse("https://example.com").unwrap(),
743 );
744 *request.timeout_mut() = Some(Duration::from_secs(1));
745
746 client.apply_default_timeout(&mut request);
747
748 assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
749 }
750
751 fn url(s: &str) -> reqwest::Url {
752 reqwest::Url::parse(s).expect("valid url")
753 }
754
755 fn breaker_with(config: CircuitBreakerConfig) -> CircuitBreakerClient {
756 CircuitBreakerClient::new(reqwest::Client::new(), config)
757 }
758
759 #[test]
762 fn private_ip_guard_blocks_ipv4_loopback_and_metadata_endpoint() {
763 assert!(CircuitBreakerClient::url_targets_private_ip(&url(
766 "http://127.0.0.1/"
767 )));
768 assert!(CircuitBreakerClient::url_targets_private_ip(&url(
769 "http://169.254.169.254/latest/meta-data/"
770 )));
771 }
772
773 #[test]
774 fn private_ip_guard_blocks_all_ipv4_classes_doc_says_it_blocks() {
775 let blocked = [
778 "http://10.0.0.1/", "http://172.16.0.1/", "http://192.168.1.1/", "http://169.254.1.1/", "http://0.0.0.0/", "http://255.255.255.255/", "http://192.0.2.1/", "http://198.51.100.1/", "http://203.0.113.1/", ];
788 for u in blocked {
789 assert!(
790 CircuitBreakerClient::url_targets_private_ip(&url(u)),
791 "should block {u}"
792 );
793 }
794 }
795
796 #[test]
797 fn private_ip_guard_blocks_ipv6_loopback_link_local_and_ula() {
798 let blocked = [
801 "http://[::1]/", "http://[::]/", "http://[fe80::1]/", "http://[febf::1]/", "http://[fc00::1]/", "http://[fd00::1]/", ];
808 for u in blocked {
809 assert!(
810 CircuitBreakerClient::url_targets_private_ip(&url(u)),
811 "should block {u}"
812 );
813 }
814 }
815
816 #[test]
817 fn private_ip_guard_allows_public_ips_and_dns_hostnames() {
818 let allowed = [
821 "http://1.1.1.1/",
822 "http://8.8.8.8/",
823 "http://[2001:4860:4860::8888]/", "http://api.example.com/",
827 "http://localhost/",
828 ];
829 for u in allowed {
830 assert!(
831 !CircuitBreakerClient::url_targets_private_ip(&url(u)),
832 "should NOT block {u}"
833 );
834 }
835 }
836
837 #[tokio::test]
838 async fn execute_returns_private_host_blocked_error_when_guard_trips() {
839 let breaker = breaker_with(CircuitBreakerConfig {
841 allow_private: false,
842 ..Default::default()
843 });
844 let req = reqwest::Request::new(Method::GET, url("http://127.0.0.1/"));
845 let err = breaker.execute(req).await.expect_err("loopback blocked");
846 match err {
847 CircuitBreakerError::PrivateHostBlocked(host) => {
848 assert_eq!(host, "127.0.0.1");
849 }
850 other => panic!("expected PrivateHostBlocked, got {other:?}"),
851 }
852 }
853
854 #[test]
857 fn is_private_ip_blocks_all_private_ranges() {
858 let blocked: Vec<IpAddr> = vec![
859 "127.0.0.1".parse().unwrap(),
860 "10.0.0.1".parse().unwrap(),
861 "172.16.0.1".parse().unwrap(),
862 "192.168.1.1".parse().unwrap(),
863 "169.254.169.254".parse().unwrap(),
864 "0.0.0.0".parse().unwrap(),
865 "255.255.255.255".parse().unwrap(),
866 "::1".parse().unwrap(),
867 "::".parse().unwrap(),
868 "fe80::1".parse().unwrap(),
869 "fc00::1".parse().unwrap(),
870 "fd00::1".parse().unwrap(),
871 ];
872 for ip in blocked {
873 assert!(is_private_ip(ip), "should block {ip}");
874 }
875 }
876
877 #[test]
878 fn is_private_ip_blocks_ipv4_mapped_ipv6() {
879 let mapped: Vec<IpAddr> = vec![
880 "::ffff:127.0.0.1".parse().unwrap(),
881 "::ffff:10.0.0.1".parse().unwrap(),
882 "::ffff:169.254.169.254".parse().unwrap(),
883 "::ffff:192.168.1.1".parse().unwrap(),
884 ];
885 for ip in mapped {
886 assert!(is_private_ip(ip), "should block IPv4-mapped {ip}");
887 }
888 }
889
890 #[test]
891 fn is_private_ip_allows_public_addresses() {
892 let allowed: Vec<IpAddr> = vec![
893 "1.1.1.1".parse().unwrap(),
894 "8.8.8.8".parse().unwrap(),
895 "93.184.216.34".parse().unwrap(),
896 "2001:4860:4860::8888".parse().unwrap(),
897 ];
898 for ip in allowed {
899 assert!(!is_private_ip(ip), "should allow {ip}");
900 }
901 }
902
903 #[test]
906 fn success_in_half_open_below_threshold_keeps_circuit_half_open() {
907 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
910 let host = "https://flaky.example.com";
911
912 for _ in 0..5 {
913 breaker.record_failure(host);
914 }
915 assert_eq!(breaker.get_state(host).unwrap().state, CircuitStatus::Open);
916
917 breaker.record_success(host);
919 let s = breaker.get_state(host).unwrap();
920 assert_eq!(s.state, CircuitStatus::HalfOpen);
921 assert_eq!(s.success_count, 1);
922
923 }
926
927 #[test]
928 fn second_success_in_half_open_closes_circuit_and_resets_counters() {
929 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
930 let host = "https://flaky2.example.com";
931
932 for _ in 0..5 {
933 breaker.record_failure(host);
934 }
935 breaker.record_success(host); breaker.record_success(host); let s = breaker.get_state(host).unwrap();
939 assert_eq!(s.state, CircuitStatus::Closed);
940 assert_eq!(s.failure_count, 0);
941 assert_eq!(s.success_count, 0);
942 assert!(
943 s.opened_at.is_none(),
944 "opened_at must clear on full recovery"
945 );
946 }
947
948 #[test]
949 fn failure_in_half_open_reopens_with_exponential_backoff() {
950 let breaker = breaker_with(CircuitBreakerConfig {
951 failure_threshold: 3,
952 success_threshold: 2,
953 base_timeout: Duration::from_secs(10),
954 max_backoff: Duration::from_secs(600),
955 backoff_multiplier: 2.0,
956 enabled: true,
957 allow_private: true,
958 });
959 let host = "https://still-down.example.com";
960
961 for _ in 0..3 {
963 breaker.record_failure(host);
964 }
965 let initial_backoff = breaker.get_state(host).unwrap().current_backoff;
968 breaker.record_success(host); breaker.record_failure(host); let s = breaker.get_state(host).unwrap();
972 assert_eq!(s.state, CircuitStatus::Open);
973 assert_eq!(s.success_count, 0, "success_count must reset on reopen");
974 let expected = Duration::from_secs_f64(initial_backoff.as_secs_f64() * 2.0);
975 assert_eq!(
976 s.current_backoff, expected,
977 "backoff must scale by multiplier on reopen"
978 );
979 }
980
981 #[test]
982 fn failure_in_half_open_caps_backoff_at_max() {
983 let breaker = breaker_with(CircuitBreakerConfig {
986 failure_threshold: 1,
987 success_threshold: 1,
988 base_timeout: Duration::from_secs(30),
989 max_backoff: Duration::from_secs(45),
990 backoff_multiplier: 10.0,
991 enabled: true,
992 allow_private: true,
993 });
994 let host = "https://capped.example.com";
995
996 breaker.record_failure(host); breaker.record_success(host); breaker.record_failure(host); let s = breaker.get_state(host).unwrap();
1001 assert_eq!(s.current_backoff, Duration::from_secs(45));
1002 }
1003
1004 #[test]
1005 fn record_failure_while_open_just_refreshes_opened_at_without_changing_state() {
1006 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1009 let host = "https://still-open.example.com";
1010 for _ in 0..5 {
1011 breaker.record_failure(host);
1012 }
1013 let before = breaker.get_state(host).unwrap();
1014 assert_eq!(before.state, CircuitStatus::Open);
1015
1016 std::thread::sleep(Duration::from_millis(2));
1018 breaker.record_failure(host);
1019
1020 let after = breaker.get_state(host).unwrap();
1021 assert_eq!(after.state, CircuitStatus::Open);
1022 assert!(
1023 after.opened_at.unwrap() >= before.opened_at.unwrap(),
1024 "opened_at should be refreshed or unchanged, not regressed"
1025 );
1026 assert_eq!(after.current_backoff, before.current_backoff);
1027 }
1028
1029 #[test]
1032 fn disabled_breaker_never_blocks_and_never_records_state() {
1033 let breaker = breaker_with(CircuitBreakerConfig {
1034 enabled: false,
1035 ..Default::default()
1036 });
1037 let host = "https://noop.example.com";
1038
1039 for _ in 0..100 {
1040 breaker.record_failure(host);
1041 }
1042 assert!(breaker.get_state(host).is_none());
1045 assert!(breaker.should_allow(host).is_ok());
1046
1047 breaker.record_success(host);
1049 assert!(breaker.get_state(host).is_none());
1050 }
1051
1052 #[test]
1055 fn reset_all_clears_state_for_every_host() {
1056 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1057 breaker.record_failure("https://a.example.com");
1058 breaker.record_failure("https://b.example.com");
1059 breaker.record_failure("https://c.example.com");
1060 assert!(breaker.get_state("https://a.example.com").is_some());
1061
1062 breaker.reset_all();
1063 assert!(breaker.get_state("https://a.example.com").is_none());
1064 assert!(breaker.get_state("https://b.example.com").is_none());
1065 assert!(breaker.get_state("https://c.example.com").is_none());
1066 }
1067
1068 #[test]
1071 fn should_allow_returns_ok_when_open_timeout_has_elapsed() {
1072 let breaker = breaker_with(CircuitBreakerConfig {
1075 failure_threshold: 1,
1076 base_timeout: Duration::from_millis(10),
1077 ..Default::default()
1078 });
1079 let host = "https://ready.example.com";
1080 breaker.record_failure(host);
1081 {
1083 let mut states = breaker.states.write().unwrap();
1084 let s = states.get_mut(host).unwrap();
1085 s.opened_at = Some(Instant::now() - Duration::from_secs(3600));
1086 s.current_backoff = Duration::from_millis(10);
1087 }
1088 assert!(
1089 breaker.should_allow(host).is_ok(),
1090 "expired open circuit must allow the next request through"
1091 );
1092 }
1093
1094 #[test]
1095 fn should_allow_reports_retry_after_when_open_and_within_backoff() {
1096 let breaker = breaker_with(CircuitBreakerConfig {
1097 failure_threshold: 1,
1098 base_timeout: Duration::from_secs(60),
1099 ..Default::default()
1100 });
1101 let host = "https://hot.example.com";
1102 breaker.record_failure(host);
1103
1104 let err = breaker.should_allow(host).expect_err("still open");
1105 assert_eq!(err.host, host);
1106 let backoff = breaker.get_state(host).unwrap().current_backoff;
1108 assert!(err.retry_after > Duration::ZERO);
1109 assert!(err.retry_after <= backoff);
1110 }
1111
1112 #[test]
1115 fn extract_host_handles_default_ports_and_no_port() {
1116 assert_eq!(
1119 CircuitBreakerClient::extract_host(&url("https://api.example.com/")),
1120 "https://api.example.com"
1121 );
1122 assert_eq!(
1123 CircuitBreakerClient::extract_host(&url("http://api.example.com/")),
1124 "http://api.example.com"
1125 );
1126 assert_eq!(
1128 CircuitBreakerClient::extract_host(&url("https://api.example.com:8443/")),
1129 "https://api.example.com:8443"
1130 );
1131 }
1132
1133 #[test]
1134 fn extract_host_includes_ipv6_brackets() {
1135 let h = CircuitBreakerClient::extract_host(&url("http://[::1]:8080/"));
1138 assert!(h.contains("::1"), "got: {h}");
1139 assert!(h.ends_with(":8080"), "got: {h}");
1140 }
1141
1142 #[test]
1145 fn circuit_breaker_open_display_mentions_host_and_retry_after() {
1146 let err = CircuitBreakerOpen {
1147 host: "https://flaky.example.com".to_string(),
1148 retry_after: Duration::from_secs(42),
1149 };
1150 let s = err.to_string();
1151 assert!(s.contains("https://flaky.example.com"));
1152 assert!(s.contains("42"));
1153 }
1154
1155 #[test]
1156 fn private_host_blocked_display_redacts_host() {
1157 let err = CircuitBreakerError::PrivateHostBlocked("127.0.0.1".to_string());
1158 let s = err.to_string();
1159 assert!(
1160 !s.contains("127.0.0.1"),
1161 "host must not leak through Display"
1162 );
1163 assert!(s.contains("private IP"));
1164 }
1165
1166 #[test]
1167 fn circuit_breaker_error_source_chains_through_inner_variants() {
1168 let inner = CircuitBreakerOpen {
1170 host: "h".to_string(),
1171 retry_after: Duration::from_secs(1),
1172 };
1173 let err = CircuitBreakerError::CircuitOpen(inner);
1174 assert!(
1175 std::error::Error::source(&err).is_some(),
1176 "CircuitOpen should expose its wrapped error as source"
1177 );
1178
1179 let err = CircuitBreakerError::PrivateHostBlocked("h".to_string());
1181 assert!(
1182 std::error::Error::source(&err).is_none(),
1183 "PrivateHostBlocked has no source"
1184 );
1185 }
1186
1187 #[test]
1190 fn http_client_apply_default_timeout_is_noop_when_default_unset() {
1191 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1192 let client = breaker.with_timeout(None);
1193 let mut req = reqwest::Request::new(Method::GET, url("https://example.com/"));
1194 client.apply_default_timeout(&mut req);
1195 assert_eq!(req.timeout(), None);
1196 }
1197
1198 #[test]
1199 fn http_client_accessors_expose_underlying_pieces() {
1200 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1201 let client = breaker.with_timeout(Some(Duration::from_secs(7)));
1202 assert_eq!(client.default_timeout(), Some(Duration::from_secs(7)));
1203 let _ = client.inner();
1206 let _ = client.circuit_breaker();
1207 }
1208}