use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl CircuitState {
#[must_use]
pub const fn allows_requests(&self) -> bool {
matches!(self, Self::Closed | Self::HalfOpen)
}
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::Closed => "Closed",
Self::Open => "Open",
Self::HalfOpen => "Half-Open",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout: Duration,
pub half_open_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout: Duration::from_secs(60),
half_open_requests: 3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreaker {
pub provider_id: String,
pub state: CircuitState,
pub config: CircuitBreakerConfig,
pub consecutive_failures: u32,
pub consecutive_successes: u32,
pub total_failures: u64,
pub total_successes: u64,
pub opened_at: Option<SystemTime>,
pub last_state_change: SystemTime,
pub half_open_request_count: u32,
}
impl CircuitBreaker {
#[must_use]
pub fn new(provider_id: String) -> Self {
Self::with_config(provider_id, CircuitBreakerConfig::default())
}
#[must_use]
pub fn with_config(provider_id: String, config: CircuitBreakerConfig) -> Self {
Self {
provider_id,
state: CircuitState::Closed,
config,
consecutive_failures: 0,
consecutive_successes: 0,
total_failures: 0,
total_successes: 0,
opened_at: None,
last_state_change: SystemTime::now(),
half_open_request_count: 0,
}
}
#[must_use]
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened_at) = self.opened_at {
if let Ok(elapsed) = opened_at.elapsed() {
if elapsed >= self.config.timeout {
self.transition_to_half_open();
return true;
}
}
}
false
}
CircuitState::HalfOpen => {
self.half_open_request_count < self.config.half_open_requests
}
}
}
pub fn record_success(&mut self) {
self.total_successes += 1;
self.consecutive_successes += 1;
self.consecutive_failures = 0;
match self.state {
CircuitState::HalfOpen => {
self.half_open_request_count += 1;
if self.consecutive_successes >= self.config.success_threshold {
self.transition_to_closed();
}
}
CircuitState::Open => {
self.transition_to_half_open();
}
CircuitState::Closed => {
}
}
}
pub fn record_failure(&mut self) {
self.total_failures += 1;
self.consecutive_failures += 1;
self.consecutive_successes = 0;
match self.state {
CircuitState::Closed => {
if self.consecutive_failures >= self.config.failure_threshold {
self.transition_to_open();
}
}
CircuitState::HalfOpen => {
self.transition_to_open();
}
CircuitState::Open => {
}
}
}
fn transition_to_closed(&mut self) {
self.state = CircuitState::Closed;
self.consecutive_failures = 0;
self.half_open_request_count = 0;
self.opened_at = None;
self.last_state_change = SystemTime::now();
}
fn transition_to_open(&mut self) {
self.state = CircuitState::Open;
self.opened_at = Some(SystemTime::now());
self.last_state_change = SystemTime::now();
self.half_open_request_count = 0;
}
fn transition_to_half_open(&mut self) {
self.state = CircuitState::HalfOpen;
self.consecutive_successes = 0;
self.consecutive_failures = 0;
self.half_open_request_count = 0;
self.last_state_change = SystemTime::now();
}
pub fn reset(&mut self) {
self.transition_to_closed();
self.consecutive_failures = 0;
self.consecutive_successes = 0;
}
#[must_use]
pub fn failure_rate(&self) -> f64 {
let total = self.total_failures + self.total_successes;
if total == 0 {
0.0
} else {
self.total_failures as f64 / total as f64
}
}
#[must_use]
pub const fn is_open(&self) -> bool {
matches!(self.state, CircuitState::Open)
}
#[must_use]
pub const fn is_closed(&self) -> bool {
matches!(self.state, CircuitState::Closed)
}
#[must_use]
pub const fn is_half_open(&self) -> bool {
matches!(self.state, CircuitState::HalfOpen)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackoffConfig {
pub initial_interval: Duration,
pub max_interval: Duration,
pub multiplier: f64,
pub max_retries: u32,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
initial_interval: Duration::from_millis(100),
max_interval: Duration::from_secs(60),
multiplier: 2.0,
max_retries: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct BackoffState {
config: BackoffConfig,
attempt: u32,
next_interval: Duration,
}
impl BackoffState {
#[must_use]
pub fn new(config: BackoffConfig) -> Self {
let next_interval = config.initial_interval;
Self {
config,
attempt: 0,
next_interval,
}
}
#[must_use]
pub fn next_backoff(&mut self) -> Option<Duration> {
if self.attempt >= self.config.max_retries {
return None;
}
let current = self.next_interval;
self.attempt += 1;
let next_ms = (current.as_millis() as f64 * self.config.multiplier) as u64;
self.next_interval = Duration::from_millis(next_ms).min(self.config.max_interval);
Some(current)
}
pub fn reset(&mut self) {
self.attempt = 0;
self.next_interval = self.config.initial_interval;
}
#[must_use]
pub const fn attempt(&self) -> u32 {
self.attempt
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FallbackChain {
pub primary: String,
pub fallbacks: Vec<String>,
current_index: usize,
}
impl FallbackChain {
#[must_use]
pub fn new(primary: String, fallbacks: Vec<String>) -> Self {
Self {
primary,
fallbacks,
current_index: 0,
}
}
#[must_use]
pub fn current_provider(&self) -> &str {
if self.current_index == 0 {
&self.primary
} else {
&self.fallbacks[self.current_index - 1]
}
}
pub fn next_fallback(&mut self) -> Option<&str> {
if self.current_index < self.fallbacks.len() {
self.current_index += 1;
Some(self.current_provider())
} else {
None
}
}
pub fn reset(&mut self) {
self.current_index = 0;
}
#[must_use]
pub const fn is_primary(&self) -> bool {
self.current_index == 0
}
#[must_use]
pub fn all_providers(&self) -> Vec<&str> {
let mut providers = vec![self.primary.as_str()];
providers.extend(self.fallbacks.iter().map(String::as_str));
providers
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DegradationConfig {
pub enabled: bool,
pub min_quality: u8,
pub reduce_quality_on_error: bool,
pub quality_step: u8,
}
impl Default for DegradationConfig {
fn default() -> Self {
Self {
enabled: true,
min_quality: 30,
reduce_quality_on_error: true,
quality_step: 10,
}
}
}
struct FailoverState {
circuit_breakers: HashMap<String, CircuitBreaker>,
backoff_states: HashMap<String, BackoffState>,
fallback_chains: HashMap<String, FallbackChain>,
manual_overrides: HashMap<String, bool>,
}
pub struct FailoverManager {
failure_threshold: u32,
circuit_timeout: Duration,
state: Arc<RwLock<FailoverState>>,
backoff_config: BackoffConfig,
}
impl FailoverManager {
#[must_use]
pub fn new(failure_threshold: u32, circuit_timeout: Duration) -> Self {
let state = FailoverState {
circuit_breakers: HashMap::new(),
backoff_states: HashMap::new(),
fallback_chains: HashMap::new(),
manual_overrides: HashMap::new(),
};
Self {
failure_threshold,
circuit_timeout,
state: Arc::new(RwLock::new(state)),
backoff_config: BackoffConfig::default(),
}
}
pub fn record_success(&self, provider_id: &str) {
let mut state = self.state.write();
let breaker = state
.circuit_breakers
.entry(provider_id.to_string())
.or_insert_with(|| {
CircuitBreaker::with_config(
provider_id.to_string(),
CircuitBreakerConfig {
failure_threshold: self.failure_threshold,
timeout: self.circuit_timeout,
..Default::default()
},
)
});
breaker.record_success();
if let Some(backoff) = state.backoff_states.get_mut(provider_id) {
backoff.reset();
}
}
pub fn record_failure(&self, provider_id: &str) {
let mut state = self.state.write();
let breaker = state
.circuit_breakers
.entry(provider_id.to_string())
.or_insert_with(|| {
CircuitBreaker::with_config(
provider_id.to_string(),
CircuitBreakerConfig {
failure_threshold: self.failure_threshold,
timeout: self.circuit_timeout,
..Default::default()
},
)
});
breaker.record_failure();
let backoff = state
.backoff_states
.entry(provider_id.to_string())
.or_insert_with(|| BackoffState::new(self.backoff_config.clone()));
let _next_backoff = backoff.next_backoff();
}
#[must_use]
pub fn is_available(&self, provider_id: &str) -> bool {
let state = self.state.read();
if let Some(&enabled) = state.manual_overrides.get(provider_id) {
if !enabled {
return false;
}
}
if let Some(breaker) = state.circuit_breakers.get(provider_id) {
!breaker.is_open()
} else {
true
}
}
#[must_use]
pub fn is_open(&self, provider_id: &str) -> bool {
self.state
.read()
.circuit_breakers
.get(provider_id)
.map_or(false, CircuitBreaker::is_open)
}
#[must_use]
pub fn get_circuit_state(&self, provider_id: &str) -> Option<CircuitState> {
self.state
.read()
.circuit_breakers
.get(provider_id)
.map(|b| b.state)
}
pub fn open_circuit(&self, provider_id: &str) {
let mut state = self.state.write();
state
.manual_overrides
.insert(provider_id.to_string(), false);
}
pub fn close_circuit(&self, provider_id: &str) {
let mut state = self.state.write();
state.manual_overrides.insert(provider_id.to_string(), true);
if let Some(breaker) = state.circuit_breakers.get_mut(provider_id) {
breaker.reset();
}
}
pub fn reset_all(&self) {
let mut state = self.state.write();
for breaker in state.circuit_breakers.values_mut() {
breaker.reset();
}
state.manual_overrides.clear();
}
#[must_use]
pub fn get_circuit_breaker(&self, provider_id: &str) -> Option<CircuitBreaker> {
self.state.read().circuit_breakers.get(provider_id).cloned()
}
pub fn add_fallback_chain(&self, primary: String, fallbacks: Vec<String>) {
let mut state = self.state.write();
let chain = FallbackChain::new(primary.clone(), fallbacks);
state.fallback_chains.insert(primary, chain);
}
pub fn get_next_fallback(&self, primary_id: &str) -> Option<String> {
let mut state = self.state.write();
state
.fallback_chains
.get_mut(primary_id)
.and_then(|chain| chain.next_fallback().map(String::from))
}
pub fn reset_fallback_chain(&self, primary_id: &str) {
let mut state = self.state.write();
if let Some(chain) = state.fallback_chains.get_mut(primary_id) {
chain.reset();
}
}
#[must_use]
pub fn get_all_states(&self) -> HashMap<String, CircuitState> {
self.state
.read()
.circuit_breakers
.iter()
.map(|(id, breaker)| (id.clone(), breaker.state))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_state() {
assert!(CircuitState::Closed.allows_requests());
assert!(CircuitState::HalfOpen.allows_requests());
assert!(!CircuitState::Open.allows_requests());
}
#[test]
fn test_circuit_breaker_creation() {
let breaker = CircuitBreaker::new("provider-1".to_string());
assert_eq!(breaker.provider_id, "provider-1");
assert_eq!(breaker.state, CircuitState::Closed);
assert!(breaker.is_closed());
}
#[test]
fn test_circuit_breaker_failure() {
let mut breaker = CircuitBreaker::new("provider-1".to_string());
for _ in 0..5 {
breaker.record_failure();
}
assert!(breaker.is_open());
assert!(!breaker.allow_request());
}
#[test]
fn test_circuit_breaker_success() {
let mut breaker = CircuitBreaker::new("provider-1".to_string());
breaker.record_success();
assert_eq!(breaker.consecutive_successes, 1);
assert_eq!(breaker.total_successes, 1);
assert!(breaker.is_closed());
}
#[test]
fn test_circuit_breaker_reset() {
let mut breaker = CircuitBreaker::new("provider-1".to_string());
for _ in 0..5 {
breaker.record_failure();
}
assert!(breaker.is_open());
breaker.reset();
assert!(breaker.is_closed());
assert_eq!(breaker.consecutive_failures, 0);
}
#[test]
fn test_backoff_state() {
let mut backoff = BackoffState::new(BackoffConfig::default());
let first = backoff.next_backoff();
assert!(first.is_some());
assert_eq!(backoff.attempt(), 1);
let second = backoff.next_backoff();
assert!(second.is_some());
assert!(second.expect("should succeed in test") > first.expect("should succeed in test"));
}
#[test]
fn test_backoff_reset() {
let mut backoff = BackoffState::new(BackoffConfig::default());
let _first = backoff.next_backoff();
assert_eq!(backoff.attempt(), 1);
backoff.reset();
assert_eq!(backoff.attempt(), 0);
}
#[test]
fn test_fallback_chain() {
let mut chain = FallbackChain::new(
"primary".to_string(),
vec!["fallback1".to_string(), "fallback2".to_string()],
);
assert_eq!(chain.current_provider(), "primary");
assert!(chain.is_primary());
chain.next_fallback();
assert_eq!(chain.current_provider(), "fallback1");
assert!(!chain.is_primary());
chain.next_fallback();
assert_eq!(chain.current_provider(), "fallback2");
chain.reset();
assert_eq!(chain.current_provider(), "primary");
assert!(chain.is_primary());
}
#[test]
fn test_failover_manager() {
let manager = FailoverManager::new(3, Duration::from_secs(60));
manager.record_success("provider-1");
assert!(manager.is_available("provider-1"));
for _ in 0..3 {
manager.record_failure("provider-1");
}
assert!(!manager.is_available("provider-1"));
assert!(manager.is_open("provider-1"));
}
#[test]
fn test_failover_manager_manual_override() {
let manager = FailoverManager::new(3, Duration::from_secs(60));
manager.open_circuit("provider-1");
assert!(!manager.is_available("provider-1"));
manager.close_circuit("provider-1");
assert!(manager.is_available("provider-1"));
}
#[test]
fn test_failover_manager_reset() {
let manager = FailoverManager::new(3, Duration::from_secs(60));
for _ in 0..3 {
manager.record_failure("provider-1");
}
assert!(!manager.is_available("provider-1"));
manager.reset_all();
assert!(manager.is_available("provider-1"));
}
}