1use crate::config::{BulkheadConfig, CircuitBreakerConfig};
4use prometheus::{Counter, Histogram, HistogramOpts, IntGauge, Opts, Registry};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant, SystemTime};
11use tokio::sync::{broadcast, RwLock};
12use tracing::{debug, error, info, warn};
13
14#[cfg(feature = "distributed")]
15use redis::{aio::ConnectionManager, AsyncCommands, Client as RedisClient};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum CircuitState {
20 Closed,
22 Open,
24 HalfOpen,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CircuitBreakerSnapshot {
31 pub state: CircuitState,
32 pub consecutive_failures: u64,
33 pub consecutive_successes: u64,
34 pub total_requests: u64,
35 pub successful_requests: u64,
36 pub failed_requests: u64,
37 pub rejected_requests: u64,
38 pub last_state_change: Option<SystemTime>,
39}
40
41#[derive(Debug, Clone)]
43pub struct CircuitStats {
44 pub total_requests: u64,
45 pub successful_requests: u64,
46 pub failed_requests: u64,
47 pub rejected_requests: u64,
48 pub state: CircuitState,
49 pub last_state_change: Option<Instant>,
50 pub consecutive_failures: u64,
51 pub consecutive_successes: u64,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CircuitStateChange {
57 pub endpoint: String,
58 pub old_state: CircuitState,
59 pub new_state: CircuitState,
60 pub timestamp: SystemTime,
61 pub reason: String,
62}
63
64#[cfg(feature = "distributed")]
66pub struct DistributedCircuitState {
67 redis: ConnectionManager,
68 key_prefix: String,
69}
70
71#[cfg(feature = "distributed")]
72impl DistributedCircuitState {
73 pub async fn new(
74 redis_url: &str,
75 key_prefix: impl Into<String>,
76 ) -> Result<Self, redis::RedisError> {
77 let client = RedisClient::open(redis_url)?;
78 let conn = ConnectionManager::new(client).await?;
79 Ok(Self {
80 redis: conn,
81 key_prefix: key_prefix.into(),
82 })
83 }
84
85 async fn key(&self, endpoint: &str) -> String {
86 format!("{}:circuit:{}", self.key_prefix, endpoint)
87 }
88
89 pub async fn save_state(
90 &mut self,
91 endpoint: &str,
92 snapshot: &CircuitBreakerSnapshot,
93 ) -> Result<(), redis::RedisError> {
94 let key = self.key(endpoint).await;
95 let data = bincode::serialize(snapshot).unwrap_or_default();
96 self.redis.set_ex(&key, data, 3600).await
97 }
98
99 pub async fn load_state(&mut self, endpoint: &str) -> Option<CircuitBreakerSnapshot> {
100 let key = self.key(endpoint).await;
101 let data: Vec<u8> = self.redis.get(&key).await.ok()?;
102 bincode::deserialize(&data).ok()
103 }
104}
105
106pub struct CircuitBreaker {
108 config: Arc<RwLock<CircuitBreakerConfig>>,
109 state: Arc<RwLock<CircuitState>>,
110 consecutive_failures: Arc<AtomicU64>,
111 consecutive_successes: Arc<AtomicU64>,
112 total_requests: Arc<AtomicU64>,
113 successful_requests: Arc<AtomicU64>,
114 failed_requests: Arc<AtomicU64>,
115 rejected_requests: Arc<AtomicU64>,
116 last_state_change: Arc<RwLock<Option<Instant>>>,
117 half_open_requests: Arc<AtomicUsize>,
118 persistence_path: Option<PathBuf>,
120 state_tx: broadcast::Sender<CircuitStateChange>,
122 #[cfg(feature = "distributed")]
124 distributed_state: Option<Arc<RwLock<DistributedCircuitState>>>,
125 endpoint: String,
127}
128
129impl Clone for CircuitBreaker {
130 fn clone(&self) -> Self {
131 Self {
132 config: self.config.clone(),
133 state: self.state.clone(),
134 consecutive_failures: self.consecutive_failures.clone(),
135 consecutive_successes: self.consecutive_successes.clone(),
136 total_requests: self.total_requests.clone(),
137 successful_requests: self.successful_requests.clone(),
138 failed_requests: self.failed_requests.clone(),
139 rejected_requests: self.rejected_requests.clone(),
140 last_state_change: self.last_state_change.clone(),
141 half_open_requests: self.half_open_requests.clone(),
142 persistence_path: self.persistence_path.clone(),
143 state_tx: self.state_tx.clone(),
144 #[cfg(feature = "distributed")]
145 distributed_state: self.distributed_state.clone(),
146 endpoint: self.endpoint.clone(),
147 }
148 }
149}
150
151impl CircuitBreaker {
152 pub fn new(config: CircuitBreakerConfig) -> Self {
154 let (state_tx, _) = broadcast::channel(100);
155 Self {
156 config: Arc::new(RwLock::new(config)),
157 state: Arc::new(RwLock::new(CircuitState::Closed)),
158 consecutive_failures: Arc::new(AtomicU64::new(0)),
159 consecutive_successes: Arc::new(AtomicU64::new(0)),
160 total_requests: Arc::new(AtomicU64::new(0)),
161 successful_requests: Arc::new(AtomicU64::new(0)),
162 failed_requests: Arc::new(AtomicU64::new(0)),
163 rejected_requests: Arc::new(AtomicU64::new(0)),
164 last_state_change: Arc::new(RwLock::new(None)),
165 half_open_requests: Arc::new(AtomicUsize::new(0)),
166 persistence_path: None,
167 state_tx,
168 #[cfg(feature = "distributed")]
169 distributed_state: None,
170 endpoint: "default".to_string(),
171 }
172 }
173
174 pub fn with_endpoint(config: CircuitBreakerConfig, endpoint: impl Into<String>) -> Self {
176 let mut breaker = Self::new(config);
177 breaker.endpoint = endpoint.into();
178 breaker
179 }
180
181 pub fn with_persistence(mut self, path: PathBuf) -> Self {
183 self.persistence_path = Some(path);
184 self
185 }
186
187 #[cfg(feature = "distributed")]
189 pub async fn with_distributed_state(
190 mut self,
191 redis_url: &str,
192 ) -> Result<Self, redis::RedisError> {
193 let dist_state = DistributedCircuitState::new(redis_url, "mockforge").await?;
194 self.distributed_state = Some(Arc::new(RwLock::new(dist_state)));
195 Ok(self)
196 }
197
198 pub fn subscribe_state_changes(&self) -> broadcast::Receiver<CircuitStateChange> {
200 self.state_tx.subscribe()
201 }
202
203 pub async fn save_state(&self) -> std::io::Result<()> {
205 if let Some(path) = &self.persistence_path {
206 let snapshot = self.create_snapshot().await;
207 let data = bincode::serialize(&snapshot).map_err(std::io::Error::other)?;
208 tokio::fs::write(path, data).await?;
209 debug!("Circuit breaker state saved to {:?}", path);
210 }
211
212 #[cfg(feature = "distributed")]
214 if let Some(dist_state) = &self.distributed_state {
215 let snapshot = self.create_snapshot().await;
216 if let Err(e) = dist_state.write().await.save_state(&self.endpoint, &snapshot).await {
217 error!("Failed to save state to Redis: {}", e);
218 }
219 }
220
221 Ok(())
222 }
223
224 pub async fn load_state(&self) -> std::io::Result<()> {
226 #[cfg(feature = "distributed")]
228 if let Some(dist_state) = &self.distributed_state {
229 if let Some(snapshot) = dist_state.write().await.load_state(&self.endpoint).await {
230 self.restore_from_snapshot(snapshot).await;
231 info!("Circuit breaker state loaded from Redis");
232 return Ok(());
233 }
234 }
235
236 if let Some(path) = &self.persistence_path {
238 if path.exists() {
239 let data = tokio::fs::read(path).await?;
240 let snapshot: CircuitBreakerSnapshot =
241 bincode::deserialize(&data).map_err(std::io::Error::other)?;
242 self.restore_from_snapshot(snapshot).await;
243 info!("Circuit breaker state loaded from {:?}", path);
244 }
245 }
246
247 Ok(())
248 }
249
250 async fn create_snapshot(&self) -> CircuitBreakerSnapshot {
252 let state = *self.state.read().await;
253 let last_change = self.last_state_change.read().await;
254 let last_state_change = last_change.map(|instant| SystemTime::now() - instant.elapsed());
255
256 CircuitBreakerSnapshot {
257 state,
258 consecutive_failures: self.consecutive_failures.load(Ordering::SeqCst),
259 consecutive_successes: self.consecutive_successes.load(Ordering::SeqCst),
260 total_requests: self.total_requests.load(Ordering::SeqCst),
261 successful_requests: self.successful_requests.load(Ordering::SeqCst),
262 failed_requests: self.failed_requests.load(Ordering::SeqCst),
263 rejected_requests: self.rejected_requests.load(Ordering::SeqCst),
264 last_state_change,
265 }
266 }
267
268 async fn restore_from_snapshot(&self, snapshot: CircuitBreakerSnapshot) {
270 *self.state.write().await = snapshot.state;
271 self.consecutive_failures.store(snapshot.consecutive_failures, Ordering::SeqCst);
272 self.consecutive_successes
273 .store(snapshot.consecutive_successes, Ordering::SeqCst);
274 self.total_requests.store(snapshot.total_requests, Ordering::SeqCst);
275 self.successful_requests.store(snapshot.successful_requests, Ordering::SeqCst);
276 self.failed_requests.store(snapshot.failed_requests, Ordering::SeqCst);
277 self.rejected_requests.store(snapshot.rejected_requests, Ordering::SeqCst);
278
279 if let Some(system_time) = snapshot.last_state_change {
280 if let Ok(elapsed) = system_time.elapsed() {
281 *self.last_state_change.write().await = Some(Instant::now() - elapsed);
282 }
283 }
284 }
285
286 pub async fn allow_request(&self) -> bool {
288 let config = self.config.read().await;
289
290 if !config.enabled {
291 return true;
292 }
293
294 let state = *self.state.read().await;
295
296 match state {
297 CircuitState::Closed => {
298 true
300 }
301 CircuitState::Open => {
302 let last_change = self.last_state_change.read().await;
304 if let Some(last) = *last_change {
305 let elapsed = last.elapsed();
306 if elapsed >= Duration::from_millis(config.timeout_ms) {
307 drop(last_change);
308 drop(config);
309 self.transition_to_half_open().await;
311 return true;
312 }
313 }
314
315 self.rejected_requests.fetch_add(1, Ordering::SeqCst);
317 debug!("Circuit breaker: Request rejected (circuit open)");
318 false
319 }
320 CircuitState::HalfOpen => {
321 let current = self.half_open_requests.load(Ordering::SeqCst);
323 if current < config.half_open_max_requests as usize {
324 self.half_open_requests.fetch_add(1, Ordering::SeqCst);
325 debug!(
326 "Circuit breaker: Request allowed in half-open state ({}/{})",
327 current + 1,
328 config.half_open_max_requests
329 );
330 true
331 } else {
332 self.rejected_requests.fetch_add(1, Ordering::SeqCst);
333 debug!("Circuit breaker: Request rejected (half-open limit reached)");
334 false
335 }
336 }
337 }
338 }
339
340 pub async fn record_success(&self) {
342 let config = self.config.read().await;
343
344 if !config.enabled {
345 return;
346 }
347
348 self.total_requests.fetch_add(1, Ordering::SeqCst);
349 self.successful_requests.fetch_add(1, Ordering::SeqCst);
350 self.consecutive_failures.store(0, Ordering::SeqCst);
351 let consecutive_successes = self.consecutive_successes.fetch_add(1, Ordering::SeqCst) + 1;
352
353 let state = *self.state.read().await;
354
355 if state == CircuitState::HalfOpen {
356 self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
357
358 if consecutive_successes >= config.success_threshold {
359 drop(config);
360 self.transition_to_closed().await;
361 }
362 }
363
364 debug!("Circuit breaker: Success recorded (consecutive: {})", consecutive_successes);
365 }
366
367 pub async fn record_failure(&self) {
369 let config = self.config.read().await;
370
371 if !config.enabled {
372 return;
373 }
374
375 self.total_requests.fetch_add(1, Ordering::SeqCst);
376 self.failed_requests.fetch_add(1, Ordering::SeqCst);
377 self.consecutive_successes.store(0, Ordering::SeqCst);
378 let consecutive_failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
379
380 let state = *self.state.read().await;
381
382 if state == CircuitState::HalfOpen {
383 self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
384 drop(config);
385 self.transition_to_open().await;
386 } else if state == CircuitState::Closed {
387 if consecutive_failures >= config.failure_threshold {
389 drop(config);
390 self.transition_to_open().await;
391 return;
392 }
393
394 let total = self.total_requests.load(Ordering::SeqCst);
396 if total >= config.min_requests_for_rate {
397 let failed = self.failed_requests.load(Ordering::SeqCst);
398 let failure_rate = (failed as f64 / total as f64) * 100.0;
399
400 if failure_rate >= config.failure_rate_threshold {
401 drop(config);
402 self.transition_to_open().await;
403 return;
404 }
405 }
406 }
407
408 debug!("Circuit breaker: Failure recorded (consecutive: {})", consecutive_failures);
409 }
410
411 async fn transition_to_open(&self) {
413 let mut state = self.state.write().await;
414 if *state != CircuitState::Open {
415 let old_state = *state;
416 *state = CircuitState::Open;
417 *self.last_state_change.write().await = Some(Instant::now());
418 warn!("Circuit breaker '{}': Transitioned to OPEN state", self.endpoint);
419
420 let change = CircuitStateChange {
422 endpoint: self.endpoint.clone(),
423 old_state,
424 new_state: CircuitState::Open,
425 timestamp: SystemTime::now(),
426 reason: "Failure threshold exceeded".to_string(),
427 };
428 let _ = self.state_tx.send(change);
429
430 drop(state);
432 if let Err(e) = self.save_state().await {
433 error!("Failed to save circuit breaker state: {}", e);
434 }
435 }
436 }
437
438 async fn transition_to_half_open(&self) {
440 let mut state = self.state.write().await;
441 if *state != CircuitState::HalfOpen {
442 let old_state = *state;
443 *state = CircuitState::HalfOpen;
444 *self.last_state_change.write().await = Some(Instant::now());
445 self.half_open_requests.store(0, Ordering::SeqCst);
446 info!("Circuit breaker '{}': Transitioned to HALF-OPEN state", self.endpoint);
447
448 let change = CircuitStateChange {
450 endpoint: self.endpoint.clone(),
451 old_state,
452 new_state: CircuitState::HalfOpen,
453 timestamp: SystemTime::now(),
454 reason: "Timeout elapsed, testing recovery".to_string(),
455 };
456 let _ = self.state_tx.send(change);
457
458 drop(state);
460 if let Err(e) = self.save_state().await {
461 error!("Failed to save circuit breaker state: {}", e);
462 }
463 }
464 }
465
466 async fn transition_to_closed(&self) {
468 let mut state = self.state.write().await;
469 if *state != CircuitState::Closed {
470 let old_state = *state;
471 *state = CircuitState::Closed;
472 *self.last_state_change.write().await = Some(Instant::now());
473 self.consecutive_failures.store(0, Ordering::SeqCst);
474 self.consecutive_successes.store(0, Ordering::SeqCst);
475 info!("Circuit breaker '{}': Transitioned to CLOSED state", self.endpoint);
476
477 let change = CircuitStateChange {
479 endpoint: self.endpoint.clone(),
480 old_state,
481 new_state: CircuitState::Closed,
482 timestamp: SystemTime::now(),
483 reason: "Service recovered successfully".to_string(),
484 };
485 let _ = self.state_tx.send(change);
486
487 drop(state);
489 if let Err(e) = self.save_state().await {
490 error!("Failed to save circuit breaker state: {}", e);
491 }
492 }
493 }
494
495 pub async fn reset(&self) {
497 *self.state.write().await = CircuitState::Closed;
498 *self.last_state_change.write().await = None;
499 self.consecutive_failures.store(0, Ordering::SeqCst);
500 self.consecutive_successes.store(0, Ordering::SeqCst);
501 self.total_requests.store(0, Ordering::SeqCst);
502 self.successful_requests.store(0, Ordering::SeqCst);
503 self.failed_requests.store(0, Ordering::SeqCst);
504 self.rejected_requests.store(0, Ordering::SeqCst);
505 self.half_open_requests.store(0, Ordering::SeqCst);
506 info!("Circuit breaker: Reset to initial state");
507 }
508
509 pub async fn stats(&self) -> CircuitStats {
511 CircuitStats {
512 total_requests: self.total_requests.load(Ordering::SeqCst),
513 successful_requests: self.successful_requests.load(Ordering::SeqCst),
514 failed_requests: self.failed_requests.load(Ordering::SeqCst),
515 rejected_requests: self.rejected_requests.load(Ordering::SeqCst),
516 state: *self.state.read().await,
517 last_state_change: *self.last_state_change.read().await,
518 consecutive_failures: self.consecutive_failures.load(Ordering::SeqCst),
519 consecutive_successes: self.consecutive_successes.load(Ordering::SeqCst),
520 }
521 }
522
523 pub async fn state(&self) -> CircuitState {
525 *self.state.read().await
526 }
527
528 pub async fn update_config(&self, config: CircuitBreakerConfig) {
530 *self.config.write().await = config;
531 info!("Circuit breaker: Configuration updated");
532 }
533
534 pub async fn config(&self) -> CircuitBreakerConfig {
536 self.config.read().await.clone()
537 }
538}
539
540#[derive(Debug, Clone)]
542pub struct BulkheadStats {
543 pub active_requests: u32,
544 pub queued_requests: u32,
545 pub total_requests: u64,
546 pub rejected_requests: u64,
547 pub timeout_requests: u64,
548}
549
550pub struct Bulkhead {
552 config: Arc<RwLock<BulkheadConfig>>,
553 active_requests: Arc<AtomicUsize>,
554 queued_requests: Arc<AtomicUsize>,
555 total_requests: Arc<AtomicU64>,
556 rejected_requests: Arc<AtomicU64>,
557 timeout_requests: Arc<AtomicU64>,
558}
559
560impl Clone for Bulkhead {
561 fn clone(&self) -> Self {
562 Self {
563 config: self.config.clone(),
564 active_requests: self.active_requests.clone(),
565 queued_requests: self.queued_requests.clone(),
566 total_requests: self.total_requests.clone(),
567 rejected_requests: self.rejected_requests.clone(),
568 timeout_requests: self.timeout_requests.clone(),
569 }
570 }
571}
572
573impl Bulkhead {
574 pub fn new(config: BulkheadConfig) -> Self {
576 Self {
577 config: Arc::new(RwLock::new(config)),
578 active_requests: Arc::new(AtomicUsize::new(0)),
579 queued_requests: Arc::new(AtomicUsize::new(0)),
580 total_requests: Arc::new(AtomicU64::new(0)),
581 rejected_requests: Arc::new(AtomicU64::new(0)),
582 timeout_requests: Arc::new(AtomicU64::new(0)),
583 }
584 }
585
586 pub async fn try_acquire(&self) -> Result<BulkheadGuard, BulkheadError> {
588 let config = self.config.read().await;
589
590 if !config.enabled {
591 return Ok(BulkheadGuard::new(self.clone(), false));
592 }
593
594 self.total_requests.fetch_add(1, Ordering::SeqCst);
595
596 let active = self.active_requests.load(Ordering::SeqCst);
597
598 if active < config.max_concurrent_requests as usize {
600 self.active_requests.fetch_add(1, Ordering::SeqCst);
601 debug!(
602 "Bulkhead: Request accepted ({}/{})",
603 active + 1,
604 config.max_concurrent_requests
605 );
606 return Ok(BulkheadGuard::new(self.clone(), true));
607 }
608
609 if config.max_queue_size == 0 {
611 self.rejected_requests.fetch_add(1, Ordering::SeqCst);
612 warn!("Bulkhead: Request rejected (no queue)");
613 return Err(BulkheadError::Rejected);
614 }
615
616 let queued = self.queued_requests.load(Ordering::SeqCst);
617 if queued >= config.max_queue_size as usize {
618 self.rejected_requests.fetch_add(1, Ordering::SeqCst);
619 warn!("Bulkhead: Request rejected (queue full: {}/{})", queued, config.max_queue_size);
620 return Err(BulkheadError::Rejected);
621 }
622
623 self.queued_requests.fetch_add(1, Ordering::SeqCst);
625 debug!("Bulkhead: Request queued ({}/{})", queued + 1, config.max_queue_size);
626
627 let timeout = Duration::from_millis(config.queue_timeout_ms);
628 drop(config);
629
630 let start = Instant::now();
632 loop {
633 if start.elapsed() >= timeout {
634 self.queued_requests.fetch_sub(1, Ordering::SeqCst);
635 self.timeout_requests.fetch_add(1, Ordering::SeqCst);
636 warn!("Bulkhead: Request timeout in queue");
637 return Err(BulkheadError::Timeout);
638 }
639
640 let active = self.active_requests.load(Ordering::SeqCst);
641 let config = self.config.read().await;
642
643 if active < config.max_concurrent_requests as usize {
644 self.active_requests.fetch_add(1, Ordering::SeqCst);
645 self.queued_requests.fetch_sub(1, Ordering::SeqCst);
646 debug!("Bulkhead: Queued request accepted");
647 return Ok(BulkheadGuard::new(self.clone(), true));
648 }
649
650 drop(config);
651 tokio::time::sleep(Duration::from_millis(10)).await;
652 }
653 }
654
655 fn release(&self) {
657 let prev = self.active_requests.fetch_sub(1, Ordering::SeqCst);
658 debug!("Bulkhead: Request completed ({}/{})", prev - 1, prev);
659 }
660
661 pub async fn stats(&self) -> BulkheadStats {
663 BulkheadStats {
664 active_requests: self.active_requests.load(Ordering::SeqCst) as u32,
665 queued_requests: self.queued_requests.load(Ordering::SeqCst) as u32,
666 total_requests: self.total_requests.load(Ordering::SeqCst),
667 rejected_requests: self.rejected_requests.load(Ordering::SeqCst),
668 timeout_requests: self.timeout_requests.load(Ordering::SeqCst),
669 }
670 }
671
672 pub async fn reset(&self) {
674 self.total_requests.store(0, Ordering::SeqCst);
675 self.rejected_requests.store(0, Ordering::SeqCst);
676 self.timeout_requests.store(0, Ordering::SeqCst);
677 info!("Bulkhead: Statistics reset");
678 }
679
680 pub async fn update_config(&self, config: BulkheadConfig) {
682 *self.config.write().await = config;
683 info!("Bulkhead: Configuration updated");
684 }
685
686 pub async fn config(&self) -> BulkheadConfig {
688 self.config.read().await.clone()
689 }
690}
691
692pub struct BulkheadGuard {
694 bulkhead: Option<Bulkhead>,
695 should_release: bool,
696}
697
698impl BulkheadGuard {
699 fn new(bulkhead: Bulkhead, should_release: bool) -> Self {
700 Self {
701 bulkhead: Some(bulkhead),
702 should_release,
703 }
704 }
705}
706
707impl Drop for BulkheadGuard {
708 fn drop(&mut self) {
709 if self.should_release {
710 if let Some(bulkhead) = &self.bulkhead {
711 bulkhead.release();
712 }
713 }
714 }
715}
716
717#[derive(Debug, Clone, Copy, PartialEq, Eq)]
719pub enum BulkheadError {
720 Rejected,
722 Timeout,
724}
725
726impl std::fmt::Display for BulkheadError {
727 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728 match self {
729 BulkheadError::Rejected => write!(f, "Request rejected by bulkhead"),
730 BulkheadError::Timeout => write!(f, "Request timed out in bulkhead queue"),
731 }
732 }
733}
734
735impl std::error::Error for BulkheadError {}
736
737#[derive(Debug, Clone)]
739pub struct RetryConfig {
740 pub max_attempts: u32,
742 pub initial_backoff_ms: u64,
744 pub max_backoff_ms: u64,
746 pub backoff_multiplier: f64,
748 pub jitter_factor: f64,
750}
751
752impl Default for RetryConfig {
753 fn default() -> Self {
754 Self {
755 max_attempts: 3,
756 initial_backoff_ms: 100,
757 max_backoff_ms: 30000,
758 backoff_multiplier: 2.0,
759 jitter_factor: 0.1,
760 }
761 }
762}
763
764pub struct RetryPolicy {
766 config: RetryConfig,
767}
768
769impl RetryPolicy {
770 pub fn new(config: RetryConfig) -> Self {
771 Self { config }
772 }
773
774 pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
776 where
777 F: FnMut() -> Fut,
778 Fut: std::future::Future<Output = Result<T, E>>,
779 E: std::fmt::Debug,
780 {
781 let mut attempt = 0;
782 let mut backoff = self.config.initial_backoff_ms;
783
784 loop {
785 attempt += 1;
786
787 match f().await {
788 Ok(result) => {
789 if attempt > 1 {
790 info!("Retry successful after {} attempts", attempt);
791 }
792 return Ok(result);
793 }
794 Err(err) => {
795 if attempt >= self.config.max_attempts {
796 warn!("Max retry attempts ({}) reached", self.config.max_attempts);
797 return Err(err);
798 }
799
800 let jitter = if self.config.jitter_factor > 0.0 {
802 let range = backoff as f64 * self.config.jitter_factor;
803 (rand::random::<f64>() * range * 2.0 - range) as u64
804 } else {
805 0
806 };
807
808 let sleep_duration = backoff.saturating_add(jitter);
809 debug!(
810 "Retry attempt {}/{} after {}ms (backoff: {}ms, jitter: {}ms)",
811 attempt, self.config.max_attempts, sleep_duration, backoff, jitter
812 );
813
814 tokio::time::sleep(Duration::from_millis(sleep_duration)).await;
815
816 backoff = ((backoff as f64 * self.config.backoff_multiplier) as u64)
818 .min(self.config.max_backoff_ms);
819 }
820 }
821 }
822 }
823}
824
825pub struct CircuitBreakerAwareRetry {
827 retry_config: RetryConfig,
828 circuit_breaker: Option<Arc<CircuitBreaker>>,
829}
830
831impl CircuitBreakerAwareRetry {
832 pub fn new(retry_config: RetryConfig) -> Self {
833 Self {
834 retry_config,
835 circuit_breaker: None,
836 }
837 }
838
839 pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
840 self.circuit_breaker = Some(circuit_breaker);
841 self
842 }
843
844 pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
846 where
847 F: FnMut() -> Fut,
848 Fut: std::future::Future<Output = Result<T, E>>,
849 E: std::fmt::Debug,
850 {
851 if let Some(cb) = &self.circuit_breaker {
853 if !cb.allow_request().await {
854 debug!("Circuit breaker open, skipping retry");
855 return f().await;
857 }
858 }
859
860 let mut attempt = 0;
861 let mut backoff = self.retry_config.initial_backoff_ms;
862
863 loop {
864 if let Some(cb) = &self.circuit_breaker {
866 if !cb.allow_request().await {
867 debug!("Circuit breaker opened during retry, aborting");
868 return f().await;
869 }
870 }
871
872 attempt += 1;
873
874 match f().await {
875 Ok(result) => {
876 if let Some(cb) = &self.circuit_breaker {
877 cb.record_success().await;
878 }
879 if attempt > 1 {
880 info!("Retry successful after {} attempts", attempt);
881 }
882 return Ok(result);
883 }
884 Err(err) => {
885 if let Some(cb) = &self.circuit_breaker {
886 cb.record_failure().await;
887 }
888
889 if attempt >= self.retry_config.max_attempts {
890 warn!("Max retry attempts ({}) reached", self.retry_config.max_attempts);
891 return Err(err);
892 }
893
894 let jitter = if self.retry_config.jitter_factor > 0.0 {
896 let range = backoff as f64 * self.retry_config.jitter_factor;
897 (rand::random::<f64>() * range * 2.0 - range) as u64
898 } else {
899 0
900 };
901
902 let sleep_duration = backoff.saturating_add(jitter);
903 debug!(
904 "Retry attempt {}/{} after {}ms",
905 attempt, self.retry_config.max_attempts, sleep_duration
906 );
907
908 tokio::time::sleep(Duration::from_millis(sleep_duration)).await;
909
910 backoff = ((backoff as f64 * self.retry_config.backoff_multiplier) as u64)
911 .min(self.retry_config.max_backoff_ms);
912 }
913 }
914 }
915 }
916}
917
918pub trait FallbackHandler: Send + Sync {
920 fn handle(&self) -> Vec<u8>;
921}
922
923pub struct JsonFallbackHandler {
925 response: Vec<u8>,
926}
927
928impl JsonFallbackHandler {
929 pub fn new(json: serde_json::Value) -> Self {
930 let response = serde_json::to_vec(&json).unwrap_or_default();
931 Self { response }
932 }
933}
934
935impl FallbackHandler for JsonFallbackHandler {
936 fn handle(&self) -> Vec<u8> {
937 self.response.clone()
938 }
939}
940
941pub struct CircuitBreakerMetrics {
943 pub state_gauge: IntGauge,
944 pub total_requests: Counter,
945 pub successful_requests: Counter,
946 pub failed_requests: Counter,
947 pub rejected_requests: Counter,
948 pub state_transitions: Counter,
949 pub request_duration: Histogram,
950}
951
952impl CircuitBreakerMetrics {
953 pub fn new(registry: &Registry, endpoint: &str) -> Result<Self, prometheus::Error> {
954 let state_gauge = IntGauge::with_opts(
955 Opts::new(
956 "circuit_breaker_state",
957 "Circuit breaker state (0=Closed, 1=Open, 2=HalfOpen)",
958 )
959 .const_label("endpoint", endpoint),
960 )?;
961 registry.register(Box::new(state_gauge.clone()))?;
962
963 let total_requests = Counter::with_opts(
964 Opts::new("circuit_breaker_requests_total", "Total requests through circuit breaker")
965 .const_label("endpoint", endpoint),
966 )?;
967 registry.register(Box::new(total_requests.clone()))?;
968
969 let successful_requests = Counter::with_opts(
970 Opts::new("circuit_breaker_requests_successful", "Successful requests")
971 .const_label("endpoint", endpoint),
972 )?;
973 registry.register(Box::new(successful_requests.clone()))?;
974
975 let failed_requests = Counter::with_opts(
976 Opts::new("circuit_breaker_requests_failed", "Failed requests")
977 .const_label("endpoint", endpoint),
978 )?;
979 registry.register(Box::new(failed_requests.clone()))?;
980
981 let rejected_requests = Counter::with_opts(
982 Opts::new("circuit_breaker_requests_rejected", "Rejected requests")
983 .const_label("endpoint", endpoint),
984 )?;
985 registry.register(Box::new(rejected_requests.clone()))?;
986
987 let state_transitions = Counter::with_opts(
988 Opts::new("circuit_breaker_state_transitions", "Circuit breaker state transitions")
989 .const_label("endpoint", endpoint),
990 )?;
991 registry.register(Box::new(state_transitions.clone()))?;
992
993 let request_duration = Histogram::with_opts(
994 HistogramOpts::new("circuit_breaker_request_duration_seconds", "Request duration")
995 .const_label("endpoint", endpoint),
996 )?;
997 registry.register(Box::new(request_duration.clone()))?;
998
999 Ok(Self {
1000 state_gauge,
1001 total_requests,
1002 successful_requests,
1003 failed_requests,
1004 rejected_requests,
1005 state_transitions,
1006 request_duration,
1007 })
1008 }
1009
1010 pub fn update_state(&self, state: CircuitState) {
1011 let value = match state {
1012 CircuitState::Closed => 0,
1013 CircuitState::Open => 1,
1014 CircuitState::HalfOpen => 2,
1015 };
1016 self.state_gauge.set(value);
1017 }
1018}
1019
1020pub struct BulkheadMetrics {
1022 pub active_requests: IntGauge,
1023 pub queued_requests: IntGauge,
1024 pub total_requests: Counter,
1025 pub rejected_requests: Counter,
1026 pub timeout_requests: Counter,
1027 pub queue_duration: Histogram,
1028}
1029
1030impl BulkheadMetrics {
1031 pub fn new(registry: &Registry, service: &str) -> Result<Self, prometheus::Error> {
1032 let active_requests = IntGauge::with_opts(
1033 Opts::new("bulkhead_active_requests", "Active requests")
1034 .const_label("service", service),
1035 )?;
1036 registry.register(Box::new(active_requests.clone()))?;
1037
1038 let queued_requests = IntGauge::with_opts(
1039 Opts::new("bulkhead_queued_requests", "Queued requests")
1040 .const_label("service", service),
1041 )?;
1042 registry.register(Box::new(queued_requests.clone()))?;
1043
1044 let total_requests = Counter::with_opts(
1045 Opts::new("bulkhead_requests_total", "Total requests").const_label("service", service),
1046 )?;
1047 registry.register(Box::new(total_requests.clone()))?;
1048
1049 let rejected_requests = Counter::with_opts(
1050 Opts::new("bulkhead_requests_rejected", "Rejected requests")
1051 .const_label("service", service),
1052 )?;
1053 registry.register(Box::new(rejected_requests.clone()))?;
1054
1055 let timeout_requests = Counter::with_opts(
1056 Opts::new("bulkhead_requests_timeout", "Timeout requests")
1057 .const_label("service", service),
1058 )?;
1059 registry.register(Box::new(timeout_requests.clone()))?;
1060
1061 let queue_duration = Histogram::with_opts(
1062 HistogramOpts::new("bulkhead_queue_duration_seconds", "Time spent in queue")
1063 .const_label("service", service),
1064 )?;
1065 registry.register(Box::new(queue_duration.clone()))?;
1066
1067 Ok(Self {
1068 active_requests,
1069 queued_requests,
1070 total_requests,
1071 rejected_requests,
1072 timeout_requests,
1073 queue_duration,
1074 })
1075 }
1076}
1077
1078pub struct DynamicThresholdAdjuster {
1080 window_size: Duration,
1082 history: Arc<RwLock<Vec<(Instant, bool)>>>,
1084 min_threshold: u64,
1086 max_threshold: u64,
1088 target_error_rate: f64,
1090}
1091
1092impl DynamicThresholdAdjuster {
1093 pub fn new(
1094 window_size: Duration,
1095 min_threshold: u64,
1096 max_threshold: u64,
1097 target_error_rate: f64,
1098 ) -> Self {
1099 Self {
1100 window_size,
1101 history: Arc::new(RwLock::new(Vec::new())),
1102 min_threshold,
1103 max_threshold,
1104 target_error_rate,
1105 }
1106 }
1107
1108 pub async fn record(&self, success: bool) {
1110 let mut history = self.history.write().await;
1111 history.push((Instant::now(), success));
1112
1113 let cutoff = Instant::now() - self.window_size;
1115 history.retain(|(time, _)| *time > cutoff);
1116 }
1117
1118 pub async fn calculate_threshold(&self, current_threshold: u64) -> u64 {
1120 let history = self.history.read().await;
1121
1122 if history.is_empty() {
1123 return current_threshold;
1124 }
1125
1126 let total = history.len() as f64;
1127 let failures = history.iter().filter(|(_, success)| !success).count() as f64;
1128 let error_rate = failures / total;
1129
1130 let adjustment_factor = if error_rate > self.target_error_rate {
1132 0.9
1134 } else if error_rate < self.target_error_rate * 0.5 {
1135 1.1
1137 } else {
1138 1.0
1139 };
1140
1141 let new_threshold = (current_threshold as f64 * adjustment_factor) as u64;
1142 new_threshold.clamp(self.min_threshold, self.max_threshold)
1143 }
1144}
1145
1146pub struct CircuitBreakerManager {
1148 breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
1149 default_config: CircuitBreakerConfig,
1150 registry: Arc<Registry>,
1151 metrics: Arc<RwLock<HashMap<String, Arc<CircuitBreakerMetrics>>>>,
1152 threshold_adjusters: Arc<RwLock<HashMap<String, Arc<DynamicThresholdAdjuster>>>>,
1153}
1154
1155impl CircuitBreakerManager {
1156 pub fn new(default_config: CircuitBreakerConfig, registry: Arc<Registry>) -> Self {
1157 Self {
1158 breakers: Arc::new(RwLock::new(HashMap::new())),
1159 default_config,
1160 registry,
1161 metrics: Arc::new(RwLock::new(HashMap::new())),
1162 threshold_adjusters: Arc::new(RwLock::new(HashMap::new())),
1163 }
1164 }
1165
1166 pub async fn get_breaker(&self, endpoint: &str) -> Arc<CircuitBreaker> {
1168 let breakers = self.breakers.read().await;
1169
1170 if let Some(breaker) = breakers.get(endpoint) {
1171 return breaker.clone();
1172 }
1173
1174 drop(breakers);
1175
1176 let mut breakers = self.breakers.write().await;
1178
1179 if let Some(breaker) = breakers.get(endpoint) {
1181 return breaker.clone();
1182 }
1183
1184 let breaker = Arc::new(CircuitBreaker::new(self.default_config.clone()));
1185 breakers.insert(endpoint.to_string(), breaker.clone());
1186
1187 if let Ok(metrics) = CircuitBreakerMetrics::new(&self.registry, endpoint) {
1189 let mut metrics_map = self.metrics.write().await;
1190 metrics_map.insert(endpoint.to_string(), Arc::new(metrics));
1191 }
1192
1193 let adjuster = Arc::new(DynamicThresholdAdjuster::new(Duration::from_secs(60), 2, 20, 0.1));
1195 let mut adjusters = self.threshold_adjusters.write().await;
1196 adjusters.insert(endpoint.to_string(), adjuster);
1197
1198 info!("Created circuit breaker for endpoint: {}", endpoint);
1199 breaker
1200 }
1201
1202 pub async fn get_metrics(&self, endpoint: &str) -> Option<Arc<CircuitBreakerMetrics>> {
1204 let metrics = self.metrics.read().await;
1205 metrics.get(endpoint).cloned()
1206 }
1207
1208 pub async fn get_all_states(&self) -> HashMap<String, CircuitState> {
1210 let breakers = self.breakers.read().await;
1211 let mut states = HashMap::new();
1212
1213 for (endpoint, breaker) in breakers.iter() {
1214 states.insert(endpoint.clone(), breaker.state().await);
1215 }
1216
1217 states
1218 }
1219
1220 pub async fn record_with_adjustment(&self, endpoint: &str, success: bool) {
1222 if let Some(adjuster) = self.threshold_adjusters.read().await.get(endpoint) {
1224 adjuster.record(success).await;
1225
1226 if let Some(breaker) = self.breakers.read().await.get(endpoint) {
1228 let current_config = breaker.config().await;
1229 let new_threshold =
1230 adjuster.calculate_threshold(current_config.failure_threshold).await;
1231
1232 if new_threshold != current_config.failure_threshold {
1233 let mut new_config = current_config;
1234 new_config.failure_threshold = new_threshold;
1235 breaker.update_config(new_config).await;
1236 debug!("Adjusted threshold for {} to {}", endpoint, new_threshold);
1237 }
1238 }
1239 }
1240 }
1241}
1242
1243impl Clone for CircuitBreakerManager {
1244 fn clone(&self) -> Self {
1245 Self {
1246 breakers: self.breakers.clone(),
1247 default_config: self.default_config.clone(),
1248 registry: self.registry.clone(),
1249 metrics: self.metrics.clone(),
1250 threshold_adjusters: self.threshold_adjusters.clone(),
1251 }
1252 }
1253}
1254
1255pub struct BulkheadManager {
1257 bulkheads: Arc<RwLock<HashMap<String, Arc<Bulkhead>>>>,
1258 default_config: BulkheadConfig,
1259 registry: Arc<Registry>,
1260 metrics: Arc<RwLock<HashMap<String, Arc<BulkheadMetrics>>>>,
1261}
1262
1263impl BulkheadManager {
1264 pub fn new(default_config: BulkheadConfig, registry: Arc<Registry>) -> Self {
1265 Self {
1266 bulkheads: Arc::new(RwLock::new(HashMap::new())),
1267 default_config,
1268 registry,
1269 metrics: Arc::new(RwLock::new(HashMap::new())),
1270 }
1271 }
1272
1273 pub async fn get_bulkhead(&self, service: &str) -> Arc<Bulkhead> {
1275 let bulkheads = self.bulkheads.read().await;
1276
1277 if let Some(bulkhead) = bulkheads.get(service) {
1278 return bulkhead.clone();
1279 }
1280
1281 drop(bulkheads);
1282
1283 let mut bulkheads = self.bulkheads.write().await;
1285
1286 if let Some(bulkhead) = bulkheads.get(service) {
1288 return bulkhead.clone();
1289 }
1290
1291 let bulkhead = Arc::new(Bulkhead::new(self.default_config.clone()));
1292 bulkheads.insert(service.to_string(), bulkhead.clone());
1293
1294 if let Ok(metrics) = BulkheadMetrics::new(&self.registry, service) {
1296 let mut metrics_map = self.metrics.write().await;
1297 metrics_map.insert(service.to_string(), Arc::new(metrics));
1298 }
1299
1300 info!("Created bulkhead for service: {}", service);
1301 bulkhead
1302 }
1303
1304 pub async fn get_metrics(&self, service: &str) -> Option<Arc<BulkheadMetrics>> {
1306 let metrics = self.metrics.read().await;
1307 metrics.get(service).cloned()
1308 }
1309
1310 pub async fn get_all_stats(&self) -> HashMap<String, BulkheadStats> {
1312 let bulkheads = self.bulkheads.read().await;
1313 let mut stats = HashMap::new();
1314
1315 for (service, bulkhead) in bulkheads.iter() {
1316 stats.insert(service.clone(), bulkhead.stats().await);
1317 }
1318
1319 stats
1320 }
1321}
1322
1323impl Clone for BulkheadManager {
1324 fn clone(&self) -> Self {
1325 Self {
1326 bulkheads: self.bulkheads.clone(),
1327 default_config: self.default_config.clone(),
1328 registry: self.registry.clone(),
1329 metrics: self.metrics.clone(),
1330 }
1331 }
1332}
1333
1334#[derive(Clone)]
1336pub enum HealthCheckProtocol {
1337 Http {
1338 url: String,
1339 },
1340 Https {
1341 url: String,
1342 },
1343 Tcp {
1344 host: String,
1345 port: u16,
1346 },
1347 Grpc {
1348 endpoint: String,
1349 },
1350 WebSocket {
1351 url: String,
1352 },
1353 Custom {
1354 checker: Arc<dyn CustomHealthChecker>,
1355 },
1356}
1357
1358impl std::fmt::Debug for HealthCheckProtocol {
1359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360 match self {
1361 HealthCheckProtocol::Http { url } => write!(f, "Http {{ url: {:?} }}", url),
1362 HealthCheckProtocol::Https { url } => write!(f, "Https {{ url: {:?} }}", url),
1363 HealthCheckProtocol::Tcp { host, port } => {
1364 write!(f, "Tcp {{ host: {:?}, port: {} }}", host, port)
1365 }
1366 HealthCheckProtocol::Grpc { endpoint } => {
1367 write!(f, "Grpc {{ endpoint: {:?} }}", endpoint)
1368 }
1369 HealthCheckProtocol::WebSocket { url } => write!(f, "WebSocket {{ url: {:?} }}", url),
1370 HealthCheckProtocol::Custom { .. } => write!(f, "Custom {{ checker: <custom> }}"),
1371 }
1372 }
1373}
1374
1375pub trait CustomHealthChecker: Send + Sync {
1377 fn check(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + '_>>;
1378}
1379
1380pub struct HealthCheckIntegration {
1382 circuit_manager: Arc<CircuitBreakerManager>,
1383}
1384
1385impl HealthCheckIntegration {
1386 pub fn new(circuit_manager: Arc<CircuitBreakerManager>) -> Self {
1387 Self { circuit_manager }
1388 }
1389
1390 pub async fn update_from_health(&self, endpoint: &str, healthy: bool) {
1392 let breaker = self.circuit_manager.get_breaker(endpoint).await;
1393
1394 if healthy {
1395 breaker.record_success().await;
1396 } else {
1397 breaker.record_failure().await;
1398 }
1399
1400 info!("Updated circuit breaker for {} based on health check: {}", endpoint, healthy);
1401 }
1402
1403 pub async fn check_health(&self, protocol: &HealthCheckProtocol) -> bool {
1405 match protocol {
1406 HealthCheckProtocol::Http { url } | HealthCheckProtocol::Https { url } => {
1407 let client = reqwest::Client::new();
1408 match client.get(url).timeout(Duration::from_secs(5)).send().await {
1409 Ok(response) => response.status().is_success(),
1410 Err(_) => false,
1411 }
1412 }
1413 HealthCheckProtocol::Tcp { host, port } => {
1414 use tokio::net::TcpStream;
1415 TcpStream::connect(format!("{}:{}", host, port)).await.is_ok()
1416 }
1417 HealthCheckProtocol::Grpc { endpoint } => {
1418 let client = reqwest::Client::new();
1420 match client
1421 .post(format!("{}/grpc.health.v1.Health/Check", endpoint))
1422 .timeout(Duration::from_secs(5))
1423 .send()
1424 .await
1425 {
1426 Ok(response) => response.status().is_success(),
1427 Err(_) => false,
1428 }
1429 }
1430 HealthCheckProtocol::WebSocket { url } => {
1431 use tokio_tungstenite::connect_async;
1433 connect_async(url).await.is_ok()
1434 }
1435 HealthCheckProtocol::Custom { checker } => checker.check().await,
1436 }
1437 }
1438
1439 pub async fn start_monitoring(
1441 &self,
1442 endpoint: String,
1443 protocol: HealthCheckProtocol,
1444 interval: Duration,
1445 ) {
1446 let circuit_manager = self.circuit_manager.clone();
1447 let integration = self.clone();
1448
1449 tokio::spawn(async move {
1450 loop {
1451 tokio::time::sleep(interval).await;
1452
1453 let healthy = integration.check_health(&protocol).await;
1454 let breaker = circuit_manager.get_breaker(&endpoint).await;
1455
1456 if healthy {
1457 breaker.record_success().await;
1458 } else {
1459 breaker.record_failure().await;
1460 }
1461 }
1462 });
1463 }
1464}
1465
1466impl Clone for HealthCheckIntegration {
1467 fn clone(&self) -> Self {
1468 Self {
1469 circuit_manager: self.circuit_manager.clone(),
1470 }
1471 }
1472}
1473
1474pub struct ResilienceWebSocketNotifier {
1476 connections: Arc<RwLock<Vec<broadcast::Sender<String>>>>,
1477}
1478
1479impl ResilienceWebSocketNotifier {
1480 pub fn new() -> Self {
1481 Self {
1482 connections: Arc::new(RwLock::new(Vec::new())),
1483 }
1484 }
1485
1486 pub async fn register(&self) -> broadcast::Receiver<String> {
1488 let (tx, rx) = broadcast::channel(100);
1489 self.connections.write().await.push(tx);
1490 rx
1491 }
1492
1493 pub async fn notify(&self, message: impl Into<String>) {
1495 let msg = message.into();
1496 let connections = self.connections.read().await;
1497 for tx in connections.iter() {
1498 let _ = tx.send(msg.clone());
1499 }
1500 }
1501
1502 pub async fn monitor_circuit_breaker(&self, breaker: Arc<CircuitBreaker>) {
1504 let notifier = self.clone();
1505 let mut rx = breaker.subscribe_state_changes();
1506
1507 tokio::spawn(async move {
1508 while let Ok(change) = rx.recv().await {
1509 let message = serde_json::to_string(&change).unwrap_or_default();
1510 notifier.notify(message).await;
1511 }
1512 });
1513 }
1514}
1515
1516impl Clone for ResilienceWebSocketNotifier {
1517 fn clone(&self) -> Self {
1518 Self {
1519 connections: self.connections.clone(),
1520 }
1521 }
1522}
1523
1524impl Default for ResilienceWebSocketNotifier {
1525 fn default() -> Self {
1526 Self::new()
1527 }
1528}
1529
1530pub struct CircuitBreakerAlertHandler {
1532 alert_manager: Arc<crate::alerts::AlertManager>,
1533}
1534
1535impl CircuitBreakerAlertHandler {
1536 pub fn new(alert_manager: Arc<crate::alerts::AlertManager>) -> Self {
1537 Self { alert_manager }
1538 }
1539
1540 pub async fn monitor(&self, breaker: Arc<CircuitBreaker>) {
1542 let alert_manager = self.alert_manager.clone();
1543 let mut rx = breaker.subscribe_state_changes();
1544
1545 tokio::spawn(async move {
1546 while let Ok(change) = rx.recv().await {
1547 if change.new_state == CircuitState::Open {
1549 let alert = crate::alerts::Alert::new(
1550 crate::alerts::AlertSeverity::Critical,
1551 crate::alerts::AlertType::Custom {
1552 message: format!("Circuit breaker opened for {}", change.endpoint),
1553 metadata: {
1554 let mut map = HashMap::new();
1555 map.insert("endpoint".to_string(), change.endpoint.clone());
1556 map.insert("reason".to_string(), change.reason.clone());
1557 map.insert(
1558 "timestamp".to_string(),
1559 format!("{:?}", change.timestamp),
1560 );
1561 map
1562 },
1563 },
1564 format!(
1565 "Circuit breaker for endpoint '{}' has opened: {}",
1566 change.endpoint, change.reason
1567 ),
1568 );
1569 alert_manager.fire_alert(alert);
1570 } else if change.new_state == CircuitState::Closed
1571 && change.old_state == CircuitState::Open
1572 {
1573 info!("Circuit breaker for '{}' recovered and closed", change.endpoint);
1575 }
1576 }
1577 });
1578 }
1579}
1580
1581#[derive(Debug, Clone)]
1583pub struct SLOConfig {
1584 pub target_success_rate: f64,
1586 pub window_duration: Duration,
1588 pub error_budget_percent: f64,
1590}
1591
1592impl Default for SLOConfig {
1593 fn default() -> Self {
1594 Self {
1595 target_success_rate: 0.99, window_duration: Duration::from_secs(300), error_budget_percent: 1.0, }
1599 }
1600}
1601
1602pub struct SLOTracker {
1604 config: SLOConfig,
1605 history: Arc<RwLock<Vec<(Instant, bool)>>>,
1606}
1607
1608impl SLOTracker {
1609 pub fn new(config: SLOConfig) -> Self {
1610 Self {
1611 config,
1612 history: Arc::new(RwLock::new(Vec::new())),
1613 }
1614 }
1615
1616 pub async fn record(&self, success: bool) {
1618 let mut history = self.history.write().await;
1619 history.push((Instant::now(), success));
1620
1621 let cutoff = Instant::now() - self.config.window_duration;
1623 history.retain(|(time, _)| *time > cutoff);
1624 }
1625
1626 pub async fn success_rate(&self) -> f64 {
1628 let history = self.history.read().await;
1629 if history.is_empty() {
1630 return 1.0;
1631 }
1632
1633 let total = history.len() as f64;
1634 let successes = history.iter().filter(|(_, success)| *success).count() as f64;
1635 successes / total
1636 }
1637
1638 pub async fn is_violated(&self) -> bool {
1640 let rate = self.success_rate().await;
1641 rate < self.config.target_success_rate
1642 }
1643
1644 pub async fn error_budget_remaining(&self) -> f64 {
1646 let rate = self.success_rate().await;
1647 let error_rate = 1.0 - rate;
1648 let budget_used = (error_rate / (self.config.error_budget_percent / 100.0)) * 100.0;
1649 (100.0 - budget_used).max(0.0)
1650 }
1651}
1652
1653pub struct SLOCircuitBreakerIntegration {
1655 circuit_manager: Arc<CircuitBreakerManager>,
1656 slo_trackers: Arc<RwLock<HashMap<String, Arc<SLOTracker>>>>,
1657}
1658
1659impl SLOCircuitBreakerIntegration {
1660 pub fn new(circuit_manager: Arc<CircuitBreakerManager>) -> Self {
1661 Self {
1662 circuit_manager,
1663 slo_trackers: Arc::new(RwLock::new(HashMap::new())),
1664 }
1665 }
1666
1667 pub async fn get_tracker(&self, endpoint: &str, config: SLOConfig) -> Arc<SLOTracker> {
1669 let mut trackers = self.slo_trackers.write().await;
1670 trackers
1671 .entry(endpoint.to_string())
1672 .or_insert_with(|| Arc::new(SLOTracker::new(config)))
1673 .clone()
1674 }
1675
1676 pub async fn record_request(&self, endpoint: &str, success: bool, slo_config: SLOConfig) {
1678 let tracker = self.get_tracker(endpoint, slo_config).await;
1679 tracker.record(success).await;
1680
1681 if tracker.is_violated().await {
1683 let breaker = self.circuit_manager.get_breaker(endpoint).await;
1684 breaker.record_failure().await;
1685 warn!("SLO violated for endpoint '{}', recording failure in circuit breaker", endpoint);
1686 }
1687 }
1688
1689 pub async fn get_slo_status(&self, endpoint: &str) -> Option<(f64, f64, bool)> {
1691 let trackers = self.slo_trackers.read().await;
1692 if let Some(tracker) = trackers.get(endpoint) {
1693 let success_rate = tracker.success_rate().await;
1694 let budget_remaining = tracker.error_budget_remaining().await;
1695 let violated = tracker.is_violated().await;
1696 Some((success_rate, budget_remaining, violated))
1697 } else {
1698 None
1699 }
1700 }
1701}
1702
1703pub struct PerUserBulkhead {
1705 bulkheads: Arc<RwLock<HashMap<String, Arc<Bulkhead>>>>,
1706 default_config: BulkheadConfig,
1707 registry: Arc<Registry>,
1708}
1709
1710impl PerUserBulkhead {
1711 pub fn new(default_config: BulkheadConfig, registry: Arc<Registry>) -> Self {
1712 Self {
1713 bulkheads: Arc::new(RwLock::new(HashMap::new())),
1714 default_config,
1715 registry,
1716 }
1717 }
1718
1719 pub async fn get_bulkhead(&self, user_id: &str) -> Arc<Bulkhead> {
1721 let bulkheads = self.bulkheads.read().await;
1722
1723 if let Some(bulkhead) = bulkheads.get(user_id) {
1724 return bulkhead.clone();
1725 }
1726
1727 drop(bulkheads);
1728
1729 let mut bulkheads = self.bulkheads.write().await;
1731
1732 if let Some(bulkhead) = bulkheads.get(user_id) {
1734 return bulkhead.clone();
1735 }
1736
1737 let bulkhead = Arc::new(Bulkhead::new(self.default_config.clone()));
1738 bulkheads.insert(user_id.to_string(), bulkhead.clone());
1739
1740 info!("Created per-user bulkhead for user: {}", user_id);
1741 bulkhead
1742 }
1743
1744 pub async fn try_acquire(&self, user_id: &str) -> Result<BulkheadGuard, BulkheadError> {
1746 let bulkhead = self.get_bulkhead(user_id).await;
1747 bulkhead.try_acquire().await
1748 }
1749
1750 pub async fn get_user_stats(&self, user_id: &str) -> Option<BulkheadStats> {
1752 let bulkheads = self.bulkheads.read().await;
1753 if let Some(bulkhead) = bulkheads.get(user_id) {
1754 Some(bulkhead.stats().await)
1755 } else {
1756 None
1757 }
1758 }
1759
1760 pub async fn get_all_stats(&self) -> HashMap<String, BulkheadStats> {
1762 let bulkheads = self.bulkheads.read().await;
1763 let mut stats = HashMap::new();
1764
1765 for (user_id, bulkhead) in bulkheads.iter() {
1766 stats.insert(user_id.clone(), bulkhead.stats().await);
1767 }
1768
1769 stats
1770 }
1771
1772 pub async fn remove_user(&self, user_id: &str) -> bool {
1774 let mut bulkheads = self.bulkheads.write().await;
1775 bulkheads.remove(user_id).is_some()
1776 }
1777}
1778
1779impl Clone for PerUserBulkhead {
1780 fn clone(&self) -> Self {
1781 Self {
1782 bulkheads: self.bulkheads.clone(),
1783 default_config: self.default_config.clone(),
1784 registry: self.registry.clone(),
1785 }
1786 }
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791 use super::*;
1792
1793 #[tokio::test]
1794 async fn test_circuit_breaker_closed_to_open() {
1795 let config = CircuitBreakerConfig {
1796 enabled: true,
1797 failure_threshold: 3,
1798 ..Default::default()
1799 };
1800
1801 let cb = CircuitBreaker::new(config);
1802
1803 assert_eq!(cb.state().await, CircuitState::Closed);
1805
1806 for _ in 0..2 {
1808 assert!(cb.allow_request().await);
1809 cb.record_failure().await;
1810 assert_eq!(cb.state().await, CircuitState::Closed);
1811 }
1812
1813 assert!(cb.allow_request().await);
1815 cb.record_failure().await;
1816 assert_eq!(cb.state().await, CircuitState::Open);
1817
1818 assert!(!cb.allow_request().await);
1820 }
1821
1822 #[tokio::test]
1823 async fn test_circuit_breaker_half_open_to_closed() {
1824 let config = CircuitBreakerConfig {
1825 enabled: true,
1826 failure_threshold: 2,
1827 success_threshold: 2,
1828 timeout_ms: 100,
1829 ..Default::default()
1830 };
1831
1832 let cb = CircuitBreaker::new(config);
1833
1834 for _ in 0..2 {
1836 cb.allow_request().await;
1837 cb.record_failure().await;
1838 }
1839 assert_eq!(cb.state().await, CircuitState::Open);
1840
1841 tokio::time::sleep(Duration::from_millis(150)).await;
1843
1844 assert!(cb.allow_request().await);
1846 assert_eq!(cb.state().await, CircuitState::HalfOpen);
1847
1848 cb.record_success().await;
1850 assert_eq!(cb.state().await, CircuitState::HalfOpen);
1851
1852 cb.allow_request().await;
1853 cb.record_success().await;
1854 assert_eq!(cb.state().await, CircuitState::Closed);
1855 }
1856
1857 #[tokio::test]
1858 async fn test_bulkhead_basic() {
1859 let config = BulkheadConfig {
1860 enabled: true,
1861 max_concurrent_requests: 2,
1862 max_queue_size: 0,
1863 ..Default::default()
1864 };
1865
1866 let bulkhead = Bulkhead::new(config);
1867
1868 let _guard1 = bulkhead.try_acquire().await.unwrap();
1870 let _guard2 = bulkhead.try_acquire().await.unwrap();
1871
1872 assert!(matches!(bulkhead.try_acquire().await, Err(BulkheadError::Rejected)));
1874
1875 drop(_guard1);
1877
1878 let _guard3 = bulkhead.try_acquire().await.unwrap();
1880 }
1881
1882 #[tokio::test]
1883 async fn test_bulkhead_with_queue() {
1884 let config = BulkheadConfig {
1885 enabled: true,
1886 max_concurrent_requests: 1,
1887 max_queue_size: 2,
1888 queue_timeout_ms: 1000,
1889 };
1890
1891 let bulkhead = Bulkhead::new(config);
1892
1893 let guard1 = bulkhead.try_acquire().await.unwrap();
1894
1895 let bulkhead_clone = bulkhead.clone();
1897 let handle = tokio::spawn(async move { bulkhead_clone.try_acquire().await });
1898
1899 tokio::time::sleep(Duration::from_millis(50)).await;
1901
1902 let stats = bulkhead.stats().await;
1904 assert_eq!(stats.active_requests, 1);
1905 assert_eq!(stats.queued_requests, 1);
1906
1907 drop(guard1);
1909
1910 let _guard2 = handle.await.unwrap().unwrap();
1912 }
1913}