use rustc_hash::FxHashMap;
use std::sync::Arc;
use weavegraph::event_bus::EventBus;
use weavegraph::schedulers::scheduler::{Scheduler, SchedulerState, StepRunResult};
use weavegraph::types::NodeKind;
mod common;
use common::{FailingNode, create_test_snapshot, make_delayed_registry, make_test_registry};
#[tokio::test]
async fn test_superstep_propagates_node_error() {
let sched = Scheduler::new(4);
let mut state = SchedulerState::default();
let mut nodes: FxHashMap<NodeKind, Arc<dyn weavegraph::node::Node>> = FxHashMap::default();
nodes.insert(
NodeKind::Custom("FAIL".into()),
Arc::new(FailingNode::default()),
);
let frontier = vec![NodeKind::Custom("FAIL".into())];
let snap = create_test_snapshot(1, 1);
let event_bus = EventBus::default();
let res = sched
.superstep(
&mut state,
&nodes,
frontier,
snap,
1,
event_bus.get_emitter(),
)
.await;
match res {
Err(weavegraph::schedulers::scheduler::SchedulerError::NodeRun {
source: weavegraph::node::NodeError::MissingInput { what },
..
}) => {
assert_eq!(what, "test_key");
}
other => panic!(
"expected SchedulerError::NodeRun(MissingInput), got: {:?}",
other
),
}
}
#[test]
fn test_should_run_and_record_seen() {
let sched = Scheduler::new(4);
let mut state = SchedulerState::default();
let id = "Other(\"A\")";
let snap1 = create_test_snapshot(1, 1);
assert!(sched.should_run(&state, id, &snap1));
sched.record_seen(&mut state, id, &snap1);
assert!(!sched.should_run(&state, id, &snap1));
let snap2 = create_test_snapshot(2, 1);
assert!(sched.should_run(&state, id, &snap2));
sched.record_seen(&mut state, id, &snap2);
let snap3 = create_test_snapshot(2, 3);
assert!(sched.should_run(&state, id, &snap3));
}
#[tokio::test]
async fn test_superstep_skips_end_and_nochange() {
let sched = Scheduler::new(8);
let mut state = SchedulerState::default();
let nodes = make_test_registry();
let frontier = vec![
NodeKind::Custom("A".into()),
NodeKind::End,
NodeKind::Custom("B".into()),
];
let event_bus = EventBus::default();
let snap = create_test_snapshot(1, 1);
let res1: StepRunResult = sched
.superstep(
&mut state,
&nodes,
frontier.clone(),
snap.clone(),
1,
event_bus.get_emitter(),
)
.await
.unwrap();
let ran1: std::collections::HashSet<_> = res1.ran_nodes.iter().cloned().collect();
assert!(ran1.contains(&NodeKind::Custom("A".into())));
assert!(ran1.contains(&NodeKind::Custom("B".into())));
assert!(!ran1.contains(&NodeKind::End));
assert!(res1.skipped_nodes.contains(&NodeKind::End));
assert_eq!(res1.outputs.len(), 2);
let res2 = sched
.superstep(
&mut state,
&nodes,
frontier.clone(),
snap.clone(),
2,
event_bus.get_emitter(),
)
.await
.unwrap();
assert!(res2.ran_nodes.is_empty());
let skipped2: std::collections::HashSet<_> = res2.skipped_nodes.iter().cloned().collect();
assert!(skipped2.contains(&NodeKind::Custom("A".into())));
assert!(skipped2.contains(&NodeKind::Custom("B".into())));
assert!(skipped2.contains(&NodeKind::End));
assert!(res2.outputs.is_empty());
let snap_bump = create_test_snapshot(2, 1);
let res3 = sched
.superstep(
&mut state,
&nodes,
frontier.clone(),
snap_bump,
3,
event_bus.get_emitter(),
)
.await
.unwrap();
let ran3: std::collections::HashSet<_> = res3.ran_nodes.iter().cloned().collect();
assert!(ran3.contains(&NodeKind::Custom("A".into())));
assert!(ran3.contains(&NodeKind::Custom("B".into())));
assert_eq!(res3.outputs.len(), 2);
}
#[tokio::test]
async fn test_superstep_outputs_order_agnostic() {
let nodes = make_delayed_registry();
let frontier = vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())];
let snap = create_test_snapshot(1, 1);
let sched = Scheduler::new(2);
let mut state = SchedulerState::default();
let event_bus = EventBus::default();
let res = sched
.superstep(
&mut state,
&nodes,
frontier.clone(),
snap,
1,
event_bus.get_emitter(),
)
.await
.unwrap();
assert_eq!(
res.ran_nodes,
vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())]
);
let ids: std::collections::HashSet<_> = res.outputs.iter().map(|(id, _)| id.clone()).collect();
let expected: std::collections::HashSet<_> =
[NodeKind::Custom("A".into()), NodeKind::Custom("B".into())]
.into_iter()
.collect();
assert_eq!(ids, expected);
}
#[tokio::test]
async fn test_superstep_serialized_with_limit_1() {
let nodes = make_delayed_registry();
let frontier = vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())];
let snap = create_test_snapshot(1, 1);
let sched = Scheduler::new(1);
let mut state = SchedulerState::default();
let event_bus = EventBus::default();
let res = sched
.superstep(
&mut state,
&nodes,
frontier.clone(),
snap,
1,
event_bus.get_emitter(),
)
.await
.unwrap();
assert_eq!(
res.ran_nodes,
vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())]
);
let output_ids: Vec<_> = res.outputs.iter().map(|(id, _)| id.clone()).collect();
assert_eq!(output_ids, res.ran_nodes);
}