Skip to main content

heliosdb_proxy/circuit_breaker/
manager.rs

1//! Circuit Breaker Manager
2//!
3//! Manages circuit breakers for multiple nodes with centralized configuration
4//! and monitoring.
5
6use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10
11use super::breaker::{CircuitBreaker, CircuitOpen, RequestGuard};
12use super::config::{CircuitBreakerConfig, NodeOverride};
13use super::metrics::{CircuitMetrics, CircuitStats};
14use super::state::{CircuitBreakerListener, CircuitState};
15
16/// Configuration for the circuit breaker manager
17#[derive(Debug, Clone)]
18pub struct ManagerConfig {
19    /// Global default configuration
20    pub default_config: CircuitBreakerConfig,
21
22    /// Per-node configuration overrides
23    pub node_overrides: Vec<NodeOverride>,
24
25    /// Enable manager-level metrics collection
26    pub metrics_enabled: bool,
27
28    /// Auto-create breakers for unknown nodes
29    pub auto_create: bool,
30}
31
32impl Default for ManagerConfig {
33    fn default() -> Self {
34        Self {
35            default_config: CircuitBreakerConfig::default(),
36            node_overrides: Vec::new(),
37            metrics_enabled: true,
38            auto_create: true,
39        }
40    }
41}
42
43impl ManagerConfig {
44    /// Create a new manager config
45    pub fn new(default_config: CircuitBreakerConfig) -> Self {
46        Self {
47            default_config,
48            ..Default::default()
49        }
50    }
51
52    /// Add a node override
53    pub fn with_node_override(mut self, override_: NodeOverride) -> Self {
54        self.node_overrides.push(override_);
55        self
56    }
57
58    /// Enable or disable metrics
59    pub fn with_metrics(mut self, enabled: bool) -> Self {
60        self.metrics_enabled = enabled;
61        self
62    }
63
64    /// Get effective config for a node
65    pub fn get_node_config(&self, node_id: &str) -> CircuitBreakerConfig {
66        for override_ in &self.node_overrides {
67            if override_.node_id == node_id {
68                return override_.apply_to(&self.default_config);
69            }
70        }
71        self.default_config.clone()
72    }
73}
74
75/// Circuit Breaker Manager
76///
77/// Manages multiple circuit breakers for different nodes, providing centralized
78/// configuration, monitoring, and node health filtering.
79pub struct CircuitBreakerManager {
80    /// Circuit breakers per node
81    breakers: DashMap<String, CircuitBreaker>,
82
83    /// Configuration
84    config: parking_lot::RwLock<ManagerConfig>,
85
86    /// Shared listeners for all breakers
87    shared_listeners: parking_lot::RwLock<Vec<Arc<dyn CircuitBreakerListener>>>,
88
89    /// Metrics collector
90    metrics: CircuitMetrics,
91}
92
93impl CircuitBreakerManager {
94    /// Create a new circuit breaker manager
95    pub fn new(config: ManagerConfig) -> Self {
96        Self {
97            breakers: DashMap::new(),
98            config: parking_lot::RwLock::new(config),
99            shared_listeners: parking_lot::RwLock::new(Vec::new()),
100            metrics: CircuitMetrics::new(),
101        }
102    }
103
104    /// Create with default configuration
105    pub fn with_defaults() -> Self {
106        Self::new(ManagerConfig::default())
107    }
108
109    /// Get or create a circuit breaker for a node
110    pub fn get_breaker(&self, node_id: &str) -> CircuitBreaker {
111        if let Some(breaker) = self.breakers.get(node_id) {
112            return breaker.clone();
113        }
114
115        let config = self.config.read();
116        if !config.auto_create {
117            // Return a permissive breaker if auto-create disabled
118            return CircuitBreaker::new(node_id, CircuitBreakerConfig::default());
119        }
120
121        let node_config = config.get_node_config(node_id);
122        drop(config);
123
124        let breaker = CircuitBreaker::new(node_id, node_config);
125
126        // Add shared listeners
127        let listeners = self.shared_listeners.read();
128        for listener in listeners.iter() {
129            breaker.add_listener(Arc::clone(listener));
130        }
131
132        self.breakers.insert(node_id.to_string(), breaker.clone());
133        breaker
134    }
135
136    /// Try to allow a request to a specific node
137    pub fn allow_request(&self, node_id: &str) -> Result<RequestGuard, CircuitOpen> {
138        let breaker = self.get_breaker(node_id);
139        let result = breaker.allow_request();
140
141        // Record metrics
142        let config = self.config.read();
143        if config.metrics_enabled {
144            drop(config);
145            match &result {
146                Ok(_) => self.metrics.record_allowed(node_id),
147                Err(_) => self.metrics.record_rejected(node_id),
148            }
149        }
150
151        result
152    }
153
154    /// Wrap a function with circuit breaker protection
155    pub fn wrap_request<F, T, E>(&self, node_id: &str, f: F) -> Result<T, WrapError<E>>
156    where
157        F: FnOnce() -> Result<T, E>,
158        E: std::fmt::Display,
159    {
160        let guard = self.allow_request(node_id).map_err(WrapError::CircuitOpen)?;
161
162        match f() {
163            Ok(result) => {
164                guard.success();
165                Ok(result)
166            }
167            Err(e) => {
168                guard.failure(&e.to_string());
169                Err(WrapError::Inner(e))
170            }
171        }
172    }
173
174    /// Async version of wrap_request
175    pub async fn wrap_request_async<F, Fut, T, E>(
176        &self,
177        node_id: &str,
178        f: F,
179    ) -> Result<T, WrapError<E>>
180    where
181        F: FnOnce() -> Fut,
182        Fut: std::future::Future<Output = Result<T, E>>,
183        E: std::fmt::Display,
184    {
185        let guard = self.allow_request(node_id).map_err(WrapError::CircuitOpen)?;
186
187        match f().await {
188            Ok(result) => {
189                guard.success();
190                Ok(result)
191            }
192            Err(e) => {
193                guard.failure(&e.to_string());
194                Err(WrapError::Inner(e))
195            }
196        }
197    }
198
199    /// Get healthy nodes from a list (filters out nodes with open circuits)
200    pub fn get_healthy_nodes<T: HasNodeId + Clone>(&self, nodes: &[T]) -> Vec<T> {
201        nodes
202            .iter()
203            .filter(|node| {
204                self.breakers
205                    .get(node.node_id())
206                    .map(|b| b.get_state() != CircuitState::Open)
207                    .unwrap_or(true) // Unknown nodes are considered healthy
208            })
209            .cloned()
210            .collect()
211    }
212
213    /// Get all node IDs with open circuits
214    pub fn get_open_circuits(&self) -> Vec<String> {
215        self.breakers
216            .iter()
217            .filter(|entry| entry.value().get_state() == CircuitState::Open)
218            .map(|entry| entry.key().clone())
219            .collect()
220    }
221
222    /// Get all node IDs with unhealthy circuits (open or half-open)
223    pub fn get_unhealthy_nodes(&self) -> Vec<String> {
224        self.breakers
225            .iter()
226            .filter(|entry| entry.value().get_state().is_unhealthy())
227            .map(|entry| entry.key().clone())
228            .collect()
229    }
230
231    /// Get state for all managed nodes
232    pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
233        self.breakers
234            .iter()
235            .map(|entry| (entry.key().clone(), entry.value().get_state()))
236            .collect()
237    }
238
239    /// Force open circuit for a node
240    pub fn force_open(&self, node_id: &str, admin: Option<&str>) {
241        let breaker = self.get_breaker(node_id);
242        breaker.force_open(admin);
243    }
244
245    /// Force close circuit for a node
246    pub fn force_close(&self, node_id: &str, admin: Option<&str>) {
247        if let Some(breaker) = self.breakers.get(node_id) {
248            breaker.force_close(admin);
249        }
250    }
251
252    /// Reset circuit for a node
253    pub fn reset(&self, node_id: &str) {
254        if let Some(breaker) = self.breakers.get(node_id) {
255            breaker.reset();
256        }
257    }
258
259    /// Reset all circuits
260    pub fn reset_all(&self) {
261        for entry in self.breakers.iter() {
262            entry.value().reset();
263        }
264    }
265
266    /// Remove a circuit breaker
267    pub fn remove(&self, node_id: &str) -> Option<CircuitBreaker> {
268        self.breakers.remove(node_id).map(|(_, b)| b)
269    }
270
271    /// Add a shared listener for all circuit breakers
272    pub fn add_listener(&self, listener: Arc<dyn CircuitBreakerListener>) {
273        // Add to existing breakers
274        for entry in self.breakers.iter() {
275            entry.value().add_listener(Arc::clone(&listener));
276        }
277
278        // Store for future breakers
279        self.shared_listeners.write().push(listener);
280    }
281
282    /// Update global configuration
283    pub fn update_config(&self, config: ManagerConfig) {
284        // Update existing breakers with new configs
285        for entry in self.breakers.iter() {
286            let node_config = config.get_node_config(entry.key());
287            entry.value().update_config(node_config);
288        }
289
290        *self.config.write() = config;
291    }
292
293    /// Get current configuration
294    pub fn config(&self) -> ManagerConfig {
295        self.config.read().clone()
296    }
297
298    /// Get metrics
299    pub fn metrics(&self) -> &CircuitMetrics {
300        &self.metrics
301    }
302
303    /// Get statistics for all circuits
304    pub fn get_stats(&self) -> CircuitStats {
305        let mut stats = CircuitStats::default();
306
307        for entry in self.breakers.iter() {
308            let breaker = entry.value();
309            stats.add_node_stats(
310                entry.key(),
311                breaker.get_state(),
312                breaker.failure_count(),
313                breaker.open_count(),
314                breaker.total_failures(),
315                breaker.total_successes(),
316            );
317        }
318
319        stats
320    }
321
322    /// Get number of managed nodes
323    pub fn node_count(&self) -> usize {
324        self.breakers.len()
325    }
326
327    /// Check if a specific node exists
328    pub fn has_node(&self, node_id: &str) -> bool {
329        self.breakers.contains_key(node_id)
330    }
331}
332
333/// Error type for wrapped requests
334#[derive(Debug)]
335pub enum WrapError<E> {
336    /// Circuit is open
337    CircuitOpen(CircuitOpen),
338    /// Inner function error
339    Inner(E),
340}
341
342impl<E: std::fmt::Display> std::fmt::Display for WrapError<E> {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        match self {
345            WrapError::CircuitOpen(open) => write!(f, "{}", open),
346            WrapError::Inner(e) => write!(f, "{}", e),
347        }
348    }
349}
350
351impl<E: std::error::Error + 'static> std::error::Error for WrapError<E> {
352    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
353        match self {
354            WrapError::CircuitOpen(open) => Some(open),
355            WrapError::Inner(e) => Some(e),
356        }
357    }
358}
359
360impl<E> WrapError<E> {
361    /// Check if this is a circuit open error
362    pub fn is_circuit_open(&self) -> bool {
363        matches!(self, WrapError::CircuitOpen(_))
364    }
365
366    /// Get retry-after duration if circuit is open
367    pub fn retry_after(&self) -> Option<Duration> {
368        match self {
369            WrapError::CircuitOpen(open) => Some(open.retry_after),
370            WrapError::Inner(_) => None,
371        }
372    }
373}
374
375/// Trait for types that have a node ID
376pub trait HasNodeId {
377    fn node_id(&self) -> &str;
378}
379
380impl HasNodeId for String {
381    fn node_id(&self) -> &str {
382        self
383    }
384}
385
386impl HasNodeId for &str {
387    fn node_id(&self) -> &str {
388        self
389    }
390}
391
392/// Simple node info for testing
393#[derive(Debug, Clone)]
394pub struct SimpleNode {
395    pub id: String,
396}
397
398impl HasNodeId for SimpleNode {
399    fn node_id(&self) -> &str {
400        &self.id
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_manager_creation() {
410        let manager = CircuitBreakerManager::with_defaults();
411        assert_eq!(manager.node_count(), 0);
412    }
413
414    #[test]
415    fn test_manager_get_breaker() {
416        let manager = CircuitBreakerManager::with_defaults();
417
418        let breaker = manager.get_breaker("node-1");
419        assert_eq!(breaker.node_id(), "node-1");
420        assert_eq!(breaker.get_state(), CircuitState::Closed);
421
422        assert_eq!(manager.node_count(), 1);
423        assert!(manager.has_node("node-1"));
424    }
425
426    #[test]
427    fn test_manager_allow_request() {
428        let manager = CircuitBreakerManager::with_defaults();
429
430        let guard = manager.allow_request("node-1").expect("should allow");
431        guard.success();
432
433        let breaker = manager.get_breaker("node-1");
434        assert_eq!(breaker.total_successes(), 1);
435    }
436
437    #[test]
438    fn test_manager_healthy_nodes() {
439        let config = ManagerConfig::new(
440            CircuitBreakerConfig::builder()
441                .failure_threshold(2)
442                .build(),
443        );
444        let manager = CircuitBreakerManager::new(config);
445
446        // Create some nodes
447        let nodes = vec![
448            SimpleNode {
449                id: "node-1".to_string(),
450            },
451            SimpleNode {
452                id: "node-2".to_string(),
453            },
454            SimpleNode {
455                id: "node-3".to_string(),
456            },
457        ];
458
459        // Initially all healthy
460        let healthy = manager.get_healthy_nodes(&nodes);
461        assert_eq!(healthy.len(), 3);
462
463        // Open circuit for node-2
464        manager.force_open("node-2", None);
465
466        let healthy = manager.get_healthy_nodes(&nodes);
467        assert_eq!(healthy.len(), 2);
468        assert!(healthy.iter().all(|n| n.id != "node-2"));
469    }
470
471    #[test]
472    fn test_manager_wrap_request() {
473        let manager = CircuitBreakerManager::with_defaults();
474
475        let result = manager.wrap_request("node-1", || Ok::<i32, &str>(42));
476        assert_eq!(result.unwrap(), 42);
477
478        let result = manager.wrap_request("node-1", || Err::<i32, &str>("error"));
479        assert!(result.is_err());
480    }
481
482    #[test]
483    fn test_manager_node_overrides() {
484        let config = ManagerConfig::new(
485            CircuitBreakerConfig::builder()
486                .failure_threshold(5)
487                .build(),
488        )
489        .with_node_override(
490            NodeOverride::new("special-node").with_failure_threshold(10),
491        );
492
493        let manager = CircuitBreakerManager::new(config);
494
495        let normal_breaker = manager.get_breaker("normal-node");
496        assert_eq!(normal_breaker.config().failure_threshold, 5);
497
498        let special_breaker = manager.get_breaker("special-node");
499        assert_eq!(special_breaker.config().failure_threshold, 10);
500    }
501
502    #[test]
503    fn test_manager_get_open_circuits() {
504        let manager = CircuitBreakerManager::with_defaults();
505
506        manager.force_open("node-1", None);
507        manager.force_open("node-3", None);
508        let _ = manager.get_breaker("node-2"); // Closed
509
510        let open = manager.get_open_circuits();
511        assert_eq!(open.len(), 2);
512        assert!(open.contains(&"node-1".to_string()));
513        assert!(open.contains(&"node-3".to_string()));
514    }
515
516    #[test]
517    fn test_manager_reset_all() {
518        let config = ManagerConfig::new(
519            CircuitBreakerConfig::builder()
520                .failure_threshold(1)
521                .build(),
522        );
523        let manager = CircuitBreakerManager::new(config);
524
525        // Open some circuits
526        manager.force_open("node-1", None);
527        manager.force_open("node-2", None);
528
529        assert_eq!(manager.get_open_circuits().len(), 2);
530
531        manager.reset_all();
532        assert_eq!(manager.get_open_circuits().len(), 0);
533    }
534
535    #[tokio::test]
536    async fn test_manager_wrap_async() {
537        let manager = CircuitBreakerManager::with_defaults();
538
539        let result = manager
540            .wrap_request_async("node-1", || async { Ok::<i32, &str>(42) })
541            .await;
542        assert_eq!(result.unwrap(), 42);
543    }
544}