heliosdb_proxy/circuit_breaker/
manager.rs1use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10
11use super::breaker::{CircuitBreaker, CircuitOpen, RequestGuard};
12use super::config::{CircuitBreakerConfig, NodeOverride};
13use super::metrics::{CircuitMetrics, CircuitStats};
14use super::state::{CircuitBreakerListener, CircuitState};
15
16#[derive(Debug, Clone)]
18pub struct ManagerConfig {
19 pub default_config: CircuitBreakerConfig,
21
22 pub node_overrides: Vec<NodeOverride>,
24
25 pub metrics_enabled: bool,
27
28 pub auto_create: bool,
30}
31
32impl Default for ManagerConfig {
33 fn default() -> Self {
34 Self {
35 default_config: CircuitBreakerConfig::default(),
36 node_overrides: Vec::new(),
37 metrics_enabled: true,
38 auto_create: true,
39 }
40 }
41}
42
43impl ManagerConfig {
44 pub fn new(default_config: CircuitBreakerConfig) -> Self {
46 Self {
47 default_config,
48 ..Default::default()
49 }
50 }
51
52 pub fn with_node_override(mut self, override_: NodeOverride) -> Self {
54 self.node_overrides.push(override_);
55 self
56 }
57
58 pub fn with_metrics(mut self, enabled: bool) -> Self {
60 self.metrics_enabled = enabled;
61 self
62 }
63
64 pub fn get_node_config(&self, node_id: &str) -> CircuitBreakerConfig {
66 for override_ in &self.node_overrides {
67 if override_.node_id == node_id {
68 return override_.apply_to(&self.default_config);
69 }
70 }
71 self.default_config.clone()
72 }
73}
74
75pub struct CircuitBreakerManager {
80 breakers: DashMap<String, CircuitBreaker>,
82
83 config: parking_lot::RwLock<ManagerConfig>,
85
86 shared_listeners: parking_lot::RwLock<Vec<Arc<dyn CircuitBreakerListener>>>,
88
89 metrics: CircuitMetrics,
91}
92
93impl CircuitBreakerManager {
94 pub fn new(config: ManagerConfig) -> Self {
96 Self {
97 breakers: DashMap::new(),
98 config: parking_lot::RwLock::new(config),
99 shared_listeners: parking_lot::RwLock::new(Vec::new()),
100 metrics: CircuitMetrics::new(),
101 }
102 }
103
104 pub fn with_defaults() -> Self {
106 Self::new(ManagerConfig::default())
107 }
108
109 pub fn get_breaker(&self, node_id: &str) -> CircuitBreaker {
111 if let Some(breaker) = self.breakers.get(node_id) {
112 return breaker.clone();
113 }
114
115 let config = self.config.read();
116 if !config.auto_create {
117 return CircuitBreaker::new(node_id, CircuitBreakerConfig::default());
119 }
120
121 let node_config = config.get_node_config(node_id);
122 drop(config);
123
124 let breaker = CircuitBreaker::new(node_id, node_config);
125
126 let listeners = self.shared_listeners.read();
128 for listener in listeners.iter() {
129 breaker.add_listener(Arc::clone(listener));
130 }
131
132 self.breakers.insert(node_id.to_string(), breaker.clone());
133 breaker
134 }
135
136 pub fn allow_request(&self, node_id: &str) -> Result<RequestGuard, CircuitOpen> {
138 let breaker = self.get_breaker(node_id);
139 let result = breaker.allow_request();
140
141 let config = self.config.read();
143 if config.metrics_enabled {
144 drop(config);
145 match &result {
146 Ok(_) => self.metrics.record_allowed(node_id),
147 Err(_) => self.metrics.record_rejected(node_id),
148 }
149 }
150
151 result
152 }
153
154 pub fn wrap_request<F, T, E>(&self, node_id: &str, f: F) -> Result<T, WrapError<E>>
156 where
157 F: FnOnce() -> Result<T, E>,
158 E: std::fmt::Display,
159 {
160 let guard = self.allow_request(node_id).map_err(WrapError::CircuitOpen)?;
161
162 match f() {
163 Ok(result) => {
164 guard.success();
165 Ok(result)
166 }
167 Err(e) => {
168 guard.failure(&e.to_string());
169 Err(WrapError::Inner(e))
170 }
171 }
172 }
173
174 pub async fn wrap_request_async<F, Fut, T, E>(
176 &self,
177 node_id: &str,
178 f: F,
179 ) -> Result<T, WrapError<E>>
180 where
181 F: FnOnce() -> Fut,
182 Fut: std::future::Future<Output = Result<T, E>>,
183 E: std::fmt::Display,
184 {
185 let guard = self.allow_request(node_id).map_err(WrapError::CircuitOpen)?;
186
187 match f().await {
188 Ok(result) => {
189 guard.success();
190 Ok(result)
191 }
192 Err(e) => {
193 guard.failure(&e.to_string());
194 Err(WrapError::Inner(e))
195 }
196 }
197 }
198
199 pub fn get_healthy_nodes<T: HasNodeId + Clone>(&self, nodes: &[T]) -> Vec<T> {
201 nodes
202 .iter()
203 .filter(|node| {
204 self.breakers
205 .get(node.node_id())
206 .map(|b| b.get_state() != CircuitState::Open)
207 .unwrap_or(true) })
209 .cloned()
210 .collect()
211 }
212
213 pub fn get_open_circuits(&self) -> Vec<String> {
215 self.breakers
216 .iter()
217 .filter(|entry| entry.value().get_state() == CircuitState::Open)
218 .map(|entry| entry.key().clone())
219 .collect()
220 }
221
222 pub fn get_unhealthy_nodes(&self) -> Vec<String> {
224 self.breakers
225 .iter()
226 .filter(|entry| entry.value().get_state().is_unhealthy())
227 .map(|entry| entry.key().clone())
228 .collect()
229 }
230
231 pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
233 self.breakers
234 .iter()
235 .map(|entry| (entry.key().clone(), entry.value().get_state()))
236 .collect()
237 }
238
239 pub fn force_open(&self, node_id: &str, admin: Option<&str>) {
241 let breaker = self.get_breaker(node_id);
242 breaker.force_open(admin);
243 }
244
245 pub fn force_close(&self, node_id: &str, admin: Option<&str>) {
247 if let Some(breaker) = self.breakers.get(node_id) {
248 breaker.force_close(admin);
249 }
250 }
251
252 pub fn reset(&self, node_id: &str) {
254 if let Some(breaker) = self.breakers.get(node_id) {
255 breaker.reset();
256 }
257 }
258
259 pub fn reset_all(&self) {
261 for entry in self.breakers.iter() {
262 entry.value().reset();
263 }
264 }
265
266 pub fn remove(&self, node_id: &str) -> Option<CircuitBreaker> {
268 self.breakers.remove(node_id).map(|(_, b)| b)
269 }
270
271 pub fn add_listener(&self, listener: Arc<dyn CircuitBreakerListener>) {
273 for entry in self.breakers.iter() {
275 entry.value().add_listener(Arc::clone(&listener));
276 }
277
278 self.shared_listeners.write().push(listener);
280 }
281
282 pub fn update_config(&self, config: ManagerConfig) {
284 for entry in self.breakers.iter() {
286 let node_config = config.get_node_config(entry.key());
287 entry.value().update_config(node_config);
288 }
289
290 *self.config.write() = config;
291 }
292
293 pub fn config(&self) -> ManagerConfig {
295 self.config.read().clone()
296 }
297
298 pub fn metrics(&self) -> &CircuitMetrics {
300 &self.metrics
301 }
302
303 pub fn get_stats(&self) -> CircuitStats {
305 let mut stats = CircuitStats::default();
306
307 for entry in self.breakers.iter() {
308 let breaker = entry.value();
309 stats.add_node_stats(
310 entry.key(),
311 breaker.get_state(),
312 breaker.failure_count(),
313 breaker.open_count(),
314 breaker.total_failures(),
315 breaker.total_successes(),
316 );
317 }
318
319 stats
320 }
321
322 pub fn node_count(&self) -> usize {
324 self.breakers.len()
325 }
326
327 pub fn has_node(&self, node_id: &str) -> bool {
329 self.breakers.contains_key(node_id)
330 }
331}
332
333#[derive(Debug)]
335pub enum WrapError<E> {
336 CircuitOpen(CircuitOpen),
338 Inner(E),
340}
341
342impl<E: std::fmt::Display> std::fmt::Display for WrapError<E> {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 match self {
345 WrapError::CircuitOpen(open) => write!(f, "{}", open),
346 WrapError::Inner(e) => write!(f, "{}", e),
347 }
348 }
349}
350
351impl<E: std::error::Error + 'static> std::error::Error for WrapError<E> {
352 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
353 match self {
354 WrapError::CircuitOpen(open) => Some(open),
355 WrapError::Inner(e) => Some(e),
356 }
357 }
358}
359
360impl<E> WrapError<E> {
361 pub fn is_circuit_open(&self) -> bool {
363 matches!(self, WrapError::CircuitOpen(_))
364 }
365
366 pub fn retry_after(&self) -> Option<Duration> {
368 match self {
369 WrapError::CircuitOpen(open) => Some(open.retry_after),
370 WrapError::Inner(_) => None,
371 }
372 }
373}
374
375pub trait HasNodeId {
377 fn node_id(&self) -> &str;
378}
379
380impl HasNodeId for String {
381 fn node_id(&self) -> &str {
382 self
383 }
384}
385
386impl HasNodeId for &str {
387 fn node_id(&self) -> &str {
388 self
389 }
390}
391
392#[derive(Debug, Clone)]
394pub struct SimpleNode {
395 pub id: String,
396}
397
398impl HasNodeId for SimpleNode {
399 fn node_id(&self) -> &str {
400 &self.id
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_manager_creation() {
410 let manager = CircuitBreakerManager::with_defaults();
411 assert_eq!(manager.node_count(), 0);
412 }
413
414 #[test]
415 fn test_manager_get_breaker() {
416 let manager = CircuitBreakerManager::with_defaults();
417
418 let breaker = manager.get_breaker("node-1");
419 assert_eq!(breaker.node_id(), "node-1");
420 assert_eq!(breaker.get_state(), CircuitState::Closed);
421
422 assert_eq!(manager.node_count(), 1);
423 assert!(manager.has_node("node-1"));
424 }
425
426 #[test]
427 fn test_manager_allow_request() {
428 let manager = CircuitBreakerManager::with_defaults();
429
430 let guard = manager.allow_request("node-1").expect("should allow");
431 guard.success();
432
433 let breaker = manager.get_breaker("node-1");
434 assert_eq!(breaker.total_successes(), 1);
435 }
436
437 #[test]
438 fn test_manager_healthy_nodes() {
439 let config = ManagerConfig::new(
440 CircuitBreakerConfig::builder()
441 .failure_threshold(2)
442 .build(),
443 );
444 let manager = CircuitBreakerManager::new(config);
445
446 let nodes = vec![
448 SimpleNode {
449 id: "node-1".to_string(),
450 },
451 SimpleNode {
452 id: "node-2".to_string(),
453 },
454 SimpleNode {
455 id: "node-3".to_string(),
456 },
457 ];
458
459 let healthy = manager.get_healthy_nodes(&nodes);
461 assert_eq!(healthy.len(), 3);
462
463 manager.force_open("node-2", None);
465
466 let healthy = manager.get_healthy_nodes(&nodes);
467 assert_eq!(healthy.len(), 2);
468 assert!(healthy.iter().all(|n| n.id != "node-2"));
469 }
470
471 #[test]
472 fn test_manager_wrap_request() {
473 let manager = CircuitBreakerManager::with_defaults();
474
475 let result = manager.wrap_request("node-1", || Ok::<i32, &str>(42));
476 assert_eq!(result.unwrap(), 42);
477
478 let result = manager.wrap_request("node-1", || Err::<i32, &str>("error"));
479 assert!(result.is_err());
480 }
481
482 #[test]
483 fn test_manager_node_overrides() {
484 let config = ManagerConfig::new(
485 CircuitBreakerConfig::builder()
486 .failure_threshold(5)
487 .build(),
488 )
489 .with_node_override(
490 NodeOverride::new("special-node").with_failure_threshold(10),
491 );
492
493 let manager = CircuitBreakerManager::new(config);
494
495 let normal_breaker = manager.get_breaker("normal-node");
496 assert_eq!(normal_breaker.config().failure_threshold, 5);
497
498 let special_breaker = manager.get_breaker("special-node");
499 assert_eq!(special_breaker.config().failure_threshold, 10);
500 }
501
502 #[test]
503 fn test_manager_get_open_circuits() {
504 let manager = CircuitBreakerManager::with_defaults();
505
506 manager.force_open("node-1", None);
507 manager.force_open("node-3", None);
508 let _ = manager.get_breaker("node-2"); let open = manager.get_open_circuits();
511 assert_eq!(open.len(), 2);
512 assert!(open.contains(&"node-1".to_string()));
513 assert!(open.contains(&"node-3".to_string()));
514 }
515
516 #[test]
517 fn test_manager_reset_all() {
518 let config = ManagerConfig::new(
519 CircuitBreakerConfig::builder()
520 .failure_threshold(1)
521 .build(),
522 );
523 let manager = CircuitBreakerManager::new(config);
524
525 manager.force_open("node-1", None);
527 manager.force_open("node-2", None);
528
529 assert_eq!(manager.get_open_circuits().len(), 2);
530
531 manager.reset_all();
532 assert_eq!(manager.get_open_circuits().len(), 0);
533 }
534
535 #[tokio::test]
536 async fn test_manager_wrap_async() {
537 let manager = CircuitBreakerManager::with_defaults();
538
539 let result = manager
540 .wrap_request_async("node-1", || async { Ok::<i32, &str>(42) })
541 .await;
542 assert_eq!(result.unwrap(), 42);
543 }
544}