use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, RwLock};
use tokio::task::JoinHandle;
use tracing::{error, info, warn};
use super::accumulator::{health_channel, shutdown_signal, AccumulatorHealth};
use super::reactor::{
reactor_health_channel, CompiledGraphFn, InputStrategy, ReactionCriteria, Reactor,
ReactorHandle,
};
use super::registry::{AccumulatorAuthPolicy, EndpointRegistry, ReactorAuthPolicy};
use super::types::{GraphResult, InputCache, SourceName};
#[derive(Clone)]
pub struct ComputationGraphDeclaration {
pub name: String,
pub accumulators: Vec<AccumulatorDeclaration>,
pub reactor: ReactorDeclaration,
pub tenant_id: Option<String>,
pub reactor_name: Option<String>,
}
#[derive(Clone)]
pub struct AccumulatorDeclaration {
pub name: String,
pub factory: Arc<dyn AccumulatorFactory>,
}
pub struct AccumulatorSpawnConfig {
pub dal: Option<crate::dal::unified::DAL>,
pub health_tx: Option<watch::Sender<AccumulatorHealth>>,
pub graph_name: String,
}
pub trait AccumulatorFactory: Send + Sync {
fn spawn(
&self,
name: String,
boundary_tx: mpsc::Sender<(SourceName, Vec<u8>)>,
shutdown_rx: watch::Receiver<bool>,
config: AccumulatorSpawnConfig,
) -> (mpsc::Sender<Vec<u8>>, JoinHandle<()>);
}
#[derive(Clone)]
pub struct ReactorDeclaration {
pub criteria: ReactionCriteria,
pub strategy: InputStrategy,
pub graph_fn: CompiledGraphFn,
}
#[derive(Debug, Clone)]
pub struct GraphStatus {
pub name: String,
pub accumulators: Vec<String>,
pub paused: bool,
pub running: bool,
pub health: Option<super::reactor::ReactorHealth>,
}
fn check_reactor_contract_matches(
existing: &ComputationGraphDeclaration,
new: &ComputationGraphDeclaration,
) -> Result<(), String> {
let existing_accs: Vec<&str> = existing
.accumulators
.iter()
.map(|a| a.name.as_str())
.collect();
let new_accs: Vec<&str> = new.accumulators.iter().map(|a| a.name.as_str()).collect();
if existing_accs != new_accs {
return Err(format!(
"accumulator set differs (existing: {:?}, new: {:?})",
existing_accs, new_accs
));
}
if existing.reactor.criteria != new.reactor.criteria {
return Err("reaction criteria differ".to_string());
}
if existing.reactor.strategy != new.reactor.strategy {
return Err("input strategy differs".to_string());
}
if existing.tenant_id != new.tenant_id {
return Err(format!(
"tenant ownership differs (existing: {:?}, new: {:?})",
existing.tenant_id, new.tenant_id
));
}
Ok(())
}
fn dummy_graph_fn() -> CompiledGraphFn {
Arc::new(|_cache: InputCache| Box::pin(async move { GraphResult::completed(vec![]) }))
}
type ReactorSubscribers = Arc<RwLock<HashMap<String, CompiledGraphFn>>>;
fn make_subscriber_dispatcher(
reactor_name: String,
subscribers: ReactorSubscribers,
) -> CompiledGraphFn {
Arc::new(move |cache: InputCache| {
let reactor_name = reactor_name.clone();
let subscribers = subscribers.clone();
Box::pin(async move {
let snapshot: Vec<(String, CompiledGraphFn)> = subscribers
.read()
.await
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let futures = snapshot.into_iter().map(|(graph_name, graph_fn)| {
let cache = cache.clone();
async move {
let result = graph_fn(cache).await;
(graph_name, result)
}
});
let results = futures::future::join_all(futures).await;
for (graph_name, result) in results {
if let GraphResult::Error(e) = result {
tracing::error!(
reactor = %reactor_name,
graph = %graph_name,
"subscriber graph failed: {}",
e
);
}
}
GraphResult::completed(vec![])
})
})
}
struct RunningGraph {
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
boundary_tx: mpsc::Sender<(SourceName, Vec<u8>)>,
accumulator_handles: Vec<(String, JoinHandle<()>)>,
reactor_handle: JoinHandle<()>,
reactor_shared: ReactorHandle,
reactor_health_rx: Option<watch::Receiver<super::reactor::ReactorHealth>>,
declaration: ComputationGraphDeclaration,
subscribers: ReactorSubscribers,
endpoint_registry_keys: Vec<String>,
failure_counts: HashMap<String, u32>,
last_success: HashMap<String, std::time::Instant>,
}
const MAX_RECOVERY_ATTEMPTS: u32 = 5;
const BACKOFF_BASE_SECS: u64 = 1;
const BACKOFF_MAX_SECS: u64 = 60;
const SUCCESS_RESET_SECS: u64 = 60;
pub struct ComputationGraphScheduler {
registry: EndpointRegistry,
reactors: Arc<RwLock<HashMap<String, RunningGraph>>>,
graph_to_reactor: Arc<RwLock<HashMap<String, String>>>,
dal: Option<crate::dal::unified::DAL>,
}
impl ComputationGraphScheduler {
pub fn new(registry: EndpointRegistry) -> Self {
Self {
registry,
reactors: Arc::new(RwLock::new(HashMap::new())),
graph_to_reactor: Arc::new(RwLock::new(HashMap::new())),
dal: None,
}
}
pub fn with_dal(registry: EndpointRegistry, dal: crate::dal::unified::DAL) -> Self {
Self {
registry,
reactors: Arc::new(RwLock::new(HashMap::new())),
graph_to_reactor: Arc::new(RwLock::new(HashMap::new())),
dal: Some(dal),
}
}
pub async fn load_reactor(
&self,
reactor_name: String,
accumulators: Vec<AccumulatorDeclaration>,
criteria: ReactionCriteria,
strategy: InputStrategy,
tenant_id: Option<String>,
register_aliases: Vec<String>,
) -> Result<(), String> {
{
let reactors = self.reactors.read().await;
if let Some(existing) = reactors.get(&reactor_name) {
let probe = ComputationGraphDeclaration {
name: reactor_name.clone(),
accumulators: accumulators.clone(),
reactor: ReactorDeclaration {
criteria: criteria.clone(),
strategy: strategy.clone(),
graph_fn: dummy_graph_fn(),
},
tenant_id: tenant_id.clone(),
reactor_name: Some(reactor_name.clone()),
};
if let Err(e) = check_reactor_contract_matches(&existing.declaration, &probe) {
return Err(format!(
"reactor '{}' is already loaded with a different contract: {}",
reactor_name, e
));
}
return Ok(());
}
}
let (shutdown_tx, shutdown_rx) = shutdown_signal();
let stored_shutdown_rx = shutdown_rx.clone();
let (boundary_tx, boundary_rx) = mpsc::channel(256);
let stored_boundary_tx = boundary_tx.clone();
let expected_sources: Vec<SourceName> = accumulators
.iter()
.map(|a| SourceName::new(&a.name))
.collect();
let mut accumulator_handles = Vec::new();
let mut acc_health_rxs: Vec<(
String,
watch::Receiver<super::accumulator::AccumulatorHealth>,
)> = Vec::new();
for acc_decl in &accumulators {
let (health_tx, health_rx) = health_channel();
acc_health_rxs.push((acc_decl.name.clone(), health_rx.clone()));
let spawn_config = AccumulatorSpawnConfig {
dal: self.dal.clone(),
health_tx: Some(health_tx),
graph_name: reactor_name.clone(),
};
let (socket_tx, handle) = acc_decl.factory.spawn(
acc_decl.name.clone(),
boundary_tx.clone(),
shutdown_rx.clone(),
spawn_config,
);
self.registry
.register_accumulator(acc_decl.name.clone(), socket_tx)
.await;
self.registry
.register_accumulator_health(acc_decl.name.clone(), health_rx)
.await;
accumulator_handles.push((acc_decl.name.clone(), handle));
}
let (manual_tx, manual_rx) = mpsc::channel(64);
let (reactor_health_tx, reactor_health_rx) = reactor_health_channel();
let subscribers: ReactorSubscribers = Arc::new(RwLock::new(HashMap::new()));
let dispatcher = make_subscriber_dispatcher(reactor_name.clone(), subscribers.clone());
let mut reactor = Reactor::new(
dispatcher,
criteria.clone(),
strategy.clone(),
boundary_rx,
manual_rx,
shutdown_rx,
)
.with_graph_name(reactor_name.clone())
.with_health(reactor_health_tx)
.with_expected_sources(expected_sources)
.with_accumulator_health(acc_health_rxs);
if let Some(ref dal) = self.dal {
reactor = reactor.with_dal(dal.clone());
}
let reactor_shared = reactor.handle();
let mut endpoint_registry_keys = vec![reactor_name.clone()];
self.registry
.register_reactor(
reactor_name.clone(),
manual_tx.clone(),
reactor_shared.clone(),
)
.await;
for alias in ®ister_aliases {
if alias != &reactor_name {
self.registry
.register_reactor(alias.clone(), manual_tx.clone(), reactor_shared.clone())
.await;
endpoint_registry_keys.push(alias.clone());
}
}
let acc_policy = match &tenant_id {
Some(tid) => AccumulatorAuthPolicy::for_tenant(tid),
None => AccumulatorAuthPolicy::allow_all(),
};
let reactor_policy = match &tenant_id {
Some(tid) => ReactorAuthPolicy::for_tenant(tid),
None => ReactorAuthPolicy::allow_all(),
};
for acc_decl in &accumulators {
self.registry
.set_accumulator_policy(acc_decl.name.clone(), acc_policy.clone())
.await;
}
for key in &endpoint_registry_keys {
self.registry
.set_reactor_policy(key.clone(), reactor_policy.clone())
.await;
}
let reactor_handle = tokio::spawn(reactor.run());
info!(reactor = %reactor_name, "reactor loaded and running");
let anchor = ComputationGraphDeclaration {
name: reactor_name.clone(),
accumulators,
reactor: ReactorDeclaration {
criteria,
strategy,
graph_fn: dummy_graph_fn(),
},
tenant_id,
reactor_name: Some(reactor_name.clone()),
};
let running = RunningGraph {
shutdown_tx,
shutdown_rx: stored_shutdown_rx,
boundary_tx: stored_boundary_tx,
accumulator_handles,
reactor_handle,
reactor_shared,
reactor_health_rx: Some(reactor_health_rx),
declaration: anchor,
subscribers,
endpoint_registry_keys,
failure_counts: HashMap::new(),
last_success: HashMap::new(),
};
self.reactors.write().await.insert(reactor_name, running);
Ok(())
}
pub async fn bind_graph_to_reactor(
&self,
graph_name: String,
reactor_name: String,
graph_fn: CompiledGraphFn,
) -> Result<(), String> {
{
let g2r = self.graph_to_reactor.read().await;
if g2r.contains_key(&graph_name) {
return Err(format!("graph '{}' already loaded", graph_name));
}
}
{
let reactors = self.reactors.read().await;
let existing = reactors
.get(&reactor_name)
.ok_or_else(|| format!("reactor '{}' is not loaded", reactor_name))?;
existing
.subscribers
.write()
.await
.insert(graph_name.clone(), graph_fn);
}
self.graph_to_reactor
.write()
.await
.insert(graph_name.clone(), reactor_name.clone());
info!(
graph = %graph_name,
reactor = %reactor_name,
"graph bound to reactor"
);
Ok(())
}
pub async fn load_graph(&self, decl: ComputationGraphDeclaration) -> Result<(), String> {
let name = decl.name.clone();
let reactor_name = decl
.reactor_name
.clone()
.unwrap_or_else(|| format!("__Reactor_{}", name));
{
let g2r = self.graph_to_reactor.read().await;
if g2r.contains_key(&name) {
return Err(format!("graph '{}' already loaded", name));
}
}
if decl.reactor_name.is_some() && decl.accumulators.is_empty() {
let already_loaded = {
let reactors = self.reactors.read().await;
reactors.contains_key(&reactor_name)
};
if already_loaded {
return self
.bind_graph_to_reactor(name, reactor_name, decl.reactor.graph_fn)
.await;
}
}
self.load_reactor(
reactor_name.clone(),
decl.accumulators.clone(),
decl.reactor.criteria.clone(),
decl.reactor.strategy.clone(),
decl.tenant_id.clone(),
vec![name.clone()],
)
.await?;
self.bind_graph_to_reactor(name, reactor_name, decl.reactor.graph_fn)
.await
}
pub async fn load_graph_split(
&self,
graph_name: String,
graph_fn: CompiledGraphFn,
reactor: &cloacina_computation_graph::ReactorRegistration,
accumulators: Vec<AccumulatorDeclaration>,
tenant_id: Option<String>,
) -> Result<(), String> {
let supplied: std::collections::HashSet<&str> =
accumulators.iter().map(|a| a.name.as_str()).collect();
for name in &reactor.accumulator_names {
if !supplied.contains(name.as_str()) {
return Err(format!(
"reactor '{}' declares accumulator '{}' but no AccumulatorDeclaration was \
supplied for it",
reactor.name, name
));
}
}
let decl = ComputationGraphDeclaration {
name: graph_name,
accumulators,
reactor: ReactorDeclaration {
criteria: reactor.reaction_mode.into(),
strategy: InputStrategy::Latest,
graph_fn,
},
tenant_id,
reactor_name: Some(reactor.name.clone()),
};
self.load_graph(decl).await
}
pub async fn unbind_graph_from_reactor(&self, name: &str) -> Result<String, String> {
let reactor_name = {
let mut g2r = self.graph_to_reactor.write().await;
g2r.remove(name)
.ok_or_else(|| format!("graph '{}' not loaded", name))?
};
let remaining = {
let reactors = self.reactors.read().await;
if let Some(running) = reactors.get(&reactor_name) {
let mut subs = running.subscribers.write().await;
subs.remove(name);
subs.len()
} else {
return Err(format!(
"graph '{}' was bound to reactor '{}' but the reactor is not loaded",
name, reactor_name
));
}
};
info!(
graph = %name,
reactor = %reactor_name,
remaining_subscribers = remaining,
"graph unbound from reactor"
);
Ok(reactor_name)
}
pub async fn unload_reactor(&self, reactor_name: &str) -> Result<(), String> {
let subscriber_names: Vec<String> = {
let reactors = self.reactors.read().await;
match reactors.get(reactor_name) {
Some(running) => running.subscribers.read().await.keys().cloned().collect(),
None => return Err(format!("reactor '{}' not loaded", reactor_name)),
}
};
if !subscriber_names.is_empty() {
return Err(format!(
"reactor '{}' has {} bound subscriber(s): {:?}; unbind them first",
reactor_name,
subscriber_names.len(),
subscriber_names
));
}
let running = {
let mut reactors = self.reactors.write().await;
reactors
.remove(reactor_name)
.ok_or_else(|| format!("reactor '{}' not loaded", reactor_name))?
};
let _ = running.shutdown_tx.send(true);
let _ =
tokio::time::timeout(std::time::Duration::from_secs(5), running.reactor_handle).await;
for (acc_name, handle) in running.accumulator_handles {
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
self.registry.deregister_accumulator(&acc_name).await;
}
for key in &running.endpoint_registry_keys {
self.registry.deregister_reactor(key).await;
}
info!(reactor = %reactor_name, "reactor unloaded");
Ok(())
}
pub async fn unload_graph(&self, name: &str) -> Result<(), String> {
let reactor_name = self.unbind_graph_from_reactor(name).await?;
let now_empty = {
let reactors = self.reactors.read().await;
match reactors.get(&reactor_name) {
Some(running) => running.subscribers.read().await.is_empty(),
None => false,
}
};
if now_empty {
self.unload_reactor(&reactor_name).await?;
}
info!(graph = %name, reactor = %reactor_name, "computation graph unloaded");
Ok(())
}
pub async fn reactor_accumulator_names(&self, reactor_name: &str) -> Option<Vec<String>> {
let reactors = self.reactors.read().await;
reactors.get(reactor_name).map(|running| {
running
.accumulator_handles
.iter()
.map(|(n, _)| n.clone())
.collect()
})
}
pub async fn list_graphs(&self) -> Vec<GraphStatus> {
let g2r = self.graph_to_reactor.read().await;
let reactors = self.reactors.read().await;
g2r.iter()
.filter_map(|(graph_name, reactor_name)| {
reactors.get(reactor_name).map(|running| GraphStatus {
name: graph_name.clone(),
accumulators: running
.accumulator_handles
.iter()
.map(|(n, _)| n.clone())
.collect(),
paused: running.reactor_shared.is_paused(),
running: !running.reactor_handle.is_finished(),
health: running
.reactor_health_rx
.as_ref()
.map(|rx| rx.borrow().clone()),
})
})
.collect()
}
pub async fn check_and_restart_failed(&self) -> usize {
let mut restarted = 0;
let mut graphs = self.reactors.write().await;
let now = std::time::Instant::now();
for (graph_name, running) in graphs.iter_mut() {
let success_threshold = std::time::Duration::from_secs(SUCCESS_RESET_SECS);
let names_to_reset: Vec<String> = running
.last_success
.iter()
.filter(|(_, ts)| now.duration_since(**ts) >= success_threshold)
.map(|(name, _)| name.clone())
.collect();
for name in names_to_reset {
running.failure_counts.remove(&name);
running.last_success.remove(&name);
}
if running.reactor_handle.is_finished() {
let reactor_key = format!("{}::reactor", graph_name);
let failures = running
.failure_counts
.entry(reactor_key.clone())
.or_insert(0);
*failures += 1;
if *failures > MAX_RECOVERY_ATTEMPTS {
error!(
graph = %graph_name,
failures = *failures,
"reactor permanently failed — circuit breaker open"
);
continue;
}
let backoff_secs =
(BACKOFF_BASE_SECS * 2u64.pow(*failures - 1)).min(BACKOFF_MAX_SECS);
let backoff = std::time::Duration::from_secs(backoff_secs);
warn!(
graph = %graph_name,
attempt = *failures,
backoff_secs = backoff_secs,
"reactor crashed, restarting (full graph restart)"
);
self.record_recovery_event(&reactor_key, *failures, backoff_secs)
.await;
tokio::time::sleep(backoff).await;
let (shutdown_tx, shutdown_rx) = shutdown_signal();
let stored_shutdown_rx = shutdown_rx.clone();
let (boundary_tx, boundary_rx) = mpsc::channel(256);
let stored_boundary_tx = boundary_tx.clone();
let expected_sources: Vec<SourceName> = running
.declaration
.accumulators
.iter()
.map(|a| SourceName::new(&a.name))
.collect();
let mut new_acc_handles = Vec::new();
let mut restart_acc_health_rxs: Vec<(
String,
watch::Receiver<super::accumulator::AccumulatorHealth>,
)> = Vec::new();
for acc_decl in &running.declaration.accumulators {
let (health_tx, health_rx) = health_channel();
restart_acc_health_rxs.push((acc_decl.name.clone(), health_rx.clone()));
let spawn_config = AccumulatorSpawnConfig {
dal: self.dal.clone(),
health_tx: Some(health_tx),
graph_name: graph_name.clone(),
};
let (socket_tx, handle) = acc_decl.factory.spawn(
acc_decl.name.clone(),
boundary_tx.clone(),
shutdown_rx.clone(),
spawn_config,
);
self.registry
.register_accumulator(acc_decl.name.clone(), socket_tx)
.await;
self.registry
.register_accumulator_health(acc_decl.name.clone(), health_rx)
.await;
new_acc_handles.push((acc_decl.name.clone(), handle));
}
let (manual_tx, manual_rx) = mpsc::channel(64);
let (reactor_health_tx, reactor_health_rx) = reactor_health_channel();
let restart_dispatcher =
make_subscriber_dispatcher(graph_name.clone(), running.subscribers.clone());
let mut reactor = Reactor::new(
restart_dispatcher,
running.declaration.reactor.criteria.clone(),
running.declaration.reactor.strategy.clone(),
boundary_rx,
manual_rx,
shutdown_rx,
)
.with_graph_name(graph_name.clone())
.with_health(reactor_health_tx)
.with_expected_sources(expected_sources)
.with_accumulator_health(restart_acc_health_rxs);
if let Some(ref dal) = self.dal {
reactor = reactor.with_dal(dal.clone());
}
let reactor_shared = reactor.handle();
let reactor_handle = tokio::spawn(reactor.run());
for key in &running.endpoint_registry_keys {
self.registry
.register_reactor(key.clone(), manual_tx.clone(), reactor_shared.clone())
.await;
}
let restart_acc_policy = match &running.declaration.tenant_id {
Some(tid) => AccumulatorAuthPolicy::for_tenant(tid),
None => AccumulatorAuthPolicy::allow_all(),
};
let restart_reactor_policy = match &running.declaration.tenant_id {
Some(tid) => ReactorAuthPolicy::for_tenant(tid),
None => ReactorAuthPolicy::allow_all(),
};
for acc_decl in &running.declaration.accumulators {
self.registry
.set_accumulator_policy(acc_decl.name.clone(), restart_acc_policy.clone())
.await;
}
for key in &running.endpoint_registry_keys {
self.registry
.set_reactor_policy(key.clone(), restart_reactor_policy.clone())
.await;
}
running.shutdown_tx = shutdown_tx;
running.shutdown_rx = stored_shutdown_rx;
running.boundary_tx = stored_boundary_tx;
running.accumulator_handles = new_acc_handles;
running.reactor_handle = reactor_handle;
running.reactor_shared = reactor_shared;
running.reactor_health_rx = Some(reactor_health_rx);
running.last_success.insert(reactor_key, now);
restarted += 1;
info!(graph = %graph_name, "reactor restarted successfully");
} else {
let mut new_handles = Vec::new();
let mut changed = false;
for (acc_name, handle) in running.accumulator_handles.drain(..) {
if handle.is_finished() {
let acc_key = format!("{}::{}", graph_name, acc_name);
let failures = running.failure_counts.entry(acc_key.clone()).or_insert(0);
*failures += 1;
if *failures > MAX_RECOVERY_ATTEMPTS {
error!(
graph = %graph_name,
accumulator = %acc_name,
failures = *failures,
"accumulator permanently failed — circuit breaker open"
);
continue;
}
let backoff_secs =
(BACKOFF_BASE_SECS * 2u64.pow(*failures - 1)).min(BACKOFF_MAX_SECS);
warn!(
graph = %graph_name,
accumulator = %acc_name,
attempt = *failures,
backoff_secs = backoff_secs,
"accumulator crashed, restarting individually"
);
self.record_recovery_event(&acc_key, *failures, backoff_secs)
.await;
tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)).await;
if let Some(acc_decl) = running
.declaration
.accumulators
.iter()
.find(|d| d.name == *acc_name)
{
let (health_tx, health_rx) = health_channel();
let spawn_config = AccumulatorSpawnConfig {
dal: self.dal.clone(),
health_tx: Some(health_tx),
graph_name: graph_name.clone(),
};
let (socket_tx, new_handle) = acc_decl.factory.spawn(
acc_name.clone(),
running.boundary_tx.clone(),
running.shutdown_rx.clone(),
spawn_config,
);
self.registry
.register_accumulator(acc_name.clone(), socket_tx)
.await;
self.registry
.register_accumulator_health(acc_name.clone(), health_rx)
.await;
let ind_acc_policy = match &running.declaration.tenant_id {
Some(tid) => AccumulatorAuthPolicy::for_tenant(tid),
None => AccumulatorAuthPolicy::allow_all(),
};
self.registry
.set_accumulator_policy(acc_name.clone(), ind_acc_policy)
.await;
running.last_success.insert(acc_key, now);
let restarted_name = acc_name.clone();
new_handles.push((acc_name, new_handle));
restarted += 1;
changed = true;
info!(
graph = %graph_name,
accumulator = %restarted_name,
"accumulator restarted individually"
);
} else {
let lost_name = acc_name.clone();
error!(
graph = %graph_name,
accumulator = %lost_name,
"cannot restart: declaration not found"
);
}
} else {
new_handles.push((acc_name, handle));
}
}
running.accumulator_handles = new_handles;
if changed {
for (acc_name, _) in &running.accumulator_handles {
let acc_key = format!("{}::{}", graph_name, acc_name);
running.last_success.entry(acc_key).or_insert(now);
}
}
}
}
restarted
}
pub fn start_supervision(
self: &Arc<Self>,
mut shutdown_rx: watch::Receiver<bool>,
check_interval: std::time::Duration,
) -> JoinHandle<()> {
let scheduler = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(check_interval);
interval.tick().await;
loop {
tokio::select! {
_ = interval.tick() => {
let restarted = scheduler.check_and_restart_failed().await;
if restarted > 0 {
info!("supervision check: restarted {} tasks", restarted);
}
}
_ = shutdown_rx.changed() => {
tracing::debug!("supervision loop shutting down");
break;
}
}
}
})
}
async fn record_recovery_event(&self, component: &str, attempt: u32, backoff_secs: u64) {
let dal = match &self.dal {
Some(d) => d,
None => return,
};
use crate::database::universal_types::UniversalUuid;
use crate::models::recovery_event::NewRecoveryEvent;
let event = NewRecoveryEvent {
workflow_execution_id: UniversalUuid::new_v4(),
task_execution_id: None,
recovery_type: "graph_component_restart".to_string(),
details: Some(format!(
"component={}, attempt={}, backoff={}s",
component, attempt, backoff_secs
)),
};
if let Err(e) = dal.recovery_event().create(event).await {
warn!(component = %component, "failed to record recovery event: {}", e);
}
}
pub async fn shutdown_all(&self) {
let graph_names: Vec<String> = {
let g2r = self.graph_to_reactor.read().await;
g2r.keys().cloned().collect()
};
for name in graph_names {
if let Err(e) = self.unload_graph(&name).await {
warn!(graph = %name, error = %e, "failed to unload graph during shutdown");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::computation_graph::accumulator::{
accumulator_runtime, Accumulator, AccumulatorContext, AccumulatorRuntimeConfig,
BoundarySender, CheckpointHandle,
};
use crate::computation_graph::types::{GraphResult, InputCache};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestEvent {
value: f64,
}
struct TestAccumulatorFactory;
impl AccumulatorFactory for TestAccumulatorFactory {
fn spawn(
&self,
name: String,
boundary_tx: mpsc::Sender<(SourceName, Vec<u8>)>,
shutdown_rx: watch::Receiver<bool>,
config: AccumulatorSpawnConfig,
) -> (mpsc::Sender<Vec<u8>>, JoinHandle<()>) {
let (socket_tx, socket_rx) = mpsc::channel(64);
struct Passthrough;
#[async_trait::async_trait]
impl Accumulator for Passthrough {
type Output = TestEvent;
fn process(&mut self, event: Vec<u8>) -> Option<TestEvent> {
serde_json::from_slice(&event).ok()
}
}
let checkpoint = config
.dal
.map(|dal| CheckpointHandle::new(dal, config.graph_name.clone(), name.clone()));
let sender = BoundarySender::new(boundary_tx, SourceName::new(&name));
let ctx = AccumulatorContext {
output: sender,
name: name.clone(),
shutdown: shutdown_rx,
checkpoint,
health: config.health_tx,
};
let handle = tokio::spawn(accumulator_runtime(
Passthrough,
ctx,
socket_rx,
AccumulatorRuntimeConfig::default(),
));
(socket_tx, handle)
}
}
#[tokio::test]
async fn test_load_graph_push_event_fires() {
let registry = EndpointRegistry::new();
let scheduler = ComputationGraphScheduler::new(registry.clone());
let fire_count = Arc::new(AtomicU32::new(0));
let fire_count_inner = fire_count.clone();
let graph_fn: CompiledGraphFn = Arc::new(move |_cache: InputCache| {
let fc = fire_count_inner.clone();
Box::pin(async move {
fc.fetch_add(1, Ordering::SeqCst);
GraphResult::completed(vec![])
})
});
let decl = ComputationGraphDeclaration {
name: "test_graph".to_string(),
accumulators: vec![AccumulatorDeclaration {
name: "alpha".to_string(),
factory: Arc::new(TestAccumulatorFactory),
}],
reactor: ReactorDeclaration {
criteria: ReactionCriteria::WhenAny,
strategy: InputStrategy::Latest,
graph_fn,
},
tenant_id: None,
reactor_name: None,
};
scheduler.load_graph(decl).await.unwrap();
let event = TestEvent { value: 42.0 };
let bytes = serde_json::to_vec(&event).unwrap();
registry.send_to_accumulator("alpha", bytes).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert_eq!(fire_count.load(Ordering::SeqCst), 1, "graph should fire");
let graphs = scheduler.list_graphs().await;
assert_eq!(graphs.len(), 1);
assert_eq!(graphs[0].name, "test_graph");
assert!(!graphs[0].paused);
scheduler.shutdown_all().await;
}
#[tokio::test]
async fn test_unload_graph_deregisters() {
let registry = EndpointRegistry::new();
let scheduler = ComputationGraphScheduler::new(registry.clone());
let graph_fn: CompiledGraphFn =
Arc::new(|_cache: InputCache| Box::pin(async { GraphResult::completed(vec![]) }));
let decl = ComputationGraphDeclaration {
name: "test_graph".to_string(),
accumulators: vec![AccumulatorDeclaration {
name: "alpha".to_string(),
factory: Arc::new(TestAccumulatorFactory),
}],
reactor: ReactorDeclaration {
criteria: ReactionCriteria::WhenAny,
strategy: InputStrategy::Latest,
graph_fn,
},
tenant_id: None,
reactor_name: None,
};
scheduler.load_graph(decl).await.unwrap();
assert_eq!(registry.accumulator_count("alpha").await, 1);
assert!(registry
.list_reactors()
.await
.contains(&"test_graph".to_string()));
scheduler.unload_graph("test_graph").await.unwrap();
assert_eq!(registry.accumulator_count("alpha").await, 0);
assert!(registry.list_reactors().await.is_empty());
}
#[tokio::test]
async fn test_duplicate_load_rejected() {
let registry = EndpointRegistry::new();
let scheduler = ComputationGraphScheduler::new(registry.clone());
let graph_fn: CompiledGraphFn =
Arc::new(|_cache: InputCache| Box::pin(async { GraphResult::completed(vec![]) }));
let decl = ComputationGraphDeclaration {
name: "dup".to_string(),
accumulators: vec![],
reactor: ReactorDeclaration {
criteria: ReactionCriteria::WhenAny,
strategy: InputStrategy::Latest,
graph_fn,
},
tenant_id: None,
reactor_name: None,
};
scheduler.load_graph(decl.clone()).await.unwrap();
let err = scheduler.load_graph(decl).await.unwrap_err();
assert!(err.contains("already loaded"));
scheduler.shutdown_all().await;
}
}