1use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
11
12#[derive(Debug, Clone)]
14pub struct CircuitState {
15 pub state: CircuitStatus,
17 pub failure_count: u32,
19 pub success_count: u32,
21 pub opened_at: Option<Instant>,
23 pub current_backoff: Duration,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CircuitStatus {
30 Closed,
32 Open,
34 HalfOpen,
36}
37
38impl Default for CircuitState {
39 fn default() -> Self {
40 Self {
41 state: CircuitStatus::Closed,
42 failure_count: 0,
43 success_count: 0,
44 opened_at: None,
45 current_backoff: Duration::from_secs(30),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct CircuitBreakerConfig {
53 pub failure_threshold: u32,
55 pub success_threshold: u32,
57 pub base_timeout: Duration,
59 pub max_backoff: Duration,
61 pub backoff_multiplier: f64,
63 pub enabled: bool,
65}
66
67impl Default for CircuitBreakerConfig {
68 fn default() -> Self {
69 Self {
70 failure_threshold: 5,
71 success_threshold: 2,
72 base_timeout: Duration::from_secs(30),
73 max_backoff: Duration::from_secs(600), backoff_multiplier: 1.5,
75 enabled: true,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct CircuitBreakerOpen {
83 pub host: String,
85 pub retry_after: Duration,
87}
88
89impl std::fmt::Display for CircuitBreakerOpen {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(
92 f,
93 "Circuit breaker open for {}: retry after {:?}",
94 self.host, self.retry_after
95 )
96 }
97}
98
99impl std::error::Error for CircuitBreakerOpen {}
100
101#[derive(Clone)]
105pub struct CircuitBreakerClient {
106 inner: reqwest::Client,
107 states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
108 config: CircuitBreakerConfig,
109}
110
111impl CircuitBreakerClient {
112 pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
114 Self {
115 inner: client,
116 states: std::sync::Arc::new(RwLock::new(HashMap::new())),
117 config,
118 }
119 }
120
121 pub fn with_defaults(client: reqwest::Client) -> Self {
123 Self::new(client, CircuitBreakerConfig::default())
124 }
125
126 pub fn inner(&self) -> &reqwest::Client {
128 &self.inner
129 }
130
131 pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
133 HttpClient::new(self.clone(), timeout)
134 }
135
136 fn extract_host(url: &reqwest::Url) -> String {
138 format!(
139 "{}://{}{}",
140 url.scheme(),
141 url.host_str().unwrap_or("unknown"),
142 url.port().map(|p| format!(":{}", p)).unwrap_or_default()
143 )
144 }
145
146 pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
148 if !self.config.enabled {
149 return Ok(());
150 }
151
152 let states = self.states.read().unwrap_or_else(|e| {
153 tracing::error!("Circuit breaker lock was poisoned, recovering");
154 e.into_inner()
155 });
156 let state = match states.get(host) {
157 Some(s) => s,
158 None => return Ok(()), };
160
161 match state.state {
162 CircuitStatus::Closed => Ok(()),
163 CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
165 let opened_at = state.opened_at.unwrap_or_else(Instant::now);
166 let elapsed = opened_at.elapsed();
167
168 if elapsed >= state.current_backoff {
169 Ok(())
171 } else {
172 Err(CircuitBreakerOpen {
173 host: host.to_string(),
174 retry_after: state.current_backoff - elapsed,
175 })
176 }
177 }
178 }
179 }
180
181 pub fn record_success(&self, host: &str) {
183 if !self.config.enabled {
184 return;
185 }
186
187 let mut states = self.states.write().unwrap_or_else(|e| {
188 tracing::error!("Circuit breaker lock was poisoned, recovering");
189 e.into_inner()
190 });
191 let state = states.entry(host.to_string()).or_default();
192
193 match state.state {
194 CircuitStatus::Closed => {
195 state.failure_count = 0;
197 }
198 CircuitStatus::HalfOpen => {
199 state.success_count += 1;
200 if state.success_count >= self.config.success_threshold {
201 tracing::info!(host = %host, "Circuit breaker closed, service recovered");
203 state.state = CircuitStatus::Closed;
204 state.failure_count = 0;
205 state.success_count = 0;
206 state.opened_at = None;
207 state.current_backoff = self.config.base_timeout;
208 }
209 }
210 CircuitStatus::Open => {
211 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
213 state.state = CircuitStatus::HalfOpen;
214 state.success_count = 1;
215 }
216 }
217 }
218
219 pub fn record_failure(&self, host: &str) {
221 if !self.config.enabled {
222 return;
223 }
224
225 let mut states = self.states.write().unwrap_or_else(|e| {
226 tracing::error!("Circuit breaker lock was poisoned, recovering");
227 e.into_inner()
228 });
229 let state = states.entry(host.to_string()).or_default();
230
231 match state.state {
232 CircuitStatus::Closed => {
233 state.failure_count += 1;
234 if state.failure_count >= self.config.failure_threshold {
235 tracing::warn!(
237 host = %host,
238 failures = state.failure_count,
239 "Circuit breaker opened, service unhealthy"
240 );
241 state.state = CircuitStatus::Open;
242 state.opened_at = Some(Instant::now());
243 }
244 }
245 CircuitStatus::HalfOpen => {
246 let new_backoff = Duration::from_secs_f64(
248 (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
249 .min(self.config.max_backoff.as_secs_f64()),
250 );
251 tracing::warn!(
252 host = %host,
253 backoff_secs = new_backoff.as_secs(),
254 "Circuit breaker reopened, service still unhealthy"
255 );
256 state.state = CircuitStatus::Open;
257 state.opened_at = Some(Instant::now());
258 state.current_backoff = new_backoff;
259 state.success_count = 0;
260 }
261 CircuitStatus::Open => {
262 state.opened_at = Some(Instant::now());
264 }
265 }
266 }
267
268 pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
270 let host = Self::extract_host(request.url());
271
272 self.should_allow(&host)
274 .map_err(CircuitBreakerError::CircuitOpen)?;
275
276 {
278 let mut states = self.states.write().unwrap_or_else(|e| {
279 tracing::error!("Circuit breaker lock was poisoned, recovering");
280 e.into_inner()
281 });
282 if let Some(state) = states.get_mut(&host)
283 && state.state == CircuitStatus::Open
284 && let Some(opened_at) = state.opened_at
285 && opened_at.elapsed() >= state.current_backoff
286 {
287 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
288 state.state = CircuitStatus::HalfOpen;
289 state.success_count = 0;
290 }
291 }
292
293 match self.inner.execute(request).await {
295 Ok(response) => {
296 if response.status().is_server_error() {
298 self.record_failure(&host);
299 } else {
300 self.record_success(&host);
301 }
302 Ok(response)
303 }
304 Err(e) => {
305 self.record_failure(&host);
306 Err(CircuitBreakerError::Request(e))
307 }
308 }
309 }
310
311 pub fn get_state(&self, host: &str) -> Option<CircuitState> {
313 self.states
314 .read()
315 .unwrap_or_else(|e| {
316 tracing::error!("Circuit breaker lock was poisoned, recovering");
317 e.into_inner()
318 })
319 .get(host)
320 .cloned()
321 }
322
323 pub fn reset(&self, host: &str) {
325 self.states
326 .write()
327 .unwrap_or_else(|e| {
328 tracing::error!("Circuit breaker lock was poisoned, recovering");
329 e.into_inner()
330 })
331 .remove(host);
332 }
333
334 pub fn reset_all(&self) {
336 self.states
337 .write()
338 .unwrap_or_else(|e| {
339 tracing::error!("Circuit breaker lock was poisoned, recovering");
340 e.into_inner()
341 })
342 .clear();
343 }
344}
345
346#[derive(Debug)]
348pub enum CircuitBreakerError {
349 CircuitOpen(CircuitBreakerOpen),
351 Request(reqwest::Error),
353}
354
355impl std::fmt::Display for CircuitBreakerError {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 match self {
358 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
359 CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
360 }
361 }
362}
363
364impl std::error::Error for CircuitBreakerError {
365 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
366 match self {
367 CircuitBreakerError::CircuitOpen(e) => Some(e),
368 CircuitBreakerError::Request(e) => Some(e),
369 }
370 }
371}
372
373impl From<reqwest::Error> for CircuitBreakerError {
374 fn from(e: reqwest::Error) -> Self {
375 CircuitBreakerError::Request(e)
376 }
377}
378
379#[derive(Clone)]
382pub struct HttpClient {
383 circuit_breaker: CircuitBreakerClient,
384 default_timeout: Option<Duration>,
385}
386
387impl HttpClient {
388 pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
390 Self {
391 circuit_breaker,
392 default_timeout,
393 }
394 }
395
396 pub fn inner(&self) -> &reqwest::Client {
398 self.circuit_breaker.inner()
399 }
400
401 pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
403 &self.circuit_breaker
404 }
405
406 pub fn default_timeout(&self) -> Option<Duration> {
408 self.default_timeout
409 }
410
411 pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
413 HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
414 }
415
416 pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
417 self.request(Method::GET, url)
418 }
419
420 pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
421 self.request(Method::POST, url)
422 }
423
424 pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
425 self.request(Method::PUT, url)
426 }
427
428 pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
429 self.request(Method::PATCH, url)
430 }
431
432 pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
433 self.request(Method::DELETE, url)
434 }
435
436 pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
437 self.request(Method::HEAD, url)
438 }
439
440 pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
442 self.apply_default_timeout(&mut request);
443 self.circuit_breaker
444 .execute(request)
445 .await
446 .map_err(Into::into)
447 }
448
449 fn apply_default_timeout(&self, request: &mut Request) {
450 if request.timeout().is_none()
451 && let Some(timeout) = self.default_timeout
452 {
453 *request.timeout_mut() = Some(timeout);
454 }
455 }
456}
457
458pub struct HttpRequestBuilder {
460 client: HttpClient,
461 request: RequestBuilder,
462}
463
464impl HttpRequestBuilder {
465 fn new(client: HttpClient, request: RequestBuilder) -> Self {
466 Self { client, request }
467 }
468
469 pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
470 Self {
471 request: self.request.header(key.as_ref(), value.as_ref()),
472 ..self
473 }
474 }
475
476 pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
477 Self {
478 request: self.request.headers(headers),
479 ..self
480 }
481 }
482
483 pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
484 Self {
485 request: self.request.bearer_auth(token),
486 ..self
487 }
488 }
489
490 pub fn basic_auth(
491 self,
492 username: impl std::fmt::Display,
493 password: Option<impl std::fmt::Display>,
494 ) -> Self {
495 Self {
496 request: self.request.basic_auth(username, password),
497 ..self
498 }
499 }
500
501 pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
502 Self {
503 request: self.request.body(body),
504 ..self
505 }
506 }
507
508 pub fn json(self, json: &impl serde::Serialize) -> Self {
509 Self {
510 request: self.request.json(json),
511 ..self
512 }
513 }
514
515 pub fn form(self, form: &impl serde::Serialize) -> Self {
516 Self {
517 request: self.request.form(form),
518 ..self
519 }
520 }
521
522 pub fn query(self, query: &impl serde::Serialize) -> Self {
523 Self {
524 request: self.request.query(query),
525 ..self
526 }
527 }
528
529 pub fn timeout(self, timeout: Duration) -> Self {
530 Self {
531 request: self.request.timeout(timeout),
532 ..self
533 }
534 }
535
536 pub fn version(self, version: reqwest::Version) -> Self {
537 Self {
538 request: self.request.version(version),
539 ..self
540 }
541 }
542
543 pub fn try_clone(&self) -> Option<Self> {
544 self.request.try_clone().map(|request| Self {
545 client: self.client.clone(),
546 request,
547 })
548 }
549
550 pub fn build(self) -> crate::Result<Request> {
551 self.request
552 .build()
553 .map_err(|e| crate::ForgeError::Internal(e.to_string()))
554 }
555
556 pub async fn send(self) -> crate::Result<Response> {
557 let client = self.client.clone();
558 let request = self.build()?;
559 client.execute(request).await
560 }
561}
562
563#[cfg(test)]
564#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_circuit_breaker_defaults() {
570 let config = CircuitBreakerConfig::default();
571 assert_eq!(config.failure_threshold, 5);
572 assert_eq!(config.success_threshold, 2);
573 assert!(config.enabled);
574 }
575
576 #[test]
577 fn test_circuit_state_transitions() {
578 let client = reqwest::Client::new();
579 let breaker = CircuitBreakerClient::with_defaults(client);
580 let host = "https://api.example.com";
581
582 assert!(breaker.should_allow(host).is_ok());
584
585 for _ in 0..5 {
587 breaker.record_failure(host);
588 }
589
590 let state = breaker.get_state(host).unwrap();
592 assert_eq!(state.state, CircuitStatus::Open);
593
594 assert!(breaker.should_allow(host).is_err());
596
597 breaker.reset(host);
599 assert!(breaker.should_allow(host).is_ok());
600 }
601
602 #[test]
603 fn test_extract_host() {
604 let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
605 assert_eq!(
606 CircuitBreakerClient::extract_host(&url),
607 "https://api.example.com:8080"
608 );
609
610 let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
611 assert_eq!(
612 CircuitBreakerClient::extract_host(&url2),
613 "http://localhost"
614 );
615 }
616
617 #[test]
618 fn test_http_client_applies_default_timeout_when_missing() {
619 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
620 let client = breaker.with_timeout(Some(Duration::from_secs(5)));
621 let mut request = reqwest::Request::new(
622 Method::GET,
623 reqwest::Url::parse("https://example.com").unwrap(),
624 );
625
626 client.apply_default_timeout(&mut request);
627
628 assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
629 }
630
631 #[test]
632 fn test_http_client_preserves_explicit_timeout() {
633 let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
634 let client = breaker.with_timeout(Some(Duration::from_secs(5)));
635 let mut request = reqwest::Request::new(
636 Method::GET,
637 reqwest::Url::parse("https://example.com").unwrap(),
638 );
639 *request.timeout_mut() = Some(Duration::from_secs(1));
640
641 client.apply_default_timeout(&mut request);
642
643 assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
644 }
645}