use crate::error::CanoError;
use crate::resource::Resource;
use cano_macros::resource;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[cfg(feature = "tracing")]
use tracing::info;
#[derive(Debug, Clone)]
pub struct CircuitPolicy {
pub failure_threshold: u32,
pub reset_timeout: Duration,
pub half_open_max_calls: u32,
}
impl Default for CircuitPolicy {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open {
until: Instant,
},
HalfOpen,
}
#[derive(Debug)]
struct Inner {
state: CircuitState,
consecutive_failures: u32,
half_open_in_flight: u32,
half_open_successes: u32,
epoch: u64,
}
impl Inner {
fn new() -> Self {
Self {
state: CircuitState::Closed,
consecutive_failures: 0,
half_open_in_flight: 0,
half_open_successes: 0,
epoch: 0,
}
}
fn reset_half_open(&mut self) {
self.half_open_in_flight = 0;
self.half_open_successes = 0;
}
fn bump_epoch(&mut self) {
self.epoch = self.epoch.wrapping_add(1);
}
fn transition(&mut self, new_state: CircuitState) {
self.state = new_state;
self.reset_half_open();
self.bump_epoch();
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
inner: Arc<Mutex<Inner>>,
policy: CircuitPolicy,
}
const HALF_OPEN_MAX_CALLS_LIMIT: u32 = u32::MAX / 2;
impl CircuitBreaker {
pub fn new(policy: CircuitPolicy) -> Self {
assert!(
policy.failure_threshold >= 1,
"CircuitPolicy::failure_threshold must be >= 1; 0 would make the breaker trip semantics nonsensical"
);
assert!(
policy.half_open_max_calls >= 1,
"CircuitPolicy::half_open_max_calls must be >= 1; 0 deadlocks the breaker in HalfOpen (no trial would ever be admitted)"
);
assert!(
policy.half_open_max_calls <= HALF_OPEN_MAX_CALLS_LIMIT,
"CircuitPolicy::half_open_max_calls must be <= {HALF_OPEN_MAX_CALLS_LIMIT}; got {} (values near u32::MAX may prevent the breaker from ever closing)",
policy.half_open_max_calls
);
Self {
inner: Arc::new(Mutex::new(Inner::new())),
policy,
}
}
pub fn policy(&self) -> &CircuitPolicy {
&self.policy
}
pub fn state(&self) -> CircuitState {
let inner = self.inner.lock().expect("circuit breaker mutex poisoned");
inner.state
}
pub fn try_acquire(self: &Arc<Self>) -> Result<Permit, CanoError> {
let mut inner = self.inner.lock().expect("circuit breaker mutex poisoned");
if let CircuitState::Open { until } = inner.state
&& Instant::now() >= until
{
inner.transition(CircuitState::HalfOpen);
#[cfg(feature = "tracing")]
info!(
half_open_max_calls = self.policy.half_open_max_calls,
"circuit breaker transition: Open -> HalfOpen"
);
}
let epoch = inner.epoch;
match inner.state {
CircuitState::Closed => Ok(Permit::new(Arc::clone(self), false, epoch)),
CircuitState::Open { .. } => Err(CanoError::circuit_open(
"circuit breaker open: rejecting call",
)),
CircuitState::HalfOpen => {
if inner.half_open_in_flight >= self.policy.half_open_max_calls {
return Err(CanoError::circuit_open(
"circuit breaker half-open: trial slot exhausted",
));
}
inner.half_open_in_flight += 1;
Ok(Permit::new(Arc::clone(self), true, epoch))
}
}
}
pub fn record_success(&self, mut permit: Permit) {
permit.consumed = true;
let mut inner = self.inner.lock().expect("circuit breaker mutex poisoned");
if permit.epoch != inner.epoch {
return;
}
if permit.was_half_open && inner.half_open_in_flight > 0 {
inner.half_open_in_flight -= 1;
}
match inner.state {
CircuitState::Closed => {
inner.consecutive_failures = 0;
}
CircuitState::HalfOpen => {
inner.half_open_successes = inner.half_open_successes.saturating_add(1);
if inner.half_open_successes >= self.policy.half_open_max_calls {
inner.consecutive_failures = 0;
inner.transition(CircuitState::Closed);
#[cfg(feature = "tracing")]
info!(
successes = self.policy.half_open_max_calls,
"circuit breaker transition: HalfOpen -> Closed"
);
}
}
CircuitState::Open { .. } => {
debug_assert!(
false,
"record_success on Open with current epoch is unreachable; epoch tracking should have filtered the stale outcome"
);
}
}
}
pub fn record_failure(&self, mut permit: Permit) {
permit.consumed = true;
self.do_record_failure(permit.was_half_open, permit.epoch);
}
fn do_record_failure(&self, was_half_open: bool, permit_epoch: u64) {
let mut inner = self.inner.lock().expect("circuit breaker mutex poisoned");
if permit_epoch != inner.epoch {
return;
}
if was_half_open && inner.half_open_in_flight > 0 {
inner.half_open_in_flight -= 1;
}
match inner.state {
CircuitState::Closed => {
inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
if inner.consecutive_failures >= self.policy.failure_threshold {
inner.transition(CircuitState::Open {
until: Instant::now() + self.policy.reset_timeout,
});
#[cfg(feature = "tracing")]
info!(
failure_threshold = self.policy.failure_threshold,
reset_timeout_ms = self.policy.reset_timeout.as_millis() as u64,
"circuit breaker transition: Closed -> Open"
);
}
}
CircuitState::HalfOpen => {
inner.transition(CircuitState::Open {
until: Instant::now() + self.policy.reset_timeout,
});
#[cfg(feature = "tracing")]
info!(
reset_timeout_ms = self.policy.reset_timeout.as_millis() as u64,
"circuit breaker transition: HalfOpen -> Open"
);
}
CircuitState::Open { .. } => {
debug_assert!(
false,
"record_failure on Open with current epoch is unreachable; epoch tracking should have filtered the stale outcome"
);
}
}
}
}
#[resource]
impl Resource for CircuitBreaker {}
#[must_use = "drop a Permit only as a deliberate failure signal; pass it to record_success or record_failure to indicate the call outcome"]
pub struct Permit {
breaker: Arc<CircuitBreaker>,
was_half_open: bool,
consumed: bool,
epoch: u64,
}
impl Permit {
fn new(breaker: Arc<CircuitBreaker>, was_half_open: bool, epoch: u64) -> Self {
Self {
breaker,
was_half_open,
consumed: false,
epoch,
}
}
}
impl std::fmt::Debug for Permit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Permit")
.field("was_half_open", &self.was_half_open)
.field("consumed", &self.consumed)
.field("epoch", &self.epoch)
.finish()
}
}
impl Drop for Permit {
fn drop(&mut self) {
if !self.consumed {
self.breaker
.do_record_failure(self.was_half_open, self.epoch);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fast_policy() -> CircuitPolicy {
CircuitPolicy {
failure_threshold: 3,
reset_timeout: Duration::from_millis(20),
half_open_max_calls: 2,
}
}
#[test]
fn closed_initial_state() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn opens_after_threshold_failures() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let permit = breaker.try_acquire().unwrap();
breaker.record_failure(permit);
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
let err = breaker.try_acquire().unwrap_err();
assert_eq!(err.category(), "circuit_open");
}
#[test]
fn success_resets_consecutive_failures() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..2 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
let p = breaker.try_acquire().unwrap();
breaker.record_success(p);
for _ in 0..2 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn open_to_half_open_after_reset_timeout() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
tokio::time::sleep(Duration::from_millis(40)).await;
let permit = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success(permit);
}
#[tokio::test]
async fn half_open_full_success_closes() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
tokio::time::sleep(Duration::from_millis(40)).await;
let p1 = breaker.try_acquire().unwrap();
breaker.record_success(p1);
let p2 = breaker.try_acquire().unwrap();
breaker.record_success(p2);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn half_open_failure_reopens() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
tokio::time::sleep(Duration::from_millis(40)).await;
let p = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_failure(p);
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
}
#[tokio::test]
async fn half_open_caps_concurrent_trials() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
tokio::time::sleep(Duration::from_millis(40)).await;
let p1 = breaker.try_acquire().unwrap();
let p2 = breaker.try_acquire().unwrap();
let err = breaker.try_acquire().unwrap_err();
assert_eq!(err.category(), "circuit_open");
breaker.record_success(p1);
breaker.record_success(p2);
}
#[test]
fn dropped_permit_counts_as_failure() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let _p = breaker.try_acquire().unwrap();
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
}
#[test]
fn shared_breaker_trips_for_all_callers() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
let breaker_a = Arc::clone(&breaker);
let breaker_b = Arc::clone(&breaker);
for _ in 0..3 {
let p = breaker_a.try_acquire().unwrap();
breaker_a.record_failure(p);
}
let err = breaker_b.try_acquire().unwrap_err();
assert_eq!(err.category(), "circuit_open");
}
#[test]
#[should_panic(expected = "half_open_max_calls must be >= 1")]
fn rejects_zero_half_open_max_calls() {
let _ = CircuitBreaker::new(CircuitPolicy {
failure_threshold: 1,
reset_timeout: Duration::from_millis(1),
half_open_max_calls: 0,
});
}
#[test]
#[should_panic(expected = "failure_threshold must be >= 1")]
fn rejects_zero_failure_threshold() {
let _ = CircuitBreaker::new(CircuitPolicy {
failure_threshold: 0,
reset_timeout: Duration::from_millis(1),
half_open_max_calls: 1,
});
}
#[tokio::test]
async fn stale_closed_permit_does_not_close_a_later_half_open() {
let breaker = Arc::new(CircuitBreaker::new(CircuitPolicy {
failure_threshold: 2,
reset_timeout: Duration::from_millis(20),
half_open_max_calls: 1,
}));
let stale_permit = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::Closed);
for _ in 0..2 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
tokio::time::sleep(Duration::from_millis(40)).await;
let probe_permit = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success(stale_permit);
assert_eq!(
breaker.state(),
CircuitState::HalfOpen,
"stale Closed-epoch success must not close a later HalfOpen"
);
breaker.record_success(probe_permit);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn open_state_is_debug_loggable() {
let breaker = Arc::new(CircuitBreaker::new(fast_policy()));
for _ in 0..3 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
let dbg = format!("{:?}", breaker.state());
assert!(
dbg.starts_with("Open { until:"),
"expected Debug to start with `Open {{ until:`, got: {dbg}"
);
}
#[test]
#[should_panic(expected = "must be <=")]
fn rejects_pathological_half_open_max_calls() {
let _ = CircuitBreaker::new(CircuitPolicy {
failure_threshold: 1,
reset_timeout: Duration::from_millis(1),
half_open_max_calls: u32::MAX,
});
}
#[tokio::test]
async fn open_half_open_open_round_trip() {
let breaker = Arc::new(CircuitBreaker::new(CircuitPolicy {
failure_threshold: 2,
reset_timeout: Duration::from_millis(20),
half_open_max_calls: 1,
}));
for _ in 0..2 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
tokio::time::sleep(Duration::from_millis(40)).await;
let p = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_failure(p);
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
tokio::time::sleep(Duration::from_millis(40)).await;
let p = breaker.try_acquire().unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success(p);
assert_eq!(breaker.state(), CircuitState::Closed);
for _ in 0..2 {
let p = breaker.try_acquire().unwrap();
breaker.record_failure(p);
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
}
}