1use std::collections::HashMap;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9
10use reqwest::{Request, Response};
11
12#[derive(Debug, Clone)]
14pub struct CircuitState {
15 pub state: CircuitStatus,
17 pub failure_count: u32,
19 pub success_count: u32,
21 pub opened_at: Option<Instant>,
23 pub current_backoff: Duration,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CircuitStatus {
30 Closed,
32 Open,
34 HalfOpen,
36}
37
38impl Default for CircuitState {
39 fn default() -> Self {
40 Self {
41 state: CircuitStatus::Closed,
42 failure_count: 0,
43 success_count: 0,
44 opened_at: None,
45 current_backoff: Duration::from_secs(30),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct CircuitBreakerConfig {
53 pub failure_threshold: u32,
55 pub success_threshold: u32,
57 pub base_timeout: Duration,
59 pub max_backoff: Duration,
61 pub backoff_multiplier: f64,
63 pub enabled: bool,
65}
66
67impl Default for CircuitBreakerConfig {
68 fn default() -> Self {
69 Self {
70 failure_threshold: 5,
71 success_threshold: 2,
72 base_timeout: Duration::from_secs(30),
73 max_backoff: Duration::from_secs(600), backoff_multiplier: 1.5,
75 enabled: true,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct CircuitBreakerOpen {
83 pub host: String,
85 pub retry_after: Duration,
87}
88
89impl std::fmt::Display for CircuitBreakerOpen {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(
92 f,
93 "Circuit breaker open for {}: retry after {:?}",
94 self.host, self.retry_after
95 )
96 }
97}
98
99impl std::error::Error for CircuitBreakerOpen {}
100
101#[derive(Clone)]
105pub struct CircuitBreakerClient {
106 inner: reqwest::Client,
107 states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
108 config: CircuitBreakerConfig,
109}
110
111impl CircuitBreakerClient {
112 pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
114 Self {
115 inner: client,
116 states: std::sync::Arc::new(RwLock::new(HashMap::new())),
117 config,
118 }
119 }
120
121 pub fn with_defaults(client: reqwest::Client) -> Self {
123 Self::new(client, CircuitBreakerConfig::default())
124 }
125
126 pub fn inner(&self) -> &reqwest::Client {
128 &self.inner
129 }
130
131 fn extract_host(url: &reqwest::Url) -> String {
133 format!(
134 "{}://{}{}",
135 url.scheme(),
136 url.host_str().unwrap_or("unknown"),
137 url.port().map(|p| format!(":{}", p)).unwrap_or_default()
138 )
139 }
140
141 pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
143 if !self.config.enabled {
144 return Ok(());
145 }
146
147 let states = self.states.read().unwrap();
148 let state = match states.get(host) {
149 Some(s) => s,
150 None => return Ok(()), };
152
153 match state.state {
154 CircuitStatus::Closed => Ok(()),
155 CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
157 let opened_at = state.opened_at.unwrap_or_else(Instant::now);
158 let elapsed = opened_at.elapsed();
159
160 if elapsed >= state.current_backoff {
161 Ok(())
163 } else {
164 Err(CircuitBreakerOpen {
165 host: host.to_string(),
166 retry_after: state.current_backoff - elapsed,
167 })
168 }
169 }
170 }
171 }
172
173 pub fn record_success(&self, host: &str) {
175 if !self.config.enabled {
176 return;
177 }
178
179 let mut states = self.states.write().unwrap();
180 let state = states.entry(host.to_string()).or_default();
181
182 match state.state {
183 CircuitStatus::Closed => {
184 state.failure_count = 0;
186 }
187 CircuitStatus::HalfOpen => {
188 state.success_count += 1;
189 if state.success_count >= self.config.success_threshold {
190 tracing::info!(host = %host, "Circuit breaker closed, service recovered");
192 state.state = CircuitStatus::Closed;
193 state.failure_count = 0;
194 state.success_count = 0;
195 state.opened_at = None;
196 state.current_backoff = self.config.base_timeout;
197 }
198 }
199 CircuitStatus::Open => {
200 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
202 state.state = CircuitStatus::HalfOpen;
203 state.success_count = 1;
204 }
205 }
206 }
207
208 pub fn record_failure(&self, host: &str) {
210 if !self.config.enabled {
211 return;
212 }
213
214 let mut states = self.states.write().unwrap();
215 let state = states.entry(host.to_string()).or_default();
216
217 match state.state {
218 CircuitStatus::Closed => {
219 state.failure_count += 1;
220 if state.failure_count >= self.config.failure_threshold {
221 tracing::warn!(
223 host = %host,
224 failures = state.failure_count,
225 "Circuit breaker opened, service unhealthy"
226 );
227 state.state = CircuitStatus::Open;
228 state.opened_at = Some(Instant::now());
229 }
230 }
231 CircuitStatus::HalfOpen => {
232 let new_backoff = Duration::from_secs_f64(
234 (state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
235 .min(self.config.max_backoff.as_secs_f64()),
236 );
237 tracing::warn!(
238 host = %host,
239 backoff_secs = new_backoff.as_secs(),
240 "Circuit breaker reopened, service still unhealthy"
241 );
242 state.state = CircuitStatus::Open;
243 state.opened_at = Some(Instant::now());
244 state.current_backoff = new_backoff;
245 state.success_count = 0;
246 }
247 CircuitStatus::Open => {
248 state.opened_at = Some(Instant::now());
250 }
251 }
252 }
253
254 pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
256 let host = Self::extract_host(request.url());
257
258 self.should_allow(&host)
260 .map_err(CircuitBreakerError::CircuitOpen)?;
261
262 {
264 let mut states = self.states.write().unwrap();
265 if let Some(state) = states.get_mut(&host) {
266 if state.state == CircuitStatus::Open {
267 if let Some(opened_at) = state.opened_at {
268 if opened_at.elapsed() >= state.current_backoff {
269 tracing::info!(host = %host, "Circuit breaker half-open, testing service");
270 state.state = CircuitStatus::HalfOpen;
271 state.success_count = 0;
272 }
273 }
274 }
275 }
276 }
277
278 match self.inner.execute(request).await {
280 Ok(response) => {
281 if response.status().is_server_error() {
283 self.record_failure(&host);
284 } else {
285 self.record_success(&host);
286 }
287 Ok(response)
288 }
289 Err(e) => {
290 self.record_failure(&host);
291 Err(CircuitBreakerError::Request(e))
292 }
293 }
294 }
295
296 pub fn get_state(&self, host: &str) -> Option<CircuitState> {
298 self.states.read().unwrap().get(host).cloned()
299 }
300
301 pub fn reset(&self, host: &str) {
303 self.states.write().unwrap().remove(host);
304 }
305
306 pub fn reset_all(&self) {
308 self.states.write().unwrap().clear();
309 }
310}
311
312#[derive(Debug)]
314pub enum CircuitBreakerError {
315 CircuitOpen(CircuitBreakerOpen),
317 Request(reqwest::Error),
319}
320
321impl std::fmt::Display for CircuitBreakerError {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 match self {
324 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
325 CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
326 }
327 }
328}
329
330impl std::error::Error for CircuitBreakerError {
331 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
332 match self {
333 CircuitBreakerError::CircuitOpen(e) => Some(e),
334 CircuitBreakerError::Request(e) => Some(e),
335 }
336 }
337}
338
339impl From<reqwest::Error> for CircuitBreakerError {
340 fn from(e: reqwest::Error) -> Self {
341 CircuitBreakerError::Request(e)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_circuit_breaker_defaults() {
351 let config = CircuitBreakerConfig::default();
352 assert_eq!(config.failure_threshold, 5);
353 assert_eq!(config.success_threshold, 2);
354 assert!(config.enabled);
355 }
356
357 #[test]
358 fn test_circuit_state_transitions() {
359 let client = reqwest::Client::new();
360 let breaker = CircuitBreakerClient::with_defaults(client);
361 let host = "https://api.example.com";
362
363 assert!(breaker.should_allow(host).is_ok());
365
366 for _ in 0..5 {
368 breaker.record_failure(host);
369 }
370
371 let state = breaker.get_state(host).unwrap();
373 assert_eq!(state.state, CircuitStatus::Open);
374
375 assert!(breaker.should_allow(host).is_err());
377
378 breaker.reset(host);
380 assert!(breaker.should_allow(host).is_ok());
381 }
382
383 #[test]
384 fn test_extract_host() {
385 let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
386 assert_eq!(
387 CircuitBreakerClient::extract_host(&url),
388 "https://api.example.com:8080"
389 );
390
391 let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
392 assert_eq!(
393 CircuitBreakerClient::extract_host(&url2),
394 "http://localhost"
395 );
396 }
397}