use crate::{
StorageBackend, bulkhead::BulkheadSemaphore, callbacks::Callbacks,
classifier::FailureClassifier, errors::CircuitError,
};
use state_machines::state_machine;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Config {
pub failure_threshold: Option<usize>,
pub failure_rate_threshold: Option<f64>,
pub minimum_calls: usize,
pub failure_window_secs: f64,
pub half_open_timeout_secs: f64,
pub success_threshold: usize,
pub jitter_factor: f64,
}
impl Default for Config {
fn default() -> Self {
Self {
failure_threshold: Some(5),
failure_rate_threshold: None,
minimum_calls: 20,
failure_window_secs: 60.0,
half_open_timeout_secs: 30.0,
success_threshold: 2,
jitter_factor: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct FallbackContext {
pub circuit_name: String,
pub opened_at: f64,
pub state: &'static str,
}
pub type FallbackFn<T, E> = Box<dyn FnOnce(&FallbackContext) -> Result<T, E> + Send>;
pub struct CallOptions<T, E> {
pub fallback: Option<FallbackFn<T, E>>,
}
impl<T, E> Default for CallOptions<T, E> {
fn default() -> Self {
Self { fallback: None }
}
}
impl<T, E> CallOptions<T, E> {
pub fn new() -> Self {
Self::default()
}
pub fn with_fallback<F>(mut self, f: F) -> Self
where
F: FnOnce(&FallbackContext) -> Result<T, E> + Send + 'static,
{
self.fallback = Some(Box::new(f));
self
}
}
pub type CallableFn<T, E> = Box<dyn FnOnce() -> Result<T, E>>;
pub trait IntoCallOptions<T, E> {
fn into_call_options(self) -> (CallableFn<T, E>, CallOptions<T, E>);
}
impl<T, E, F> IntoCallOptions<T, E> for F
where
F: FnOnce() -> Result<T, E> + 'static,
{
fn into_call_options(self) -> (Box<dyn FnOnce() -> Result<T, E>>, CallOptions<T, E>) {
(Box::new(self), CallOptions::default())
}
}
impl<T, E, F> IntoCallOptions<T, E> for (F, CallOptions<T, E>)
where
F: FnOnce() -> Result<T, E> + 'static,
{
fn into_call_options(self) -> (Box<dyn FnOnce() -> Result<T, E>>, CallOptions<T, E>) {
(Box::new(self.0), self.1)
}
}
#[derive(Clone)]
pub struct CircuitContext {
pub name: String,
pub config: Config,
pub storage: Arc<dyn StorageBackend>,
pub failure_classifier: Option<Arc<dyn FailureClassifier>>,
pub bulkhead: Option<Arc<BulkheadSemaphore>>,
}
impl Default for CircuitContext {
fn default() -> Self {
Self {
name: String::new(),
config: Config::default(),
storage: Arc::new(crate::MemoryStorage::new()),
failure_classifier: None,
bulkhead: None,
}
}
}
impl std::fmt::Debug for CircuitContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitContext")
.field("name", &self.name)
.field("config", &self.config)
.field("storage", &"<dyn StorageBackend>")
.field(
"failure_classifier",
&self
.failure_classifier
.as_ref()
.map(|_| "<dyn FailureClassifier>"),
)
.field("bulkhead", &self.bulkhead)
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct OpenData {
pub opened_at: f64,
}
#[derive(Debug, Clone, Default)]
pub struct HalfOpenData {
pub consecutive_successes: usize,
}
state_machine! {
name: Circuit,
context: CircuitContext,
dynamic: true,
initial: Closed,
states: [
Closed,
Open(OpenData),
HalfOpen(HalfOpenData),
],
events {
trip {
guards: [should_open],
transition: { from: [Closed, HalfOpen], to: Open }
}
attempt_reset {
guards: [timeout_elapsed],
transition: { from: Open, to: HalfOpen }
}
close {
guards: [should_close],
transition: { from: HalfOpen, to: Closed }
}
}
}
impl Circuit<Closed> {
fn should_open(&self, ctx: &CircuitContext) -> bool {
let failures = ctx
.storage
.failure_count(&ctx.name, ctx.config.failure_window_secs);
if let Some(threshold) = ctx.config.failure_threshold
&& failures >= threshold
{
return true;
}
if let Some(rate_threshold) = ctx.config.failure_rate_threshold {
let successes = ctx
.storage
.success_count(&ctx.name, ctx.config.failure_window_secs);
let total = failures + successes;
if total >= ctx.config.minimum_calls {
let failure_rate = if total > 0 {
failures as f64 / total as f64
} else {
0.0
};
if failure_rate >= rate_threshold {
return true;
}
}
}
false
}
}
impl Circuit<HalfOpen> {
fn should_open(&self, ctx: &CircuitContext) -> bool {
let failures = ctx
.storage
.failure_count(&ctx.name, ctx.config.failure_window_secs);
if let Some(threshold) = ctx.config.failure_threshold
&& failures >= threshold
{
return true;
}
if let Some(rate_threshold) = ctx.config.failure_rate_threshold {
let successes = ctx
.storage
.success_count(&ctx.name, ctx.config.failure_window_secs);
let total = failures + successes;
if total >= ctx.config.minimum_calls {
let failure_rate = if total > 0 {
failures as f64 / total as f64
} else {
0.0
};
if failure_rate >= rate_threshold {
return true;
}
}
}
false
}
fn should_close(&self, ctx: &CircuitContext) -> bool {
let data = self
.state_data_half_open()
.expect("HalfOpen state must have data");
data.consecutive_successes >= ctx.config.success_threshold
}
}
impl Circuit<Open> {
fn timeout_elapsed(&self, ctx: &CircuitContext) -> bool {
let data = self.state_data_open().expect("Open state must have data");
let current_time = ctx.storage.monotonic_time();
let elapsed = current_time - data.opened_at;
let timeout_secs = if ctx.config.jitter_factor > 0.0 {
let policy = chrono_machines::Policy {
max_attempts: 1,
base_delay_ms: (ctx.config.half_open_timeout_secs * 1000.0) as u64,
multiplier: 1.0,
max_delay_ms: (ctx.config.half_open_timeout_secs * 1000.0) as u64,
};
let timeout_ms = policy.calculate_delay(1, ctx.config.jitter_factor);
(timeout_ms as f64) / 1000.0
} else {
ctx.config.half_open_timeout_secs
};
elapsed >= timeout_secs
}
}
pub struct CircuitBreaker {
machine: DynamicCircuit,
context: CircuitContext,
callbacks: Callbacks,
}
impl CircuitBreaker {
pub fn new(name: String, config: Config) -> Self {
let storage = Arc::new(crate::MemoryStorage::new());
let context = CircuitContext {
name,
config,
storage,
failure_classifier: None,
bulkhead: None,
};
let machine = DynamicCircuit::new(context.clone());
let callbacks = Callbacks::new();
Self {
machine,
context,
callbacks,
}
}
pub(crate) fn with_context_and_callbacks(
context: CircuitContext,
callbacks: Callbacks,
) -> Self {
let machine = DynamicCircuit::new(context.clone());
Self {
machine,
context,
callbacks,
}
}
pub fn builder(name: impl Into<String>) -> crate::builder::CircuitBuilder {
crate::builder::CircuitBuilder::new(name)
}
pub fn call<I, T, E: 'static>(&mut self, input: I) -> Result<T, CircuitError<E>>
where
I: IntoCallOptions<T, E>,
{
let (f, options) = input.into_call_options();
let _guard = if let Some(bulkhead) = &self.context.bulkhead {
match bulkhead.try_acquire() {
Some(guard) => Some(guard),
None => {
return Err(CircuitError::BulkheadFull {
circuit: self.context.name.clone(),
limit: bulkhead.limit(),
});
}
}
} else {
None
};
if self.machine.current_state() == "Open" {
let _ = self.machine.handle(CircuitEvent::AttemptReset);
if self.machine.current_state() == "HalfOpen" {
self.callbacks.trigger_half_open(&self.context.name);
}
}
match self.machine.current_state() {
"Open" => {
let opened_at = self.machine.open_data().map(|d| d.opened_at).unwrap_or(0.0);
if let Some(fallback) = options.fallback {
let ctx = FallbackContext {
circuit_name: self.context.name.clone(),
opened_at,
state: "Open",
};
return fallback(&ctx).map_err(CircuitError::Execution);
}
Err(CircuitError::Open {
circuit: self.context.name.clone(),
opened_at,
})
}
"HalfOpen" => {
if let Some(data) = self.machine.half_open_data()
&& data.consecutive_successes >= self.context.config.success_threshold
{
return Err(CircuitError::HalfOpenLimitReached {
circuit: self.context.name.clone(),
});
}
self.execute_call(f)
}
_ => self.execute_call(f),
}
}
fn execute_call<T, E: 'static>(
&mut self,
f: Box<dyn FnOnce() -> Result<T, E>>,
) -> Result<T, CircuitError<E>> {
let start = self.context.storage.monotonic_time();
match f() {
Ok(val) => {
let duration = self.context.storage.monotonic_time() - start;
self.context
.storage
.record_success(&self.context.name, duration);
if self.machine.current_state() == "HalfOpen" {
if let Some(data) = self.machine.half_open_data_mut() {
data.consecutive_successes += 1;
}
if self.machine.handle(CircuitEvent::Close).is_ok() {
self.callbacks.trigger_close(&self.context.name);
}
}
Ok(val)
}
Err(e) => {
let duration = self.context.storage.monotonic_time() - start;
let should_trip = if let Some(classifier) = &self.context.failure_classifier {
let ctx = crate::classifier::FailureContext {
circuit_name: &self.context.name,
error: &e as &dyn std::any::Any,
duration,
};
classifier.should_trip(&ctx)
} else {
true
};
if should_trip {
self.context
.storage
.record_failure(&self.context.name, duration);
let result = self.machine.handle(CircuitEvent::Trip);
if result.is_ok() {
self.mark_open();
} else if self.machine.current_state() == "HalfOpen" {
if let Some(data) = self.machine.half_open_data_mut() {
data.consecutive_successes = 0;
}
}
}
Err(CircuitError::Execution(e))
}
}
}
pub fn record_success_and_maybe_close(&mut self, duration: f64) {
self.context
.storage
.record_success(&self.context.name, duration);
if self.machine.current_state() == "HalfOpen" {
if let Some(data) = self.machine.half_open_data_mut() {
data.consecutive_successes += 1;
}
if self.machine.handle(CircuitEvent::Close).is_ok() {
self.callbacks.trigger_close(&self.context.name);
}
}
}
pub fn record_failure_and_maybe_trip(&mut self, duration: f64) {
self.context
.storage
.record_failure(&self.context.name, duration);
let result = self.machine.handle(CircuitEvent::Trip);
if result.is_ok() {
self.mark_open();
} else if self.machine.current_state() == "HalfOpen"
&& let Some(data) = self.machine.half_open_data_mut()
{
data.consecutive_successes = 0;
}
}
pub fn record_success(&self, duration: f64) {
self.context
.storage
.record_success(&self.context.name, duration);
}
pub fn record_failure(&self, duration: f64) {
self.context
.storage
.record_failure(&self.context.name, duration);
}
pub fn check_and_trip(&mut self) -> bool {
if self.machine.handle(CircuitEvent::Trip).is_ok() {
self.mark_open();
true
} else {
false
}
}
pub fn is_open(&self) -> bool {
self.machine.current_state() == "Open"
}
pub fn is_closed(&self) -> bool {
self.machine.current_state() == "Closed"
}
pub fn state_name(&self) -> &'static str {
self.machine.current_state()
}
pub fn reset(&mut self) {
self.context.storage.clear(&self.context.name);
self.machine = DynamicCircuit::new(self.context.clone());
}
fn mark_open(&mut self) {
if let Some(data) = self.machine.open_data_mut() {
data.opened_at = self.context.storage.monotonic_time();
}
self.callbacks.trigger_open(&self.context.name);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_creation() {
let config = Config::default();
let circuit = CircuitBreaker::new("test".to_string(), config);
assert!(circuit.is_closed());
assert!(!circuit.is_open());
}
#[test]
fn test_circuit_opens_after_threshold() {
let config = Config {
failure_threshold: Some(3),
..Default::default()
};
let mut circuit = CircuitBreaker::new("test".to_string(), config);
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_closed());
let _ = circuit.call(|| Err::<(), _>("error 3"));
assert!(circuit.is_open());
}
#[test]
fn test_reset_clears_state() {
let config = Config {
failure_threshold: Some(2),
..Default::default()
};
let mut circuit = CircuitBreaker::new("test".to_string(), config);
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_open());
circuit.reset();
assert!(circuit.is_closed());
}
#[test]
fn test_state_machine_closed_to_open_transition() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(3),
..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "test_circuit".to_string(),
config,
storage: storage.clone(),
};
let mut circuit = DynamicCircuit::new(ctx.clone());
let result = circuit.handle(CircuitEvent::Trip);
assert!(result.is_err(), "Should fail guard when below threshold");
storage.record_failure("test_circuit", 0.1);
storage.record_failure("test_circuit", 0.1);
storage.record_failure("test_circuit", 0.1);
circuit
.handle(CircuitEvent::Trip)
.expect("Should open after reaching threshold");
assert_eq!(circuit.current_state(), "Open");
}
#[test]
fn test_state_machine_open_to_half_open_transition() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(2),
half_open_timeout_secs: 0.001, ..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "test_circuit".to_string(),
config,
storage: storage.clone(),
};
storage.record_failure("test_circuit", 0.1);
storage.record_failure("test_circuit", 0.1);
let mut circuit = DynamicCircuit::new(ctx.clone());
circuit.handle(CircuitEvent::Trip).expect("Should open");
if let Some(data) = circuit.open_data_mut() {
data.opened_at = storage.monotonic_time();
}
let result = circuit.handle(CircuitEvent::AttemptReset);
assert!(
result.is_err(),
"Should fail guard when timeout not elapsed"
);
std::thread::sleep(std::time::Duration::from_millis(5));
circuit
.handle(CircuitEvent::AttemptReset)
.expect("Should reset after timeout");
assert_eq!(circuit.current_state(), "HalfOpen");
let data = circuit.half_open_data().expect("Should have HalfOpen data");
assert_eq!(data.consecutive_successes, 0);
}
#[test]
fn test_state_machine_half_open_to_closed_guard() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(2),
half_open_timeout_secs: 0.001,
..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "test_circuit".to_string(),
config,
storage: storage.clone(),
};
storage.record_failure("test_circuit", 0.1);
storage.record_failure("test_circuit", 0.1);
let mut circuit = DynamicCircuit::new(ctx.clone());
circuit.handle(CircuitEvent::Trip).expect("Should open");
if let Some(data) = circuit.open_data_mut() {
data.opened_at = storage.monotonic_time();
}
std::thread::sleep(std::time::Duration::from_millis(5));
circuit
.handle(CircuitEvent::AttemptReset)
.expect("Should reset");
let result = circuit.handle(CircuitEvent::Close);
assert!(result.is_err(), "Should fail guard without successes");
}
#[test]
fn test_jitter_disabled() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(1),
half_open_timeout_secs: 1.0, jitter_factor: 0.0, ..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "test_circuit".to_string(),
config,
storage: storage.clone(),
};
storage.record_failure("test_circuit", 0.1);
let mut circuit = DynamicCircuit::new(ctx.clone());
circuit.handle(CircuitEvent::Trip).expect("Should open");
if let Some(data) = circuit.open_data_mut() {
data.opened_at = storage.monotonic_time();
}
std::thread::sleep(std::time::Duration::from_secs(1));
circuit
.handle(CircuitEvent::AttemptReset)
.expect("Should reset after exact timeout");
assert_eq!(circuit.current_state(), "HalfOpen");
}
#[test]
fn test_jitter_enabled() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(1),
half_open_timeout_secs: 1.0,
jitter_factor: 0.1, ..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "test_circuit".to_string(),
config,
storage: storage.clone(),
};
let mut found_early_reset = false;
for _ in 0..10 {
storage.record_failure("test_circuit", 0.1);
let mut circuit = DynamicCircuit::new(ctx.clone());
circuit.handle(CircuitEvent::Trip).expect("Should open");
if let Some(data) = circuit.open_data_mut() {
data.opened_at = storage.monotonic_time();
}
std::thread::sleep(std::time::Duration::from_millis(950));
if circuit.handle(CircuitEvent::AttemptReset).is_ok() {
found_early_reset = true;
break;
}
storage.clear("test_circuit");
}
assert!(
found_early_reset,
"Jitter should occasionally allow reset before full timeout"
);
}
#[test]
fn test_builder_with_jitter() {
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(2)
.half_open_timeout_secs(1.0)
.jitter_factor(0.5) .build();
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_open());
assert_eq!(circuit.context.config.jitter_factor, 0.5);
}
#[test]
fn test_fallback_when_open() {
let mut circuit = CircuitBreaker::builder("test").failure_threshold(2).build();
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_open());
let result = circuit.call((
|| Err::<String, _>("should not execute"),
CallOptions::new().with_fallback(|ctx| {
assert_eq!(ctx.circuit_name, "test");
assert_eq!(ctx.state, "Open");
Ok("fallback response".to_string())
}),
));
assert!(result.is_ok());
assert_eq!(result.unwrap(), "fallback response");
}
#[test]
fn test_fallback_error_propagation() {
let mut circuit = CircuitBreaker::builder("test").failure_threshold(1).build();
let _ = circuit.call(|| Err::<(), _>("error"));
assert!(circuit.is_open());
let result = circuit.call((
|| Ok::<String, _>("should not execute".to_string()),
CallOptions::new().with_fallback(|_ctx| Err::<String, _>("fallback error")),
));
assert!(result.is_err());
match result {
Err(CircuitError::Execution(e)) => assert_eq!(e, "fallback error"),
_ => panic!("Expected CircuitError::Execution"),
}
}
#[test]
fn test_rate_based_threshold() {
let mut circuit = CircuitBreaker::builder("test")
.disable_failure_threshold() .failure_rate(0.5) .minimum_calls(10)
.build();
for i in 0..9 {
let _result = if i % 2 == 0 {
circuit.call(|| Ok::<(), _>(()))
} else {
circuit.call(|| Err::<(), _>("error"))
};
assert!(circuit.is_closed(), "Circuit opened before minimum calls");
}
let _ = circuit.call(|| Err::<(), _>("error"));
assert!(circuit.is_open(), "Circuit did not open at rate threshold");
}
#[test]
fn test_rate_and_absolute_threshold_both_active() {
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(3) .failure_rate(0.8) .minimum_calls(10)
.build();
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_closed());
let _ = circuit.call(|| Err::<(), _>("error 3"));
assert!(
circuit.is_open(),
"Circuit did not open at absolute threshold"
);
}
#[test]
fn test_minimum_calls_prevents_premature_trip() {
let mut circuit = CircuitBreaker::builder("test")
.disable_failure_threshold()
.failure_rate(0.5)
.minimum_calls(20)
.build();
for _ in 0..10 {
let _ = circuit.call(|| Err::<(), _>("error"));
}
assert!(
circuit.is_closed(),
"Circuit opened before reaching minimum_calls"
);
}
#[test]
fn test_failure_classifier_filters_errors() {
use crate::classifier::PredicateClassifier;
let classifier = Arc::new(PredicateClassifier::new(|ctx| {
ctx.error
.downcast_ref::<&str>()
.map(|e| e.contains("server"))
.unwrap_or(true)
}));
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(2)
.failure_classifier(classifier)
.build();
for _ in 0..5 {
let _ = circuit.call(|| Err::<(), _>("client_error"));
}
assert!(
circuit.is_closed(),
"Circuit should not trip on filtered errors"
);
let _ = circuit.call(|| Err::<(), _>("server_error_1"));
let _ = circuit.call(|| Err::<(), _>("server_error_2"));
assert!(circuit.is_open(), "Circuit should trip on server errors");
}
#[test]
fn test_failure_classifier_with_slow_errors() {
use crate::classifier::PredicateClassifier;
let classifier = Arc::new(PredicateClassifier::new(|ctx| ctx.duration > 0.5));
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(2)
.failure_classifier(classifier)
.build();
for _ in 0..10 {
let _ = circuit.call(|| Err::<(), _>("fast error"));
}
assert!(
circuit.is_closed(),
"Circuit should not trip on fast errors"
);
}
#[test]
fn test_no_classifier_default_behavior() {
let mut circuit = CircuitBreaker::builder("test").failure_threshold(3).build();
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_closed());
let _ = circuit.call(|| Err::<(), _>("error 3"));
assert!(
circuit.is_open(),
"All errors should trip circuit without classifier"
);
}
#[test]
fn test_classifier_with_custom_error_type() {
use crate::classifier::PredicateClassifier;
#[derive(Debug)]
enum ApiError {
ClientError(u16),
ServerError(u16),
}
let classifier = Arc::new(PredicateClassifier::new(|ctx| {
ctx.error
.downcast_ref::<ApiError>()
.map(|e| match e {
ApiError::ServerError(code) => *code >= 500,
ApiError::ClientError(code) => *code >= 500, })
.unwrap_or(true)
}));
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(2)
.failure_classifier(classifier)
.build();
for _ in 0..10 {
let _ = circuit.call(|| Err::<(), _>(ApiError::ClientError(404)));
}
assert!(circuit.is_closed(), "Client errors should not trip circuit");
let _ = circuit.call(|| Err::<(), _>(ApiError::ServerError(500)));
let _ = circuit.call(|| Err::<(), _>(ApiError::ServerError(503)));
assert!(circuit.is_open(), "Server errors should trip circuit");
}
#[test]
fn test_bulkhead_rejects_at_capacity() {
let mut circuit = CircuitBreaker::builder("test").max_concurrency(2).build();
let result1 = circuit.call(|| Ok::<_, String>("success 1"));
let result2 = circuit.call(|| Ok::<_, String>("success 2"));
assert!(result1.is_ok());
assert!(result2.is_ok());
}
#[test]
fn test_bulkhead_releases_on_success() {
use std::sync::{Arc, Mutex};
let circuit = Arc::new(Mutex::new(
CircuitBreaker::builder("test").max_concurrency(1).build(),
));
let result1 = circuit.lock().unwrap().call(|| Ok::<_, String>("success"));
assert!(result1.is_ok());
let result2 = circuit.lock().unwrap().call(|| Ok::<_, String>("success"));
assert!(result2.is_ok());
}
#[test]
fn test_bulkhead_releases_on_failure() {
use std::sync::{Arc, Mutex};
let circuit = Arc::new(Mutex::new(
CircuitBreaker::builder("test")
.max_concurrency(1)
.failure_threshold(10) .build(),
));
let result1 = circuit.lock().unwrap().call(|| Err::<(), _>("error"));
assert!(result1.is_err());
let result2 = circuit.lock().unwrap().call(|| Ok::<_, String>("success"));
assert!(result2.is_ok());
}
#[test]
fn test_bulkhead_without_limit() {
let mut circuit = CircuitBreaker::builder("test").build();
for _ in 0..100 {
let result = circuit.call(|| Ok::<_, String>("success"));
assert!(result.is_ok());
}
}
#[test]
fn test_bulkhead_error_contains_limit() {
use std::sync::Arc;
let bulkhead = Arc::new(BulkheadSemaphore::new(2));
let mut circuit = CircuitBreaker::builder("test").build();
circuit.context.bulkhead = Some(bulkhead.clone());
let _guard1 = bulkhead.try_acquire().unwrap();
let _guard2 = bulkhead.try_acquire().unwrap();
let result = circuit.call(|| Ok::<_, String>("should fail"));
match result {
Err(CircuitError::BulkheadFull {
circuit: name,
limit,
}) => {
assert_eq!(name, "test");
assert_eq!(limit, 2);
}
_ => panic!("Expected BulkheadFull error, got: {:?}", result),
}
drop(_guard1);
drop(_guard2);
let result = circuit.call(|| Ok::<_, String>("success"));
assert!(result.is_ok());
}
#[test]
fn test_bulkhead_with_circuit_breaker() {
let mut circuit = CircuitBreaker::builder("test")
.max_concurrency(5)
.failure_threshold(3)
.build();
let result = circuit.call(|| Ok::<_, String>("success"));
assert!(result.is_ok());
for _ in 0..3 {
let _ = circuit.call(|| Err::<(), _>("error"));
}
assert!(circuit.is_open());
let result = circuit.call(|| Ok::<_, String>("should fail"));
assert!(matches!(result, Err(CircuitError::Open { .. })));
}
#[test]
fn test_check_and_trip_sets_opened_at_and_callback() {
use std::sync::atomic::{AtomicBool, Ordering};
let opened = Arc::new(AtomicBool::new(false));
let opened_clone = opened.clone();
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(1)
.on_open(move |_name| {
opened_clone.store(true, Ordering::SeqCst);
})
.build();
circuit.record_failure(0.1);
let tripped = circuit.check_and_trip();
assert!(tripped, "Trip should succeed");
assert!(circuit.is_open(), "Circuit should be open after trip");
let opened_at = circuit
.machine
.open_data()
.expect("Open data should be present")
.opened_at;
assert!(opened_at > 0.0, "opened_at should be set");
assert!(
opened.load(Ordering::SeqCst),
"on_open callback should fire"
);
}
#[test]
fn test_half_open_failure_resets_consecutive_successes() {
let mut circuit = CircuitBreaker::builder("test")
.failure_threshold(2)
.half_open_timeout_secs(0.001)
.success_threshold(2)
.build();
let _ = circuit.call(|| Err::<(), _>("error 1"));
let _ = circuit.call(|| Err::<(), _>("error 2"));
assert!(circuit.is_open());
if let Some(data) = circuit.machine.open_data_mut() {
data.opened_at = circuit.context.storage.monotonic_time();
}
std::thread::sleep(std::time::Duration::from_millis(2));
circuit
.machine
.handle(CircuitEvent::AttemptReset)
.expect("Should transition to HalfOpen");
assert_eq!(circuit.machine.current_state(), "HalfOpen");
circuit.context.storage.clear("test");
let _ = circuit.call(|| Ok::<_, String>("ok"));
assert_eq!(
circuit
.machine
.half_open_data()
.expect("HalfOpen data")
.consecutive_successes,
1
);
let _ = circuit.call(|| Err::<(), _>("fail"));
assert_eq!(circuit.machine.current_state(), "HalfOpen");
assert_eq!(
circuit
.machine
.half_open_data()
.expect("HalfOpen data")
.consecutive_successes,
0
);
let _ = circuit.call(|| Ok::<_, String>("ok2"));
assert_eq!(
circuit
.machine
.half_open_data()
.expect("HalfOpen data")
.consecutive_successes,
1
);
}
#[test]
fn test_jitter_distribution_within_bounds() {
let storage = Arc::new(crate::MemoryStorage::new());
let base_timeout = 1.0; let jitter_factor = 0.25;
let config = Config {
failure_threshold: Some(1),
half_open_timeout_secs: base_timeout,
jitter_factor,
..Default::default()
};
let ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "jitter_test".to_string(),
config,
storage: storage.clone(),
};
let mut min_seen = f64::MAX;
let mut max_seen = f64::MIN;
for _ in 0..50 {
storage.record_failure("jitter_test", 0.1);
let mut circuit = DynamicCircuit::new(ctx.clone());
circuit.handle(CircuitEvent::Trip).expect("Should open");
if let Some(data) = circuit.open_data_mut() {
data.opened_at = storage.monotonic_time();
}
let policy = chrono_machines::Policy {
max_attempts: 1,
base_delay_ms: (base_timeout * 1000.0) as u64,
multiplier: 1.0,
max_delay_ms: (base_timeout * 1000.0) as u64,
};
let timeout_ms = policy.calculate_delay(1, jitter_factor);
let timeout_secs = (timeout_ms as f64) / 1000.0;
min_seen = min_seen.min(timeout_secs);
max_seen = max_seen.max(timeout_secs);
storage.clear("jitter_test");
}
let min_expected = base_timeout * (1.0 - jitter_factor);
let max_expected = base_timeout;
assert!(
min_seen >= min_expected - 0.01,
"Minimum jittered timeout {} should be >= {}",
min_seen,
min_expected
);
assert!(
max_seen <= max_expected + 0.01,
"Maximum jittered timeout {} should be <= {}",
max_seen,
max_expected
);
}
#[test]
fn test_jitter_produces_variance() {
let storage = Arc::new(crate::MemoryStorage::new());
let config = Config {
failure_threshold: Some(1),
half_open_timeout_secs: 1.0,
jitter_factor: 0.5, ..Default::default()
};
let _ctx = CircuitContext {
failure_classifier: None,
bulkhead: None,
name: "jitter_variance".to_string(),
config,
storage: storage.clone(),
};
let mut values = std::collections::HashSet::new();
for _ in 0..20 {
let policy = chrono_machines::Policy {
max_attempts: 1,
base_delay_ms: 1000,
multiplier: 1.0,
max_delay_ms: 1000,
};
let timeout_ms = policy.calculate_delay(1, 0.5);
values.insert(timeout_ms);
}
assert!(
values.len() >= 2,
"Jitter should produce variance, got {} unique values",
values.len()
);
}
#[test]
fn test_zero_jitter_produces_constant_timeout() {
let policy = chrono_machines::Policy {
max_attempts: 1,
base_delay_ms: 1000,
multiplier: 1.0,
max_delay_ms: 1000,
};
let mut values = std::collections::HashSet::new();
for _ in 0..10 {
let timeout_ms = policy.calculate_delay(1, 0.0);
values.insert(timeout_ms);
}
assert_eq!(
values.len(),
1,
"Zero jitter should produce constant timeout"
);
assert!(values.contains(&1000), "Timeout should be exactly 1000ms");
}
}