use crate::metrics::MetricsCollector;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
use std::time::{Duration, SystemTime};
use vtcode_commons::ErrorCategory;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl From<u8> for CircuitState {
fn from(val: u8) -> Self {
match val {
0 => CircuitState::Closed,
1 => CircuitState::Open,
2 => CircuitState::HalfOpen,
_ => CircuitState::Closed, }
}
}
use std::fs;
use std::path::PathBuf;
#[derive(serde::Serialize, serde::Deserialize)]
struct PersistedState {
state: u8,
consecutive_failures: u32,
last_failure_epoch_secs: Option<u64>,
}
pub struct McpCircuitBreaker {
state: AtomicU8,
consecutive_failures: AtomicU32,
half_open_successes: AtomicU32,
last_failure_time: parking_lot::Mutex<Option<SystemTime>>,
blocked_requests: AtomicU32,
config: CircuitBreakerConfig,
persistence_path: Option<PathBuf>,
metrics: Option<Arc<MetricsCollector>>,
}
#[derive(Debug, Clone, Copy)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub base_timeout: Duration,
pub max_timeout: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 3, success_threshold: 2, base_timeout: Duration::from_secs(10),
max_timeout: Duration::from_secs(60),
}
}
}
#[allow(dead_code)]
impl McpCircuitBreaker {
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
pub fn with_metrics(metrics: Arc<MetricsCollector>) -> Self {
Self::with_config_and_metrics(CircuitBreakerConfig::default(), metrics)
}
pub fn with_config(config: CircuitBreakerConfig) -> Self {
Self::build(config, None, None)
}
pub fn with_config_and_metrics(
config: CircuitBreakerConfig,
metrics: Arc<MetricsCollector>,
) -> Self {
Self::build(config, None, Some(metrics))
}
fn build(
config: CircuitBreakerConfig,
persistence_path: Option<PathBuf>,
metrics: Option<Arc<MetricsCollector>>,
) -> Self {
Self {
state: AtomicU8::new(CircuitState::Closed as u8),
consecutive_failures: AtomicU32::new(0),
half_open_successes: AtomicU32::new(0),
last_failure_time: parking_lot::Mutex::new(None),
blocked_requests: AtomicU32::new(0),
config,
persistence_path,
metrics,
}
}
pub fn with_persistence(path: PathBuf) -> Self {
let breaker = Self::build(CircuitBreakerConfig::default(), Some(path.clone()), None);
if let Ok(data) = fs::read_to_string(&path)
&& let Ok(state) = serde_json::from_str::<PersistedState>(&data)
{
breaker.state.store(state.state, Ordering::Release);
breaker
.consecutive_failures
.store(state.consecutive_failures, Ordering::Relaxed);
if let Some(epoch) = state.last_failure_epoch_secs {
let now = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if epoch <= now {
*breaker.last_failure_time.lock() =
Some(std::time::UNIX_EPOCH + Duration::from_secs(epoch));
}
}
}
breaker
}
#[inline]
fn record_half_open_metric(&self) {
if let Some(metrics) = &self.metrics {
metrics.record_half_open();
}
}
#[inline]
fn record_breaker_denial_metric(&self) {
if let Some(metrics) = &self.metrics {
metrics.record_breaker_denial();
}
}
#[inline]
fn record_circuit_open_metric(&self) {
if let Some(metrics) = &self.metrics {
metrics.record_circuit_open();
}
}
fn persist(&self) {
if let Some(path) = &self.persistence_path {
let last_failure = *self.last_failure_time.lock();
let epoch = last_failure.map(|t| {
t.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
});
let state = PersistedState {
state: self.state.load(Ordering::Acquire),
consecutive_failures: self.consecutive_failures.load(Ordering::Acquire),
last_failure_epoch_secs: epoch,
};
if let Ok(data) = serde_json::to_string(&state) {
let _ = fs::write(path, data);
}
}
}
pub fn state(&self) -> CircuitState {
self.state.load(Ordering::Relaxed).into()
}
pub fn allow_request(&self) -> bool {
let state = self.state();
match state {
CircuitState::Closed => true,
CircuitState::Open => {
let should_retry = {
let last_failure = self.last_failure_time.lock();
if let Some(failure_time) = *last_failure {
if let Ok(elapsed) = failure_time.elapsed() {
let timeout = self.calculate_timeout();
elapsed >= timeout
} else {
false
}
} else {
true
}
};
if should_retry {
self.state
.store(CircuitState::HalfOpen as u8, Ordering::Release);
self.half_open_successes.store(0, Ordering::Relaxed);
self.record_half_open_metric();
self.persist();
true
} else {
self.blocked_requests.fetch_add(1, Ordering::Relaxed);
self.record_breaker_denial_metric();
false
}
}
CircuitState::HalfOpen => {
true
}
}
}
pub fn record_success(&self) {
let state = self.state();
match state {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::Relaxed);
}
CircuitState::HalfOpen => {
let successes = self.half_open_successes.fetch_add(1, Ordering::AcqRel) + 1;
if successes >= self.config.success_threshold {
self.state
.store(CircuitState::Closed as u8, Ordering::Release);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.half_open_successes.store(0, Ordering::Relaxed);
*self.last_failure_time.lock() = None;
self.persist();
}
}
CircuitState::Open => {
self.state
.store(CircuitState::HalfOpen as u8, Ordering::Release);
self.half_open_successes.store(1, Ordering::Relaxed);
self.persist();
}
}
}
pub fn record_failure(&self) {
self.record_failure_category(ErrorCategory::ExecutionError);
}
pub fn record_failure_category(&self, category: ErrorCategory) {
if !category.should_trip_circuit_breaker() {
return;
}
let state = self.state();
*self.last_failure_time.lock() = Some(SystemTime::now());
match state {
CircuitState::Closed => {
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
if failures >= self.config.failure_threshold {
self.state
.store(CircuitState::Open as u8, Ordering::Release);
self.record_circuit_open_metric();
}
}
CircuitState::HalfOpen => {
self.state
.store(CircuitState::Open as u8, Ordering::Release);
self.consecutive_failures.fetch_add(1, Ordering::AcqRel);
self.half_open_successes.store(0, Ordering::Relaxed);
self.record_circuit_open_metric();
}
CircuitState::Open => {
self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
}
}
self.persist();
}
fn calculate_timeout(&self) -> Duration {
let failures = self.consecutive_failures.load(Ordering::Relaxed);
let multiplier = if failures > self.config.failure_threshold {
2u32.saturating_pow(failures.saturating_sub(self.config.failure_threshold))
} else {
1
};
let timeout = self.config.base_timeout.saturating_mul(multiplier);
timeout.min(self.config.max_timeout)
}
pub fn diagnostics(&self) -> CircuitBreakerDiagnostics {
let retry_after = if self.state() == CircuitState::Open {
(*self.last_failure_time.lock()).and_then(|failure_time| {
failure_time
.elapsed()
.ok()
.and_then(|elapsed| self.calculate_timeout().checked_sub(elapsed))
})
} else {
None
};
CircuitBreakerDiagnostics {
state: self.state(),
consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed),
half_open_successes: self.half_open_successes.load(Ordering::Relaxed),
last_failure_time: *self.last_failure_time.lock(),
current_timeout: self.calculate_timeout(),
retry_after,
blocked_requests: self.blocked_requests.load(Ordering::Relaxed),
is_blocking: self.state() == CircuitState::Open,
}
}
#[allow(dead_code)]
pub fn reset(&self) {
self.state
.store(CircuitState::Closed as u8, Ordering::Release);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.half_open_successes.store(0, Ordering::Relaxed);
self.blocked_requests.store(0, Ordering::Relaxed);
*self.last_failure_time.lock() = None;
self.persist();
}
}
impl Default for McpCircuitBreaker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct CircuitBreakerDiagnostics {
pub state: CircuitState,
pub consecutive_failures: u32,
#[allow(dead_code)]
pub half_open_successes: u32,
pub last_failure_time: Option<SystemTime>,
pub current_timeout: Duration,
pub retry_after: Option<Duration>,
pub blocked_requests: u32,
pub is_blocking: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::MetricsCollector;
use std::thread;
#[test]
fn test_circuit_breaker_closed_state() {
let breaker = McpCircuitBreaker::new();
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.allow_request());
}
#[test]
fn test_circuit_breaker_opens_after_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let breaker = McpCircuitBreaker::with_config(config);
breaker.record_failure(); assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure(); assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure(); assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.allow_request()); assert!(breaker.diagnostics().blocked_requests > 0);
}
#[test]
fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
base_timeout: Duration::from_millis(100),
..Default::default()
};
let breaker = McpCircuitBreaker::with_config(config);
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_closes_after_successes() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
base_timeout: Duration::from_millis(50),
..Default::default()
};
let breaker = McpCircuitBreaker::with_config(config);
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(60));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success(); assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success(); assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_failure_in_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
base_timeout: Duration::from_millis(50),
..Default::default()
};
let breaker = McpCircuitBreaker::with_config(config);
breaker.record_failure();
breaker.record_failure();
thread::sleep(Duration::from_millis(60));
breaker.allow_request();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_exponential_backoff() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
base_timeout: Duration::from_secs(10),
max_timeout: Duration::from_secs(60),
..Default::default()
};
let breaker = McpCircuitBreaker::with_config(config);
for _ in 0..5 {
breaker.record_failure();
}
let diag = breaker.diagnostics();
assert_eq!(diag.current_timeout, Duration::from_secs(60));
}
#[test]
fn authentication_failure_does_not_trip_breaker() {
let breaker = McpCircuitBreaker::new();
breaker.record_failure_category(ErrorCategory::Authentication);
assert_eq!(breaker.state(), CircuitState::Closed);
assert_eq!(breaker.diagnostics().consecutive_failures, 0);
}
#[test]
fn reliability_metrics_capture_open_half_open_and_denials() {
let metrics = Arc::new(MetricsCollector::new());
let breaker = McpCircuitBreaker::with_config_and_metrics(
CircuitBreakerConfig {
failure_threshold: 1,
base_timeout: Duration::from_millis(10),
..Default::default()
},
metrics.clone(),
);
breaker.record_failure_category(ErrorCategory::ExecutionError);
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.allow_request());
thread::sleep(Duration::from_millis(20));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
let execution = metrics.get_execution_metrics();
assert_eq!(execution.circuit_open_events, 1);
assert_eq!(execution.breaker_denials, 1);
assert_eq!(execution.half_open_events, 1);
}
}