use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
pub const STATE_CLOSED: u64 = 0;
pub const STATE_OPEN: u64 = 1;
pub const STATE_HALF_OPEN: u64 = 2;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum CircuitHealthState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize)]
pub struct SubgraphCircuitHealth {
pub subgraph: String,
pub state: CircuitHealthState,
}
#[derive(Debug)]
enum CircuitState {
Closed { consecutive_failures: u32 },
Open {
opened_at: Instant,
recovery_timeout: Duration,
},
HalfOpen {
consecutive_failures: u32,
probe_in_flight: bool,
successes: u32,
},
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout_secs: u64,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout_secs: 30,
success_threshold: 2,
}
}
}
pub(crate) struct EntityCircuitBreaker {
pub(crate) config: CircuitBreakerConfig,
state: Mutex<CircuitState>,
}
impl EntityCircuitBreaker {
pub(crate) const fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Mutex::new(CircuitState::Closed {
consecutive_failures: 0,
}),
}
}
pub(crate) fn check(&self) -> Option<u64> {
let mut state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => return None,
CircuitState::Open {
opened_at,
recovery_timeout,
} => {
if opened_at.elapsed() < *recovery_timeout {
return Some(self.config.recovery_timeout_secs);
}
},
CircuitState::HalfOpen {
probe_in_flight: true,
..
} => {
return Some(self.config.recovery_timeout_secs);
},
CircuitState::HalfOpen {
probe_in_flight: false,
..
} => {
},
}
match &mut *state {
CircuitState::Open { .. } => {
*state = CircuitState::HalfOpen {
consecutive_failures: 0,
probe_in_flight: true,
successes: 0,
};
},
CircuitState::HalfOpen {
probe_in_flight, ..
} => {
*probe_in_flight = true;
},
CircuitState::Closed { .. } => {
},
}
None
}
pub(crate) fn record_success(&self) {
let mut state = self.state.lock();
let new_successes = match &*state {
CircuitState::HalfOpen { successes, .. } => *successes + 1,
CircuitState::Closed { .. } | CircuitState::Open { .. } => return,
};
if new_successes >= self.config.success_threshold {
*state = CircuitState::Closed {
consecutive_failures: 0,
};
info!("Federation circuit breaker closed after successful recovery");
} else if let CircuitState::HalfOpen {
successes,
probe_in_flight,
..
} = &mut *state
{
*successes = new_successes;
*probe_in_flight = false;
}
}
pub(crate) fn record_failure(&self) {
let mut state = self.state.lock();
let new_count = match &*state {
CircuitState::Open { .. } => return,
CircuitState::Closed {
consecutive_failures,
}
| CircuitState::HalfOpen {
consecutive_failures,
..
} => *consecutive_failures + 1,
};
if new_count >= self.config.failure_threshold {
let from_half_open = matches!(*state, CircuitState::HalfOpen { .. });
*state = CircuitState::Open {
opened_at: Instant::now(),
recovery_timeout: Duration::from_secs(self.config.recovery_timeout_secs),
};
if from_half_open {
info!(
consecutive_failures = new_count,
recovery_timeout_secs = self.config.recovery_timeout_secs,
"Federation circuit breaker re-opened from HalfOpen"
);
} else {
info!(
consecutive_failures = new_count,
recovery_timeout_secs = self.config.recovery_timeout_secs,
"Federation circuit breaker opened"
);
}
} else {
match &mut *state {
CircuitState::Closed {
consecutive_failures,
} => {
*consecutive_failures = new_count;
},
CircuitState::HalfOpen {
consecutive_failures,
probe_in_flight,
..
} => {
*consecutive_failures = new_count;
*probe_in_flight = false;
},
CircuitState::Open { .. } => {
},
}
}
}
pub(crate) fn state_code(&self) -> u64 {
let state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => STATE_CLOSED,
CircuitState::Open { .. } => STATE_OPEN,
CircuitState::HalfOpen { .. } => STATE_HALF_OPEN,
}
}
pub(crate) fn state_for_health(&self) -> CircuitHealthState {
let state = self.state.lock();
match &*state {
CircuitState::Closed { .. } => CircuitHealthState::Closed,
CircuitState::Open { .. } => CircuitHealthState::Open,
CircuitState::HalfOpen { .. } => CircuitHealthState::HalfOpen,
}
}
}
#[derive(Deserialize, Debug)]
struct CircuitBreakerJson {
#[serde(default)]
enabled: bool,
failure_threshold: Option<u32>,
recovery_timeout_secs: Option<u64>,
success_threshold: Option<u32>,
#[serde(default, alias = "per_database")]
per_entity: Vec<PerEntityJson>,
}
#[derive(Deserialize, Debug)]
struct PerEntityJson {
#[serde(alias = "database")]
entity_type: String,
failure_threshold: Option<u32>,
recovery_timeout_secs: Option<u64>,
success_threshold: Option<u32>,
}
pub struct FederationCircuitBreakerManager {
breakers: DashMap<String, Arc<EntityCircuitBreaker>>,
pub(crate) default_config: CircuitBreakerConfig,
per_entity_config: DashMap<String, CircuitBreakerConfig>,
}
impl FederationCircuitBreakerManager {
fn new(default_config: CircuitBreakerConfig) -> Self {
Self {
breakers: DashMap::new(),
default_config,
per_entity_config: DashMap::new(),
}
}
#[must_use]
pub fn from_config(fed: &fraiseql_core::schema::FederationConfig) -> Option<Arc<Self>> {
let cb = fed.circuit_breaker.as_ref()?;
if !cb.enabled {
return None;
}
let default_config = CircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
recovery_timeout_secs: cb.recovery_timeout_secs,
success_threshold: cb.success_threshold,
};
let manager = Arc::new(Self::new(default_config));
for override_entry in &cb.per_entity {
let entity_config = CircuitBreakerConfig {
failure_threshold: override_entry
.failure_threshold
.unwrap_or(manager.default_config.failure_threshold),
recovery_timeout_secs: override_entry
.recovery_timeout
.unwrap_or(manager.default_config.recovery_timeout_secs),
success_threshold: override_entry
.success_threshold
.unwrap_or(manager.default_config.success_threshold),
};
manager.per_entity_config.insert(override_entry.entity.clone(), entity_config);
}
let override_keys: Vec<String> =
manager.per_entity_config.iter().map(|r| r.key().clone()).collect();
for entity_type in override_keys {
manager.get_or_create(&entity_type);
}
info!(
failure_threshold = manager.default_config.failure_threshold,
recovery_timeout_secs = manager.default_config.recovery_timeout_secs,
success_threshold = manager.default_config.success_threshold,
per_entity_overrides = manager.per_entity_config.len(),
"Federation circuit breaker initialized"
);
Some(manager)
}
#[must_use]
pub fn from_schema_json(federation_json: &serde_json::Value) -> Option<Arc<Self>> {
let cb_json: CircuitBreakerJson = match federation_json.get("circuit_breaker") {
None => return None,
Some(v) => match serde_json::from_value(v.clone()) {
Ok(j) => j,
Err(e) => {
warn!(
error = %e,
"circuit_breaker config present but malformed — circuit breaker disabled"
);
return None;
},
},
};
if !cb_json.enabled {
return None;
}
let default_config = CircuitBreakerConfig {
failure_threshold: cb_json.failure_threshold.unwrap_or(5),
recovery_timeout_secs: cb_json.recovery_timeout_secs.unwrap_or(30),
success_threshold: cb_json.success_threshold.unwrap_or(2),
};
let manager = Arc::new(Self::new(default_config));
for override_entry in cb_json.per_entity {
let entity_config = CircuitBreakerConfig {
failure_threshold: override_entry
.failure_threshold
.unwrap_or(manager.default_config.failure_threshold),
recovery_timeout_secs: override_entry
.recovery_timeout_secs
.unwrap_or(manager.default_config.recovery_timeout_secs),
success_threshold: override_entry
.success_threshold
.unwrap_or(manager.default_config.success_threshold),
};
manager.per_entity_config.insert(override_entry.entity_type, entity_config);
}
let override_keys: Vec<String> =
manager.per_entity_config.iter().map(|r| r.key().clone()).collect();
for entity_type in override_keys {
manager.get_or_create(&entity_type);
}
info!(
failure_threshold = manager.default_config.failure_threshold,
recovery_timeout_secs = manager.default_config.recovery_timeout_secs,
success_threshold = manager.default_config.success_threshold,
per_entity_overrides = manager.per_entity_config.len(),
"Federation circuit breaker initialized"
);
Some(manager)
}
fn get_or_create(&self, entity: &str) -> Arc<EntityCircuitBreaker> {
self.breakers
.entry(entity.to_string())
.or_insert_with(|| {
let config = self
.per_entity_config
.get(entity)
.map_or_else(|| self.default_config.clone(), |r| r.value().clone());
Arc::new(EntityCircuitBreaker::new(config))
})
.clone()
}
#[must_use]
pub fn check(&self, entity: &str) -> Option<u64> {
self.get_or_create(entity).check()
}
pub fn record_success(&self, entity: &str) {
self.get_or_create(entity).record_success();
}
pub fn record_failure(&self, entity: &str) {
self.get_or_create(entity).record_failure();
}
#[must_use]
pub fn collect_states(&self) -> Vec<(String, u64)> {
self.breakers
.iter()
.map(|entry| (entry.key().clone(), entry.value().state_code()))
.collect()
}
#[must_use]
pub fn health_snapshot(&self) -> Vec<SubgraphCircuitHealth> {
self.breakers
.iter()
.map(|entry| SubgraphCircuitHealth {
subgraph: entry.key().clone(),
state: entry.value().state_for_health(),
})
.collect()
}
}
#[must_use]
pub fn extract_entity_types(variables: Option<&serde_json::Value>) -> Vec<String> {
let Some(vars) = variables else {
return vec![];
};
let Some(representations) = vars.get("representations").and_then(|r| r.as_array()) else {
return vec![];
};
let mut types = Vec::new();
for rep in representations {
if let Some(typename) = rep.get("__typename").and_then(|t| t.as_str()) {
types.push(typename.to_string());
} else {
warn!(
"Federation representation missing __typename field; entity skipped for circuit \
breaker"
);
}
}
types.sort_unstable();
types.dedup();
types
}