Skip to main content

forge_core/http/
mod.rs

1//! HTTP client with circuit breaker pattern.
2//!
3//! Wraps `reqwest::Client` with automatic failure tracking per host.
4//! After repeated failures, requests fail fast to prevent cascade failures.
5
6use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
11
12/// Circuit breaker state for a single host.
13#[derive(Debug, Clone)]
14pub struct CircuitState {
15    /// Current state of the circuit.
16    pub state: CircuitStatus,
17    /// Number of consecutive failures.
18    pub failure_count: u32,
19    /// Number of consecutive successes (used in half-open state).
20    pub success_count: u32,
21    /// When the circuit was opened (for timeout calculation).
22    pub opened_at: Option<Instant>,
23    /// Current backoff duration.
24    pub current_backoff: Duration,
25}
26
27/// Circuit breaker status.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CircuitStatus {
30    /// Normal operation, requests pass through.
31    Closed,
32    /// Circuit tripped, requests fail fast.
33    Open,
34    /// Testing if service recovered, limited requests allowed.
35    HalfOpen,
36}
37
38impl Default for CircuitState {
39    fn default() -> Self {
40        Self {
41            state: CircuitStatus::Closed,
42            failure_count: 0,
43            success_count: 0,
44            opened_at: None,
45            current_backoff: Duration::from_secs(30),
46        }
47    }
48}
49
50/// Configuration for the circuit breaker.
51#[derive(Debug, Clone)]
52pub struct CircuitBreakerConfig {
53    /// Number of failures before opening the circuit.
54    pub failure_threshold: u32,
55    /// Number of successes in half-open state before closing.
56    pub success_threshold: u32,
57    /// Initial timeout before trying half-open.
58    pub base_timeout: Duration,
59    /// Maximum backoff duration.
60    pub max_backoff: Duration,
61    /// Backoff multiplier for exponential backoff.
62    pub backoff_multiplier: f64,
63    /// Whether the circuit breaker is enabled.
64    pub enabled: bool,
65}
66
67impl Default for CircuitBreakerConfig {
68    fn default() -> Self {
69        Self {
70            failure_threshold: 5,
71            success_threshold: 2,
72            base_timeout: Duration::from_secs(30),
73            max_backoff: Duration::from_secs(600), // 10 minutes
74            backoff_multiplier: 1.5,
75            enabled: true,
76        }
77    }
78}
79
80/// Error returned when circuit breaker is open.
81#[derive(Debug, Clone)]
82pub struct CircuitBreakerOpen {
83    /// The host that is being blocked.
84    pub host: String,
85    /// Time until the circuit may try again.
86    pub retry_after: Duration,
87}
88
89impl std::fmt::Display for CircuitBreakerOpen {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(
92            f,
93            "Circuit breaker open for {}: retry after {:?}",
94            self.host, self.retry_after
95        )
96    }
97}
98
99impl std::error::Error for CircuitBreakerOpen {}
100
101/// HTTP client with circuit breaker pattern.
102///
103/// Tracks failure rates per host and fails fast when a host is unhealthy.
104#[derive(Clone)]
105pub struct CircuitBreakerClient {
106    inner: reqwest::Client,
107    states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
108    config: CircuitBreakerConfig,
109}
110
111impl CircuitBreakerClient {
112    /// Create a new circuit breaker client wrapping the given reqwest client.
113    pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
114        Self {
115            inner: client,
116            states: std::sync::Arc::new(RwLock::new(HashMap::new())),
117            config,
118        }
119    }
120
121    /// Create with default configuration.
122    pub fn with_defaults(client: reqwest::Client) -> Self {
123        Self::new(client, CircuitBreakerConfig::default())
124    }
125
126    /// Get the underlying reqwest client for building requests.
127    pub fn inner(&self) -> &reqwest::Client {
128        &self.inner
129    }
130
131    /// Create a request client view with an optional default request timeout.
132    pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
133        HttpClient::new(self.clone(), timeout)
134    }
135
136    /// Extract host from URL for tracking.
137    fn extract_host(url: &reqwest::Url) -> String {
138        format!(
139            "{}://{}{}",
140            url.scheme(),
141            url.host_str().unwrap_or("unknown"),
142            url.port().map(|p| format!(":{}", p)).unwrap_or_default()
143        )
144    }
145
146    /// Check if a request to the given host should be allowed.
147    pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
148        if !self.config.enabled {
149            return Ok(());
150        }
151
152        let states = self.states.read().unwrap_or_else(|e| {
153            tracing::error!("Circuit breaker lock was poisoned, recovering");
154            e.into_inner()
155        });
156        let state = match states.get(host) {
157            Some(s) => s,
158            None => return Ok(()), // No state = first request, allow
159        };
160
161        match state.state {
162            CircuitStatus::Closed => Ok(()),
163            CircuitStatus::HalfOpen => Ok(()), // Allow test requests
164            CircuitStatus::Open => {
165                let opened_at = state.opened_at.unwrap_or_else(Instant::now);
166                let elapsed = opened_at.elapsed();
167
168                if elapsed >= state.current_backoff {
169                    // Timeout expired, will transition to half-open
170                    Ok(())
171                } else {
172                    Err(CircuitBreakerOpen {
173                        host: host.to_string(),
174                        retry_after: state.current_backoff - elapsed,
175                    })
176                }
177            }
178        }
179    }
180
181    /// Record a successful request.
182    pub fn record_success(&self, host: &str) {
183        if !self.config.enabled {
184            return;
185        }
186
187        let mut states = self.states.write().unwrap_or_else(|e| {
188            tracing::error!("Circuit breaker lock was poisoned, recovering");
189            e.into_inner()
190        });
191        let state = states.entry(host.to_string()).or_default();
192
193        match state.state {
194            CircuitStatus::Closed => {
195                // Reset failure count on success
196                state.failure_count = 0;
197            }
198            CircuitStatus::HalfOpen => {
199                state.success_count += 1;
200                if state.success_count >= self.config.success_threshold {
201                    // Service recovered, close the circuit
202                    tracing::info!(host = %host, "Circuit breaker closed, service recovered");
203                    state.state = CircuitStatus::Closed;
204                    state.failure_count = 0;
205                    state.success_count = 0;
206                    state.opened_at = None;
207                    state.current_backoff = self.config.base_timeout;
208                }
209            }
210            CircuitStatus::Open => {
211                // Transition to half-open on first success after timeout
212                tracing::info!(host = %host, "Circuit breaker half-open, testing service");
213                state.state = CircuitStatus::HalfOpen;
214                state.success_count = 1;
215            }
216        }
217    }
218
219    /// Record a failed request.
220    pub fn record_failure(&self, host: &str) {
221        if !self.config.enabled {
222            return;
223        }
224
225        let mut states = self.states.write().unwrap_or_else(|e| {
226            tracing::error!("Circuit breaker lock was poisoned, recovering");
227            e.into_inner()
228        });
229        let state = states.entry(host.to_string()).or_default();
230
231        match state.state {
232            CircuitStatus::Closed => {
233                state.failure_count += 1;
234                if state.failure_count >= self.config.failure_threshold {
235                    // Trip the circuit
236                    tracing::warn!(
237                        host = %host,
238                        failures = state.failure_count,
239                        "Circuit breaker opened, service unhealthy"
240                    );
241                    state.state = CircuitStatus::Open;
242                    state.opened_at = Some(Instant::now());
243                }
244            }
245            CircuitStatus::HalfOpen => {
246                // Failed during test, reopen with increased backoff
247                let new_backoff = Duration::from_secs_f64(
248                    (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
249                        .min(self.config.max_backoff.as_secs_f64()),
250                );
251                tracing::warn!(
252                    host = %host,
253                    backoff_secs = new_backoff.as_secs(),
254                    "Circuit breaker reopened, service still unhealthy"
255                );
256                state.state = CircuitStatus::Open;
257                state.opened_at = Some(Instant::now());
258                state.current_backoff = new_backoff;
259                state.success_count = 0;
260            }
261            CircuitStatus::Open => {
262                // Already open, just update timestamp
263                state.opened_at = Some(Instant::now());
264            }
265        }
266    }
267
268    /// Execute a request with circuit breaker protection.
269    pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
270        let host = Self::extract_host(request.url());
271
272        // Check circuit state
273        self.should_allow(&host)
274            .map_err(CircuitBreakerError::CircuitOpen)?;
275
276        // If circuit is open but timeout expired, transition to half-open
277        {
278            let mut states = self.states.write().unwrap_or_else(|e| {
279                tracing::error!("Circuit breaker lock was poisoned, recovering");
280                e.into_inner()
281            });
282            if let Some(state) = states.get_mut(&host)
283                && state.state == CircuitStatus::Open
284                && let Some(opened_at) = state.opened_at
285                && opened_at.elapsed() >= state.current_backoff
286            {
287                tracing::info!(host = %host, "Circuit breaker half-open, testing service");
288                state.state = CircuitStatus::HalfOpen;
289                state.success_count = 0;
290            }
291        }
292
293        // Execute the request
294        match self.inner.execute(request).await {
295            Ok(response) => {
296                // Check if response indicates server error
297                if response.status().is_server_error() {
298                    self.record_failure(&host);
299                } else {
300                    self.record_success(&host);
301                }
302                Ok(response)
303            }
304            Err(e) => {
305                self.record_failure(&host);
306                Err(CircuitBreakerError::Request(e))
307            }
308        }
309    }
310
311    /// Get the current state for a host.
312    pub fn get_state(&self, host: &str) -> Option<CircuitState> {
313        self.states
314            .read()
315            .unwrap_or_else(|e| {
316                tracing::error!("Circuit breaker lock was poisoned, recovering");
317                e.into_inner()
318            })
319            .get(host)
320            .cloned()
321    }
322
323    /// Reset the circuit breaker state for a host.
324    pub fn reset(&self, host: &str) {
325        self.states
326            .write()
327            .unwrap_or_else(|e| {
328                tracing::error!("Circuit breaker lock was poisoned, recovering");
329                e.into_inner()
330            })
331            .remove(host);
332    }
333
334    /// Reset all circuit breaker states.
335    pub fn reset_all(&self) {
336        self.states
337            .write()
338            .unwrap_or_else(|e| {
339                tracing::error!("Circuit breaker lock was poisoned, recovering");
340                e.into_inner()
341            })
342            .clear();
343    }
344}
345
346/// Error type for circuit breaker operations.
347#[derive(Debug)]
348pub enum CircuitBreakerError {
349    /// The circuit is open, request was not attempted.
350    CircuitOpen(CircuitBreakerOpen),
351    /// The request failed.
352    Request(reqwest::Error),
353}
354
355impl std::fmt::Display for CircuitBreakerError {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        match self {
358            CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
359            CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
360        }
361    }
362}
363
364impl std::error::Error for CircuitBreakerError {
365    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
366        match self {
367            CircuitBreakerError::CircuitOpen(e) => Some(e),
368            CircuitBreakerError::Request(e) => Some(e),
369        }
370    }
371}
372
373impl From<reqwest::Error> for CircuitBreakerError {
374    fn from(e: reqwest::Error) -> Self {
375        CircuitBreakerError::Request(e)
376    }
377}
378
379/// HTTP client facade that routes requests through a circuit breaker and can
380/// apply a default timeout to requests that do not set one explicitly.
381#[derive(Clone)]
382pub struct HttpClient {
383    circuit_breaker: CircuitBreakerClient,
384    default_timeout: Option<Duration>,
385}
386
387impl HttpClient {
388    /// Create a new HTTP client facade.
389    pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
390        Self {
391            circuit_breaker,
392            default_timeout,
393        }
394    }
395
396    /// Get the underlying reqwest client.
397    pub fn inner(&self) -> &reqwest::Client {
398        self.circuit_breaker.inner()
399    }
400
401    /// Get the underlying circuit breaker client.
402    pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
403        &self.circuit_breaker
404    }
405
406    /// Get the default timeout applied to requests that do not override it.
407    pub fn default_timeout(&self) -> Option<Duration> {
408        self.default_timeout
409    }
410
411    /// Create a request builder.
412    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
413        HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
414    }
415
416    pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
417        self.request(Method::GET, url)
418    }
419
420    pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
421        self.request(Method::POST, url)
422    }
423
424    pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
425        self.request(Method::PUT, url)
426    }
427
428    pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
429        self.request(Method::PATCH, url)
430    }
431
432    pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
433        self.request(Method::DELETE, url)
434    }
435
436    pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
437        self.request(Method::HEAD, url)
438    }
439
440    /// Execute a pre-built request through the circuit breaker.
441    pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
442        self.apply_default_timeout(&mut request);
443        self.circuit_breaker
444            .execute(request)
445            .await
446            .map_err(Into::into)
447    }
448
449    fn apply_default_timeout(&self, request: &mut Request) {
450        if request.timeout().is_none()
451            && let Some(timeout) = self.default_timeout
452        {
453            *request.timeout_mut() = Some(timeout);
454        }
455    }
456}
457
458/// Request builder paired with a circuit-breaker-backed HTTP client.
459pub struct HttpRequestBuilder {
460    client: HttpClient,
461    request: RequestBuilder,
462}
463
464impl HttpRequestBuilder {
465    fn new(client: HttpClient, request: RequestBuilder) -> Self {
466        Self { client, request }
467    }
468
469    pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
470        Self {
471            request: self.request.header(key.as_ref(), value.as_ref()),
472            ..self
473        }
474    }
475
476    pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
477        Self {
478            request: self.request.headers(headers),
479            ..self
480        }
481    }
482
483    pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
484        Self {
485            request: self.request.bearer_auth(token),
486            ..self
487        }
488    }
489
490    pub fn basic_auth(
491        self,
492        username: impl std::fmt::Display,
493        password: Option<impl std::fmt::Display>,
494    ) -> Self {
495        Self {
496            request: self.request.basic_auth(username, password),
497            ..self
498        }
499    }
500
501    pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
502        Self {
503            request: self.request.body(body),
504            ..self
505        }
506    }
507
508    pub fn json(self, json: &impl serde::Serialize) -> Self {
509        Self {
510            request: self.request.json(json),
511            ..self
512        }
513    }
514
515    pub fn form(self, form: &impl serde::Serialize) -> Self {
516        Self {
517            request: self.request.form(form),
518            ..self
519        }
520    }
521
522    pub fn query(self, query: &impl serde::Serialize) -> Self {
523        Self {
524            request: self.request.query(query),
525            ..self
526        }
527    }
528
529    pub fn timeout(self, timeout: Duration) -> Self {
530        Self {
531            request: self.request.timeout(timeout),
532            ..self
533        }
534    }
535
536    pub fn version(self, version: reqwest::Version) -> Self {
537        Self {
538            request: self.request.version(version),
539            ..self
540        }
541    }
542
543    pub fn try_clone(&self) -> Option<Self> {
544        self.request.try_clone().map(|request| Self {
545            client: self.client.clone(),
546            request,
547        })
548    }
549
550    pub fn build(self) -> crate::Result<Request> {
551        self.request
552            .build()
553            .map_err(|e| crate::ForgeError::Internal(e.to_string()))
554    }
555
556    pub async fn send(self) -> crate::Result<Response> {
557        let client = self.client.clone();
558        let request = self.build()?;
559        client.execute(request).await
560    }
561}
562
563#[cfg(test)]
564#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
565mod tests {
566    use super::*;
567
568    #[test]
569    fn test_circuit_breaker_defaults() {
570        let config = CircuitBreakerConfig::default();
571        assert_eq!(config.failure_threshold, 5);
572        assert_eq!(config.success_threshold, 2);
573        assert!(config.enabled);
574    }
575
576    #[test]
577    fn test_circuit_state_transitions() {
578        let client = reqwest::Client::new();
579        let breaker = CircuitBreakerClient::with_defaults(client);
580        let host = "https://api.example.com";
581
582        // Initial state should allow
583        assert!(breaker.should_allow(host).is_ok());
584
585        // Record failures to trip the circuit
586        for _ in 0..5 {
587            breaker.record_failure(host);
588        }
589
590        // Circuit should be open
591        let state = breaker.get_state(host).unwrap();
592        assert_eq!(state.state, CircuitStatus::Open);
593
594        // Should be blocked
595        assert!(breaker.should_allow(host).is_err());
596
597        // Reset and verify
598        breaker.reset(host);
599        assert!(breaker.should_allow(host).is_ok());
600    }
601
602    #[test]
603    fn test_extract_host() {
604        let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
605        assert_eq!(
606            CircuitBreakerClient::extract_host(&url),
607            "https://api.example.com:8080"
608        );
609
610        let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
611        assert_eq!(
612            CircuitBreakerClient::extract_host(&url2),
613            "http://localhost"
614        );
615    }
616
617    #[test]
618    fn test_http_client_applies_default_timeout_when_missing() {
619        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
620        let client = breaker.with_timeout(Some(Duration::from_secs(5)));
621        let mut request = reqwest::Request::new(
622            Method::GET,
623            reqwest::Url::parse("https://example.com").unwrap(),
624        );
625
626        client.apply_default_timeout(&mut request);
627
628        assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
629    }
630
631    #[test]
632    fn test_http_client_preserves_explicit_timeout() {
633        let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
634        let client = breaker.with_timeout(Some(Duration::from_secs(5)));
635        let mut request = reqwest::Request::new(
636            Method::GET,
637            reqwest::Url::parse("https://example.com").unwrap(),
638        );
639        *request.timeout_mut() = Some(Duration::from_secs(1));
640
641        client.apply_default_timeout(&mut request);
642
643        assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
644    }
645}