Skip to main content

forge_core/http/
mod.rs

1//! HTTP client with circuit breaker pattern.
2//!
3//! Wraps `reqwest::Client` with automatic failure tracking per host.
4//! After repeated failures, requests fail fast to prevent cascade failures.
5
6use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{Request, Response};
11
12/// Circuit breaker state for a single host.
13#[derive(Debug, Clone)]
14pub struct CircuitState {
15    /// Current state of the circuit.
16    pub state: CircuitStatus,
17    /// Number of consecutive failures.
18    pub failure_count: u32,
19    /// Number of consecutive successes (used in half-open state).
20    pub success_count: u32,
21    /// When the circuit was opened (for timeout calculation).
22    pub opened_at: Option<Instant>,
23    /// Current backoff duration.
24    pub current_backoff: Duration,
25}
26
27/// Circuit breaker status.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CircuitStatus {
30    /// Normal operation, requests pass through.
31    Closed,
32    /// Circuit tripped, requests fail fast.
33    Open,
34    /// Testing if service recovered, limited requests allowed.
35    HalfOpen,
36}
37
38impl Default for CircuitState {
39    fn default() -> Self {
40        Self {
41            state: CircuitStatus::Closed,
42            failure_count: 0,
43            success_count: 0,
44            opened_at: None,
45            current_backoff: Duration::from_secs(30),
46        }
47    }
48}
49
50/// Configuration for the circuit breaker.
51#[derive(Debug, Clone)]
52pub struct CircuitBreakerConfig {
53    /// Number of failures before opening the circuit.
54    pub failure_threshold: u32,
55    /// Number of successes in half-open state before closing.
56    pub success_threshold: u32,
57    /// Initial timeout before trying half-open.
58    pub base_timeout: Duration,
59    /// Maximum backoff duration.
60    pub max_backoff: Duration,
61    /// Backoff multiplier for exponential backoff.
62    pub backoff_multiplier: f64,
63    /// Whether the circuit breaker is enabled.
64    pub enabled: bool,
65}
66
67impl Default for CircuitBreakerConfig {
68    fn default() -> Self {
69        Self {
70            failure_threshold: 5,
71            success_threshold: 2,
72            base_timeout: Duration::from_secs(30),
73            max_backoff: Duration::from_secs(600), // 10 minutes
74            backoff_multiplier: 1.5,
75            enabled: true,
76        }
77    }
78}
79
80/// Error returned when circuit breaker is open.
81#[derive(Debug, Clone)]
82pub struct CircuitBreakerOpen {
83    /// The host that is being blocked.
84    pub host: String,
85    /// Time until the circuit may try again.
86    pub retry_after: Duration,
87}
88
89impl std::fmt::Display for CircuitBreakerOpen {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(
92            f,
93            "Circuit breaker open for {}: retry after {:?}",
94            self.host, self.retry_after
95        )
96    }
97}
98
99impl std::error::Error for CircuitBreakerOpen {}
100
101/// HTTP client with circuit breaker pattern.
102///
103/// Tracks failure rates per host and fails fast when a host is unhealthy.
104#[derive(Clone)]
105pub struct CircuitBreakerClient {
106    inner: reqwest::Client,
107    states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
108    config: CircuitBreakerConfig,
109}
110
111impl CircuitBreakerClient {
112    /// Create a new circuit breaker client wrapping the given reqwest client.
113    pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
114        Self {
115            inner: client,
116            states: std::sync::Arc::new(RwLock::new(HashMap::new())),
117            config,
118        }
119    }
120
121    /// Create with default configuration.
122    pub fn with_defaults(client: reqwest::Client) -> Self {
123        Self::new(client, CircuitBreakerConfig::default())
124    }
125
126    /// Get the underlying reqwest client for building requests.
127    pub fn inner(&self) -> &reqwest::Client {
128        &self.inner
129    }
130
131    /// Extract host from URL for tracking.
132    fn extract_host(url: &reqwest::Url) -> String {
133        format!(
134            "{}://{}{}",
135            url.scheme(),
136            url.host_str().unwrap_or("unknown"),
137            url.port().map(|p| format!(":{}", p)).unwrap_or_default()
138        )
139    }
140
141    /// Check if a request to the given host should be allowed.
142    pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
143        if !self.config.enabled {
144            return Ok(());
145        }
146
147        let states = self.states.read().unwrap();
148        let state = match states.get(host) {
149            Some(s) => s,
150            None => return Ok(()), // No state = first request, allow
151        };
152
153        match state.state {
154            CircuitStatus::Closed => Ok(()),
155            CircuitStatus::HalfOpen => Ok(()), // Allow test requests
156            CircuitStatus::Open => {
157                let opened_at = state.opened_at.unwrap_or_else(Instant::now);
158                let elapsed = opened_at.elapsed();
159
160                if elapsed >= state.current_backoff {
161                    // Timeout expired, will transition to half-open
162                    Ok(())
163                } else {
164                    Err(CircuitBreakerOpen {
165                        host: host.to_string(),
166                        retry_after: state.current_backoff - elapsed,
167                    })
168                }
169            }
170        }
171    }
172
173    /// Record a successful request.
174    pub fn record_success(&self, host: &str) {
175        if !self.config.enabled {
176            return;
177        }
178
179        let mut states = self.states.write().unwrap();
180        let state = states.entry(host.to_string()).or_default();
181
182        match state.state {
183            CircuitStatus::Closed => {
184                // Reset failure count on success
185                state.failure_count = 0;
186            }
187            CircuitStatus::HalfOpen => {
188                state.success_count += 1;
189                if state.success_count >= self.config.success_threshold {
190                    // Service recovered, close the circuit
191                    tracing::info!(host = %host, "Circuit breaker closed, service recovered");
192                    state.state = CircuitStatus::Closed;
193                    state.failure_count = 0;
194                    state.success_count = 0;
195                    state.opened_at = None;
196                    state.current_backoff = self.config.base_timeout;
197                }
198            }
199            CircuitStatus::Open => {
200                // Transition to half-open on first success after timeout
201                tracing::info!(host = %host, "Circuit breaker half-open, testing service");
202                state.state = CircuitStatus::HalfOpen;
203                state.success_count = 1;
204            }
205        }
206    }
207
208    /// Record a failed request.
209    pub fn record_failure(&self, host: &str) {
210        if !self.config.enabled {
211            return;
212        }
213
214        let mut states = self.states.write().unwrap();
215        let state = states.entry(host.to_string()).or_default();
216
217        match state.state {
218            CircuitStatus::Closed => {
219                state.failure_count += 1;
220                if state.failure_count >= self.config.failure_threshold {
221                    // Trip the circuit
222                    tracing::warn!(
223                        host = %host,
224                        failures = state.failure_count,
225                        "Circuit breaker opened, service unhealthy"
226                    );
227                    state.state = CircuitStatus::Open;
228                    state.opened_at = Some(Instant::now());
229                }
230            }
231            CircuitStatus::HalfOpen => {
232                // Failed during test, reopen with increased backoff
233                let new_backoff = Duration::from_secs_f64(
234                    (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
235                        .min(self.config.max_backoff.as_secs_f64()),
236                );
237                tracing::warn!(
238                    host = %host,
239                    backoff_secs = new_backoff.as_secs(),
240                    "Circuit breaker reopened, service still unhealthy"
241                );
242                state.state = CircuitStatus::Open;
243                state.opened_at = Some(Instant::now());
244                state.current_backoff = new_backoff;
245                state.success_count = 0;
246            }
247            CircuitStatus::Open => {
248                // Already open, just update timestamp
249                state.opened_at = Some(Instant::now());
250            }
251        }
252    }
253
254    /// Execute a request with circuit breaker protection.
255    pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
256        let host = Self::extract_host(request.url());
257
258        // Check circuit state
259        self.should_allow(&host)
260            .map_err(CircuitBreakerError::CircuitOpen)?;
261
262        // If circuit is open but timeout expired, transition to half-open
263        {
264            let mut states = self.states.write().unwrap();
265            if let Some(state) = states.get_mut(&host) {
266                if state.state == CircuitStatus::Open {
267                    if let Some(opened_at) = state.opened_at {
268                        if opened_at.elapsed() >= state.current_backoff {
269                            tracing::info!(host = %host, "Circuit breaker half-open, testing service");
270                            state.state = CircuitStatus::HalfOpen;
271                            state.success_count = 0;
272                        }
273                    }
274                }
275            }
276        }
277
278        // Execute the request
279        match self.inner.execute(request).await {
280            Ok(response) => {
281                // Check if response indicates server error
282                if response.status().is_server_error() {
283                    self.record_failure(&host);
284                } else {
285                    self.record_success(&host);
286                }
287                Ok(response)
288            }
289            Err(e) => {
290                self.record_failure(&host);
291                Err(CircuitBreakerError::Request(e))
292            }
293        }
294    }
295
296    /// Get the current state for a host.
297    pub fn get_state(&self, host: &str) -> Option<CircuitState> {
298        self.states.read().unwrap().get(host).cloned()
299    }
300
301    /// Reset the circuit breaker state for a host.
302    pub fn reset(&self, host: &str) {
303        self.states.write().unwrap().remove(host);
304    }
305
306    /// Reset all circuit breaker states.
307    pub fn reset_all(&self) {
308        self.states.write().unwrap().clear();
309    }
310}
311
312/// Error type for circuit breaker operations.
313#[derive(Debug)]
314pub enum CircuitBreakerError {
315    /// The circuit is open, request was not attempted.
316    CircuitOpen(CircuitBreakerOpen),
317    /// The request failed.
318    Request(reqwest::Error),
319}
320
321impl std::fmt::Display for CircuitBreakerError {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        match self {
324            CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
325            CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
326        }
327    }
328}
329
330impl std::error::Error for CircuitBreakerError {
331    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
332        match self {
333            CircuitBreakerError::CircuitOpen(e) => Some(e),
334            CircuitBreakerError::Request(e) => Some(e),
335        }
336    }
337}
338
339impl From<reqwest::Error> for CircuitBreakerError {
340    fn from(e: reqwest::Error) -> Self {
341        CircuitBreakerError::Request(e)
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_circuit_breaker_defaults() {
351        let config = CircuitBreakerConfig::default();
352        assert_eq!(config.failure_threshold, 5);
353        assert_eq!(config.success_threshold, 2);
354        assert!(config.enabled);
355    }
356
357    #[test]
358    fn test_circuit_state_transitions() {
359        let client = reqwest::Client::new();
360        let breaker = CircuitBreakerClient::with_defaults(client);
361        let host = "https://api.example.com";
362
363        // Initial state should allow
364        assert!(breaker.should_allow(host).is_ok());
365
366        // Record failures to trip the circuit
367        for _ in 0..5 {
368            breaker.record_failure(host);
369        }
370
371        // Circuit should be open
372        let state = breaker.get_state(host).unwrap();
373        assert_eq!(state.state, CircuitStatus::Open);
374
375        // Should be blocked
376        assert!(breaker.should_allow(host).is_err());
377
378        // Reset and verify
379        breaker.reset(host);
380        assert!(breaker.should_allow(host).is_ok());
381    }
382
383    #[test]
384    fn test_extract_host() {
385        let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
386        assert_eq!(
387            CircuitBreakerClient::extract_host(&url),
388            "https://api.example.com:8080"
389        );
390
391        let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
392        assert_eq!(
393            CircuitBreakerClient::extract_host(&url2),
394            "http://localhost"
395        );
396    }
397}