use crate::{
JunctureError, State,
edge::{CompiledEdge, TriggerSource, TriggerTable},
pregel::types::{PendingTask, SuperstepResult, TaskOutput},
state::FieldsChanged,
};
use indexmap::IndexMap;
use std::{collections::HashMap, collections::HashSet};
#[derive(Clone, Debug)]
pub struct FieldVersionTracker {
versions: Vec<u64>,
global_max: u64,
}
impl FieldVersionTracker {
#[must_use]
pub fn new(num_fields: usize) -> Self {
assert!(
num_fields <= 64,
"Cannot track more than 64 fields (got {num_fields})"
);
Self {
versions: vec![0; num_fields],
global_max: 0,
}
}
pub fn bump_all(&mut self, changed: &FieldsChanged) {
for field_idx in 0..self.versions.len() {
if changed.has_field(field_idx) {
self.bump(field_idx);
}
}
}
pub fn bump(&mut self, field_idx: usize) {
self.global_max = self.global_max.saturating_add(1);
self.versions[field_idx] = self.global_max;
}
#[must_use]
pub fn get(&self, field_idx: usize) -> u64 {
self.versions[field_idx]
}
#[must_use]
pub fn versions(&self) -> &[u64] {
&self.versions
}
#[must_use]
pub const fn len(&self) -> usize {
self.versions.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.versions.is_empty()
}
#[must_use]
pub fn as_slice(&self) -> &[u64] {
self.versions()
}
#[must_use]
pub const fn global_max(&self) -> u64 {
self.global_max
}
}
#[derive(Clone, Debug)]
pub struct VersionsSeen {
seen: IndexMap<String, Vec<u64>>,
}
impl VersionsSeen {
#[must_use]
pub fn new(node_names: &[String], num_fields: usize) -> Self {
let seen = node_names
.iter()
.map(|name| (name.clone(), vec![0; num_fields]))
.collect();
Self { seen }
}
#[must_use]
pub fn should_activate(
&self,
node_name: &str,
trigger_fields: &[usize],
current: &[u64],
) -> bool {
let Some(seen_versions) = self.seen.get(node_name) else {
return true; };
for &field_idx in trigger_fields {
if current[field_idx] > seen_versions[field_idx] {
return true;
}
}
false
}
pub fn mark_consumed(&mut self, node_name: &str, current: &[u64]) {
if let Some(seen_versions) = self.seen.get_mut(node_name) {
seen_versions.copy_from_slice(current);
}
}
#[must_use]
pub fn get_seen(&self, node_name: &str) -> &[u64] {
self.seen.get(node_name).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn get_versions(&self, node_name: &str) -> &[u64] {
self.get_seen(node_name)
}
#[must_use]
pub fn compute_triggered_fields(
&self,
node_name: &str,
trigger_fields: &[usize],
current_versions: &[u64],
) -> Vec<usize> {
let Some(seen_versions) = self.seen.get(node_name) else {
return trigger_fields.to_vec();
};
trigger_fields
.iter()
.filter(|&&field_idx| current_versions[field_idx] > seen_versions[field_idx])
.copied()
.collect()
}
}
pub async fn compute_next_tasks<S: State>(
completed_tasks: &[TaskOutput<S>],
trigger_table: &TriggerTable<S>,
trigger_to_nodes: &TriggerToNodes,
state: &S,
) -> Result<Vec<PendingTask<S>>, JunctureError> {
let mut next_tasks = Vec::new();
let mut seen_nodes = HashSet::new();
for task_output in completed_tasks {
let command = &task_output.command;
match &command.goto {
crate::Goto::None => {
let triggered =
trigger_to_nodes.triggered_nodes(std::slice::from_ref(&task_output.node_name));
if let Some(edges) = trigger_table.outgoing.get(&task_output.node_name) {
for edge in edges {
if should_process_edge(edge, state, &triggered).await? {
process_edge(
edge,
state,
&mut next_tasks,
&mut seen_nodes,
&task_output.node_name,
)
.await?;
}
}
}
}
crate::Goto::Next(target) => {
if !seen_nodes.contains(target) {
seen_nodes.insert(target.clone());
next_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
target.clone(),
));
}
}
crate::Goto::Multiple(targets) => {
for target in targets {
if !seen_nodes.contains(target) {
seen_nodes.insert(target.clone());
next_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
target.clone(),
));
}
}
}
crate::Goto::Send(send_targets) => {
for (idx, target) in send_targets.iter().enumerate() {
next_tasks.push(PendingTask::push(
uuid::Uuid::new_v4().to_string(),
target.node.clone(),
idx,
target.state.clone(),
));
}
}
crate::Goto::End => {
}
}
}
Ok(next_tasks)
}
async fn should_process_edge<S: State>(
edge: &CompiledEdge<S>,
state: &S,
triggered_nodes: &HashSet<String>,
) -> Result<bool, JunctureError> {
match edge {
CompiledEdge::Fixed { target } => Ok(triggered_nodes.contains(target)),
CompiledEdge::Conditional { router, .. } => {
let route_result = router.route(state).await?;
Ok(route_result
.as_target()
.is_some_and(|t| triggered_nodes.contains(t)))
}
}
}
async fn process_edge<S: State>(
edge: &CompiledEdge<S>,
state: &S,
next_tasks: &mut Vec<PendingTask<S>>,
seen_nodes: &mut HashSet<String>,
from_node: &str,
) -> Result<(), JunctureError> {
match edge {
CompiledEdge::Fixed { target } => {
if target != crate::edge::END && !seen_nodes.contains(target) {
seen_nodes.insert(target.clone());
next_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
target.clone(),
));
}
}
CompiledEdge::Conditional { router, .. } => {
let route_result = router.route(state).await?;
let target = route_result.as_target().ok_or_else(|| {
JunctureError::execution(format!(
"Conditional edge from '{from_node}' returned no target: {route_result:?}"
))
})?;
if target != crate::edge::END && !seen_nodes.contains(target) {
seen_nodes.insert(target.to_string());
next_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
target.to_string(),
));
}
}
}
Ok(())
}
pub fn apply_writes<S: State>(
state: &mut S,
task_outputs: &[crate::pregel::types::TaskOutput<S>],
field_versions: &mut FieldVersionTracker,
) -> Result<FieldsChanged, JunctureError> {
check_replace_conflicts_from_state::<S>(task_outputs)?;
let mut total_changed = FieldsChanged(0);
let mut sorted_indices: Vec<usize> = (0..task_outputs.len()).collect();
sorted_indices.sort_by(|&a, &b| {
let task_a = &task_outputs[a];
let task_b = &task_outputs[b];
match (&task_a.trigger, &task_b.trigger) {
(crate::pregel::types::TaskTrigger::Pull, crate::pregel::types::TaskTrigger::Pull) => {
task_a.node_name.cmp(&task_b.node_name)
}
(
crate::pregel::types::TaskTrigger::Push { index: idx_a },
crate::pregel::types::TaskTrigger::Push { index: idx_b },
) => idx_a.cmp(idx_b),
(
crate::pregel::types::TaskTrigger::Pull,
crate::pregel::types::TaskTrigger::Push { .. },
) => std::cmp::Ordering::Less,
(
crate::pregel::types::TaskTrigger::Push { .. },
crate::pregel::types::TaskTrigger::Pull,
) => std::cmp::Ordering::Greater,
}
});
for idx in sorted_indices {
let output = &task_outputs[idx];
if let Some(ref update) = output.command.update {
let changed = state
.try_apply(update.clone())
.map_err(|e| JunctureError::invalid_update(e.to_string()))?;
total_changed.merge(&changed);
}
}
field_versions.bump_all(&total_changed);
Ok(total_changed)
}
pub struct TriggerToNodes {
mapping: HashMap<String, HashSet<String>>,
}
impl TriggerToNodes {
#[must_use]
pub fn from_trigger_table<S: State>(table: &TriggerTable<S>) -> Self {
let mut mapping: HashMap<String, HashSet<String>> = HashMap::new();
for (node_name, sources) in &table.incoming {
for source in sources {
match source {
TriggerSource::Edge { from } | TriggerSource::Send { from } => {
mapping
.entry(from.clone())
.or_default()
.insert(node_name.clone());
}
}
}
}
Self { mapping }
}
#[must_use]
pub fn triggered_nodes(&self, updated_channels: &[String]) -> HashSet<String> {
updated_channels
.iter()
.filter_map(|ch| self.mapping.get(ch))
.flatten()
.cloned()
.collect()
}
}
impl std::fmt::Debug for TriggerToNodes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TriggerToNodes")
.field("mapping_len", &self.mapping.len())
.finish()
}
}
pub fn check_replace_conflicts<S: State>(
superstep_result: &SuperstepResult<S>,
replace_fields: &[usize],
) -> Result<(), JunctureError> {
for &field_idx in replace_fields {
let writers: Vec<&str> = superstep_result
.task_outputs
.iter()
.filter(|o| {
o.command
.update
.as_ref()
.is_some_and(|u| S::field_is_set(u, field_idx))
})
.map(|o| o.node_name.as_str())
.collect();
if writers.len() > 1 {
return Err(JunctureError::execution(format!(
"Multiple writers for replace field {field_idx}: {writers:?}"
)));
}
}
Ok(())
}
fn check_replace_conflicts_from_state<S: State>(
task_outputs: &[crate::pregel::types::TaskOutput<S>],
) -> Result<(), JunctureError> {
let replace_fields = S::replace_field_indices();
for &field_idx in replace_fields {
let writers: Vec<&str> = task_outputs
.iter()
.filter(|o| {
o.command
.update
.as_ref()
.is_some_and(|u| S::field_is_set(u, field_idx))
})
.map(|o| o.node_name.as_str())
.collect();
if writers.len() > 1 {
return Err(JunctureError::multiple_writers(
field_idx,
writers.into_iter().map(String::from).collect(),
));
}
}
Ok(())
}
pub fn consume_triggered_channels<S: State>(state: &mut S, triggered_channels: &[usize]) {
for &field_idx in triggered_channels {
state.consume_field(field_idx);
}
}
#[expect(
clippy::implicit_hasher,
reason = "public API accepts std HashMap; callers typically construct from builder metadata"
)]
#[expect(
clippy::cognitive_complexity,
reason = "function has multiple early-return guards (circuit_blocked, error, fallback_map, missing node, self-reference) that are individually simple but add up"
)]
pub fn schedule_fallback_tasks<S: State>(
task_outputs: &[TaskOutput<S>],
nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
fallback_map: &std::collections::HashMap<String, String>,
) -> (Vec<PendingTask<S>>, std::collections::HashSet<String>) {
let mut recovery_tasks = Vec::new();
let mut handled_nodes = std::collections::HashSet::new();
for output in task_outputs {
if output.circuit_blocked {
continue;
}
let Some(ref error) = output.error else {
continue;
};
let Some(fallback_name) = fallback_map.get(&output.node_name) else {
continue;
};
if !nodes.contains_key(fallback_name) {
tracing::warn!(
name: "juncture.fallback.missing_node",
node_name = %output.node_name,
fallback_name = %fallback_name,
error = %error,
"Fallback node not found in graph, skipping fallback"
);
continue;
}
if fallback_name == &output.node_name {
tracing::warn!(
name: "juncture.fallback.self_reference",
node_name = %output.node_name,
error = %error,
"Node configured as its own fallback, skipping to prevent infinite loop"
);
continue;
}
if handled_nodes.contains(fallback_name) {
tracing::warn!(
name: "juncture.fallback.cycle_detected",
node_name = %output.node_name,
fallback_name = %fallback_name,
error = %error,
"Fallback cycle detected, skipping to prevent infinite loop"
);
continue;
}
tracing::info!(
name: "juncture.fallback.scheduled",
node_name = %output.node_name,
fallback_name = %fallback_name,
error = %error,
"Scheduling fallback node for failed task"
);
recovery_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
fallback_name.clone(),
));
handled_nodes.insert(output.node_name.clone());
}
(recovery_tasks, handled_nodes)
}
#[expect(
clippy::implicit_hasher,
reason = "public API accepts std HashMap; callers typically construct from builder metadata"
)]
pub fn schedule_error_handlers_filtered<S: State>(
task_outputs: &[TaskOutput<S>],
nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
error_handler_map: &std::collections::HashMap<String, String>,
fallback_handled: &std::collections::HashSet<String>,
) -> Vec<PendingTask<S>> {
let mut recovery_tasks = Vec::new();
for output in task_outputs {
if output.circuit_blocked {
continue;
}
let Some(ref error) = output.error else {
continue;
};
if fallback_handled.contains(&output.node_name) {
continue;
}
let Some(handler_name) = error_handler_map.get(&output.node_name) else {
continue;
};
if !nodes.contains_key(handler_name) {
tracing::warn!(
name: "juncture.error_handler.missing_node",
node_name = %output.node_name,
handler_name = %handler_name,
error = %error,
"Error handler node not found in graph, skipping recovery"
);
continue;
}
recovery_tasks.push(PendingTask::pull(
uuid::Uuid::new_v4().to_string(),
handler_name.clone(),
));
}
recovery_tasks
}
#[must_use]
#[allow(
dead_code,
reason = "tested via unit tests; public API awaiting external consumers"
)]
pub fn get_error_handler_node(
node_name: &str,
error_handler_map: &std::collections::HashMap<String, String>,
) -> Option<String> {
error_handler_map.get(node_name).cloned()
}
#[cfg(test)]
mod scheduler_tests {
use super::*;
use crate::node::IntoNode;
use crate::state::FieldVersions;
#[derive(Clone, Debug, Default)]
struct TestState;
impl State for TestState {
type Update = TestUpdate;
type FieldVersions = FieldVersions;
fn apply(&mut self, _: Self::Update) -> FieldsChanged {
FieldsChanged(0)
}
fn reset_ephemeral(&mut self) {}
}
#[derive(Clone, Debug, Default, serde::Serialize)]
struct TestUpdate;
#[test]
fn test_trigger_to_nodes_from_empty_table() {
let table: TriggerTable<TestState> = TriggerTable::default();
let ttn = TriggerToNodes::from_trigger_table(&table);
assert!(ttn.triggered_nodes(&["node_a".to_string()]).is_empty());
}
#[test]
fn test_trigger_to_nodes_with_sources() {
let mut table: TriggerTable<TestState> = TriggerTable::default();
table.add_incoming(
"node_b".to_string(),
TriggerSource::Edge {
from: "node_a".to_string(),
},
);
table.add_incoming(
"node_c".to_string(),
TriggerSource::Edge {
from: "node_a".to_string(),
},
);
table.add_incoming(
"node_c".to_string(),
TriggerSource::Edge {
from: "node_d".to_string(),
},
);
let ttn = TriggerToNodes::from_trigger_table(&table);
let triggered = ttn.triggered_nodes(&["node_a".to_string()]);
assert!(triggered.contains("node_b"));
assert!(triggered.contains("node_c"));
assert!(!triggered.contains("node_d"));
let triggered_d = ttn.triggered_nodes(&["node_d".to_string()]);
assert!(triggered_d.contains("node_c"));
assert!(!triggered_d.contains("node_b"));
}
#[test]
fn test_trigger_to_nodes_debug() {
let table: TriggerTable<TestState> = TriggerTable::default();
let ttn = TriggerToNodes::from_trigger_table(&table);
let debug = format!("{ttn:?}");
assert!(debug.contains("TriggerToNodes"));
}
#[test]
fn test_apply_writes_empty_outputs() {
let mut state = TestState;
let mut tracker = FieldVersionTracker::new(3);
let outputs: Vec<crate::pregel::types::TaskOutput<TestState>> = Vec::new();
let changed =
apply_writes(&mut state, &outputs, &mut tracker).expect("empty outputs should succeed");
assert_eq!(changed.0, 0);
}
#[test]
fn test_check_replace_conflicts_empty() {
let result: SuperstepResult<TestState> = SuperstepResult {
task_outputs: Vec::new(),
bubble_ups: Vec::new(),
};
let replace_fields = vec![0, 1];
check_replace_conflicts(&result, &replace_fields).unwrap();
}
#[test]
fn test_check_replace_conflicts_no_conflicts() {
use crate::Command;
let task_output_a: crate::pregel::types::TaskOutput<TestState> =
crate::pregel::types::TaskOutput {
triggered_fields: vec![],
task_id: "task_1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::end(),
duration: std::time::Duration::from_millis(10),
error: None,
circuit_blocked: false,
};
let result: SuperstepResult<TestState> = SuperstepResult {
task_outputs: vec![task_output_a],
bubble_ups: Vec::new(),
};
let replace_fields = vec![0, 1];
check_replace_conflicts(&result, &replace_fields).unwrap();
}
#[test]
fn test_consume_triggered_channels_empty() {
let mut state = TestState;
let triggered_channels = vec![0usize; 0];
consume_triggered_channels(&mut state, &triggered_channels);
}
#[test]
fn test_consume_triggered_channels_some() {
let mut state = TestState;
let triggered_channels = vec![0, 2];
consume_triggered_channels(&mut state, &triggered_channels);
}
#[test]
fn test_schedule_error_handlers_no_failures() {
let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
indexmap::IndexMap::new();
let task_outputs: Vec<TaskOutput<TestState>> = Vec::new();
let error_handler_map = std::collections::HashMap::new();
let recovery_tasks = schedule_error_handlers_filtered(
&task_outputs,
&nodes,
&error_handler_map,
&std::collections::HashSet::new(),
);
assert!(recovery_tasks.is_empty());
}
#[test]
fn test_schedule_error_handlers_with_failure() {
use crate::Command;
let mut nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
indexmap::IndexMap::new();
nodes.insert(
"error_handler_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("error_handler_a"),
);
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "failing_node".to_string(),
command: Command::default(),
duration: std::time::Duration::ZERO,
trigger: crate::pregel::types::TaskTrigger::Pull,
error: Some(crate::JunctureError::execution("test failure")),
circuit_blocked: false,
}];
let mut error_handler_map = std::collections::HashMap::new();
error_handler_map.insert("failing_node".to_string(), "error_handler_a".to_string());
let recovery_tasks = schedule_error_handlers_filtered(
&task_outputs,
&nodes,
&error_handler_map,
&std::collections::HashSet::new(),
);
assert_eq!(recovery_tasks.len(), 1);
assert_eq!(recovery_tasks[0].node_name, "error_handler_a");
}
#[test]
fn test_schedule_error_handlers_missing_handler_node() {
use crate::Command;
let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
indexmap::IndexMap::new();
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "failing_node".to_string(),
command: Command::default(),
duration: std::time::Duration::ZERO,
trigger: crate::pregel::types::TaskTrigger::Pull,
error: Some(crate::JunctureError::execution("test failure")),
circuit_blocked: false,
}];
let mut error_handler_map = std::collections::HashMap::new();
error_handler_map.insert(
"failing_node".to_string(),
"nonexistent_handler".to_string(),
);
let recovery_tasks = schedule_error_handlers_filtered(
&task_outputs,
&nodes,
&error_handler_map,
&std::collections::HashSet::new(),
);
assert!(
recovery_tasks.is_empty(),
"handler node not in graph, no recovery task"
);
}
#[test]
fn test_schedule_error_handlers_no_handler_registered() {
use crate::Command;
let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
indexmap::IndexMap::new();
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "failing_node".to_string(),
command: Command::default(),
duration: std::time::Duration::ZERO,
trigger: crate::pregel::types::TaskTrigger::Pull,
error: Some(crate::JunctureError::execution("test failure")),
circuit_blocked: false,
}];
let error_handler_map = std::collections::HashMap::new();
let recovery_tasks = schedule_error_handlers_filtered(
&task_outputs,
&nodes,
&error_handler_map,
&std::collections::HashSet::new(),
);
assert!(recovery_tasks.is_empty());
}
#[test]
fn test_get_error_handler_node_found() {
let mut error_handler_map = std::collections::HashMap::new();
error_handler_map.insert("node_a".to_string(), "handler_a".to_string());
let handler = get_error_handler_node("node_a", &error_handler_map);
assert_eq!(handler, Some("handler_a".to_string()));
}
#[test]
fn test_get_error_handler_node_not_found() {
let error_handler_map = std::collections::HashMap::new();
let handler = get_error_handler_node("node_a", &error_handler_map);
assert!(handler.is_none());
}
#[test]
fn test_schedule_fallback_tasks_no_errors() {
use crate::Command;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_a"),
);
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::end(),
duration: std::time::Duration::ZERO,
error: None,
circuit_blocked: false,
}];
let fallback_map = std::collections::HashMap::new();
let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
assert!(tasks.is_empty());
assert!(handled.is_empty());
}
#[test]
fn test_schedule_fallback_tasks_with_fallback() {
use crate::Command;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_a"),
);
nodes.insert(
"fallback_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("fallback_a"),
);
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::default(),
duration: std::time::Duration::ZERO,
error: Some(crate::JunctureError::execution("test error")),
circuit_blocked: false,
}];
let mut fallback_map = std::collections::HashMap::new();
fallback_map.insert("node_a".to_string(), "fallback_a".to_string());
let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].node_name, "fallback_a");
assert!(handled.contains("node_a"));
}
#[test]
fn test_schedule_fallback_tasks_skips_circuit_blocked() {
use crate::Command;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_a"),
);
nodes.insert(
"fallback_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("fallback_a"),
);
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::default(),
duration: std::time::Duration::ZERO,
error: Some(crate::JunctureError::execution("circuit open")),
circuit_blocked: true,
}];
let mut fallback_map = std::collections::HashMap::new();
fallback_map.insert("node_a".to_string(), "fallback_a".to_string());
let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
assert!(tasks.is_empty());
assert!(handled.is_empty());
}
#[test]
fn test_schedule_fallback_tasks_self_reference_guard() {
use crate::Command;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_a"),
);
let task_outputs = vec![TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::default(),
duration: std::time::Duration::ZERO,
error: Some(crate::JunctureError::execution("test error")),
circuit_blocked: false,
}];
let mut fallback_map = std::collections::HashMap::new();
fallback_map.insert("node_a".to_string(), "node_a".to_string());
let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
assert!(tasks.is_empty());
assert!(handled.is_empty());
}
#[test]
fn test_schedule_fallback_tasks_cycle_guard() {
use crate::Command;
let mut nodes = indexmap::IndexMap::new();
nodes.insert(
"node_a".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_a"),
);
nodes.insert(
"node_b".to_string(),
crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
.into_node("node_b"),
);
let task_outputs = vec![
TaskOutput {
triggered_fields: vec![],
task_id: "task-1".to_string(),
node_name: "node_a".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::default(),
duration: std::time::Duration::ZERO,
error: Some(crate::JunctureError::execution("node_a failed")),
circuit_blocked: false,
},
TaskOutput {
triggered_fields: vec![],
task_id: "task-2".to_string(),
node_name: "node_b".to_string(),
trigger: crate::pregel::types::TaskTrigger::Pull,
command: Command::default(),
duration: std::time::Duration::ZERO,
error: Some(crate::JunctureError::execution("node_b failed")),
circuit_blocked: false,
},
];
let mut fallback_map = std::collections::HashMap::new();
fallback_map.insert("node_a".to_string(), "node_b".to_string());
fallback_map.insert("node_b".to_string(), "node_a".to_string());
let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].node_name, "node_b");
assert!(handled.contains("node_a"));
}
}