#![allow(unused_assignments)]
use crate::event_bus::EventEmitter;
use crate::node::{Node, NodeContext, NodeError, NodePartial};
use crate::state::StateSnapshot;
use crate::types::NodeKind;
use crate::utils::clock::Clock;
use futures_util::stream::{self, StreamExt};
use rustc_hash::FxHashMap;
use std::sync::Arc;
use thiserror::Error;
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct StepRunResult {
pub ran_nodes: Vec<NodeKind>,
pub skipped_nodes: Vec<NodeKind>,
pub outputs: Vec<(NodeKind, NodePartial)>,
}
#[derive(Clone)]
#[non_exhaustive]
pub struct SchedulerRunContext {
pub event_emitter: Arc<dyn EventEmitter>,
pub clock: Option<Arc<dyn Clock>>,
pub invocation_id: Option<String>,
}
impl SchedulerRunContext {
#[must_use]
pub fn new(event_emitter: Arc<dyn EventEmitter>) -> Self {
Self {
event_emitter,
clock: None,
invocation_id: None,
}
}
#[must_use]
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = Some(clock);
self
}
#[must_use]
pub fn with_invocation_id(mut self, invocation_id: impl Into<String>) -> Self {
self.invocation_id = Some(invocation_id.into());
self
}
}
#[derive(Debug, Default, Clone)]
pub struct SchedulerState {
pub versions_seen: FxHashMap<String, FxHashMap<String, u64>>,
}
#[derive(Debug, Default, Clone)]
pub struct Scheduler {
pub concurrency_limit: usize,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum SchedulerError {
#[error("node {kind:?} in frontier not found in registry at step {step}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::scheduler::node_not_found),
help("Ensure all nodes in the graph are registered before execution.")
)
)]
NodeNotFound {
kind: NodeKind,
step: u64,
},
#[error("node run error at step {step} for {kind:?}: {source}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::node)))]
NodeRun {
kind: NodeKind,
step: u64,
#[source]
source: NodeError,
},
#[error("task join error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::join)))]
Join(#[from] tokio::task::JoinError),
}
impl Scheduler {
#[must_use]
pub fn new(concurrency_limit: usize) -> Self {
Self {
concurrency_limit: concurrency_limit.max(1),
}
}
#[inline]
fn channel_versions(snap: &StateSnapshot) -> [(&'static str, u64); 2] {
[
("messages", snap.messages_version as u64),
("extra", snap.extra_version as u64),
]
}
#[must_use]
pub fn should_run(&self, state: &SchedulerState, node_id: &str, snap: &StateSnapshot) -> bool {
self.should_run_with(state, node_id, &Self::channel_versions(snap))
}
#[must_use]
pub fn should_run_with(
&self,
state: &SchedulerState,
node_id: &str,
channels: &[(&str, u64)],
) -> bool {
let Some(seen) = state.versions_seen.get(node_id) else {
return true;
};
channels
.iter()
.any(|&(name, ver)| ver > seen.get(name).copied().unwrap_or(0))
}
pub fn record_seen(&self, state: &mut SchedulerState, node_id: &str, snap: &StateSnapshot) {
self.record_seen_with(state, node_id, &Self::channel_versions(snap));
}
pub fn record_seen_with(
&self,
state: &mut SchedulerState,
node_id: &str,
channels: &[(&str, u64)],
) {
let entry = state.versions_seen.entry(node_id.to_owned()).or_default();
for &(name, ver) in channels {
entry.insert(name.to_owned(), ver);
}
}
#[instrument(skip(self, state, nodes, frontier, snap, run_context))]
pub async fn superstep(
&self,
state: &mut SchedulerState,
nodes: &FxHashMap<NodeKind, Arc<dyn Node>>,
frontier: Vec<NodeKind>,
snap: StateSnapshot,
step: u64,
run_context: SchedulerRunContext,
) -> Result<StepRunResult, SchedulerError> {
let channels = Self::channel_versions(&snap);
let mut to_run: Vec<NodeKind> = Vec::new();
let mut to_run_ids: Vec<String> = Vec::new();
let mut skipped_kinds: Vec<NodeKind> = Vec::new();
for kind in frontier {
if matches!(kind, NodeKind::Start | NodeKind::End) {
skipped_kinds.push(kind);
continue;
}
let id = format!("{kind:?}");
if self.should_run_with(state, &id, &channels) {
to_run_ids.push(id);
to_run.push(kind);
} else {
skipped_kinds.push(kind);
}
}
for kind in &to_run {
if !nodes.contains_key(kind) {
return Err(SchedulerError::NodeNotFound {
kind: kind.clone(),
step,
});
}
}
let tasks: Vec<_> = to_run
.iter()
.zip(&to_run_ids)
.map(|(kind, id)| {
let node = nodes[kind].clone();
let ctx = NodeContext {
node_id: id.clone(),
step,
event_emitter: Arc::clone(&run_context.event_emitter),
clock: run_context.clock.clone(),
invocation_id: run_context.invocation_id.clone(),
};
let s = snap.clone();
let kind = kind.clone();
async move { (kind, node.run(s, ctx).await) }
})
.collect();
let mut outputs: Vec<(NodeKind, NodePartial)> = Vec::new();
let mut stream = stream::iter(tasks).buffer_unordered(self.concurrency_limit);
while let Some((kind, res)) = stream.next().await {
match res {
Ok(part) => outputs.push((kind, part)),
Err(e) => {
return Err(SchedulerError::NodeRun {
kind,
step,
source: e,
});
}
}
}
for id in &to_run_ids {
self.record_seen_with(state, id, &channels);
}
Ok(StepRunResult {
ran_nodes: to_run,
skipped_nodes: skipped_kinds,
outputs,
})
}
}