use std::sync::Arc;
use std::time::Duration;
pub use tasker_shared::resilience::CircuitState;
use tasker_shared::resilience::{
CircuitBreaker, CircuitBreakerBehavior, CircuitBreakerConfig, CircuitBreakerMetrics,
};
#[derive(Debug, Clone)]
pub struct WebDatabaseCircuitBreaker {
breaker: Arc<CircuitBreaker>,
}
impl WebDatabaseCircuitBreaker {
pub fn new(
failure_threshold: u32,
recovery_timeout: Duration,
component_name: impl Into<String>,
) -> Self {
let config = CircuitBreakerConfig {
failure_threshold,
timeout: recovery_timeout,
success_threshold: 2, };
Self {
breaker: Arc::new(CircuitBreaker::new(component_name.into(), config)),
}
}
pub fn from_config(component_name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
breaker: Arc::new(CircuitBreaker::new(component_name.into(), config)),
}
}
pub fn is_circuit_open(&self) -> bool {
!self.breaker.should_allow()
}
pub fn record_success(&self) {
self.breaker.record_success_manual(Duration::ZERO);
}
pub fn record_failure(&self) {
self.breaker.record_failure_manual(Duration::ZERO);
}
pub fn current_state(&self) -> CircuitState {
self.breaker.state()
}
pub fn current_failures(&self) -> u32 {
self.breaker.metrics().consecutive_failures as u32
}
pub fn component_name(&self) -> &str {
self.breaker.name()
}
pub fn force_open(&self) {
self.breaker.force_open();
}
pub fn force_closed(&self) {
self.breaker.force_closed();
}
pub fn metrics(&self) -> CircuitBreakerMetrics {
self.breaker.metrics()
}
}
impl Default for WebDatabaseCircuitBreaker {
fn default() -> Self {
Self::new(
5, Duration::from_secs(30), "web_database", )
}
}
impl CircuitBreakerBehavior for WebDatabaseCircuitBreaker {
fn name(&self) -> &str {
self.breaker.name()
}
fn state(&self) -> CircuitState {
self.breaker.state()
}
fn should_allow(&self) -> bool {
self.breaker.should_allow()
}
fn record_success(&self, duration: Duration) {
self.breaker.record_success_manual(duration);
}
fn record_failure(&self, duration: Duration) {
self.breaker.record_failure_manual(duration);
}
fn is_healthy(&self) -> bool {
self.breaker.is_healthy()
}
fn force_open(&self) {
self.breaker.force_open();
}
fn force_closed(&self) {
self.breaker.force_closed();
}
fn metrics(&self) -> CircuitBreakerMetrics {
self.breaker.metrics()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_starts_closed() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(5), "test");
assert!(!cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_opens_after_threshold_failures() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(5), "test");
cb.record_failure();
cb.record_failure();
assert!(!cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Closed);
cb.record_failure();
assert!(cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Open);
}
#[test]
fn test_circuit_closes_on_success_via_half_open() {
let cb = WebDatabaseCircuitBreaker::new(
2,
Duration::ZERO, "test",
);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Open);
assert!(!cb.is_circuit_open());
cb.record_success();
cb.record_success();
assert_eq!(cb.current_state(), CircuitState::Closed);
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_circuit_state_from_u8_conversion() {
assert_eq!(CircuitState::from(0), CircuitState::Closed);
assert_eq!(CircuitState::from(1), CircuitState::Open);
assert_eq!(CircuitState::from(2), CircuitState::HalfOpen);
assert_eq!(CircuitState::from(3), CircuitState::Open);
assert_eq!(CircuitState::from(255), CircuitState::Open);
}
#[test]
fn test_default_circuit_breaker_configuration() {
let cb = WebDatabaseCircuitBreaker::default();
assert_eq!(cb.component_name(), "web_database");
assert_eq!(cb.current_state(), CircuitState::Closed);
assert_eq!(cb.current_failures(), 0);
assert!(!cb.is_circuit_open());
}
#[test]
fn test_component_name_accessor() {
let cb = WebDatabaseCircuitBreaker::new(5, Duration::from_secs(30), "custom_component");
assert_eq!(cb.component_name(), "custom_component");
}
#[test]
fn test_failure_count_increments_correctly() {
let cb = WebDatabaseCircuitBreaker::new(10, Duration::from_secs(30), "test");
assert_eq!(cb.current_failures(), 0);
cb.record_failure();
assert_eq!(cb.current_failures(), 1);
cb.record_failure();
assert_eq!(cb.current_failures(), 2);
cb.record_failure();
assert_eq!(cb.current_failures(), 3);
}
#[test]
fn test_success_resets_failure_count() {
let cb = WebDatabaseCircuitBreaker::new(10, Duration::from_secs(30), "test");
cb.record_failure();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_failures(), 3);
cb.record_success();
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_circuit_breaker_exact_threshold() {
let cb = WebDatabaseCircuitBreaker::new(5, Duration::from_secs(30), "test");
for i in 1..5 {
cb.record_failure();
assert!(
!cb.is_circuit_open(),
"Circuit should be closed at {} failures (threshold is 5)",
i
);
}
cb.record_failure();
assert!(
cb.is_circuit_open(),
"Circuit should be open at threshold (5 failures)"
);
}
#[test]
fn test_force_operations() {
let cb = WebDatabaseCircuitBreaker::default();
cb.force_open();
assert_eq!(cb.current_state(), CircuitState::Open);
cb.force_closed();
assert_eq!(cb.current_state(), CircuitState::Closed);
}
#[test]
fn test_behavior_trait_conformance() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(5), "trait_test");
let behavior: &dyn CircuitBreakerBehavior = &cb;
assert_eq!(behavior.name(), "trait_test");
assert_eq!(behavior.state(), CircuitState::Closed);
assert!(behavior.should_allow());
behavior.record_failure(Duration::ZERO);
behavior.record_failure(Duration::ZERO);
behavior.record_failure(Duration::ZERO);
assert_eq!(behavior.state(), CircuitState::Open);
assert!(!behavior.should_allow());
}
}