1use crate::{
8 JunctureError, State,
9 edge::{CompiledEdge, TriggerSource, TriggerTable},
10 pregel::types::{PendingTask, SuperstepResult, TaskOutput},
11 state::FieldsChanged,
12};
13use indexmap::IndexMap;
14use std::{collections::HashMap, collections::HashSet};
15
16#[derive(Clone, Debug)]
21pub struct FieldVersionTracker {
22 versions: Vec<u64>,
24
25 global_max: u64,
27}
28
29impl FieldVersionTracker {
30 #[must_use]
46 pub fn new(num_fields: usize) -> Self {
47 assert!(
48 num_fields <= 64,
49 "Cannot track more than 64 fields (got {num_fields})"
50 );
51
52 Self {
53 versions: vec![0; num_fields],
54 global_max: 0,
55 }
56 }
57
58 pub fn bump_all(&mut self, changed: &FieldsChanged) {
74 for field_idx in 0..self.versions.len() {
75 if changed.has_field(field_idx) {
76 self.bump(field_idx);
77 }
78 }
79 }
80
81 pub fn bump(&mut self, field_idx: usize) {
94 self.global_max = self.global_max.saturating_add(1);
95 self.versions[field_idx] = self.global_max;
96 }
97
98 #[must_use]
114 pub fn get(&self, field_idx: usize) -> u64 {
115 self.versions[field_idx]
116 }
117
118 #[must_use]
130 pub fn versions(&self) -> &[u64] {
131 &self.versions
132 }
133
134 #[must_use]
145 pub const fn len(&self) -> usize {
146 self.versions.len()
147 }
148
149 #[must_use]
151 pub const fn is_empty(&self) -> bool {
152 self.versions.is_empty()
153 }
154
155 #[must_use]
166 pub fn as_slice(&self) -> &[u64] {
167 self.versions()
168 }
169
170 #[must_use]
183 pub const fn global_max(&self) -> u64 {
184 self.global_max
185 }
186}
187
188#[derive(Clone, Debug)]
193pub struct VersionsSeen {
194 seen: IndexMap<String, Vec<u64>>,
198}
199
200impl VersionsSeen {
201 #[must_use]
213 pub fn new(node_names: &[String], num_fields: usize) -> Self {
214 let seen = node_names
215 .iter()
216 .map(|name| (name.clone(), vec![0; num_fields]))
217 .collect();
218
219 Self { seen }
220 }
221
222 #[must_use]
241 pub fn should_activate(
242 &self,
243 node_name: &str,
244 trigger_fields: &[usize],
245 current: &[u64],
246 ) -> bool {
247 let Some(seen_versions) = self.seen.get(node_name) else {
248 return true; };
250
251 for &field_idx in trigger_fields {
252 if current[field_idx] > seen_versions[field_idx] {
253 return true;
254 }
255 }
256
257 false
258 }
259
260 pub fn mark_consumed(&mut self, node_name: &str, current: &[u64]) {
277 if let Some(seen_versions) = self.seen.get_mut(node_name) {
278 seen_versions.copy_from_slice(current);
279 }
280 }
281
282 #[must_use]
286 pub fn get_seen(&self, node_name: &str) -> &[u64] {
287 self.seen.get(node_name).map_or(&[], Vec::as_slice)
288 }
289
290 #[must_use]
294 pub fn get_versions(&self, node_name: &str) -> &[u64] {
295 self.get_seen(node_name)
296 }
297
298 #[must_use]
327 pub fn compute_triggered_fields(
328 &self,
329 node_name: &str,
330 trigger_fields: &[usize],
331 current_versions: &[u64],
332 ) -> Vec<usize> {
333 let Some(seen_versions) = self.seen.get(node_name) else {
334 return trigger_fields.to_vec();
336 };
337
338 trigger_fields
339 .iter()
340 .filter(|&&field_idx| current_versions[field_idx] > seen_versions[field_idx])
341 .copied()
342 .collect()
343 }
344}
345
346pub async fn compute_next_tasks<S: State>(
382 completed_tasks: &[TaskOutput<S>],
383 trigger_table: &TriggerTable<S>,
384 trigger_to_nodes: &TriggerToNodes,
385 state: &S,
386) -> Result<Vec<PendingTask<S>>, JunctureError> {
387 let mut next_tasks = Vec::new();
388 let mut seen_nodes = HashSet::new();
389
390 for task_output in completed_tasks {
392 let command = &task_output.command;
393
394 match &command.goto {
395 crate::Goto::None => {
396 let triggered =
399 trigger_to_nodes.triggered_nodes(std::slice::from_ref(&task_output.node_name));
400
401 if let Some(edges) = trigger_table.outgoing.get(&task_output.node_name) {
403 for edge in edges {
404 if should_process_edge(edge, state, &triggered).await? {
406 process_edge(
407 edge,
408 state,
409 &mut next_tasks,
410 &mut seen_nodes,
411 &task_output.node_name,
412 )
413 .await?;
414 }
415 }
416 }
417 }
418 crate::Goto::Next(target) => {
419 if !seen_nodes.contains(target) {
421 seen_nodes.insert(target.clone());
422 next_tasks.push(PendingTask::pull(
423 uuid::Uuid::new_v4().to_string(),
424 target.clone(),
425 ));
426 }
427 }
428 crate::Goto::Multiple(targets) => {
429 for target in targets {
431 if !seen_nodes.contains(target) {
432 seen_nodes.insert(target.clone());
433 next_tasks.push(PendingTask::pull(
434 uuid::Uuid::new_v4().to_string(),
435 target.clone(),
436 ));
437 }
438 }
439 }
440 crate::Goto::Send(send_targets) => {
441 for (idx, target) in send_targets.iter().enumerate() {
445 next_tasks.push(PendingTask::push(
446 uuid::Uuid::new_v4().to_string(),
447 target.node.clone(),
448 idx,
449 target.state.clone(),
450 ));
451 }
452 }
453 crate::Goto::End => {
454 }
456 }
457 }
458
459 Ok(next_tasks)
460}
461
462async fn should_process_edge<S: State>(
467 edge: &CompiledEdge<S>,
468 state: &S,
469 triggered_nodes: &HashSet<String>,
470) -> Result<bool, JunctureError> {
471 match edge {
472 CompiledEdge::Fixed { target } => Ok(triggered_nodes.contains(target)),
473 CompiledEdge::Conditional { router, .. } => {
474 let route_result = router.route(state).await?;
475 Ok(route_result
476 .as_target()
477 .is_some_and(|t| triggered_nodes.contains(t)))
478 }
479 }
480}
481
482async fn process_edge<S: State>(
484 edge: &CompiledEdge<S>,
485 state: &S,
486 next_tasks: &mut Vec<PendingTask<S>>,
487 seen_nodes: &mut HashSet<String>,
488 from_node: &str,
489) -> Result<(), JunctureError> {
490 match edge {
491 CompiledEdge::Fixed { target } => {
492 if target != crate::edge::END && !seen_nodes.contains(target) {
493 seen_nodes.insert(target.clone());
494 next_tasks.push(PendingTask::pull(
495 uuid::Uuid::new_v4().to_string(),
496 target.clone(),
497 ));
498 }
499 }
500 CompiledEdge::Conditional { router, .. } => {
501 let route_result = router.route(state).await?;
502 let target = route_result.as_target().ok_or_else(|| {
503 JunctureError::execution(format!(
504 "Conditional edge from '{from_node}' returned no target: {route_result:?}"
505 ))
506 })?;
507
508 if target != crate::edge::END && !seen_nodes.contains(target) {
509 seen_nodes.insert(target.to_string());
510 next_tasks.push(PendingTask::pull(
511 uuid::Uuid::new_v4().to_string(),
512 target.to_string(),
513 ));
514 }
515 }
516 }
517
518 Ok(())
519}
520
521pub fn apply_writes<S: State>(
551 state: &mut S,
552 task_outputs: &[crate::pregel::types::TaskOutput<S>],
553 field_versions: &mut FieldVersionTracker,
554) -> Result<FieldsChanged, JunctureError> {
555 check_replace_conflicts_from_state::<S>(task_outputs)?;
559
560 let mut total_changed = FieldsChanged(0);
561
562 let mut sorted_indices: Vec<usize> = (0..task_outputs.len()).collect();
566 sorted_indices.sort_by(|&a, &b| {
567 let task_a = &task_outputs[a];
568 let task_b = &task_outputs[b];
569 match (&task_a.trigger, &task_b.trigger) {
570 (crate::pregel::types::TaskTrigger::Pull, crate::pregel::types::TaskTrigger::Pull) => {
571 task_a.node_name.cmp(&task_b.node_name)
572 }
573 (
574 crate::pregel::types::TaskTrigger::Push { index: idx_a },
575 crate::pregel::types::TaskTrigger::Push { index: idx_b },
576 ) => idx_a.cmp(idx_b),
577 (
578 crate::pregel::types::TaskTrigger::Pull,
579 crate::pregel::types::TaskTrigger::Push { .. },
580 ) => std::cmp::Ordering::Less,
581 (
582 crate::pregel::types::TaskTrigger::Push { .. },
583 crate::pregel::types::TaskTrigger::Pull,
584 ) => std::cmp::Ordering::Greater,
585 }
586 });
587
588 for idx in sorted_indices {
589 let output = &task_outputs[idx];
590 if let Some(ref update) = output.command.update {
591 let changed = state
592 .try_apply(update.clone())
593 .map_err(|e| JunctureError::invalid_update(e.to_string()))?;
594 total_changed.merge(&changed);
595 }
596 }
597
598 field_versions.bump_all(&total_changed);
600
601 Ok(total_changed)
602}
603
604pub struct TriggerToNodes {
619 mapping: HashMap<String, HashSet<String>>,
620}
621
622impl TriggerToNodes {
623 #[must_use]
628 pub fn from_trigger_table<S: State>(table: &TriggerTable<S>) -> Self {
629 let mut mapping: HashMap<String, HashSet<String>> = HashMap::new();
630 for (node_name, sources) in &table.incoming {
631 for source in sources {
632 match source {
633 TriggerSource::Edge { from } | TriggerSource::Send { from } => {
634 mapping
635 .entry(from.clone())
636 .or_default()
637 .insert(node_name.clone());
638 }
639 }
640 }
641 }
642 Self { mapping }
643 }
644
645 #[must_use]
650 pub fn triggered_nodes(&self, updated_channels: &[String]) -> HashSet<String> {
651 updated_channels
652 .iter()
653 .filter_map(|ch| self.mapping.get(ch))
654 .flatten()
655 .cloned()
656 .collect()
657 }
658}
659
660impl std::fmt::Debug for TriggerToNodes {
663 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664 f.debug_struct("TriggerToNodes")
665 .field("mapping_len", &self.mapping.len())
666 .finish()
667 }
668}
669
670pub fn check_replace_conflicts<S: State>(
699 superstep_result: &SuperstepResult<S>,
700 replace_fields: &[usize],
701) -> Result<(), JunctureError> {
702 for &field_idx in replace_fields {
703 let writers: Vec<&str> = superstep_result
704 .task_outputs
705 .iter()
706 .filter(|o| {
707 o.command
708 .update
709 .as_ref()
710 .is_some_and(|u| S::field_is_set(u, field_idx))
711 })
712 .map(|o| o.node_name.as_str())
713 .collect();
714
715 if writers.len() > 1 {
716 return Err(JunctureError::execution(format!(
717 "Multiple writers for replace field {field_idx}: {writers:?}"
718 )));
719 }
720 }
721 Ok(())
722}
723
724fn check_replace_conflicts_from_state<S: State>(
735 task_outputs: &[crate::pregel::types::TaskOutput<S>],
736) -> Result<(), JunctureError> {
737 let replace_fields = S::replace_field_indices();
738 for &field_idx in replace_fields {
739 let writers: Vec<&str> = task_outputs
740 .iter()
741 .filter(|o| {
742 o.command
743 .update
744 .as_ref()
745 .is_some_and(|u| S::field_is_set(u, field_idx))
746 })
747 .map(|o| o.node_name.as_str())
748 .collect();
749
750 if writers.len() > 1 {
751 return Err(JunctureError::multiple_writers(
752 field_idx,
753 writers.into_iter().map(String::from).collect(),
754 ));
755 }
756 }
757 Ok(())
758}
759
760pub fn consume_triggered_channels<S: State>(state: &mut S, triggered_channels: &[usize]) {
785 for &field_idx in triggered_channels {
786 state.consume_field(field_idx);
787 }
788}
789
790#[expect(
826 clippy::implicit_hasher,
827 reason = "public API accepts std HashMap; callers typically construct from builder metadata"
828)]
829#[expect(
830 clippy::cognitive_complexity,
831 reason = "function has multiple early-return guards (circuit_blocked, error, fallback_map, missing node, self-reference) that are individually simple but add up"
832)]
833pub fn schedule_fallback_tasks<S: State>(
834 task_outputs: &[TaskOutput<S>],
835 nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
836 fallback_map: &std::collections::HashMap<String, String>,
837) -> (Vec<PendingTask<S>>, std::collections::HashSet<String>) {
838 let mut recovery_tasks = Vec::new();
839 let mut handled_nodes = std::collections::HashSet::new();
840
841 for output in task_outputs {
842 if output.circuit_blocked {
845 continue;
846 }
847
848 let Some(ref error) = output.error else {
849 continue;
850 };
851
852 let Some(fallback_name) = fallback_map.get(&output.node_name) else {
853 continue;
854 };
855
856 if !nodes.contains_key(fallback_name) {
858 tracing::warn!(
859 name: "juncture.fallback.missing_node",
860 node_name = %output.node_name,
861 fallback_name = %fallback_name,
862 error = %error,
863 "Fallback node not found in graph, skipping fallback"
864 );
865 continue;
866 }
867
868 if fallback_name == &output.node_name {
870 tracing::warn!(
871 name: "juncture.fallback.self_reference",
872 node_name = %output.node_name,
873 error = %error,
874 "Node configured as its own fallback, skipping to prevent infinite loop"
875 );
876 continue;
877 }
878
879 if handled_nodes.contains(fallback_name) {
884 tracing::warn!(
885 name: "juncture.fallback.cycle_detected",
886 node_name = %output.node_name,
887 fallback_name = %fallback_name,
888 error = %error,
889 "Fallback cycle detected, skipping to prevent infinite loop"
890 );
891 continue;
892 }
893
894 tracing::info!(
895 name: "juncture.fallback.scheduled",
896 node_name = %output.node_name,
897 fallback_name = %fallback_name,
898 error = %error,
899 "Scheduling fallback node for failed task"
900 );
901
902 recovery_tasks.push(PendingTask::pull(
903 uuid::Uuid::new_v4().to_string(),
904 fallback_name.clone(),
905 ));
906 handled_nodes.insert(output.node_name.clone());
907 }
908
909 (recovery_tasks, handled_nodes)
910}
911
912#[expect(
917 clippy::implicit_hasher,
918 reason = "public API accepts std HashMap; callers typically construct from builder metadata"
919)]
920pub fn schedule_error_handlers_filtered<S: State>(
921 task_outputs: &[TaskOutput<S>],
922 nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
923 error_handler_map: &std::collections::HashMap<String, String>,
924 fallback_handled: &std::collections::HashSet<String>,
925) -> Vec<PendingTask<S>> {
926 let mut recovery_tasks = Vec::new();
927
928 for output in task_outputs {
929 if output.circuit_blocked {
932 continue;
933 }
934
935 let Some(ref error) = output.error else {
936 continue;
937 };
938
939 if fallback_handled.contains(&output.node_name) {
941 continue;
942 }
943
944 let Some(handler_name) = error_handler_map.get(&output.node_name) else {
945 continue;
946 };
947
948 if !nodes.contains_key(handler_name) {
950 tracing::warn!(
951 name: "juncture.error_handler.missing_node",
952 node_name = %output.node_name,
953 handler_name = %handler_name,
954 error = %error,
955 "Error handler node not found in graph, skipping recovery"
956 );
957 continue;
958 }
959
960 recovery_tasks.push(PendingTask::pull(
961 uuid::Uuid::new_v4().to_string(),
962 handler_name.clone(),
963 ));
964 }
965
966 recovery_tasks
967}
968
969#[must_use]
983#[allow(
984 dead_code,
985 reason = "tested via unit tests; public API awaiting external consumers"
986)]
987pub fn get_error_handler_node(
988 node_name: &str,
989 error_handler_map: &std::collections::HashMap<String, String>,
990) -> Option<String> {
991 error_handler_map.get(node_name).cloned()
992}
993
994#[cfg(test)]
995mod scheduler_tests {
996 use super::*;
997 use crate::node::IntoNode;
998 use crate::state::FieldVersions;
999
1000 #[derive(Clone, Debug, Default)]
1001 struct TestState;
1002
1003 impl State for TestState {
1004 type Update = TestUpdate;
1005 type FieldVersions = FieldVersions;
1006
1007 fn apply(&mut self, _: Self::Update) -> FieldsChanged {
1008 FieldsChanged(0)
1009 }
1010
1011 fn reset_ephemeral(&mut self) {}
1012 }
1013
1014 #[derive(Clone, Debug, Default, serde::Serialize)]
1015 struct TestUpdate;
1016
1017 #[test]
1018 fn test_trigger_to_nodes_from_empty_table() {
1019 let table: TriggerTable<TestState> = TriggerTable::default();
1020 let ttn = TriggerToNodes::from_trigger_table(&table);
1021 assert!(ttn.triggered_nodes(&["node_a".to_string()]).is_empty());
1022 }
1023
1024 #[test]
1025 fn test_trigger_to_nodes_with_sources() {
1026 let mut table: TriggerTable<TestState> = TriggerTable::default();
1027 table.add_incoming(
1028 "node_b".to_string(),
1029 TriggerSource::Edge {
1030 from: "node_a".to_string(),
1031 },
1032 );
1033 table.add_incoming(
1034 "node_c".to_string(),
1035 TriggerSource::Edge {
1036 from: "node_a".to_string(),
1037 },
1038 );
1039 table.add_incoming(
1040 "node_c".to_string(),
1041 TriggerSource::Edge {
1042 from: "node_d".to_string(),
1043 },
1044 );
1045
1046 let ttn = TriggerToNodes::from_trigger_table(&table);
1047 let triggered = ttn.triggered_nodes(&["node_a".to_string()]);
1048 assert!(triggered.contains("node_b"));
1049 assert!(triggered.contains("node_c"));
1050 assert!(!triggered.contains("node_d"));
1051
1052 let triggered_d = ttn.triggered_nodes(&["node_d".to_string()]);
1053 assert!(triggered_d.contains("node_c"));
1054 assert!(!triggered_d.contains("node_b"));
1055 }
1056
1057 #[test]
1058 fn test_trigger_to_nodes_debug() {
1059 let table: TriggerTable<TestState> = TriggerTable::default();
1060 let ttn = TriggerToNodes::from_trigger_table(&table);
1061 let debug = format!("{ttn:?}");
1062 assert!(debug.contains("TriggerToNodes"));
1063 }
1064
1065 #[test]
1066 fn test_apply_writes_empty_outputs() {
1067 let mut state = TestState;
1068 let mut tracker = FieldVersionTracker::new(3);
1069 let outputs: Vec<crate::pregel::types::TaskOutput<TestState>> = Vec::new();
1070
1071 let changed =
1072 apply_writes(&mut state, &outputs, &mut tracker).expect("empty outputs should succeed");
1073 assert_eq!(changed.0, 0);
1074 }
1075
1076 #[test]
1077 fn test_check_replace_conflicts_empty() {
1078 let result: SuperstepResult<TestState> = SuperstepResult {
1079 task_outputs: Vec::new(),
1080 bubble_ups: Vec::new(),
1081 };
1082 let replace_fields = vec![0, 1];
1083 check_replace_conflicts(&result, &replace_fields).unwrap();
1084 }
1085
1086 #[test]
1087 fn test_check_replace_conflicts_no_conflicts() {
1088 use crate::Command;
1089
1090 let task_output_a: crate::pregel::types::TaskOutput<TestState> =
1091 crate::pregel::types::TaskOutput {
1092 triggered_fields: vec![],
1093 task_id: "task_1".to_string(),
1094 node_name: "node_a".to_string(),
1095 trigger: crate::pregel::types::TaskTrigger::Pull,
1096 command: Command::end(),
1097 duration: std::time::Duration::from_millis(10),
1098 error: None,
1099 circuit_blocked: false,
1100 };
1101
1102 let result: SuperstepResult<TestState> = SuperstepResult {
1103 task_outputs: vec![task_output_a],
1104 bubble_ups: Vec::new(),
1105 };
1106 let replace_fields = vec![0, 1];
1107 check_replace_conflicts(&result, &replace_fields).unwrap();
1108 }
1109
1110 #[test]
1111 fn test_consume_triggered_channels_empty() {
1112 let mut state = TestState;
1113 let triggered_channels = vec![0usize; 0];
1114 consume_triggered_channels(&mut state, &triggered_channels);
1115 }
1116
1117 #[test]
1118 fn test_consume_triggered_channels_some() {
1119 let mut state = TestState;
1120 let triggered_channels = vec![0, 2];
1121 consume_triggered_channels(&mut state, &triggered_channels);
1122 }
1123
1124 #[test]
1125 fn test_schedule_error_handlers_no_failures() {
1126 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1127 indexmap::IndexMap::new();
1128 let task_outputs: Vec<TaskOutput<TestState>> = Vec::new();
1129 let error_handler_map = std::collections::HashMap::new();
1130
1131 let recovery_tasks = schedule_error_handlers_filtered(
1132 &task_outputs,
1133 &nodes,
1134 &error_handler_map,
1135 &std::collections::HashSet::new(),
1136 );
1137 assert!(recovery_tasks.is_empty());
1138 }
1139
1140 #[test]
1141 fn test_schedule_error_handlers_with_failure() {
1142 use crate::Command;
1143
1144 let mut nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1145 indexmap::IndexMap::new();
1146 nodes.insert(
1147 "error_handler_a".to_string(),
1148 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1149 .into_node("error_handler_a"),
1150 );
1151
1152 let task_outputs = vec![TaskOutput {
1153 triggered_fields: vec![],
1154 task_id: "task-1".to_string(),
1155 node_name: "failing_node".to_string(),
1156 command: Command::default(),
1157 duration: std::time::Duration::ZERO,
1158 trigger: crate::pregel::types::TaskTrigger::Pull,
1159 error: Some(crate::JunctureError::execution("test failure")),
1160 circuit_blocked: false,
1161 }];
1162
1163 let mut error_handler_map = std::collections::HashMap::new();
1164 error_handler_map.insert("failing_node".to_string(), "error_handler_a".to_string());
1165
1166 let recovery_tasks = schedule_error_handlers_filtered(
1167 &task_outputs,
1168 &nodes,
1169 &error_handler_map,
1170 &std::collections::HashSet::new(),
1171 );
1172 assert_eq!(recovery_tasks.len(), 1);
1173 assert_eq!(recovery_tasks[0].node_name, "error_handler_a");
1174 }
1175
1176 #[test]
1177 fn test_schedule_error_handlers_missing_handler_node() {
1178 use crate::Command;
1179
1180 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1181 indexmap::IndexMap::new();
1182
1183 let task_outputs = vec![TaskOutput {
1184 triggered_fields: vec![],
1185 task_id: "task-1".to_string(),
1186 node_name: "failing_node".to_string(),
1187 command: Command::default(),
1188 duration: std::time::Duration::ZERO,
1189 trigger: crate::pregel::types::TaskTrigger::Pull,
1190 error: Some(crate::JunctureError::execution("test failure")),
1191 circuit_blocked: false,
1192 }];
1193
1194 let mut error_handler_map = std::collections::HashMap::new();
1195 error_handler_map.insert(
1196 "failing_node".to_string(),
1197 "nonexistent_handler".to_string(),
1198 );
1199
1200 let recovery_tasks = schedule_error_handlers_filtered(
1201 &task_outputs,
1202 &nodes,
1203 &error_handler_map,
1204 &std::collections::HashSet::new(),
1205 );
1206 assert!(
1207 recovery_tasks.is_empty(),
1208 "handler node not in graph, no recovery task"
1209 );
1210 }
1211
1212 #[test]
1213 fn test_schedule_error_handlers_no_handler_registered() {
1214 use crate::Command;
1215
1216 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1217 indexmap::IndexMap::new();
1218
1219 let task_outputs = vec![TaskOutput {
1220 triggered_fields: vec![],
1221 task_id: "task-1".to_string(),
1222 node_name: "failing_node".to_string(),
1223 command: Command::default(),
1224 duration: std::time::Duration::ZERO,
1225 trigger: crate::pregel::types::TaskTrigger::Pull,
1226 error: Some(crate::JunctureError::execution("test failure")),
1227 circuit_blocked: false,
1228 }];
1229
1230 let error_handler_map = std::collections::HashMap::new();
1231
1232 let recovery_tasks = schedule_error_handlers_filtered(
1233 &task_outputs,
1234 &nodes,
1235 &error_handler_map,
1236 &std::collections::HashSet::new(),
1237 );
1238 assert!(recovery_tasks.is_empty());
1239 }
1240
1241 #[test]
1242 fn test_get_error_handler_node_found() {
1243 let mut error_handler_map = std::collections::HashMap::new();
1244 error_handler_map.insert("node_a".to_string(), "handler_a".to_string());
1245
1246 let handler = get_error_handler_node("node_a", &error_handler_map);
1247 assert_eq!(handler, Some("handler_a".to_string()));
1248 }
1249
1250 #[test]
1251 fn test_get_error_handler_node_not_found() {
1252 let error_handler_map = std::collections::HashMap::new();
1253
1254 let handler = get_error_handler_node("node_a", &error_handler_map);
1255 assert!(handler.is_none());
1256 }
1257
1258 #[test]
1261 fn test_schedule_fallback_tasks_no_errors() {
1262 use crate::Command;
1263
1264 let mut nodes = indexmap::IndexMap::new();
1265 nodes.insert(
1266 "node_a".to_string(),
1267 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1268 .into_node("node_a"),
1269 );
1270
1271 let task_outputs = vec![TaskOutput {
1272 triggered_fields: vec![],
1273 task_id: "task-1".to_string(),
1274 node_name: "node_a".to_string(),
1275 trigger: crate::pregel::types::TaskTrigger::Pull,
1276 command: Command::end(),
1277 duration: std::time::Duration::ZERO,
1278 error: None,
1279 circuit_blocked: false,
1280 }];
1281
1282 let fallback_map = std::collections::HashMap::new();
1283 let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
1284 assert!(tasks.is_empty());
1285 assert!(handled.is_empty());
1286 }
1287
1288 #[test]
1289 fn test_schedule_fallback_tasks_with_fallback() {
1290 use crate::Command;
1291
1292 let mut nodes = indexmap::IndexMap::new();
1293 nodes.insert(
1294 "node_a".to_string(),
1295 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1296 .into_node("node_a"),
1297 );
1298 nodes.insert(
1299 "fallback_a".to_string(),
1300 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1301 .into_node("fallback_a"),
1302 );
1303
1304 let task_outputs = vec![TaskOutput {
1305 triggered_fields: vec![],
1306 task_id: "task-1".to_string(),
1307 node_name: "node_a".to_string(),
1308 trigger: crate::pregel::types::TaskTrigger::Pull,
1309 command: Command::default(),
1310 duration: std::time::Duration::ZERO,
1311 error: Some(crate::JunctureError::execution("test error")),
1312 circuit_blocked: false,
1313 }];
1314
1315 let mut fallback_map = std::collections::HashMap::new();
1316 fallback_map.insert("node_a".to_string(), "fallback_a".to_string());
1317
1318 let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
1319 assert_eq!(tasks.len(), 1);
1320 assert_eq!(tasks[0].node_name, "fallback_a");
1321 assert!(handled.contains("node_a"));
1322 }
1323
1324 #[test]
1325 fn test_schedule_fallback_tasks_skips_circuit_blocked() {
1326 use crate::Command;
1327
1328 let mut nodes = indexmap::IndexMap::new();
1329 nodes.insert(
1330 "node_a".to_string(),
1331 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1332 .into_node("node_a"),
1333 );
1334 nodes.insert(
1335 "fallback_a".to_string(),
1336 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1337 .into_node("fallback_a"),
1338 );
1339
1340 let task_outputs = vec![TaskOutput {
1341 triggered_fields: vec![],
1342 task_id: "task-1".to_string(),
1343 node_name: "node_a".to_string(),
1344 trigger: crate::pregel::types::TaskTrigger::Pull,
1345 command: Command::default(),
1346 duration: std::time::Duration::ZERO,
1347 error: Some(crate::JunctureError::execution("circuit open")),
1348 circuit_blocked: true,
1349 }];
1350
1351 let mut fallback_map = std::collections::HashMap::new();
1352 fallback_map.insert("node_a".to_string(), "fallback_a".to_string());
1353
1354 let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
1355 assert!(tasks.is_empty());
1356 assert!(handled.is_empty());
1357 }
1358
1359 #[test]
1360 fn test_schedule_fallback_tasks_self_reference_guard() {
1361 use crate::Command;
1362
1363 let mut nodes = indexmap::IndexMap::new();
1364 nodes.insert(
1365 "node_a".to_string(),
1366 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1367 .into_node("node_a"),
1368 );
1369
1370 let task_outputs = vec![TaskOutput {
1371 triggered_fields: vec![],
1372 task_id: "task-1".to_string(),
1373 node_name: "node_a".to_string(),
1374 trigger: crate::pregel::types::TaskTrigger::Pull,
1375 command: Command::default(),
1376 duration: std::time::Duration::ZERO,
1377 error: Some(crate::JunctureError::execution("test error")),
1378 circuit_blocked: false,
1379 }];
1380
1381 let mut fallback_map = std::collections::HashMap::new();
1382 fallback_map.insert("node_a".to_string(), "node_a".to_string());
1383
1384 let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
1385 assert!(tasks.is_empty());
1386 assert!(handled.is_empty());
1387 }
1388
1389 #[test]
1390 fn test_schedule_fallback_tasks_cycle_guard() {
1391 use crate::Command;
1392
1393 let mut nodes = indexmap::IndexMap::new();
1394 nodes.insert(
1395 "node_a".to_string(),
1396 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1397 .into_node("node_a"),
1398 );
1399 nodes.insert(
1400 "node_b".to_string(),
1401 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1402 .into_node("node_b"),
1403 );
1404
1405 let task_outputs = vec![
1407 TaskOutput {
1408 triggered_fields: vec![],
1409 task_id: "task-1".to_string(),
1410 node_name: "node_a".to_string(),
1411 trigger: crate::pregel::types::TaskTrigger::Pull,
1412 command: Command::default(),
1413 duration: std::time::Duration::ZERO,
1414 error: Some(crate::JunctureError::execution("node_a failed")),
1415 circuit_blocked: false,
1416 },
1417 TaskOutput {
1418 triggered_fields: vec![],
1419 task_id: "task-2".to_string(),
1420 node_name: "node_b".to_string(),
1421 trigger: crate::pregel::types::TaskTrigger::Pull,
1422 command: Command::default(),
1423 duration: std::time::Duration::ZERO,
1424 error: Some(crate::JunctureError::execution("node_b failed")),
1425 circuit_blocked: false,
1426 },
1427 ];
1428
1429 let mut fallback_map = std::collections::HashMap::new();
1430 fallback_map.insert("node_a".to_string(), "node_b".to_string());
1431 fallback_map.insert("node_b".to_string(), "node_a".to_string());
1432
1433 let (tasks, handled) = schedule_fallback_tasks(&task_outputs, &nodes, &fallback_map);
1434 assert_eq!(tasks.len(), 1);
1437 assert_eq!(tasks[0].node_name, "node_b");
1438 assert!(handled.contains("node_a"));
1439 }
1440}
1441
1442