weavegraph 0.7.0

Graph-driven, concurrent agent workflow framework with versioned state, deterministic barrier merges, and rich diagnostics.
Documentation
// miette::Diagnostic derive triggers unused_assignments on Rust 1.93+.
#![allow(unused_assignments)]
//! Frontier-based scheduler with version gating and bounded concurrency.
//!
//! Manages concurrent node execution; nodes are skipped when they have already
//! processed the current channel versions. The scheduler is stateless — all
//! tracking lives in [`SchedulerState`].
//!
//! ```rust
//! use weavegraph::channels::Channel;
//! use weavegraph::schedulers::{Scheduler, SchedulerState};
//! use weavegraph::state::VersionedState;
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let scheduler = Scheduler::new(4);
//! let mut state = SchedulerState::default();
//! let mut vs = VersionedState::builder().build();
//! vs.messages.set_version(2);
//! assert!(scheduler.should_run(&state, "my_node", &vs.snapshot()));
//! # Ok(())
//! # }
//! ```

use crate::event_bus::EventEmitter;
use crate::node::{Node, NodeContext, NodeError, NodePartial};
use crate::state::StateSnapshot;
use crate::types::NodeKind;
use crate::utils::clock::Clock;
use futures_util::stream::{self, StreamExt};
use rustc_hash::FxHashMap;
use std::sync::Arc;
use thiserror::Error;
use tracing::instrument;

/// Execution summary for a single superstep.
#[derive(Debug, Clone)]
pub struct StepRunResult {
    /// Nodes that executed this step, in scheduling order.
    pub ran_nodes: Vec<NodeKind>,
    /// Nodes skipped this step (structural or version-gated).
    pub skipped_nodes: Vec<NodeKind>,
    /// Outputs from executed nodes as `(node_kind, partial)` pairs.
    pub outputs: Vec<(NodeKind, NodePartial)>,
}

/// Runtime context injected into each superstep.
#[derive(Clone)]
#[non_exhaustive]
pub struct SchedulerRunContext {
    /// Event emitter forwarded to each node context.
    pub event_emitter: Arc<dyn EventEmitter>,
    /// Optional clock forwarded to each node context.
    pub clock: Option<Arc<dyn Clock>>,
    /// Optional invocation identifier forwarded to each node context.
    pub invocation_id: Option<String>,
}

impl SchedulerRunContext {
    /// Build a context with only an event emitter; clock and invocation ID default to `None`.
    #[must_use]
    pub fn new(event_emitter: Arc<dyn EventEmitter>) -> Self {
        Self {
            event_emitter,
            clock: None,
            invocation_id: None,
        }
    }

    /// Attach a clock.
    #[must_use]
    pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
        self.clock = Some(clock);
        self
    }

    /// Attach an invocation identifier.
    #[must_use]
    pub fn with_invocation_id(mut self, invocation_id: impl Into<String>) -> Self {
        self.invocation_id = Some(invocation_id.into());
        self
    }
}

/// Version-tracking state that drives the scheduler's execution gate.
///
/// `versions_seen[node_id][channel]` records the last version a node consumed
/// per channel. When a snapshot's version exceeds that value the node runs
/// again; otherwise it is skipped.
///
/// ```rust
/// use weavegraph::channels::Channel;
/// use weavegraph::schedulers::{Scheduler, SchedulerState};
/// use weavegraph::state::VersionedState;
///
/// let scheduler = Scheduler::new(2);
/// let mut state = SchedulerState::default();
/// let mut vs = VersionedState::builder().build();
/// vs.messages.set_version(3);
/// let snap = vs.snapshot();
/// scheduler.record_seen(&mut state, "node_a", &snap);
/// assert!(!scheduler.should_run(&state, "node_a", &snap));
/// ```
#[derive(Debug, Default, Clone)]
pub struct SchedulerState {
    /// `versions_seen[node_id][channel]` — last version the node processed.
    pub versions_seen: FxHashMap<String, FxHashMap<String, u64>>,
}

/// Frontier scheduler with version gating and bounded concurrency.
///
/// Stateless execution engine — all tracking lives in [`SchedulerState`].
/// Eligible nodes run concurrently up to `concurrency_limit`; structural nodes
/// and nodes that have already processed the current state are skipped.
///
/// ```rust
/// use weavegraph::schedulers::Scheduler;
///
/// assert_eq!(Scheduler::new(8).concurrency_limit, 8);
/// assert_eq!(Scheduler::new(0).concurrency_limit, 1); // zero clamps to 1
/// ```
#[derive(Debug, Default, Clone)]
pub struct Scheduler {
    /// Maximum nodes that may run concurrently in a single superstep.
    pub concurrency_limit: usize,
}

/// Errors raised during scheduler execution.
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum SchedulerError {
    /// A frontier node was absent from the node registry.
    #[error("node {kind:?} in frontier not found in registry at step {step}")]
    #[cfg_attr(
        feature = "diagnostics",
        diagnostic(
            code(weavegraph::scheduler::node_not_found),
            help("Ensure all nodes in the graph are registered before execution.")
        )
    )]
    NodeNotFound {
        /// The node kind that was missing.
        kind: NodeKind,
        /// Workflow step at which the lookup failed.
        step: u64,
    },

    /// A node's `run` method returned an error.
    #[error("node run error at step {step} for {kind:?}: {source}")]
    #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::node)))]
    NodeRun {
        /// The node kind that failed.
        kind: NodeKind,
        /// Workflow step at which the failure occurred.
        step: u64,
        /// Underlying node error.
        #[source]
        source: NodeError,
    },

    /// An async task join failed (panic or cancellation).
    #[error("task join error: {0}")]
    #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::join)))]
    Join(#[from] tokio::task::JoinError),
}

impl Scheduler {
    /// Create a scheduler; `concurrency_limit` of 0 is clamped to 1.
    #[must_use]
    pub fn new(concurrency_limit: usize) -> Self {
        Self {
            concurrency_limit: concurrency_limit.max(1),
        }
    }

    #[inline]
    fn channel_versions(snap: &StateSnapshot) -> [(&'static str, u64); 2] {
        [
            ("messages", snap.messages_version as u64),
            ("extra", snap.extra_version as u64),
        ]
    }

    /// Return `true` if the node should run given the snapshot's current versions.
    #[must_use]
    pub fn should_run(&self, state: &SchedulerState, node_id: &str, snap: &StateSnapshot) -> bool {
        self.should_run_with(state, node_id, &Self::channel_versions(snap))
    }

    /// Return `true` if any channel version exceeds what `node_id` last processed.
    #[must_use]
    pub fn should_run_with(
        &self,
        state: &SchedulerState,
        node_id: &str,
        channels: &[(&str, u64)],
    ) -> bool {
        let Some(seen) = state.versions_seen.get(node_id) else {
            return true;
        };
        channels
            .iter()
            .any(|&(name, ver)| ver > seen.get(name).copied().unwrap_or(0))
    }

    /// Record that `node_id` has processed the snapshot's channel versions.
    pub fn record_seen(&self, state: &mut SchedulerState, node_id: &str, snap: &StateSnapshot) {
        self.record_seen_with(state, node_id, &Self::channel_versions(snap));
    }

    /// Record arbitrary `(channel, version)` pairs as seen by `node_id`.
    ///
    /// ```rust
    /// use weavegraph::schedulers::{Scheduler, SchedulerState};
    ///
    /// let mut state = SchedulerState::default();
    /// Scheduler::new(1).record_seen_with(&mut state, "n", &[("messages", 7)]);
    /// assert_eq!(state.versions_seen["n"]["messages"], 7);
    /// ```
    pub fn record_seen_with(
        &self,
        state: &mut SchedulerState,
        node_id: &str,
        channels: &[(&str, u64)],
    ) {
        let entry = state.versions_seen.entry(node_id.to_owned()).or_default();
        for &(name, ver) in channels {
            entry.insert(name.to_owned(), ver);
        }
    }

    /// Execute one superstep over `frontier` with bounded concurrency.
    ///
    /// Structural nodes (`Start`, `End`) and version-gated nodes are skipped.
    /// Remaining nodes run concurrently up to `concurrency_limit`. Returns the
    /// step summary or the first error encountered.
    #[instrument(skip(self, state, nodes, frontier, snap, run_context))]
    pub async fn superstep(
        &self,
        state: &mut SchedulerState,
        nodes: &FxHashMap<NodeKind, Arc<dyn Node>>,
        frontier: Vec<NodeKind>,
        snap: StateSnapshot,
        step: u64,
        run_context: SchedulerRunContext,
    ) -> Result<StepRunResult, SchedulerError> {
        let channels = Self::channel_versions(&snap);
        let mut to_run: Vec<NodeKind> = Vec::new();
        let mut to_run_ids: Vec<String> = Vec::new();
        let mut skipped_kinds: Vec<NodeKind> = Vec::new();

        for kind in frontier {
            if matches!(kind, NodeKind::Start | NodeKind::End) {
                skipped_kinds.push(kind);
                continue;
            }
            let id = format!("{kind:?}");
            if self.should_run_with(state, &id, &channels) {
                to_run_ids.push(id);
                to_run.push(kind);
            } else {
                skipped_kinds.push(kind);
            }
        }

        for kind in &to_run {
            if !nodes.contains_key(kind) {
                return Err(SchedulerError::NodeNotFound {
                    kind: kind.clone(),
                    step,
                });
            }
        }

        let tasks: Vec<_> = to_run
            .iter()
            .zip(&to_run_ids)
            .map(|(kind, id)| {
                let node = nodes[kind].clone();
                let ctx = NodeContext {
                    node_id: id.clone(),
                    step,
                    event_emitter: Arc::clone(&run_context.event_emitter),
                    clock: run_context.clock.clone(),
                    invocation_id: run_context.invocation_id.clone(),
                };
                let s = snap.clone();
                let kind = kind.clone();
                async move { (kind, node.run(s, ctx).await) }
            })
            .collect();

        let mut outputs: Vec<(NodeKind, NodePartial)> = Vec::new();
        let mut stream = stream::iter(tasks).buffer_unordered(self.concurrency_limit);
        while let Some((kind, res)) = stream.next().await {
            match res {
                Ok(part) => outputs.push((kind, part)),
                Err(e) => {
                    return Err(SchedulerError::NodeRun {
                        kind,
                        step,
                        source: e,
                    });
                }
            }
        }

        for id in &to_run_ids {
            self.record_seen_with(state, id, &channels);
        }

        Ok(StepRunResult {
            ran_nodes: to_run,
            skipped_nodes: skipped_kinds,
            outputs,
        })
    }
}