Skip to main content

heliosdb_proxy/circuit_breaker/
manager.rs

1//! Circuit Breaker Manager
2//!
3//! Manages circuit breakers for multiple nodes with centralized configuration
4//! and monitoring.
5
6use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10
11use super::breaker::{CircuitBreaker, CircuitOpen, RequestGuard};
12use super::config::{CircuitBreakerConfig, NodeOverride};
13use super::metrics::{CircuitMetrics, CircuitStats};
14use super::state::{CircuitBreakerListener, CircuitState};
15
16/// Configuration for the circuit breaker manager
17#[derive(Debug, Clone)]
18pub struct ManagerConfig {
19    /// Global default configuration
20    pub default_config: CircuitBreakerConfig,
21
22    /// Per-node configuration overrides
23    pub node_overrides: Vec<NodeOverride>,
24
25    /// Enable manager-level metrics collection
26    pub metrics_enabled: bool,
27
28    /// Auto-create breakers for unknown nodes
29    pub auto_create: bool,
30}
31
32impl Default for ManagerConfig {
33    fn default() -> Self {
34        Self {
35            default_config: CircuitBreakerConfig::default(),
36            node_overrides: Vec::new(),
37            metrics_enabled: true,
38            auto_create: true,
39        }
40    }
41}
42
43impl ManagerConfig {
44    /// Create a new manager config
45    pub fn new(default_config: CircuitBreakerConfig) -> Self {
46        Self {
47            default_config,
48            ..Default::default()
49        }
50    }
51
52    /// Add a node override
53    pub fn with_node_override(mut self, override_: NodeOverride) -> Self {
54        self.node_overrides.push(override_);
55        self
56    }
57
58    /// Enable or disable metrics
59    pub fn with_metrics(mut self, enabled: bool) -> Self {
60        self.metrics_enabled = enabled;
61        self
62    }
63
64    /// Get effective config for a node
65    pub fn get_node_config(&self, node_id: &str) -> CircuitBreakerConfig {
66        for override_ in &self.node_overrides {
67            if override_.node_id == node_id {
68                return override_.apply_to(&self.default_config);
69            }
70        }
71        self.default_config.clone()
72    }
73}
74
75/// Circuit Breaker Manager
76///
77/// Manages multiple circuit breakers for different nodes, providing centralized
78/// configuration, monitoring, and node health filtering.
79pub struct CircuitBreakerManager {
80    /// Circuit breakers per node
81    breakers: DashMap<String, CircuitBreaker>,
82
83    /// Configuration
84    config: parking_lot::RwLock<ManagerConfig>,
85
86    /// Shared listeners for all breakers
87    shared_listeners: parking_lot::RwLock<Vec<Arc<dyn CircuitBreakerListener>>>,
88
89    /// Metrics collector
90    metrics: CircuitMetrics,
91}
92
93impl CircuitBreakerManager {
94    /// Create a new circuit breaker manager
95    pub fn new(config: ManagerConfig) -> Self {
96        Self {
97            breakers: DashMap::new(),
98            config: parking_lot::RwLock::new(config),
99            shared_listeners: parking_lot::RwLock::new(Vec::new()),
100            metrics: CircuitMetrics::new(),
101        }
102    }
103
104    /// Create with default configuration
105    pub fn with_defaults() -> Self {
106        Self::new(ManagerConfig::default())
107    }
108
109    /// Get or create a circuit breaker for a node
110    pub fn get_breaker(&self, node_id: &str) -> CircuitBreaker {
111        if let Some(breaker) = self.breakers.get(node_id) {
112            return breaker.clone();
113        }
114
115        let config = self.config.read();
116        if !config.auto_create {
117            // Return a permissive breaker if auto-create disabled
118            return CircuitBreaker::new(node_id, CircuitBreakerConfig::default());
119        }
120
121        let node_config = config.get_node_config(node_id);
122        drop(config);
123
124        let breaker = CircuitBreaker::new(node_id, node_config);
125
126        // Add shared listeners
127        let listeners = self.shared_listeners.read();
128        for listener in listeners.iter() {
129            breaker.add_listener(Arc::clone(listener));
130        }
131
132        self.breakers.insert(node_id.to_string(), breaker.clone());
133        breaker
134    }
135
136    /// Try to allow a request to a specific node
137    pub fn allow_request(&self, node_id: &str) -> Result<RequestGuard, CircuitOpen> {
138        let breaker = self.get_breaker(node_id);
139        let result = breaker.allow_request();
140
141        // Record metrics
142        let config = self.config.read();
143        if config.metrics_enabled {
144            drop(config);
145            match &result {
146                Ok(_) => self.metrics.record_allowed(node_id),
147                Err(_) => self.metrics.record_rejected(node_id),
148            }
149        }
150
151        result
152    }
153
154    /// Wrap a function with circuit breaker protection
155    pub fn wrap_request<F, T, E>(&self, node_id: &str, f: F) -> Result<T, WrapError<E>>
156    where
157        F: FnOnce() -> Result<T, E>,
158        E: std::fmt::Display,
159    {
160        let guard = self
161            .allow_request(node_id)
162            .map_err(WrapError::CircuitOpen)?;
163
164        match f() {
165            Ok(result) => {
166                guard.success();
167                Ok(result)
168            }
169            Err(e) => {
170                guard.failure(&e.to_string());
171                Err(WrapError::Inner(e))
172            }
173        }
174    }
175
176    /// Async version of wrap_request
177    pub async fn wrap_request_async<F, Fut, T, E>(
178        &self,
179        node_id: &str,
180        f: F,
181    ) -> Result<T, WrapError<E>>
182    where
183        F: FnOnce() -> Fut,
184        Fut: std::future::Future<Output = Result<T, E>>,
185        E: std::fmt::Display,
186    {
187        let guard = self
188            .allow_request(node_id)
189            .map_err(WrapError::CircuitOpen)?;
190
191        match f().await {
192            Ok(result) => {
193                guard.success();
194                Ok(result)
195            }
196            Err(e) => {
197                guard.failure(&e.to_string());
198                Err(WrapError::Inner(e))
199            }
200        }
201    }
202
203    /// Get healthy nodes from a list (filters out nodes with open circuits)
204    pub fn get_healthy_nodes<T: HasNodeId + Clone>(&self, nodes: &[T]) -> Vec<T> {
205        nodes
206            .iter()
207            .filter(|node| {
208                self.breakers
209                    .get(node.node_id())
210                    .map(|b| b.get_state() != CircuitState::Open)
211                    .unwrap_or(true) // Unknown nodes are considered healthy
212            })
213            .cloned()
214            .collect()
215    }
216
217    /// Get all node IDs with open circuits
218    pub fn get_open_circuits(&self) -> Vec<String> {
219        self.breakers
220            .iter()
221            .filter(|entry| entry.value().get_state() == CircuitState::Open)
222            .map(|entry| entry.key().clone())
223            .collect()
224    }
225
226    /// Get all node IDs with unhealthy circuits (open or half-open)
227    pub fn get_unhealthy_nodes(&self) -> Vec<String> {
228        self.breakers
229            .iter()
230            .filter(|entry| entry.value().get_state().is_unhealthy())
231            .map(|entry| entry.key().clone())
232            .collect()
233    }
234
235    /// Get state for all managed nodes
236    pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
237        self.breakers
238            .iter()
239            .map(|entry| (entry.key().clone(), entry.value().get_state()))
240            .collect()
241    }
242
243    /// Force open circuit for a node
244    pub fn force_open(&self, node_id: &str, admin: Option<&str>) {
245        let breaker = self.get_breaker(node_id);
246        breaker.force_open(admin);
247    }
248
249    /// Force close circuit for a node
250    pub fn force_close(&self, node_id: &str, admin: Option<&str>) {
251        if let Some(breaker) = self.breakers.get(node_id) {
252            breaker.force_close(admin);
253        }
254    }
255
256    /// Reset circuit for a node
257    pub fn reset(&self, node_id: &str) {
258        if let Some(breaker) = self.breakers.get(node_id) {
259            breaker.reset();
260        }
261    }
262
263    /// Reset all circuits
264    pub fn reset_all(&self) {
265        for entry in self.breakers.iter() {
266            entry.value().reset();
267        }
268    }
269
270    /// Remove a circuit breaker
271    pub fn remove(&self, node_id: &str) -> Option<CircuitBreaker> {
272        self.breakers.remove(node_id).map(|(_, b)| b)
273    }
274
275    /// Add a shared listener for all circuit breakers
276    pub fn add_listener(&self, listener: Arc<dyn CircuitBreakerListener>) {
277        // Add to existing breakers
278        for entry in self.breakers.iter() {
279            entry.value().add_listener(Arc::clone(&listener));
280        }
281
282        // Store for future breakers
283        self.shared_listeners.write().push(listener);
284    }
285
286    /// Update global configuration
287    pub fn update_config(&self, config: ManagerConfig) {
288        // Update existing breakers with new configs
289        for entry in self.breakers.iter() {
290            let node_config = config.get_node_config(entry.key());
291            entry.value().update_config(node_config);
292        }
293
294        *self.config.write() = config;
295    }
296
297    /// Get current configuration
298    pub fn config(&self) -> ManagerConfig {
299        self.config.read().clone()
300    }
301
302    /// Get metrics
303    pub fn metrics(&self) -> &CircuitMetrics {
304        &self.metrics
305    }
306
307    /// Get statistics for all circuits
308    pub fn get_stats(&self) -> CircuitStats {
309        let mut stats = CircuitStats::default();
310
311        for entry in self.breakers.iter() {
312            let breaker = entry.value();
313            stats.add_node_stats(
314                entry.key(),
315                breaker.get_state(),
316                breaker.failure_count(),
317                breaker.open_count(),
318                breaker.total_failures(),
319                breaker.total_successes(),
320            );
321        }
322
323        stats
324    }
325
326    /// Get number of managed nodes
327    pub fn node_count(&self) -> usize {
328        self.breakers.len()
329    }
330
331    /// Check if a specific node exists
332    pub fn has_node(&self, node_id: &str) -> bool {
333        self.breakers.contains_key(node_id)
334    }
335}
336
337/// Error type for wrapped requests
338#[derive(Debug)]
339pub enum WrapError<E> {
340    /// Circuit is open
341    CircuitOpen(CircuitOpen),
342    /// Inner function error
343    Inner(E),
344}
345
346impl<E: std::fmt::Display> std::fmt::Display for WrapError<E> {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        match self {
349            WrapError::CircuitOpen(open) => write!(f, "{}", open),
350            WrapError::Inner(e) => write!(f, "{}", e),
351        }
352    }
353}
354
355impl<E: std::error::Error + 'static> std::error::Error for WrapError<E> {
356    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
357        match self {
358            WrapError::CircuitOpen(open) => Some(open),
359            WrapError::Inner(e) => Some(e),
360        }
361    }
362}
363
364impl<E> WrapError<E> {
365    /// Check if this is a circuit open error
366    pub fn is_circuit_open(&self) -> bool {
367        matches!(self, WrapError::CircuitOpen(_))
368    }
369
370    /// Get retry-after duration if circuit is open
371    pub fn retry_after(&self) -> Option<Duration> {
372        match self {
373            WrapError::CircuitOpen(open) => Some(open.retry_after),
374            WrapError::Inner(_) => None,
375        }
376    }
377}
378
379/// Trait for types that have a node ID
380pub trait HasNodeId {
381    fn node_id(&self) -> &str;
382}
383
384impl HasNodeId for String {
385    fn node_id(&self) -> &str {
386        self
387    }
388}
389
390impl HasNodeId for &str {
391    fn node_id(&self) -> &str {
392        self
393    }
394}
395
396/// Simple node info for testing
397#[derive(Debug, Clone)]
398pub struct SimpleNode {
399    pub id: String,
400}
401
402impl HasNodeId for SimpleNode {
403    fn node_id(&self) -> &str {
404        &self.id
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_manager_creation() {
414        let manager = CircuitBreakerManager::with_defaults();
415        assert_eq!(manager.node_count(), 0);
416    }
417
418    #[test]
419    fn test_manager_get_breaker() {
420        let manager = CircuitBreakerManager::with_defaults();
421
422        let breaker = manager.get_breaker("node-1");
423        assert_eq!(breaker.node_id(), "node-1");
424        assert_eq!(breaker.get_state(), CircuitState::Closed);
425
426        assert_eq!(manager.node_count(), 1);
427        assert!(manager.has_node("node-1"));
428    }
429
430    #[test]
431    fn test_manager_allow_request() {
432        let manager = CircuitBreakerManager::with_defaults();
433
434        let guard = manager.allow_request("node-1").expect("should allow");
435        guard.success();
436
437        let breaker = manager.get_breaker("node-1");
438        assert_eq!(breaker.total_successes(), 1);
439    }
440
441    #[test]
442    fn test_manager_healthy_nodes() {
443        let config =
444            ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(2).build());
445        let manager = CircuitBreakerManager::new(config);
446
447        // Create some nodes
448        let nodes = vec![
449            SimpleNode {
450                id: "node-1".to_string(),
451            },
452            SimpleNode {
453                id: "node-2".to_string(),
454            },
455            SimpleNode {
456                id: "node-3".to_string(),
457            },
458        ];
459
460        // Initially all healthy
461        let healthy = manager.get_healthy_nodes(&nodes);
462        assert_eq!(healthy.len(), 3);
463
464        // Open circuit for node-2
465        manager.force_open("node-2", None);
466
467        let healthy = manager.get_healthy_nodes(&nodes);
468        assert_eq!(healthy.len(), 2);
469        assert!(healthy.iter().all(|n| n.id != "node-2"));
470    }
471
472    #[test]
473    fn test_manager_wrap_request() {
474        let manager = CircuitBreakerManager::with_defaults();
475
476        let result = manager.wrap_request("node-1", || Ok::<i32, &str>(42));
477        assert_eq!(result.unwrap(), 42);
478
479        let result = manager.wrap_request("node-1", || Err::<i32, &str>("error"));
480        assert!(result.is_err());
481    }
482
483    #[test]
484    fn test_manager_node_overrides() {
485        let config =
486            ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(5).build())
487                .with_node_override(NodeOverride::new("special-node").with_failure_threshold(10));
488
489        let manager = CircuitBreakerManager::new(config);
490
491        let normal_breaker = manager.get_breaker("normal-node");
492        assert_eq!(normal_breaker.config().failure_threshold, 5);
493
494        let special_breaker = manager.get_breaker("special-node");
495        assert_eq!(special_breaker.config().failure_threshold, 10);
496    }
497
498    #[test]
499    fn test_manager_get_open_circuits() {
500        let manager = CircuitBreakerManager::with_defaults();
501
502        manager.force_open("node-1", None);
503        manager.force_open("node-3", None);
504        let _ = manager.get_breaker("node-2"); // Closed
505
506        let open = manager.get_open_circuits();
507        assert_eq!(open.len(), 2);
508        assert!(open.contains(&"node-1".to_string()));
509        assert!(open.contains(&"node-3".to_string()));
510    }
511
512    #[test]
513    fn test_manager_reset_all() {
514        let config =
515            ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(1).build());
516        let manager = CircuitBreakerManager::new(config);
517
518        // Open some circuits
519        manager.force_open("node-1", None);
520        manager.force_open("node-2", None);
521
522        assert_eq!(manager.get_open_circuits().len(), 2);
523
524        manager.reset_all();
525        assert_eq!(manager.get_open_circuits().len(), 0);
526    }
527
528    #[tokio::test]
529    async fn test_manager_wrap_async() {
530        let manager = CircuitBreakerManager::with_defaults();
531
532        let result = manager
533            .wrap_request_async("node-1", || async { Ok::<i32, &str>(42) })
534            .await;
535        assert_eq!(result.unwrap(), 42);
536    }
537}