use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
pub const STATE_CLOSED: u64 = 0;
pub const STATE_OPEN: u64 = 1;
pub const STATE_HALF_OPEN: u64 = 2;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum CircuitHealthState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize)]
pub struct SubgraphCircuitHealth {
pub subgraph: String,
pub state: CircuitHealthState,
}
#[derive(Debug)]
enum CircuitState {
Closed { consecutive_failures: u32 },
Open {
opened_at: Instant,
recovery_timeout: Duration,
},
HalfOpen {
consecutive_failures: u32,
probe_in_flight: bool,
successes: u32,
},
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout_secs: u64,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout_secs: 30,
success_threshold: 2,
}
}
}
struct EntityCircuitBreaker {
config: CircuitBreakerConfig,
state: Mutex<CircuitState>,
}
impl EntityCircuitBreaker {
const fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Mutex::new(CircuitState::Closed {
consecutive_failures: 0,
}),
}
}
fn check(&self) -> Option<u64> {
let mut state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => return None,
CircuitState::Open {
opened_at,
recovery_timeout,
} => {
if opened_at.elapsed() < *recovery_timeout {
return Some(self.config.recovery_timeout_secs);
}
},
CircuitState::HalfOpen {
probe_in_flight: true,
..
} => {
return Some(self.config.recovery_timeout_secs);
},
CircuitState::HalfOpen {
probe_in_flight: false,
..
} => {
},
}
match &mut *state {
CircuitState::Open { .. } => {
*state = CircuitState::HalfOpen {
consecutive_failures: 0,
probe_in_flight: true,
successes: 0,
};
},
CircuitState::HalfOpen {
probe_in_flight, ..
} => {
*probe_in_flight = true;
},
CircuitState::Closed { .. } => {
},
}
None
}
fn record_success(&self) {
let mut state = self.state.lock();
let new_successes = match &*state {
CircuitState::HalfOpen { successes, .. } => *successes + 1,
CircuitState::Closed { .. } | CircuitState::Open { .. } => return,
};
if new_successes >= self.config.success_threshold {
*state = CircuitState::Closed {
consecutive_failures: 0,
};
info!("Federation circuit breaker closed after successful recovery");
} else if let CircuitState::HalfOpen {
successes,
probe_in_flight,
..
} = &mut *state
{
*successes = new_successes;
*probe_in_flight = false;
}
}
fn record_failure(&self) {
let mut state = self.state.lock();
let new_count = match &*state {
CircuitState::Open { .. } => return,
CircuitState::Closed {
consecutive_failures,
}
| CircuitState::HalfOpen {
consecutive_failures,
..
} => *consecutive_failures + 1,
};
if new_count >= self.config.failure_threshold {
let from_half_open = matches!(*state, CircuitState::HalfOpen { .. });
*state = CircuitState::Open {
opened_at: Instant::now(),
recovery_timeout: Duration::from_secs(self.config.recovery_timeout_secs),
};
if from_half_open {
info!(
consecutive_failures = new_count,
recovery_timeout_secs = self.config.recovery_timeout_secs,
"Federation circuit breaker re-opened from HalfOpen"
);
} else {
info!(
consecutive_failures = new_count,
recovery_timeout_secs = self.config.recovery_timeout_secs,
"Federation circuit breaker opened"
);
}
} else {
match &mut *state {
CircuitState::Closed {
consecutive_failures,
} => {
*consecutive_failures = new_count;
},
CircuitState::HalfOpen {
consecutive_failures,
probe_in_flight,
..
} => {
*consecutive_failures = new_count;
*probe_in_flight = false;
},
CircuitState::Open { .. } => {
},
}
}
}
fn state_code(&self) -> u64 {
let state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => STATE_CLOSED,
CircuitState::Open { .. } => STATE_OPEN,
CircuitState::HalfOpen { .. } => STATE_HALF_OPEN,
}
}
fn state_for_health(&self) -> CircuitHealthState {
let state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => CircuitHealthState::Closed,
CircuitState::Open { .. } => CircuitHealthState::Open,
CircuitState::HalfOpen { .. } => CircuitHealthState::HalfOpen,
}
}
}
#[derive(Deserialize, Debug)]
struct CircuitBreakerJson {
#[serde(default)]
enabled: bool,
failure_threshold: Option<u32>,
recovery_timeout_secs: Option<u64>,
success_threshold: Option<u32>,
#[serde(default, alias = "per_database")]
per_entity: Vec<PerEntityJson>,
}
#[derive(Deserialize, Debug)]
struct PerEntityJson {
#[serde(alias = "database")]
entity_type: String,
failure_threshold: Option<u32>,
recovery_timeout_secs: Option<u64>,
success_threshold: Option<u32>,
}
pub struct FederationCircuitBreakerManager {
breakers: DashMap<String, Arc<EntityCircuitBreaker>>,
default_config: CircuitBreakerConfig,
per_entity_config: DashMap<String, CircuitBreakerConfig>,
}
impl FederationCircuitBreakerManager {
fn new(default_config: CircuitBreakerConfig) -> Self {
Self {
breakers: DashMap::new(),
default_config,
per_entity_config: DashMap::new(),
}
}
#[must_use]
pub fn from_config(fed: &fraiseql_core::schema::FederationConfig) -> Option<Arc<Self>> {
let cb = fed.circuit_breaker.as_ref()?;
if !cb.enabled {
return None;
}
let default_config = CircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
recovery_timeout_secs: cb.recovery_timeout_secs,
success_threshold: cb.success_threshold,
};
let manager = Arc::new(Self::new(default_config));
for override_entry in &cb.per_entity {
let entity_config = CircuitBreakerConfig {
failure_threshold: override_entry
.failure_threshold
.unwrap_or(manager.default_config.failure_threshold),
recovery_timeout_secs: override_entry
.recovery_timeout
.unwrap_or(manager.default_config.recovery_timeout_secs),
success_threshold: override_entry
.success_threshold
.unwrap_or(manager.default_config.success_threshold),
};
manager.per_entity_config.insert(override_entry.entity.clone(), entity_config);
}
let override_keys: Vec<String> =
manager.per_entity_config.iter().map(|r| r.key().clone()).collect();
for entity_type in override_keys {
manager.get_or_create(&entity_type);
}
info!(
failure_threshold = manager.default_config.failure_threshold,
recovery_timeout_secs = manager.default_config.recovery_timeout_secs,
success_threshold = manager.default_config.success_threshold,
per_entity_overrides = manager.per_entity_config.len(),
"Federation circuit breaker initialized"
);
Some(manager)
}
#[must_use]
pub fn from_schema_json(federation_json: &serde_json::Value) -> Option<Arc<Self>> {
let cb_json: CircuitBreakerJson = match federation_json.get("circuit_breaker") {
None => return None,
Some(v) => match serde_json::from_value(v.clone()) {
Ok(j) => j,
Err(e) => {
warn!(
error = %e,
"circuit_breaker config present but malformed — circuit breaker disabled"
);
return None;
},
},
};
if !cb_json.enabled {
return None;
}
let default_config = CircuitBreakerConfig {
failure_threshold: cb_json.failure_threshold.unwrap_or(5),
recovery_timeout_secs: cb_json.recovery_timeout_secs.unwrap_or(30),
success_threshold: cb_json.success_threshold.unwrap_or(2),
};
let manager = Arc::new(Self::new(default_config));
for override_entry in cb_json.per_entity {
let entity_config = CircuitBreakerConfig {
failure_threshold: override_entry
.failure_threshold
.unwrap_or(manager.default_config.failure_threshold),
recovery_timeout_secs: override_entry
.recovery_timeout_secs
.unwrap_or(manager.default_config.recovery_timeout_secs),
success_threshold: override_entry
.success_threshold
.unwrap_or(manager.default_config.success_threshold),
};
manager.per_entity_config.insert(override_entry.entity_type, entity_config);
}
let override_keys: Vec<String> =
manager.per_entity_config.iter().map(|r| r.key().clone()).collect();
for entity_type in override_keys {
manager.get_or_create(&entity_type);
}
info!(
failure_threshold = manager.default_config.failure_threshold,
recovery_timeout_secs = manager.default_config.recovery_timeout_secs,
success_threshold = manager.default_config.success_threshold,
per_entity_overrides = manager.per_entity_config.len(),
"Federation circuit breaker initialized"
);
Some(manager)
}
fn get_or_create(&self, entity: &str) -> Arc<EntityCircuitBreaker> {
self.breakers
.entry(entity.to_string())
.or_insert_with(|| {
let config = self
.per_entity_config
.get(entity)
.map_or_else(|| self.default_config.clone(), |r| r.value().clone());
Arc::new(EntityCircuitBreaker::new(config))
})
.clone()
}
pub fn check(&self, entity: &str) -> Option<u64> {
self.get_or_create(entity).check()
}
pub fn record_success(&self, entity: &str) {
self.get_or_create(entity).record_success();
}
pub fn record_failure(&self, entity: &str) {
self.get_or_create(entity).record_failure();
}
#[must_use]
pub fn collect_states(&self) -> Vec<(String, u64)> {
self.breakers
.iter()
.map(|entry| (entry.key().clone(), entry.value().state_code()))
.collect()
}
#[must_use]
pub fn health_snapshot(&self) -> Vec<SubgraphCircuitHealth> {
self.breakers
.iter()
.map(|entry| SubgraphCircuitHealth {
subgraph: entry.key().clone(),
state: entry.value().state_for_health(),
})
.collect()
}
}
#[must_use]
pub fn extract_entity_types(variables: Option<&serde_json::Value>) -> Vec<String> {
let Some(vars) = variables else {
return vec![];
};
let Some(representations) = vars.get("representations").and_then(|r| r.as_array()) else {
return vec![];
};
let mut types = Vec::new();
for rep in representations {
if let Some(typename) = rep.get("__typename").and_then(|t| t.as_str()) {
types.push(typename.to_string());
} else {
warn!(
"Federation representation missing __typename field; entity skipped for circuit \
breaker"
);
}
}
types.sort_unstable();
types.dedup();
types
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use super::*;
#[test]
fn test_state_for_health_returns_closed_initially() {
let breaker = EntityCircuitBreaker::new(CircuitBreakerConfig::default());
assert!(matches!(breaker.state_for_health(), CircuitHealthState::Closed));
}
#[test]
fn test_state_for_health_returns_open_after_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 3600,
success_threshold: 2,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
assert!(matches!(breaker.state_for_health(), CircuitHealthState::Open));
}
#[test]
fn test_state_for_health_returns_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 0, success_threshold: 5,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
breaker.check(); assert!(matches!(breaker.state_for_health(), CircuitHealthState::HalfOpen));
}
#[test]
fn test_health_snapshot_returns_entries_for_all_breakers() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"failure_threshold": 1,
"recovery_timeout_secs": 3600,
"success_threshold": 2,
"per_entity": [
{ "entity_type": "Product", "failure_threshold": 1 },
{ "entity_type": "User", "failure_threshold": 1 }
]
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
manager.record_failure("Product");
let snapshot = manager.health_snapshot();
assert_eq!(snapshot.len(), 2, "should have one entry per configured entity");
let product = snapshot.iter().find(|s| s.subgraph == "Product").unwrap();
assert!(matches!(product.state, CircuitHealthState::Open));
let user = snapshot.iter().find(|s| s.subgraph == "User").unwrap();
assert!(matches!(user.state, CircuitHealthState::Closed));
}
#[test]
fn test_circuit_starts_closed() {
let breaker = EntityCircuitBreaker::new(CircuitBreakerConfig::default());
assert!(breaker.check().is_none());
assert_eq!(breaker.state_code(), STATE_CLOSED);
}
#[test]
fn test_circuit_opens_after_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
recovery_timeout_secs: 60,
success_threshold: 2,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
assert!(breaker.check().is_none());
breaker.record_failure();
assert!(breaker.check().is_none());
breaker.record_failure();
assert_eq!(breaker.check(), Some(60));
assert_eq!(breaker.state_code(), STATE_OPEN);
}
#[test]
fn test_circuit_stays_open_before_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 3600, success_threshold: 2,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.check(), Some(3600));
assert_eq!(breaker.state_code(), STATE_OPEN);
}
#[test]
fn test_circuit_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 0, success_threshold: 2,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
assert!(breaker.check().is_none());
assert_eq!(breaker.state_code(), STATE_HALF_OPEN);
}
#[test]
fn test_circuit_half_open_blocks_concurrent_probes() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 0,
success_threshold: 5, };
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
assert!(breaker.check().is_none(), "first probe should be allowed");
assert!(breaker.check().is_some(), "second concurrent probe should be rejected");
assert_eq!(breaker.state_code(), STATE_HALF_OPEN);
}
#[test]
fn test_circuit_closes_after_recovery() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 0,
success_threshold: 2,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
breaker.check(); assert_eq!(breaker.state_code(), STATE_HALF_OPEN);
breaker.record_success();
assert_eq!(breaker.state_code(), STATE_HALF_OPEN);
breaker.record_success();
assert_eq!(breaker.state_code(), STATE_CLOSED); }
#[test]
fn test_circuit_half_open_probe_cleared_after_success() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout_secs: 0,
success_threshold: 3,
};
let breaker = EntityCircuitBreaker::new(config);
breaker.record_failure();
breaker.check(); assert!(
breaker.check().is_some(),
"second check should return backoff while probe is in flight"
);
breaker.record_success(); assert!(breaker.check().is_none()); }
#[test]
fn test_extract_entity_types_from_representations() {
let vars = serde_json::json!({
"representations": [
{"__typename": "Product", "id": "1"},
{"__typename": "User", "id": "2"},
{"__typename": "Product", "id": "3"},
]
});
let types = extract_entity_types(Some(&vars));
assert_eq!(types, vec!["Product", "User"]);
}
#[test]
fn test_extract_entity_types_missing_representations() {
let vars = serde_json::json!({ "other": "data" });
assert!(extract_entity_types(Some(&vars)).is_empty());
}
#[test]
fn test_extract_entity_types_no_variables() {
assert!(extract_entity_types(None).is_empty());
}
#[test]
fn test_extract_entity_types_missing_typename_skipped() {
let vars = serde_json::json!({
"representations": [
{"id": "1"}, {"__typename": "User", "id": "2"},
]
});
let types = extract_entity_types(Some(&vars));
assert_eq!(types, vec!["User"]);
}
#[test]
fn test_manager_from_schema_json_disabled() {
let json = serde_json::json!({ "circuit_breaker": { "enabled": false } });
assert!(FederationCircuitBreakerManager::from_schema_json(&json).is_none());
}
#[test]
fn test_manager_from_schema_json_missing_section() {
let json = serde_json::json!({ "enabled": true, "entities": [] });
assert!(FederationCircuitBreakerManager::from_schema_json(&json).is_none());
}
#[test]
fn test_manager_from_schema_json_malformed_config() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"failure_threshold": "five"
}
});
assert!(FederationCircuitBreakerManager::from_schema_json(&json).is_none());
}
#[test]
fn test_manager_from_schema_json_enabled() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"failure_threshold": 3,
"recovery_timeout_secs": 30,
"success_threshold": 2,
"per_database": []
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
assert_eq!(manager.default_config.failure_threshold, 3);
}
#[test]
fn test_manager_from_schema_json_per_entity_new_key() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"per_entity": [
{ "entity_type": "Product", "failure_threshold": 2 }
]
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
manager.record_failure("Product");
manager.record_failure("Product");
assert!(manager.check("Product").is_some());
}
#[test]
fn test_manager_from_schema_json_per_entity_override() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"failure_threshold": 5,
"recovery_timeout_secs": 30,
"success_threshold": 2,
"per_database": [
{
"database": "Product",
"failure_threshold": 2
}
]
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
manager.record_failure("Product");
manager.record_failure("Product");
assert!(manager.check("Product").is_some());
manager.record_failure("User");
assert!(manager.check("User").is_none());
}
#[test]
fn test_manager_pre_seeds_overridden_entities() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"per_entity": [
{ "entity_type": "Product", "failure_threshold": 2 }
]
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
let states = manager.collect_states();
assert!(
states.iter().any(|(e, _)| e == "Product"),
"Product should be pre-seeded in the breakers map"
);
}
#[test]
fn test_manager_collect_states() {
let json = serde_json::json!({
"circuit_breaker": {
"enabled": true,
"failure_threshold": 1,
"recovery_timeout_secs": 60,
"success_threshold": 1,
"per_database": []
}
});
let manager = FederationCircuitBreakerManager::from_schema_json(&json).unwrap();
manager.record_failure("Product");
let states = manager.collect_states();
let product_state = states.iter().find(|(e, _)| e == "Product").map(|(_, s)| *s);
assert_eq!(product_state, Some(STATE_OPEN));
}
#[test]
fn test_concurrent_failures_no_spurious_open() {
use std::{sync::Arc as StdArc, thread};
let config = CircuitBreakerConfig {
failure_threshold: 10,
recovery_timeout_secs: 60,
success_threshold: 2,
};
let breaker = StdArc::new(EntityCircuitBreaker::new(config));
let handles: Vec<_> = (0..8)
.map(|_| {
let b = StdArc::clone(&breaker);
thread::spawn(move || b.record_failure())
})
.collect();
for handle in handles {
handle.join().expect("thread panicked");
}
assert!(breaker.check().is_none(), "circuit should remain closed after 8 < 10 failures");
assert_eq!(breaker.state_code(), STATE_CLOSED);
}
}