use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use super::base::Middleware;
use super::context_namespace::{enforce_context_key, namespace_keys::CIRCUIT_STATE, ContextWriter};
use crate::context::Context;
use crate::errors::ModuleError;
use crate::events::emitter::{ApCoreEvent, EventEmitter};
pub const DEFAULT_OPEN_THRESHOLD: f64 = 0.5;
pub const DEFAULT_WINDOW_SIZE: usize = 20;
pub const DEFAULT_RECOVERY_WINDOW_MS: u64 = 30_000;
pub const DEFAULT_MIN_SAMPLES: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
impl CircuitBreakerState {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Closed => "CLOSED",
Self::Open => "OPEN",
Self::HalfOpen => "HALF_OPEN",
}
}
}
#[derive(Debug)]
struct Circuit {
state: CircuitBreakerState,
window: VecDeque<bool>,
opened_at: Option<Instant>,
probe_in_flight: bool,
probe_started_at: Option<Instant>,
}
impl Circuit {
fn new() -> Self {
Self {
state: CircuitBreakerState::Closed,
window: VecDeque::new(),
opened_at: None,
probe_in_flight: false,
probe_started_at: None,
}
}
fn push(&mut self, success: bool, window_size: usize) {
let cap = window_size.max(1);
while self.window.len() >= cap {
self.window.pop_front();
}
self.window.push_back(success);
}
fn error_rate(&self) -> f64 {
if self.window.is_empty() {
return 0.0;
}
let errors = self.window.iter().filter(|s| !**s).count();
#[allow(clippy::cast_precision_loss)] let rate = errors as f64 / self.window.len() as f64;
rate
}
}
#[derive(Debug, Clone, Copy)]
pub struct CircuitBreakerConfig {
pub open_threshold: f64,
pub window_size: usize,
pub recovery_window_ms: u64,
pub min_samples: usize,
pub priority: u16,
}
pub const DEFAULT_PRIORITY: u16 = 900;
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
open_threshold: DEFAULT_OPEN_THRESHOLD,
window_size: DEFAULT_WINDOW_SIZE,
recovery_window_ms: DEFAULT_RECOVERY_WINDOW_MS,
min_samples: DEFAULT_MIN_SAMPLES,
priority: DEFAULT_PRIORITY,
}
}
}
#[derive(Default)]
pub struct CircuitBreakerBuilder {
config: CircuitBreakerConfig,
emitter: Option<Arc<EventEmitter>>,
clock: Option<Box<dyn Fn() -> Instant + Send + Sync>>,
}
impl std::fmt::Debug for CircuitBreakerBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerBuilder")
.field("config", &self.config)
.field("has_emitter", &self.emitter.is_some())
.field("has_clock", &self.clock.is_some())
.finish()
}
}
impl CircuitBreakerBuilder {
#[must_use]
pub fn open_threshold(mut self, value: f64) -> Self {
self.config.open_threshold = value;
self
}
#[must_use]
pub fn window_size(mut self, value: usize) -> Self {
self.config.window_size = value.max(1);
self
}
#[must_use]
pub fn recovery_window_ms(mut self, value: u64) -> Self {
self.config.recovery_window_ms = value;
self
}
#[must_use]
pub fn min_samples(mut self, value: usize) -> Self {
self.config.min_samples = value;
self
}
#[must_use]
pub fn priority(mut self, value: u16) -> Self {
self.config.priority = value;
self
}
#[must_use]
pub fn emitter(mut self, emitter: Arc<EventEmitter>) -> Self {
self.emitter = Some(emitter);
self
}
#[must_use]
pub fn clock<F>(mut self, clock: F) -> Self
where
F: Fn() -> Instant + Send + Sync + 'static,
{
self.clock = Some(Box::new(clock));
self
}
#[must_use]
pub fn build(self) -> CircuitBreakerMiddleware {
let mut mw = CircuitBreakerMiddleware::with_parts(self.config, self.emitter);
if let Some(clock) = self.clock {
mw.clock = clock;
}
mw
}
}
pub struct CircuitBreakerMiddleware {
config: CircuitBreakerConfig,
emitter: Option<Arc<EventEmitter>>,
circuits: Mutex<HashMap<(String, String), Circuit>>,
clock: Box<dyn Fn() -> Instant + Send + Sync>,
}
impl std::fmt::Debug for CircuitBreakerMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerMiddleware")
.field("config", &self.config)
.field("circuit_count", &self.circuits.lock().len())
.finish_non_exhaustive()
}
}
impl CircuitBreakerMiddleware {
#[must_use]
pub fn builder() -> CircuitBreakerBuilder {
CircuitBreakerBuilder::default()
}
#[must_use]
pub fn with_parts(
mut config: CircuitBreakerConfig,
emitter: Option<Arc<EventEmitter>>,
) -> Self {
if config.window_size == 0 {
tracing::warn!(
"CircuitBreakerConfig.window_size was 0; clamping to 1. \
The breaker will only retain the most recent sample."
);
config.window_size = 1;
}
if config.min_samples > config.window_size {
tracing::warn!(
min_samples = config.min_samples,
window_size = config.window_size,
"CircuitBreakerConfig.min_samples exceeds window_size; \
the breaker can never open. Reduce min_samples or grow window_size."
);
}
Self {
config,
emitter,
circuits: Mutex::new(HashMap::new()),
clock: Box::new(Instant::now),
}
}
#[must_use]
pub fn with_clock<F>(mut self, clock: F) -> Self
where
F: Fn() -> Instant + Send + Sync + 'static,
{
self.clock = Box::new(clock);
self
}
#[must_use]
pub fn state(&self, module_id: &str, caller_id: &str) -> CircuitBreakerState {
let key = (module_id.to_string(), caller_id.to_string());
self.circuits
.lock()
.get(&key)
.map_or(CircuitBreakerState::Closed, |c| c.state)
}
pub fn force_state(
&self,
module_id: &str,
caller_id: &str,
state: CircuitBreakerState,
opened_at: Option<Instant>,
) {
let key = (module_id.to_string(), caller_id.to_string());
let mut circuits = self.circuits.lock();
let entry = circuits.entry(key).or_insert_with(Circuit::new);
entry.state = state;
entry.opened_at = opened_at;
entry.probe_in_flight = false;
entry.probe_started_at = None;
}
fn key_of(module_id: &str, ctx: &Context<serde_json::Value>) -> (String, String) {
(
module_id.to_string(),
ctx.caller_id.clone().unwrap_or_default(),
)
}
fn write_state_to_context(ctx: &Context<serde_json::Value>, state: CircuitBreakerState) {
let _ = enforce_context_key(ContextWriter::Framework, CIRCUIT_STATE);
let mut data = ctx.data.write();
data.insert(
CIRCUIT_STATE.to_string(),
serde_json::Value::String(state.as_str().to_string()),
);
}
fn build_event(
event_type: &str,
module_id: &str,
caller_id: &str,
error_rate: f64,
severity: &str,
) -> ApCoreEvent {
ApCoreEvent::with_module(
event_type,
serde_json::json!({
"module_id": module_id,
"caller_id": caller_id,
"error_rate": error_rate,
}),
module_id,
severity,
)
}
async fn emit(&self, event: ApCoreEvent) {
if let Some(emitter) = &self.emitter {
emitter.emit(&event).await;
}
}
}
#[async_trait]
impl Middleware for CircuitBreakerMiddleware {
fn name(&self) -> &'static str {
"circuit_breaker"
}
fn priority(&self) -> u16 {
self.config.priority
}
async fn before(
&self,
module_id: &str,
_inputs: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
let key = Self::key_of(module_id, ctx);
let now = (self.clock)();
let recovery = Duration::from_millis(self.config.recovery_window_ms);
let outcome = {
let mut circuits = self.circuits.lock();
let circuit = circuits.entry(key.clone()).or_insert_with(Circuit::new);
if circuit.state == CircuitBreakerState::Open {
if let Some(opened_at) = circuit.opened_at {
let elapsed = now.saturating_duration_since(opened_at);
if elapsed >= recovery {
circuit.state = CircuitBreakerState::HalfOpen;
circuit.probe_in_flight = false;
circuit.probe_started_at = None;
}
}
}
if circuit.state == CircuitBreakerState::HalfOpen && circuit.probe_in_flight {
if let Some(started) = circuit.probe_started_at {
if now.saturating_duration_since(started) >= recovery {
circuit.probe_in_flight = false;
circuit.probe_started_at = None;
}
}
}
match circuit.state {
CircuitBreakerState::Open => Outcome::Reject,
CircuitBreakerState::HalfOpen => {
if circuit.probe_in_flight {
Outcome::Reject
} else {
circuit.probe_in_flight = true;
circuit.probe_started_at = Some(now);
Outcome::Allow(CircuitBreakerState::HalfOpen)
}
}
CircuitBreakerState::Closed => Outcome::Allow(CircuitBreakerState::Closed),
}
};
match outcome {
Outcome::Reject => {
Self::write_state_to_context(ctx, CircuitBreakerState::Open);
Err(ModuleError::circuit_breaker_open(&key.0, &key.1))
}
Outcome::Allow(state) => {
Self::write_state_to_context(ctx, state);
Ok(None)
}
}
}
async fn after(
&self,
module_id: &str,
_inputs: serde_json::Value,
_output: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
let key = Self::key_of(module_id, ctx);
let event = {
let mut circuits = self.circuits.lock();
let circuit = circuits.entry(key.clone()).or_insert_with(Circuit::new);
circuit.push(true, self.config.window_size);
if circuit.state == CircuitBreakerState::HalfOpen {
circuit.state = CircuitBreakerState::Closed;
circuit.opened_at = None;
circuit.probe_in_flight = false;
circuit.probe_started_at = None;
circuit.window.clear();
Some(Self::build_event(
"apcore.circuit.closed",
&key.0,
&key.1,
0.0,
"info",
))
} else {
None
}
};
if let Some(event) = event {
self.emit(event).await;
}
Ok(None)
}
async fn on_error(
&self,
module_id: &str,
_inputs: serde_json::Value,
error: &ModuleError,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
if error.code == crate::errors::ErrorCode::CircuitBreakerOpen {
return Ok(None);
}
let key = Self::key_of(module_id, ctx);
let now = (self.clock)();
let event = {
let mut circuits = self.circuits.lock();
let circuit = circuits.entry(key.clone()).or_insert_with(Circuit::new);
circuit.push(false, self.config.window_size);
let rate = circuit.error_rate();
let opens = match circuit.state {
CircuitBreakerState::HalfOpen => true,
CircuitBreakerState::Closed => {
circuit.window.len() >= self.config.min_samples
&& rate >= self.config.open_threshold
}
CircuitBreakerState::Open => false,
};
if opens {
circuit.state = CircuitBreakerState::Open;
circuit.opened_at = Some(now);
circuit.probe_in_flight = false;
circuit.probe_started_at = None;
Some(Self::build_event(
"apcore.circuit.opened",
&key.0,
&key.1,
rate,
"warn",
))
} else {
None
}
};
if let Some(event) = event {
self.emit(event).await;
}
Ok(None)
}
}
enum Outcome {
Allow(CircuitBreakerState),
Reject,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{Context, Identity};
use crate::errors::ErrorCode;
fn ctx_with_caller(caller: &str) -> Context<serde_json::Value> {
let identity = Identity::new(
"test".to_string(),
"user".to_string(),
vec![],
std::collections::HashMap::new(),
);
let mut ctx: Context<serde_json::Value> = Context::new(identity);
ctx.caller_id = Some(caller.to_string());
ctx
}
#[tokio::test]
async fn closed_circuit_allows_calls() {
let mw = CircuitBreakerMiddleware::builder().build();
let ctx = ctx_with_caller("orchestrator.x");
let result = mw.before("executor.foo", serde_json::json!({}), &ctx).await;
assert!(result.is_ok());
assert_eq!(
mw.state("executor.foo", "orchestrator.x"),
CircuitBreakerState::Closed
);
}
#[tokio::test]
async fn opens_when_error_rate_exceeds_threshold() {
let mw = CircuitBreakerMiddleware::builder()
.open_threshold(0.5)
.window_size(10)
.min_samples(5)
.build();
let ctx = ctx_with_caller("orchestrator.billing");
let module = "executor.payment.charge";
let err = ModuleError::new(ErrorCode::ModuleExecuteError, "boom");
for _ in 0..6 {
mw.on_error(module, serde_json::json!({}), &err, &ctx)
.await
.unwrap();
}
assert_eq!(
mw.state(module, "orchestrator.billing"),
CircuitBreakerState::Open
);
}
#[tokio::test]
async fn open_state_short_circuits_before() {
let mw = CircuitBreakerMiddleware::builder().build();
let ctx = ctx_with_caller("orchestrator.billing");
let module = "executor.payment.charge";
mw.force_state(
module,
"orchestrator.billing",
CircuitBreakerState::Open,
Some(Instant::now()),
);
let result = mw.before(module, serde_json::json!({}), &ctx).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::CircuitBreakerOpen);
}
#[tokio::test]
async fn open_to_half_open_after_recovery_window() {
let mw = CircuitBreakerMiddleware::builder()
.recovery_window_ms(30_000)
.build();
let ctx = ctx_with_caller("orchestrator.billing");
let module = "executor.payment.charge";
let opened_at = Instant::now()
.checked_sub(Duration::from_secs(35))
.expect("test clock far enough from epoch");
mw.force_state(
module,
"orchestrator.billing",
CircuitBreakerState::Open,
Some(opened_at),
);
let r = mw.before(module, serde_json::json!({}), &ctx).await;
assert!(r.is_ok());
assert_eq!(
mw.state(module, "orchestrator.billing"),
CircuitBreakerState::HalfOpen
);
}
#[tokio::test]
async fn half_open_concurrent_probes_capped_at_one() {
let mw = CircuitBreakerMiddleware::builder().build();
let ctx = ctx_with_caller("orchestrator.billing");
let module = "executor.payment.charge";
mw.force_state(
module,
"orchestrator.billing",
CircuitBreakerState::HalfOpen,
Some(Instant::now()),
);
let first = mw.before(module, serde_json::json!({}), &ctx).await;
assert!(first.is_ok());
let second = mw.before(module, serde_json::json!({}), &ctx).await;
assert!(second.is_err());
assert_eq!(second.unwrap_err().code, ErrorCode::CircuitBreakerOpen);
}
#[tokio::test]
async fn half_open_success_closes_circuit() {
let mw = CircuitBreakerMiddleware::builder().build();
let ctx = ctx_with_caller("orchestrator.billing");
let module = "executor.payment.charge";
mw.force_state(
module,
"orchestrator.billing",
CircuitBreakerState::HalfOpen,
Some(Instant::now()),
);
mw.before(module, serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after(module, serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
assert_eq!(
mw.state(module, "orchestrator.billing"),
CircuitBreakerState::Closed
);
}
#[tokio::test]
async fn writes_state_to_context_data() {
let mw = CircuitBreakerMiddleware::builder().build();
let ctx = ctx_with_caller("orch");
mw.before("mod.a", serde_json::json!({}), &ctx)
.await
.unwrap();
let data = ctx.data.read();
assert_eq!(
data.get("_apcore.mw.circuit.state")
.and_then(|v| v.as_str()),
Some("CLOSED")
);
}
#[tokio::test]
async fn ignores_circuit_breaker_open_in_on_error() {
let mw = CircuitBreakerMiddleware::builder()
.open_threshold(0.5)
.min_samples(2)
.build();
let ctx = ctx_with_caller("orch");
let err = ModuleError::circuit_breaker_open("mod.a", "orch");
mw.on_error("mod.a", serde_json::json!({}), &err, &ctx)
.await
.unwrap();
assert_eq!(mw.state("mod.a", "orch"), CircuitBreakerState::Closed);
}
}