Skip to main content

forge_core/http/
mod.rs

1//! HTTP client with circuit breaker pattern.
2//!
3//! Wraps `reqwest::Client` with automatic failure tracking per host.
4//! After repeated failures, requests fail fast to prevent cascade failures.
5
6use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
11use std::net::{IpAddr, SocketAddr};
12
13/// Circuit breaker state for a single host.
14#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub struct CircuitState {
17    pub state: CircuitStatus,
18    pub failure_count: u32,
19    /// Consecutive successes in half-open state.
20    pub success_count: u32,
21    pub opened_at: Option<Instant>,
22    pub current_backoff: Duration,
23}
24
25/// Circuit breaker status.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum CircuitStatus {
29    /// Normal operation, requests pass through.
30    Closed,
31    /// Circuit tripped, requests fail fast.
32    Open,
33    /// Testing if service recovered, limited requests allowed.
34    HalfOpen,
35}
36
37impl Default for CircuitState {
38    fn default() -> Self {
39        Self {
40            state: CircuitStatus::Closed,
41            failure_count: 0,
42            success_count: 0,
43            opened_at: None,
44            current_backoff: Duration::from_secs(30),
45        }
46    }
47}
48
49/// Configuration for the circuit breaker.
50#[derive(Debug, Clone)]
51pub struct CircuitBreakerConfig {
52    pub failure_threshold: u32,
53    pub success_threshold: u32,
54    pub base_timeout: Duration,
55    pub max_backoff: Duration,
56    pub backoff_multiplier: f64,
57    pub enabled: bool,
58    /// Defaults to `false` to block SSRF targets (`127.0.0.1`, `169.254.169.254`, etc.).
59    /// Set `true` in development or when intentionally calling private CIDRs.
60    /// Clients built via `build_ssrf_safe_client` also block hostnames resolving to private IPs.
61    pub allow_private: bool,
62}
63
64impl Default for CircuitBreakerConfig {
65    fn default() -> Self {
66        Self {
67            failure_threshold: 5,
68            success_threshold: 2,
69            base_timeout: Duration::from_secs(30),
70            max_backoff: Duration::from_secs(600),
71            backoff_multiplier: 1.5,
72            enabled: true,
73            allow_private: false,
74        }
75    }
76}
77
78/// Error returned when circuit breaker is open.
79#[derive(Debug, Clone)]
80pub struct CircuitBreakerOpen {
81    pub host: String,
82    pub retry_after: Duration,
83}
84
85impl std::fmt::Display for CircuitBreakerOpen {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        write!(
88            f,
89            "Circuit breaker open for {}: retry after {:?}",
90            self.host, self.retry_after
91        )
92    }
93}
94
95impl std::error::Error for CircuitBreakerOpen {}
96
97/// Returns true when the given IP is in a loopback, private (RFC 1918 / ULA),
98/// link-local, broadcast, unspecified, or documentation range.
99/// Also handles IPv4-mapped IPv6 addresses (`::ffff:x.x.x.x`).
100pub fn is_private_ip(ip: IpAddr) -> bool {
101    match ip {
102        IpAddr::V4(v4) => is_private_v4(v4),
103        IpAddr::V6(v6) => {
104            // IPv4-mapped addresses (::ffff:0:0/96) — check the inner v4
105            if let Some(v4) = v6.to_ipv4_mapped() {
106                return is_private_v4(v4);
107            }
108            let seg0 = v6.segments().first().copied().unwrap_or(0);
109            v6.is_loopback()
110                || v6.is_unspecified()
111                || (seg0 & 0xffc0) == 0xfe80 // link-local fe80::/10
112                || (seg0 & 0xfe00) == 0xfc00 // ULA fc00::/7
113        }
114    }
115}
116
117fn is_private_v4(v4: std::net::Ipv4Addr) -> bool {
118    v4.is_loopback()
119        || v4.is_private()
120        || v4.is_link_local()
121        || v4.is_broadcast()
122        || v4.is_unspecified()
123        || v4.is_documentation()
124}
125
126/// DNS resolver that filters out private/loopback/link-local addresses from
127/// resolution results. Prevents SSRF via DNS rebinding or hostnames that
128/// resolve to internal IPs (e.g. `metadata.internal` -> `169.254.169.254`).
129struct SsrfSafeResolver;
130
131impl reqwest::dns::Resolve for SsrfSafeResolver {
132    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
133        Box::pin(async move {
134            let host = name.as_str().to_string();
135            let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:0"))
136                .await?
137                .collect();
138            let safe: Vec<SocketAddr> = addrs
139                .into_iter()
140                .filter(|addr| !is_private_ip(addr.ip()))
141                .collect();
142            if safe.is_empty() {
143                return Err(format!("DNS resolution for {host} returned only private IPs").into());
144            }
145            let addrs: reqwest::dns::Addrs = Box::new(safe.into_iter());
146            Ok(addrs)
147        })
148    }
149}
150
151/// Build a reqwest client with SSRF-safe DNS resolution. Hostnames that
152/// resolve to private/loopback/link-local IPs are rejected at the DNS layer.
153///
154/// # Panics
155///
156/// Panics if the TLS backend is unavailable, which is a fatal startup error.
157pub fn build_ssrf_safe_client() -> reqwest::Client {
158    reqwest::Client::builder()
159        .dns_resolver(std::sync::Arc::new(SsrfSafeResolver))
160        .build()
161        .unwrap_or_else(|e| {
162            tracing::error!("Failed to build SSRF-safe HTTP client: {e}");
163            // This only fails when the TLS backend is missing. Proceeding with
164            // an unprotected client would silently remove DNS-level SSRF guards,
165            // so we propagate the failure as a panic at startup.
166            unreachable!("TLS backend required for HTTP client")
167        })
168}
169
170/// HTTP client with circuit breaker pattern.
171///
172/// Tracks failure rates per host and fails fast when a host is unhealthy.
173#[derive(Clone)]
174pub struct CircuitBreakerClient {
175    inner: reqwest::Client,
176    states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
177    config: CircuitBreakerConfig,
178}
179
180impl CircuitBreakerClient {
181    /// Create a new circuit breaker client wrapping the given reqwest client.
182    pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
183        Self {
184            inner: client,
185            states: std::sync::Arc::new(RwLock::new(HashMap::new())),
186            config,
187        }
188    }
189
190    /// Create with default configuration.
191    pub fn with_defaults(client: reqwest::Client) -> Self {
192        Self::new(client, CircuitBreakerConfig::default())
193    }
194
195    /// Create with default configuration and SSRF-safe DNS resolution.
196    /// Hostnames resolving to private/loopback/link-local IPs are blocked.
197    pub fn with_ssrf_protection() -> Self {
198        Self::new(build_ssrf_safe_client(), CircuitBreakerConfig::default())
199    }
200
201    /// Get the underlying reqwest client for building requests.
202    pub fn inner(&self) -> &reqwest::Client {
203        &self.inner
204    }
205
206    /// Create a request client view with an optional default request timeout.
207    pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
208        HttpClient::new(self.clone(), timeout)
209    }
210
211    /// Extract host from URL for tracking.
212    fn extract_host(url: &reqwest::Url) -> String {
213        format!(
214            "{}://{}{}",
215            url.scheme(),
216            url.host_str().unwrap_or("unknown"),
217            url.port().map(|p| format!(":{}", p)).unwrap_or_default()
218        )
219    }
220
221    /// Returns true when the URL's host is a literal IP in a private range.
222    /// Hostnames are not resolved here — DNS-level SSRF protection is handled
223    /// by `SsrfSafeResolver` when the client is built via `build_ssrf_safe_client`.
224    fn url_targets_private_ip(url: &reqwest::Url) -> bool {
225        let Some(host) = url.host_str() else {
226            return false;
227        };
228        let trimmed = host.trim_start_matches('[').trim_end_matches(']');
229        let Ok(ip) = trimmed.parse::<IpAddr>() else {
230            return false;
231        };
232        is_private_ip(ip)
233    }
234
235    /// Check if a request to the given host should be allowed.
236    pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
237        if !self.config.enabled {
238            return Ok(());
239        }
240
241        let states = self.states.read().unwrap_or_else(|e| {
242            tracing::error!("Circuit breaker lock was poisoned, recovering");
243            e.into_inner()
244        });
245        let state = match states.get(host) {
246            Some(s) => s,
247            None => return Ok(()), // No state = first request, allow
248        };
249
250        match state.state {
251            CircuitStatus::Closed => Ok(()),
252            CircuitStatus::HalfOpen => Ok(()), // Allow test requests
253            CircuitStatus::Open => {
254                let opened_at = state.opened_at.unwrap_or_else(Instant::now);
255                let elapsed = opened_at.elapsed();
256
257                if elapsed >= state.current_backoff {
258                    // Timeout expired, will transition to half-open
259                    Ok(())
260                } else {
261                    Err(CircuitBreakerOpen {
262                        host: host.to_string(),
263                        retry_after: state.current_backoff - elapsed,
264                    })
265                }
266            }
267        }
268    }
269
270    /// Record a successful request.
271    pub fn record_success(&self, host: &str) {
272        if !self.config.enabled {
273            return;
274        }
275
276        let mut states = self.states.write().unwrap_or_else(|e| {
277            tracing::error!("Circuit breaker lock was poisoned, recovering");
278            e.into_inner()
279        });
280        let state = states.entry(host.to_string()).or_default();
281
282        match state.state {
283            CircuitStatus::Closed => {
284                // Reset failure count on success
285                state.failure_count = 0;
286            }
287            CircuitStatus::HalfOpen => {
288                state.success_count += 1;
289                if state.success_count >= self.config.success_threshold {
290                    // Service recovered, close the circuit
291                    tracing::info!(host = %host, "Circuit breaker closed, service recovered");
292                    state.state = CircuitStatus::Closed;
293                    state.failure_count = 0;
294                    state.success_count = 0;
295                    state.opened_at = None;
296                    state.current_backoff = self.config.base_timeout;
297                }
298            }
299            CircuitStatus::Open => {
300                // Transition to half-open on first success after timeout
301                tracing::info!(host = %host, "Circuit breaker half-open, testing service");
302                state.state = CircuitStatus::HalfOpen;
303                state.success_count = 1;
304            }
305        }
306    }
307
308    /// Record a failed request.
309    pub fn record_failure(&self, host: &str) {
310        if !self.config.enabled {
311            return;
312        }
313
314        let mut states = self.states.write().unwrap_or_else(|e| {
315            tracing::error!("Circuit breaker lock was poisoned, recovering");
316            e.into_inner()
317        });
318        let state = states.entry(host.to_string()).or_default();
319
320        match state.state {
321            CircuitStatus::Closed => {
322                state.failure_count += 1;
323                if state.failure_count >= self.config.failure_threshold {
324                    // Trip the circuit
325                    tracing::warn!(
326                        host = %host,
327                        failures = state.failure_count,
328                        "Circuit breaker opened, service unhealthy"
329                    );
330                    state.state = CircuitStatus::Open;
331                    state.opened_at = Some(Instant::now());
332                }
333            }
334            CircuitStatus::HalfOpen => {
335                // Failed during test, reopen with increased backoff
336                let new_backoff = Duration::from_secs_f64(
337                    (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
338                        .min(self.config.max_backoff.as_secs_f64()),
339                );
340                tracing::warn!(
341                    host = %host,
342                    backoff_secs = new_backoff.as_secs(),
343                    "Circuit breaker reopened, service still unhealthy"
344                );
345                state.state = CircuitStatus::Open;
346                state.opened_at = Some(Instant::now());
347                state.current_backoff = new_backoff;
348                state.success_count = 0;
349            }
350            CircuitStatus::Open => {
351                // Already open, just update timestamp
352                state.opened_at = Some(Instant::now());
353            }
354        }
355    }
356
357    /// Execute a request with circuit breaker protection.
358    pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
359        // SSRF guard: refuse private/loopback/link-local IP literals unless
360        // the operator has opted in.
361        if !self.config.allow_private && Self::url_targets_private_ip(request.url()) {
362            return Err(CircuitBreakerError::PrivateHostBlocked(
363                request.url().host_str().unwrap_or("unknown").to_string(),
364            ));
365        }
366
367        let host = Self::extract_host(request.url());
368
369        // Check circuit state
370        self.should_allow(&host)
371            .map_err(CircuitBreakerError::CircuitOpen)?;
372
373        // If circuit is open but timeout expired, transition to half-open
374        {
375            let mut states = self.states.write().unwrap_or_else(|e| {
376                tracing::error!("Circuit breaker lock was poisoned, recovering");
377                e.into_inner()
378            });
379            if let Some(state) = states.get_mut(&host)
380                && state.state == CircuitStatus::Open
381                && let Some(opened_at) = state.opened_at
382                && opened_at.elapsed() >= state.current_backoff
383            {
384                tracing::info!(host = %host, "Circuit breaker half-open, testing service");
385                state.state = CircuitStatus::HalfOpen;
386                state.success_count = 0;
387            }
388        }
389
390        // Execute the request
391        match self.inner.execute(request).await {
392            Ok(response) => {
393                // Check if response indicates server error
394                if response.status().is_server_error() {
395                    self.record_failure(&host);
396                } else {
397                    self.record_success(&host);
398                }
399                Ok(response)
400            }
401            Err(e) => {
402                self.record_failure(&host);
403                Err(CircuitBreakerError::Request(e))
404            }
405        }
406    }
407
408    /// Get the current state for a host.
409    pub fn get_state(&self, host: &str) -> Option<CircuitState> {
410        self.states
411            .read()
412            .unwrap_or_else(|e| {
413                tracing::error!("Circuit breaker lock was poisoned, recovering");
414                e.into_inner()
415            })
416            .get(host)
417            .cloned()
418    }
419
420    /// Reset the circuit breaker state for a host.
421    pub fn reset(&self, host: &str) {
422        self.states
423            .write()
424            .unwrap_or_else(|e| {
425                tracing::error!("Circuit breaker lock was poisoned, recovering");
426                e.into_inner()
427            })
428            .remove(host);
429    }
430
431    /// Reset all circuit breaker states.
432    pub fn reset_all(&self) {
433        self.states
434            .write()
435            .unwrap_or_else(|e| {
436                tracing::error!("Circuit breaker lock was poisoned, recovering");
437                e.into_inner()
438            })
439            .clear();
440    }
441}
442
443/// Error type for circuit breaker operations.
444#[derive(Debug)]
445pub enum CircuitBreakerError {
446    /// The circuit is open, request was not attempted.
447    CircuitOpen(CircuitBreakerOpen),
448    /// Outbound request blocked because the URL host resolves to a
449    /// private/loopback/link-local IP and `allow_private` is false.
450    PrivateHostBlocked(String),
451    /// The request failed.
452    Request(reqwest::Error),
453}
454
455impl std::fmt::Display for CircuitBreakerError {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        match self {
458            CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
459            CircuitBreakerError::PrivateHostBlocked(_host) => write!(
460                f,
461                "Outbound request blocked: target resolves to a private IP"
462            ),
463            CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
464        }
465    }
466}
467
468impl std::error::Error for CircuitBreakerError {
469    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
470        match self {
471            CircuitBreakerError::CircuitOpen(e) => Some(e),
472            CircuitBreakerError::PrivateHostBlocked(_) => None,
473            CircuitBreakerError::Request(e) => Some(e),
474        }
475    }
476}
477
478impl From<reqwest::Error> for CircuitBreakerError {
479    fn from(e: reqwest::Error) -> Self {
480        CircuitBreakerError::Request(e)
481    }
482}
483
484/// HTTP client facade that routes requests through a circuit breaker and can
485/// apply a default timeout to requests that do not set one explicitly.
486#[derive(Clone)]
487pub struct HttpClient {
488    circuit_breaker: CircuitBreakerClient,
489    default_timeout: Option<Duration>,
490}
491
492impl HttpClient {
493    /// Create a new HTTP client facade.
494    pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
495        Self {
496            circuit_breaker,
497            default_timeout,
498        }
499    }
500
501    /// Get the underlying reqwest client.
502    pub fn inner(&self) -> &reqwest::Client {
503        self.circuit_breaker.inner()
504    }
505
506    /// Get the underlying circuit breaker client.
507    pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
508        &self.circuit_breaker
509    }
510
511    /// Get the default timeout applied to requests that do not override it.
512    pub fn default_timeout(&self) -> Option<Duration> {
513        self.default_timeout
514    }
515
516    /// Create a request builder.
517    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
518        HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
519    }
520
521    pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
522        self.request(Method::GET, url)
523    }
524
525    pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
526        self.request(Method::POST, url)
527    }
528
529    pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
530        self.request(Method::PUT, url)
531    }
532
533    pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
534        self.request(Method::PATCH, url)
535    }
536
537    pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
538        self.request(Method::DELETE, url)
539    }
540
541    pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
542        self.request(Method::HEAD, url)
543    }
544
545    /// Execute a pre-built request through the circuit breaker.
546    pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
547        self.apply_default_timeout(&mut request);
548        self.circuit_breaker
549            .execute(request)
550            .await
551            .map_err(Into::into)
552    }
553
554    fn apply_default_timeout(&self, request: &mut Request) {
555        if request.timeout().is_none()
556            && let Some(timeout) = self.default_timeout
557        {
558            *request.timeout_mut() = Some(timeout);
559        }
560    }
561}
562
563/// Request builder paired with a circuit-breaker-backed HTTP client.
564pub struct HttpRequestBuilder {
565    client: HttpClient,
566    request: RequestBuilder,
567}
568
569impl HttpRequestBuilder {
570    fn new(client: HttpClient, request: RequestBuilder) -> Self {
571        Self { client, request }
572    }
573
574    pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
575        Self {
576            request: self.request.header(key.as_ref(), value.as_ref()),
577            ..self
578        }
579    }
580
581    pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
582        Self {
583            request: self.request.headers(headers),
584            ..self
585        }
586    }
587
588    pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
589        Self {
590            request: self.request.bearer_auth(token),
591            ..self
592        }
593    }
594
595    pub fn basic_auth(
596        self,
597        username: impl std::fmt::Display,
598        password: Option<impl std::fmt::Display>,
599    ) -> Self {
600        Self {
601            request: self.request.basic_auth(username, password),
602            ..self
603        }
604    }
605
606    pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
607        Self {
608            request: self.request.body(body),
609            ..self
610        }
611    }
612
613    pub fn json(self, json: &impl serde::Serialize) -> Self {
614        Self {
615            request: self.request.json(json),
616            ..self
617        }
618    }
619
620    pub fn form(self, form: &impl serde::Serialize) -> Self {
621        Self {
622            request: self.request.form(form),
623            ..self
624        }
625    }
626
627    pub fn query(self, query: &impl serde::Serialize) -> Self {
628        Self {
629            request: self.request.query(query),
630            ..self
631        }
632    }
633
634    pub fn timeout(self, timeout: Duration) -> Self {
635        Self {
636            request: self.request.timeout(timeout),
637            ..self
638        }
639    }
640
641    pub fn version(self, version: reqwest::Version) -> Self {
642        Self {
643            request: self.request.version(version),
644            ..self
645        }
646    }
647
648    pub fn try_clone(&self) -> Option<Self> {
649        self.request.try_clone().map(|request| Self {
650            client: self.client.clone(),
651            request,
652        })
653    }
654
655    pub fn build(self) -> crate::Result<Request> {
656        self.request
657            .build()
658            .map_err(|e| crate::ForgeError::internal_with("Failed to build HTTP request", e))
659    }
660
661    pub async fn send(self) -> crate::Result<Response> {
662        let client = self.client.clone();
663        let request = self.build()?;
664        client.execute(request).await
665    }
666}
667
668#[cfg(test)]
669#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
670mod tests {
671    use super::*;
672
673    #[test]
674    fn test_circuit_breaker_defaults() {
675        let config = CircuitBreakerConfig::default();
676        assert_eq!(config.failure_threshold, 5);
677        assert_eq!(config.success_threshold, 2);
678        assert!(config.enabled);
679    }
680
681    #[test]
682    fn test_circuit_state_transitions() {
683        let client = reqwest::Client::new();
684        let breaker = CircuitBreakerClient::with_defaults(client);
685        let host = "https://api.example.com";
686
687        // Initial state should allow
688        assert!(breaker.should_allow(host).is_ok());
689
690        // Record failures to trip the circuit
691        for _ in 0..5 {
692            breaker.record_failure(host);
693        }
694
695        // Circuit should be open
696        let state = breaker.get_state(host).unwrap();
697        assert_eq!(state.state, CircuitStatus::Open);
698
699        // Should be blocked
700        assert!(breaker.should_allow(host).is_err());
701
702        // Reset and verify
703        breaker.reset(host);
704        assert!(breaker.should_allow(host).is_ok());
705    }
706
707    #[test]
708    fn test_extract_host() {
709        let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
710        assert_eq!(
711            CircuitBreakerClient::extract_host(&url),
712            "https://api.example.com:8080"
713        );
714
715        let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
716        assert_eq!(
717            CircuitBreakerClient::extract_host(&url2),
718            "http://localhost"
719        );
720    }
721
722    #[test]
723    fn test_http_client_applies_default_timeout_when_missing() {
724        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
725        let client = breaker.with_timeout(Some(Duration::from_secs(5)));
726        let mut request = reqwest::Request::new(
727            Method::GET,
728            reqwest::Url::parse("https://example.com").unwrap(),
729        );
730
731        client.apply_default_timeout(&mut request);
732
733        assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
734    }
735
736    #[test]
737    fn test_http_client_preserves_explicit_timeout() {
738        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
739        let client = breaker.with_timeout(Some(Duration::from_secs(5)));
740        let mut request = reqwest::Request::new(
741            Method::GET,
742            reqwest::Url::parse("https://example.com").unwrap(),
743        );
744        *request.timeout_mut() = Some(Duration::from_secs(1));
745
746        client.apply_default_timeout(&mut request);
747
748        assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
749    }
750
751    fn url(s: &str) -> reqwest::Url {
752        reqwest::Url::parse(s).expect("valid url")
753    }
754
755    fn breaker_with(config: CircuitBreakerConfig) -> CircuitBreakerClient {
756        CircuitBreakerClient::new(reqwest::Client::new(), config)
757    }
758
759    // ---- SSRF guard ----
760
761    #[test]
762    fn private_ip_guard_blocks_ipv4_loopback_and_metadata_endpoint() {
763        // These are the cases that most matter operationally — 127.0.0.1 and
764        // the AWS/GCE metadata IP 169.254.169.254 (link-local).
765        assert!(CircuitBreakerClient::url_targets_private_ip(&url(
766            "http://127.0.0.1/"
767        )));
768        assert!(CircuitBreakerClient::url_targets_private_ip(&url(
769            "http://169.254.169.254/latest/meta-data/"
770        )));
771    }
772
773    #[test]
774    fn private_ip_guard_blocks_all_ipv4_classes_doc_says_it_blocks() {
775        // Walk every RFC class the docstring promises to cover; if one slips
776        // the matrix, the SSRF guarantee is broken.
777        let blocked = [
778            "http://10.0.0.1/",        // private 10/8
779            "http://172.16.0.1/",      // private 172.16/12
780            "http://192.168.1.1/",     // private 192.168/16
781            "http://169.254.1.1/",     // link-local
782            "http://0.0.0.0/",         // unspecified
783            "http://255.255.255.255/", // broadcast
784            "http://192.0.2.1/",       // documentation TEST-NET-1
785            "http://198.51.100.1/",    // documentation TEST-NET-2
786            "http://203.0.113.1/",     // documentation TEST-NET-3
787        ];
788        for u in blocked {
789            assert!(
790                CircuitBreakerClient::url_targets_private_ip(&url(u)),
791                "should block {u}"
792            );
793        }
794    }
795
796    #[test]
797    fn private_ip_guard_blocks_ipv6_loopback_link_local_and_ula() {
798        // IPv6 mirror of the IPv4 cases. The bracket-trimming logic must
799        // strip the URL-encoded brackets before parsing.
800        let blocked = [
801            "http://[::1]/",     // loopback
802            "http://[::]/",      // unspecified
803            "http://[fe80::1]/", // link-local fe80::/10
804            "http://[febf::1]/", // link-local upper edge
805            "http://[fc00::1]/", // ULA fc00::/7
806            "http://[fd00::1]/", // ULA upper half
807        ];
808        for u in blocked {
809            assert!(
810                CircuitBreakerClient::url_targets_private_ip(&url(u)),
811                "should block {u}"
812            );
813        }
814    }
815
816    #[test]
817    fn private_ip_guard_allows_public_ips_and_dns_hostnames() {
818        // Public IP literals must pass — the guard is opt-out via
819        // allow_private, not a blanket "no IPs at all" filter.
820        let allowed = [
821            "http://1.1.1.1/",
822            "http://8.8.8.8/",
823            "http://[2001:4860:4860::8888]/", // Google public DNS v6
824            // Hostnames pass the URL-literal check; DNS-level blocking is
825            // handled by SsrfSafeResolver at connect time.
826            "http://api.example.com/",
827            "http://localhost/",
828        ];
829        for u in allowed {
830            assert!(
831                !CircuitBreakerClient::url_targets_private_ip(&url(u)),
832                "should NOT block {u}"
833            );
834        }
835    }
836
837    #[tokio::test]
838    async fn execute_returns_private_host_blocked_error_when_guard_trips() {
839        // Verify the guard is wired into execute() — not just the helper.
840        let breaker = breaker_with(CircuitBreakerConfig {
841            allow_private: false,
842            ..Default::default()
843        });
844        let req = reqwest::Request::new(Method::GET, url("http://127.0.0.1/"));
845        let err = breaker.execute(req).await.expect_err("loopback blocked");
846        match err {
847            CircuitBreakerError::PrivateHostBlocked(host) => {
848                assert_eq!(host, "127.0.0.1");
849            }
850            other => panic!("expected PrivateHostBlocked, got {other:?}"),
851        }
852    }
853
854    // ---- is_private_ip ----
855
856    #[test]
857    fn is_private_ip_blocks_all_private_ranges() {
858        let blocked: Vec<IpAddr> = vec![
859            "127.0.0.1".parse().unwrap(),
860            "10.0.0.1".parse().unwrap(),
861            "172.16.0.1".parse().unwrap(),
862            "192.168.1.1".parse().unwrap(),
863            "169.254.169.254".parse().unwrap(),
864            "0.0.0.0".parse().unwrap(),
865            "255.255.255.255".parse().unwrap(),
866            "::1".parse().unwrap(),
867            "::".parse().unwrap(),
868            "fe80::1".parse().unwrap(),
869            "fc00::1".parse().unwrap(),
870            "fd00::1".parse().unwrap(),
871        ];
872        for ip in blocked {
873            assert!(is_private_ip(ip), "should block {ip}");
874        }
875    }
876
877    #[test]
878    fn is_private_ip_blocks_ipv4_mapped_ipv6() {
879        let mapped: Vec<IpAddr> = vec![
880            "::ffff:127.0.0.1".parse().unwrap(),
881            "::ffff:10.0.0.1".parse().unwrap(),
882            "::ffff:169.254.169.254".parse().unwrap(),
883            "::ffff:192.168.1.1".parse().unwrap(),
884        ];
885        for ip in mapped {
886            assert!(is_private_ip(ip), "should block IPv4-mapped {ip}");
887        }
888    }
889
890    #[test]
891    fn is_private_ip_allows_public_addresses() {
892        let allowed: Vec<IpAddr> = vec![
893            "1.1.1.1".parse().unwrap(),
894            "8.8.8.8".parse().unwrap(),
895            "93.184.216.34".parse().unwrap(),
896            "2001:4860:4860::8888".parse().unwrap(),
897        ];
898        for ip in allowed {
899            assert!(!is_private_ip(ip), "should allow {ip}");
900        }
901    }
902
903    // ---- Half-open state machine ----
904
905    #[test]
906    fn success_in_half_open_below_threshold_keeps_circuit_half_open() {
907        // success_threshold defaults to 2 — so a single success after
908        // open->half-open must NOT yet close the circuit.
909        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
910        let host = "https://flaky.example.com";
911
912        for _ in 0..5 {
913            breaker.record_failure(host);
914        }
915        assert_eq!(breaker.get_state(host).unwrap().state, CircuitStatus::Open);
916
917        // Open->HalfOpen on first success after timeout.
918        breaker.record_success(host);
919        let s = breaker.get_state(host).unwrap();
920        assert_eq!(s.state, CircuitStatus::HalfOpen);
921        assert_eq!(s.success_count, 1);
922
923        // One more success would meet threshold — but we stop at one to
924        // pin "below threshold stays half-open."
925    }
926
927    #[test]
928    fn second_success_in_half_open_closes_circuit_and_resets_counters() {
929        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
930        let host = "https://flaky2.example.com";
931
932        for _ in 0..5 {
933            breaker.record_failure(host);
934        }
935        breaker.record_success(host); // -> HalfOpen
936        breaker.record_success(host); // -> Closed (threshold = 2)
937
938        let s = breaker.get_state(host).unwrap();
939        assert_eq!(s.state, CircuitStatus::Closed);
940        assert_eq!(s.failure_count, 0);
941        assert_eq!(s.success_count, 0);
942        assert!(
943            s.opened_at.is_none(),
944            "opened_at must clear on full recovery"
945        );
946    }
947
948    #[test]
949    fn failure_in_half_open_reopens_with_exponential_backoff() {
950        let breaker = breaker_with(CircuitBreakerConfig {
951            failure_threshold: 3,
952            success_threshold: 2,
953            base_timeout: Duration::from_secs(10),
954            max_backoff: Duration::from_secs(600),
955            backoff_multiplier: 2.0,
956            enabled: true,
957            allow_private: true,
958        });
959        let host = "https://still-down.example.com";
960
961        // Trip and partially recover.
962        for _ in 0..3 {
963            breaker.record_failure(host);
964        }
965        // current_backoff defaults to 30s from CircuitState::default(); the
966        // first failure-in-half-open multiplies that by backoff_multiplier.
967        let initial_backoff = breaker.get_state(host).unwrap().current_backoff;
968        breaker.record_success(host); // -> HalfOpen
969        breaker.record_failure(host); // -> Open with backoff * multiplier
970
971        let s = breaker.get_state(host).unwrap();
972        assert_eq!(s.state, CircuitStatus::Open);
973        assert_eq!(s.success_count, 0, "success_count must reset on reopen");
974        let expected = Duration::from_secs_f64(initial_backoff.as_secs_f64() * 2.0);
975        assert_eq!(
976            s.current_backoff, expected,
977            "backoff must scale by multiplier on reopen"
978        );
979    }
980
981    #[test]
982    fn failure_in_half_open_caps_backoff_at_max() {
983        // Pick a max well below what the multiplier would otherwise produce
984        // and verify saturation kicks in.
985        let breaker = breaker_with(CircuitBreakerConfig {
986            failure_threshold: 1,
987            success_threshold: 1,
988            base_timeout: Duration::from_secs(30),
989            max_backoff: Duration::from_secs(45),
990            backoff_multiplier: 10.0,
991            enabled: true,
992            allow_private: true,
993        });
994        let host = "https://capped.example.com";
995
996        breaker.record_failure(host); // -> Open
997        breaker.record_success(host); // -> HalfOpen
998        breaker.record_failure(host); // -> Open, attempted backoff = 30 * 10 = 300s, capped at 45
999
1000        let s = breaker.get_state(host).unwrap();
1001        assert_eq!(s.current_backoff, Duration::from_secs(45));
1002    }
1003
1004    #[test]
1005    fn record_failure_while_open_just_refreshes_opened_at_without_changing_state() {
1006        // No transition; just confirms the Open branch in record_failure
1007        // doesn't accidentally trip a re-counted "extra failure" path.
1008        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1009        let host = "https://still-open.example.com";
1010        for _ in 0..5 {
1011            breaker.record_failure(host);
1012        }
1013        let before = breaker.get_state(host).unwrap();
1014        assert_eq!(before.state, CircuitStatus::Open);
1015
1016        // Sleep a tick so opened_at can advance even at coarse clock res.
1017        std::thread::sleep(Duration::from_millis(2));
1018        breaker.record_failure(host);
1019
1020        let after = breaker.get_state(host).unwrap();
1021        assert_eq!(after.state, CircuitStatus::Open);
1022        assert!(
1023            after.opened_at.unwrap() >= before.opened_at.unwrap(),
1024            "opened_at should be refreshed or unchanged, not regressed"
1025        );
1026        assert_eq!(after.current_backoff, before.current_backoff);
1027    }
1028
1029    // ---- enabled = false short-circuits ----
1030
1031    #[test]
1032    fn disabled_breaker_never_blocks_and_never_records_state() {
1033        let breaker = breaker_with(CircuitBreakerConfig {
1034            enabled: false,
1035            ..Default::default()
1036        });
1037        let host = "https://noop.example.com";
1038
1039        for _ in 0..100 {
1040            breaker.record_failure(host);
1041        }
1042        // Nothing should have been stored — the early-return in
1043        // record_failure must skip even creating a state entry.
1044        assert!(breaker.get_state(host).is_none());
1045        assert!(breaker.should_allow(host).is_ok());
1046
1047        // record_success is similarly a no-op.
1048        breaker.record_success(host);
1049        assert!(breaker.get_state(host).is_none());
1050    }
1051
1052    // ---- reset / reset_all ----
1053
1054    #[test]
1055    fn reset_all_clears_state_for_every_host() {
1056        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1057        breaker.record_failure("https://a.example.com");
1058        breaker.record_failure("https://b.example.com");
1059        breaker.record_failure("https://c.example.com");
1060        assert!(breaker.get_state("https://a.example.com").is_some());
1061
1062        breaker.reset_all();
1063        assert!(breaker.get_state("https://a.example.com").is_none());
1064        assert!(breaker.get_state("https://b.example.com").is_none());
1065        assert!(breaker.get_state("https://c.example.com").is_none());
1066    }
1067
1068    // ---- should_allow with expired timeout ----
1069
1070    #[test]
1071    fn should_allow_returns_ok_when_open_timeout_has_elapsed() {
1072        // Stuff an already-expired opened_at into the map and verify
1073        // should_allow lets the request through to drive the transition.
1074        let breaker = breaker_with(CircuitBreakerConfig {
1075            failure_threshold: 1,
1076            base_timeout: Duration::from_millis(10),
1077            ..Default::default()
1078        });
1079        let host = "https://ready.example.com";
1080        breaker.record_failure(host);
1081        // Force an opened_at well in the past so elapsed >= current_backoff.
1082        {
1083            let mut states = breaker.states.write().unwrap();
1084            let s = states.get_mut(host).unwrap();
1085            s.opened_at = Some(Instant::now() - Duration::from_secs(3600));
1086            s.current_backoff = Duration::from_millis(10);
1087        }
1088        assert!(
1089            breaker.should_allow(host).is_ok(),
1090            "expired open circuit must allow the next request through"
1091        );
1092    }
1093
1094    #[test]
1095    fn should_allow_reports_retry_after_when_open_and_within_backoff() {
1096        let breaker = breaker_with(CircuitBreakerConfig {
1097            failure_threshold: 1,
1098            base_timeout: Duration::from_secs(60),
1099            ..Default::default()
1100        });
1101        let host = "https://hot.example.com";
1102        breaker.record_failure(host);
1103
1104        let err = breaker.should_allow(host).expect_err("still open");
1105        assert_eq!(err.host, host);
1106        // retry_after must be > 0 and <= current_backoff.
1107        let backoff = breaker.get_state(host).unwrap().current_backoff;
1108        assert!(err.retry_after > Duration::ZERO);
1109        assert!(err.retry_after <= backoff);
1110    }
1111
1112    // ---- extract_host edges ----
1113
1114    #[test]
1115    fn extract_host_handles_default_ports_and_no_port() {
1116        // Default 443 / 80 are elided by the URL parser, so they should not
1117        // appear in the extracted host string.
1118        assert_eq!(
1119            CircuitBreakerClient::extract_host(&url("https://api.example.com/")),
1120            "https://api.example.com"
1121        );
1122        assert_eq!(
1123            CircuitBreakerClient::extract_host(&url("http://api.example.com/")),
1124            "http://api.example.com"
1125        );
1126        // Non-default port appears.
1127        assert_eq!(
1128            CircuitBreakerClient::extract_host(&url("https://api.example.com:8443/")),
1129            "https://api.example.com:8443"
1130        );
1131    }
1132
1133    #[test]
1134    fn extract_host_includes_ipv6_brackets() {
1135        // host_str() returns the bare IPv6 without brackets — the formatter
1136        // produces a host string that round-trips through later URL parsing.
1137        let h = CircuitBreakerClient::extract_host(&url("http://[::1]:8080/"));
1138        assert!(h.contains("::1"), "got: {h}");
1139        assert!(h.ends_with(":8080"), "got: {h}");
1140    }
1141
1142    // ---- error Display / source ----
1143
1144    #[test]
1145    fn circuit_breaker_open_display_mentions_host_and_retry_after() {
1146        let err = CircuitBreakerOpen {
1147            host: "https://flaky.example.com".to_string(),
1148            retry_after: Duration::from_secs(42),
1149        };
1150        let s = err.to_string();
1151        assert!(s.contains("https://flaky.example.com"));
1152        assert!(s.contains("42"));
1153    }
1154
1155    #[test]
1156    fn private_host_blocked_display_redacts_host() {
1157        let err = CircuitBreakerError::PrivateHostBlocked("127.0.0.1".to_string());
1158        let s = err.to_string();
1159        assert!(
1160            !s.contains("127.0.0.1"),
1161            "host must not leak through Display"
1162        );
1163        assert!(s.contains("private IP"));
1164    }
1165
1166    #[test]
1167    fn circuit_breaker_error_source_chains_through_inner_variants() {
1168        // CircuitOpen wraps CircuitBreakerOpen — must surface as source.
1169        let inner = CircuitBreakerOpen {
1170            host: "h".to_string(),
1171            retry_after: Duration::from_secs(1),
1172        };
1173        let err = CircuitBreakerError::CircuitOpen(inner);
1174        assert!(
1175            std::error::Error::source(&err).is_some(),
1176            "CircuitOpen should expose its wrapped error as source"
1177        );
1178
1179        // PrivateHostBlocked has no upstream cause.
1180        let err = CircuitBreakerError::PrivateHostBlocked("h".to_string());
1181        assert!(
1182            std::error::Error::source(&err).is_none(),
1183            "PrivateHostBlocked has no source"
1184        );
1185    }
1186
1187    // ---- HttpClient defaults ----
1188
1189    #[test]
1190    fn http_client_apply_default_timeout_is_noop_when_default_unset() {
1191        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1192        let client = breaker.with_timeout(None);
1193        let mut req = reqwest::Request::new(Method::GET, url("https://example.com/"));
1194        client.apply_default_timeout(&mut req);
1195        assert_eq!(req.timeout(), None);
1196    }
1197
1198    #[test]
1199    fn http_client_accessors_expose_underlying_pieces() {
1200        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
1201        let client = breaker.with_timeout(Some(Duration::from_secs(7)));
1202        assert_eq!(client.default_timeout(), Some(Duration::from_secs(7)));
1203        // inner() and circuit_breaker() exist as load-bearing public API;
1204        // calling them confirms they don't panic and return a usable handle.
1205        let _ = client.inner();
1206        let _ = client.circuit_breaker();
1207    }
1208}