#![allow(unused_assignments)]
use crate::event_bus::EventEmitter;
use crate::node::{Node, NodeContext, NodeError, NodePartial};
use crate::state::StateSnapshot;
use crate::types::NodeKind;
use crate::utils::clock::Clock;
use futures_util::stream::{self, StreamExt};
use rustc_hash::FxHashMap;
use std::sync::Arc;
use thiserror::Error;
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct StepRunResult {
pub ran_nodes: Vec<NodeKind>,
pub skipped_nodes: Vec<NodeKind>,
pub outputs: Vec<(NodeKind, NodePartial)>,
}
#[derive(Clone)]
#[non_exhaustive]
pub struct SchedulerRunContext {
pub event_emitter: Arc<dyn EventEmitter>,
pub clock: Option<Arc<dyn Clock>>,
pub invocation_id: Option<String>,
}
impl SchedulerRunContext {
#[must_use]
pub fn new(event_emitter: Arc<dyn EventEmitter>) -> Self {
Self {
event_emitter,
clock: None,
invocation_id: None,
}
}
#[must_use]
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = Some(clock);
self
}
#[must_use]
pub fn with_invocation_id(mut self, invocation_id: impl Into<String>) -> Self {
self.invocation_id = Some(invocation_id.into());
self
}
}
#[derive(Debug, Default, Clone)]
pub struct SchedulerState {
pub versions_seen: FxHashMap<String, FxHashMap<String, u64>>,
}
#[derive(Debug, Default, Clone)]
pub struct Scheduler {
pub concurrency_limit: usize,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum SchedulerError {
#[error("node {kind:?} in frontier not found in registry at step {step}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::scheduler::node_not_found),
help("Ensure all nodes in the graph are registered before execution.")
)
)]
NodeNotFound {
kind: NodeKind,
step: u64,
},
#[error("node run error at step {step} for {kind:?}: {source}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::node)))]
NodeRun {
kind: NodeKind,
step: u64,
#[source]
source: NodeError,
},
#[error("task join error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::join)))]
Join(#[from] tokio::task::JoinError),
}
impl Scheduler {
#[must_use]
pub fn new(concurrency_limit: usize) -> Self {
Self {
concurrency_limit: if concurrency_limit == 0 {
1
} else {
concurrency_limit
},
}
}
#[inline]
fn channel_versions(snap: &StateSnapshot) -> [(&'static str, u64); 2] {
[
("messages", snap.messages_version as u64),
("extra", snap.extra_version as u64),
]
}
#[must_use]
pub fn should_run(&self, state: &SchedulerState, node_id: &str, snap: &StateSnapshot) -> bool {
let channels = Self::channel_versions(snap);
self.should_run_with(state, node_id, &channels)
}
#[must_use]
pub fn should_run_with(
&self,
state: &SchedulerState,
node_id: &str,
channels: &[(&str, u64)],
) -> bool {
let seen = match state.versions_seen.get(node_id) {
Some(v) => v,
None => return true, };
for (name, ver) in channels.iter() {
let last = seen.get::<str>(name).copied().unwrap_or(0);
if *ver > last {
return true;
}
}
false
}
pub fn record_seen(&self, state: &mut SchedulerState, node_id: &str, snap: &StateSnapshot) {
let channels = Self::channel_versions(snap);
self.record_seen_with(state, node_id, &channels);
}
pub fn record_seen_with(
&self,
state: &mut SchedulerState,
node_id: &str,
channels: &[(&str, u64)],
) {
let entry = state.versions_seen.entry(node_id.to_string()).or_default();
for (name, ver) in channels.iter() {
entry.insert((*name).to_string(), *ver);
}
}
#[instrument(skip(self, state, nodes, frontier, snap, run_context))]
pub async fn superstep(
&self,
state: &mut SchedulerState,
nodes: &FxHashMap<NodeKind, Arc<dyn Node>>, frontier: Vec<NodeKind>, snap: StateSnapshot, step: u64,
run_context: SchedulerRunContext,
) -> Result<StepRunResult, SchedulerError> {
let channels = Self::channel_versions(&snap);
let skip_predicate = |k: &NodeKind| matches!(k, NodeKind::Start | NodeKind::End);
let mut to_run: Vec<NodeKind> = Vec::new();
let mut skipped_kinds: Vec<NodeKind> = Vec::new();
for k in frontier.into_iter() {
if skip_predicate(&k) {
skipped_kinds.push(k);
continue;
}
let id_str = format!("{:?}", k);
if self.should_run_with(state, &id_str, &channels) {
to_run.push(k);
} else {
skipped_kinds.push(k);
}
}
let to_run_ids: Vec<String> = to_run.iter().map(|k| format!("{:?}", k)).collect();
for kind in &to_run {
if !nodes.contains_key(kind) {
return Err(SchedulerError::NodeNotFound {
kind: kind.clone(),
step,
});
}
}
let tasks = to_run_ids
.iter()
.cloned()
.zip(to_run.clone().into_iter())
.map(|(id_str, kind)| {
let node = nodes.get(&kind).unwrap().clone();
let event_emitter = Arc::clone(&run_context.event_emitter);
let clock = run_context.clock.clone();
let invocation_id = run_context.invocation_id.clone();
let ctx = NodeContext {
node_id: id_str.clone(),
step,
event_emitter,
clock,
invocation_id,
};
let s = snap.clone();
async move {
let out = node.run(s, ctx).await;
(kind, out)
}
});
let mut outputs: Vec<(NodeKind, NodePartial)> = Vec::new();
let mut stream = stream::iter(tasks).buffer_unordered(self.concurrency_limit);
while let Some((kind, res)) = stream.next().await {
match res {
Ok(part) => outputs.push((kind, part)),
Err(e) => {
return Err(SchedulerError::NodeRun {
kind,
step,
source: e,
});
}
}
}
for id in &to_run_ids {
self.record_seen_with(state, id, &channels);
}
Ok(StepRunResult {
ran_nodes: to_run,
skipped_nodes: skipped_kinds,
outputs,
})
}
}