heliosdb_proxy/circuit_breaker/
manager.rs1use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10
11use super::breaker::{CircuitBreaker, CircuitOpen, RequestGuard};
12use super::config::{CircuitBreakerConfig, NodeOverride};
13use super::metrics::{CircuitMetrics, CircuitStats};
14use super::state::{CircuitBreakerListener, CircuitState};
15
16#[derive(Debug, Clone)]
18pub struct ManagerConfig {
19 pub default_config: CircuitBreakerConfig,
21
22 pub node_overrides: Vec<NodeOverride>,
24
25 pub metrics_enabled: bool,
27
28 pub auto_create: bool,
30}
31
32impl Default for ManagerConfig {
33 fn default() -> Self {
34 Self {
35 default_config: CircuitBreakerConfig::default(),
36 node_overrides: Vec::new(),
37 metrics_enabled: true,
38 auto_create: true,
39 }
40 }
41}
42
43impl ManagerConfig {
44 pub fn new(default_config: CircuitBreakerConfig) -> Self {
46 Self {
47 default_config,
48 ..Default::default()
49 }
50 }
51
52 pub fn with_node_override(mut self, override_: NodeOverride) -> Self {
54 self.node_overrides.push(override_);
55 self
56 }
57
58 pub fn with_metrics(mut self, enabled: bool) -> Self {
60 self.metrics_enabled = enabled;
61 self
62 }
63
64 pub fn get_node_config(&self, node_id: &str) -> CircuitBreakerConfig {
66 for override_ in &self.node_overrides {
67 if override_.node_id == node_id {
68 return override_.apply_to(&self.default_config);
69 }
70 }
71 self.default_config.clone()
72 }
73}
74
75pub struct CircuitBreakerManager {
80 breakers: DashMap<String, CircuitBreaker>,
82
83 config: parking_lot::RwLock<ManagerConfig>,
85
86 shared_listeners: parking_lot::RwLock<Vec<Arc<dyn CircuitBreakerListener>>>,
88
89 metrics: CircuitMetrics,
91}
92
93impl CircuitBreakerManager {
94 pub fn new(config: ManagerConfig) -> Self {
96 Self {
97 breakers: DashMap::new(),
98 config: parking_lot::RwLock::new(config),
99 shared_listeners: parking_lot::RwLock::new(Vec::new()),
100 metrics: CircuitMetrics::new(),
101 }
102 }
103
104 pub fn with_defaults() -> Self {
106 Self::new(ManagerConfig::default())
107 }
108
109 pub fn get_breaker(&self, node_id: &str) -> CircuitBreaker {
111 if let Some(breaker) = self.breakers.get(node_id) {
112 return breaker.clone();
113 }
114
115 let config = self.config.read();
116 if !config.auto_create {
117 return CircuitBreaker::new(node_id, CircuitBreakerConfig::default());
119 }
120
121 let node_config = config.get_node_config(node_id);
122 drop(config);
123
124 let breaker = CircuitBreaker::new(node_id, node_config);
125
126 let listeners = self.shared_listeners.read();
128 for listener in listeners.iter() {
129 breaker.add_listener(Arc::clone(listener));
130 }
131
132 self.breakers.insert(node_id.to_string(), breaker.clone());
133 breaker
134 }
135
136 pub fn allow_request(&self, node_id: &str) -> Result<RequestGuard, CircuitOpen> {
138 let breaker = self.get_breaker(node_id);
139 let result = breaker.allow_request();
140
141 let config = self.config.read();
143 if config.metrics_enabled {
144 drop(config);
145 match &result {
146 Ok(_) => self.metrics.record_allowed(node_id),
147 Err(_) => self.metrics.record_rejected(node_id),
148 }
149 }
150
151 result
152 }
153
154 pub fn wrap_request<F, T, E>(&self, node_id: &str, f: F) -> Result<T, WrapError<E>>
156 where
157 F: FnOnce() -> Result<T, E>,
158 E: std::fmt::Display,
159 {
160 let guard = self
161 .allow_request(node_id)
162 .map_err(WrapError::CircuitOpen)?;
163
164 match f() {
165 Ok(result) => {
166 guard.success();
167 Ok(result)
168 }
169 Err(e) => {
170 guard.failure(&e.to_string());
171 Err(WrapError::Inner(e))
172 }
173 }
174 }
175
176 pub async fn wrap_request_async<F, Fut, T, E>(
178 &self,
179 node_id: &str,
180 f: F,
181 ) -> Result<T, WrapError<E>>
182 where
183 F: FnOnce() -> Fut,
184 Fut: std::future::Future<Output = Result<T, E>>,
185 E: std::fmt::Display,
186 {
187 let guard = self
188 .allow_request(node_id)
189 .map_err(WrapError::CircuitOpen)?;
190
191 match f().await {
192 Ok(result) => {
193 guard.success();
194 Ok(result)
195 }
196 Err(e) => {
197 guard.failure(&e.to_string());
198 Err(WrapError::Inner(e))
199 }
200 }
201 }
202
203 pub fn get_healthy_nodes<T: HasNodeId + Clone>(&self, nodes: &[T]) -> Vec<T> {
205 nodes
206 .iter()
207 .filter(|node| {
208 self.breakers
209 .get(node.node_id())
210 .map(|b| b.get_state() != CircuitState::Open)
211 .unwrap_or(true) })
213 .cloned()
214 .collect()
215 }
216
217 pub fn get_open_circuits(&self) -> Vec<String> {
219 self.breakers
220 .iter()
221 .filter(|entry| entry.value().get_state() == CircuitState::Open)
222 .map(|entry| entry.key().clone())
223 .collect()
224 }
225
226 pub fn get_unhealthy_nodes(&self) -> Vec<String> {
228 self.breakers
229 .iter()
230 .filter(|entry| entry.value().get_state().is_unhealthy())
231 .map(|entry| entry.key().clone())
232 .collect()
233 }
234
235 pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
237 self.breakers
238 .iter()
239 .map(|entry| (entry.key().clone(), entry.value().get_state()))
240 .collect()
241 }
242
243 pub fn force_open(&self, node_id: &str, admin: Option<&str>) {
245 let breaker = self.get_breaker(node_id);
246 breaker.force_open(admin);
247 }
248
249 pub fn force_close(&self, node_id: &str, admin: Option<&str>) {
251 if let Some(breaker) = self.breakers.get(node_id) {
252 breaker.force_close(admin);
253 }
254 }
255
256 pub fn reset(&self, node_id: &str) {
258 if let Some(breaker) = self.breakers.get(node_id) {
259 breaker.reset();
260 }
261 }
262
263 pub fn reset_all(&self) {
265 for entry in self.breakers.iter() {
266 entry.value().reset();
267 }
268 }
269
270 pub fn remove(&self, node_id: &str) -> Option<CircuitBreaker> {
272 self.breakers.remove(node_id).map(|(_, b)| b)
273 }
274
275 pub fn add_listener(&self, listener: Arc<dyn CircuitBreakerListener>) {
277 for entry in self.breakers.iter() {
279 entry.value().add_listener(Arc::clone(&listener));
280 }
281
282 self.shared_listeners.write().push(listener);
284 }
285
286 pub fn update_config(&self, config: ManagerConfig) {
288 for entry in self.breakers.iter() {
290 let node_config = config.get_node_config(entry.key());
291 entry.value().update_config(node_config);
292 }
293
294 *self.config.write() = config;
295 }
296
297 pub fn config(&self) -> ManagerConfig {
299 self.config.read().clone()
300 }
301
302 pub fn metrics(&self) -> &CircuitMetrics {
304 &self.metrics
305 }
306
307 pub fn get_stats(&self) -> CircuitStats {
309 let mut stats = CircuitStats::default();
310
311 for entry in self.breakers.iter() {
312 let breaker = entry.value();
313 stats.add_node_stats(
314 entry.key(),
315 breaker.get_state(),
316 breaker.failure_count(),
317 breaker.open_count(),
318 breaker.total_failures(),
319 breaker.total_successes(),
320 );
321 }
322
323 stats
324 }
325
326 pub fn node_count(&self) -> usize {
328 self.breakers.len()
329 }
330
331 pub fn has_node(&self, node_id: &str) -> bool {
333 self.breakers.contains_key(node_id)
334 }
335}
336
337#[derive(Debug)]
339pub enum WrapError<E> {
340 CircuitOpen(CircuitOpen),
342 Inner(E),
344}
345
346impl<E: std::fmt::Display> std::fmt::Display for WrapError<E> {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 match self {
349 WrapError::CircuitOpen(open) => write!(f, "{}", open),
350 WrapError::Inner(e) => write!(f, "{}", e),
351 }
352 }
353}
354
355impl<E: std::error::Error + 'static> std::error::Error for WrapError<E> {
356 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
357 match self {
358 WrapError::CircuitOpen(open) => Some(open),
359 WrapError::Inner(e) => Some(e),
360 }
361 }
362}
363
364impl<E> WrapError<E> {
365 pub fn is_circuit_open(&self) -> bool {
367 matches!(self, WrapError::CircuitOpen(_))
368 }
369
370 pub fn retry_after(&self) -> Option<Duration> {
372 match self {
373 WrapError::CircuitOpen(open) => Some(open.retry_after),
374 WrapError::Inner(_) => None,
375 }
376 }
377}
378
379pub trait HasNodeId {
381 fn node_id(&self) -> &str;
382}
383
384impl HasNodeId for String {
385 fn node_id(&self) -> &str {
386 self
387 }
388}
389
390impl HasNodeId for &str {
391 fn node_id(&self) -> &str {
392 self
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct SimpleNode {
399 pub id: String,
400}
401
402impl HasNodeId for SimpleNode {
403 fn node_id(&self) -> &str {
404 &self.id
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_manager_creation() {
414 let manager = CircuitBreakerManager::with_defaults();
415 assert_eq!(manager.node_count(), 0);
416 }
417
418 #[test]
419 fn test_manager_get_breaker() {
420 let manager = CircuitBreakerManager::with_defaults();
421
422 let breaker = manager.get_breaker("node-1");
423 assert_eq!(breaker.node_id(), "node-1");
424 assert_eq!(breaker.get_state(), CircuitState::Closed);
425
426 assert_eq!(manager.node_count(), 1);
427 assert!(manager.has_node("node-1"));
428 }
429
430 #[test]
431 fn test_manager_allow_request() {
432 let manager = CircuitBreakerManager::with_defaults();
433
434 let guard = manager.allow_request("node-1").expect("should allow");
435 guard.success();
436
437 let breaker = manager.get_breaker("node-1");
438 assert_eq!(breaker.total_successes(), 1);
439 }
440
441 #[test]
442 fn test_manager_healthy_nodes() {
443 let config =
444 ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(2).build());
445 let manager = CircuitBreakerManager::new(config);
446
447 let nodes = vec![
449 SimpleNode {
450 id: "node-1".to_string(),
451 },
452 SimpleNode {
453 id: "node-2".to_string(),
454 },
455 SimpleNode {
456 id: "node-3".to_string(),
457 },
458 ];
459
460 let healthy = manager.get_healthy_nodes(&nodes);
462 assert_eq!(healthy.len(), 3);
463
464 manager.force_open("node-2", None);
466
467 let healthy = manager.get_healthy_nodes(&nodes);
468 assert_eq!(healthy.len(), 2);
469 assert!(healthy.iter().all(|n| n.id != "node-2"));
470 }
471
472 #[test]
473 fn test_manager_wrap_request() {
474 let manager = CircuitBreakerManager::with_defaults();
475
476 let result = manager.wrap_request("node-1", || Ok::<i32, &str>(42));
477 assert_eq!(result.unwrap(), 42);
478
479 let result = manager.wrap_request("node-1", || Err::<i32, &str>("error"));
480 assert!(result.is_err());
481 }
482
483 #[test]
484 fn test_manager_node_overrides() {
485 let config =
486 ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(5).build())
487 .with_node_override(NodeOverride::new("special-node").with_failure_threshold(10));
488
489 let manager = CircuitBreakerManager::new(config);
490
491 let normal_breaker = manager.get_breaker("normal-node");
492 assert_eq!(normal_breaker.config().failure_threshold, 5);
493
494 let special_breaker = manager.get_breaker("special-node");
495 assert_eq!(special_breaker.config().failure_threshold, 10);
496 }
497
498 #[test]
499 fn test_manager_get_open_circuits() {
500 let manager = CircuitBreakerManager::with_defaults();
501
502 manager.force_open("node-1", None);
503 manager.force_open("node-3", None);
504 let _ = manager.get_breaker("node-2"); let open = manager.get_open_circuits();
507 assert_eq!(open.len(), 2);
508 assert!(open.contains(&"node-1".to_string()));
509 assert!(open.contains(&"node-3".to_string()));
510 }
511
512 #[test]
513 fn test_manager_reset_all() {
514 let config =
515 ManagerConfig::new(CircuitBreakerConfig::builder().failure_threshold(1).build());
516 let manager = CircuitBreakerManager::new(config);
517
518 manager.force_open("node-1", None);
520 manager.force_open("node-2", None);
521
522 assert_eq!(manager.get_open_circuits().len(), 2);
523
524 manager.reset_all();
525 assert_eq!(manager.get_open_circuits().len(), 0);
526 }
527
528 #[tokio::test]
529 async fn test_manager_wrap_async() {
530 let manager = CircuitBreakerManager::with_defaults();
531
532 let result = manager
533 .wrap_request_async("node-1", || async { Ok::<i32, &str>(42) })
534 .await;
535 assert_eq!(result.unwrap(), 42);
536 }
537}