1use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{Request, Response};
11
12#[derive(Debug, Clone)]
14pub struct CircuitState {
15 pub state: CircuitStatus,
17 pub failure_count: u32,
19 pub success_count: u32,
21 pub opened_at: Option<Instant>,
23 pub current_backoff: Duration,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CircuitStatus {
30 Closed,
32 Open,
34 HalfOpen,
36}
37
38impl Default for CircuitState {
39 fn default() -> Self {
40 Self {
41 state: CircuitStatus::Closed,
42 failure_count: 0,
43 success_count: 0,
44 opened_at: None,
45 current_backoff: Duration::from_secs(30),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct CircuitBreakerConfig {
53 pub failure_threshold: u32,
55 pub success_threshold: u32,
57 pub base_timeout: Duration,
59 pub max_backoff: Duration,
61 pub backoff_multiplier: f64,
63 pub enabled: bool,
65}
66
67impl Default for CircuitBreakerConfig {
68 fn default() -> Self {
69 Self {
70 failure_threshold: 5,
71 success_threshold: 2,
72 base_timeout: Duration::from_secs(30),
73 max_backoff: Duration::from_secs(600), backoff_multiplier: 1.5,
75 enabled: true,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct CircuitBreakerOpen {
83 pub host: String,
85 pub retry_after: Duration,
87}
88
89impl std::fmt::Display for CircuitBreakerOpen {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(
92 f,
93 "Circuit breaker open for {}: retry after {:?}",
94 self.host, self.retry_after
95 )
96 }
97}
98
99impl std::error::Error for CircuitBreakerOpen {}
100
101#[derive(Clone)]
105pub struct CircuitBreakerClient {
106 inner: reqwest::Client,
107 states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
108 config: CircuitBreakerConfig,
109}
110
111impl CircuitBreakerClient {
112 pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
114 Self {
115 inner: client,
116 states: std::sync::Arc::new(RwLock::new(HashMap::new())),
117 config,
118 }
119 }
120
121 pub fn with_defaults(client: reqwest::Client) -> Self {
123 Self::new(client, CircuitBreakerConfig::default())
124 }
125
126 pub fn inner(&self) -> &reqwest::Client {
128 &self.inner
129 }
130
131 fn extract_host(url: &reqwest::Url) -> String {
133 format!(
134 "{}://{}{}",
135 url.scheme(),
136 url.host_str().unwrap_or("unknown"),
137 url.port().map(|p| format!(":{}", p)).unwrap_or_default()
138 )
139 }
140
141 pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
143 if !self.config.enabled {
144 return Ok(());
145 }
146
147 let states = self.states.read().expect("circuit breaker lock poisoned");
148 let state = match states.get(host) {
149 Some(s) => s,
150 None => return Ok(()), };
152
153 match state.state {
154 CircuitStatus::Closed => Ok(()),
155 CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
157 let opened_at = state.opened_at.unwrap_or_else(Instant::now);
158 let elapsed = opened_at.elapsed();
159
160 if elapsed >= state.current_backoff {
161 Ok(())
163 } else {
164 Err(CircuitBreakerOpen {
165 host: host.to_string(),
166 retry_after: state.current_backoff - elapsed,
167 })
168 }
169 }
170 }
171 }
172
173 pub fn record_success(&self, host: &str) {
175 if !self.config.enabled {
176 return;
177 }
178
179 let mut states = self.states.write().expect("circuit breaker lock poisoned");
180 let state = states.entry(host.to_string()).or_default();
181
182 match state.state {
183 CircuitStatus::Closed => {
184 state.failure_count = 0;
186 }
187 CircuitStatus::HalfOpen => {
188 state.success_count += 1;
189 if state.success_count >= self.config.success_threshold {
190 tracing::info!(host = %host, "Circuit breaker closed, service recovered");
192 state.state = CircuitStatus::Closed;
193 state.failure_count = 0;
194 state.success_count = 0;
195 state.opened_at = None;
196 state.current_backoff = self.config.base_timeout;
197 }
198 }
199 CircuitStatus::Open => {
200 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
202 state.state = CircuitStatus::HalfOpen;
203 state.success_count = 1;
204 }
205 }
206 }
207
208 pub fn record_failure(&self, host: &str) {
210 if !self.config.enabled {
211 return;
212 }
213
214 let mut states = self.states.write().expect("circuit breaker lock poisoned");
215 let state = states.entry(host.to_string()).or_default();
216
217 match state.state {
218 CircuitStatus::Closed => {
219 state.failure_count += 1;
220 if state.failure_count >= self.config.failure_threshold {
221 tracing::warn!(
223 host = %host,
224 failures = state.failure_count,
225 "Circuit breaker opened, service unhealthy"
226 );
227 state.state = CircuitStatus::Open;
228 state.opened_at = Some(Instant::now());
229 }
230 }
231 CircuitStatus::HalfOpen => {
232 let new_backoff = Duration::from_secs_f64(
234 (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
235 .min(self.config.max_backoff.as_secs_f64()),
236 );
237 tracing::warn!(
238 host = %host,
239 backoff_secs = new_backoff.as_secs(),
240 "Circuit breaker reopened, service still unhealthy"
241 );
242 state.state = CircuitStatus::Open;
243 state.opened_at = Some(Instant::now());
244 state.current_backoff = new_backoff;
245 state.success_count = 0;
246 }
247 CircuitStatus::Open => {
248 state.opened_at = Some(Instant::now());
250 }
251 }
252 }
253
254 pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
256 let host = Self::extract_host(request.url());
257
258 self.should_allow(&host)
260 .map_err(CircuitBreakerError::CircuitOpen)?;
261
262 {
264 let mut states = self.states.write().expect("circuit breaker lock poisoned");
265 if let Some(state) = states.get_mut(&host)
266 && state.state == CircuitStatus::Open
267 && let Some(opened_at) = state.opened_at
268 && opened_at.elapsed() >= state.current_backoff
269 {
270 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
271 state.state = CircuitStatus::HalfOpen;
272 state.success_count = 0;
273 }
274 }
275
276 match self.inner.execute(request).await {
278 Ok(response) => {
279 if response.status().is_server_error() {
281 self.record_failure(&host);
282 } else {
283 self.record_success(&host);
284 }
285 Ok(response)
286 }
287 Err(e) => {
288 self.record_failure(&host);
289 Err(CircuitBreakerError::Request(e))
290 }
291 }
292 }
293
294 pub fn get_state(&self, host: &str) -> Option<CircuitState> {
296 self.states
297 .read()
298 .expect("circuit breaker lock poisoned")
299 .get(host)
300 .cloned()
301 }
302
303 pub fn reset(&self, host: &str) {
305 self.states
306 .write()
307 .expect("circuit breaker lock poisoned")
308 .remove(host);
309 }
310
311 pub fn reset_all(&self) {
313 self.states
314 .write()
315 .expect("circuit breaker lock poisoned")
316 .clear();
317 }
318}
319
320#[derive(Debug)]
322pub enum CircuitBreakerError {
323 CircuitOpen(CircuitBreakerOpen),
325 Request(reqwest::Error),
327}
328
329impl std::fmt::Display for CircuitBreakerError {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 match self {
332 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
333 CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
334 }
335 }
336}
337
338impl std::error::Error for CircuitBreakerError {
339 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
340 match self {
341 CircuitBreakerError::CircuitOpen(e) => Some(e),
342 CircuitBreakerError::Request(e) => Some(e),
343 }
344 }
345}
346
347impl From<reqwest::Error> for CircuitBreakerError {
348 fn from(e: reqwest::Error) -> Self {
349 CircuitBreakerError::Request(e)
350 }
351}
352
353#[cfg(test)]
354#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_circuit_breaker_defaults() {
360 let config = CircuitBreakerConfig::default();
361 assert_eq!(config.failure_threshold, 5);
362 assert_eq!(config.success_threshold, 2);
363 assert!(config.enabled);
364 }
365
366 #[test]
367 fn test_circuit_state_transitions() {
368 let client = reqwest::Client::new();
369 let breaker = CircuitBreakerClient::with_defaults(client);
370 let host = "https://api.example.com";
371
372 assert!(breaker.should_allow(host).is_ok());
374
375 for _ in 0..5 {
377 breaker.record_failure(host);
378 }
379
380 let state = breaker.get_state(host).unwrap();
382 assert_eq!(state.state, CircuitStatus::Open);
383
384 assert!(breaker.should_allow(host).is_err());
386
387 breaker.reset(host);
389 assert!(breaker.should_allow(host).is_ok());
390 }
391
392 #[test]
393 fn test_extract_host() {
394 let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
395 assert_eq!(
396 CircuitBreakerClient::extract_host(&url),
397 "https://api.example.com:8080"
398 );
399
400 let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
401 assert_eq!(
402 CircuitBreakerClient::extract_host(&url2),
403 "http://localhost"
404 );
405 }
406}