juncture_core/pregel/scheduler.rs
1//! Version tracking and task computation for Pregel engine
2//!
3//! This module provides field version tracking, task scheduling logic,
4//! state write application, and trigger-to-node mapping for the Pregel
5//! execution engine.
6
7use crate::{
8 JunctureError, State,
9 edge::{CompiledEdge, TriggerSource, TriggerTable},
10 pregel::types::{PendingTask, SuperstepResult, TaskOutput},
11 state::FieldsChanged,
12};
13use indexmap::IndexMap;
14use std::{collections::HashMap, collections::HashSet};
15
16/// Field version tracker for Pregel execution
17///
18/// Tracks version numbers for each field in the state to determine
19/// when nodes should be activated based on their trigger fields.
20#[derive(Clone, Debug)]
21pub struct FieldVersionTracker {
22 /// Version number for each field (index = field position)
23 versions: Vec<u64>,
24
25 /// Global maximum version across all fields
26 global_max: u64,
27}
28
29impl FieldVersionTracker {
30 /// Create a new version tracker for the given number of fields
31 ///
32 /// # Panics
33 ///
34 /// Panics if `num_fields` is greater than 64 (the maximum number of
35 /// fields that can be tracked in a `FieldsChanged` bitmask).
36 ///
37 /// # Examples
38 ///
39 /// ```ignore
40 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
41 ///
42 /// let tracker = FieldVersionTracker::new(5);
43 /// assert_eq!(tracker.versions().len(), 5);
44 /// ```
45 #[must_use]
46 pub fn new(num_fields: usize) -> Self {
47 assert!(
48 num_fields <= 64,
49 "Cannot track more than 64 fields (got {num_fields})"
50 );
51
52 Self {
53 versions: vec![0; num_fields],
54 global_max: 0,
55 }
56 }
57
58 /// Bump all field versions (used when state changes globally)
59 ///
60 /// # Examples
61 ///
62 /// ```ignore
63 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
64 /// use juncture_core::state::FieldsChanged;
65 ///
66 /// let mut tracker = FieldVersionTracker::new(3);
67 /// let changed = FieldsChanged(0b101); // fields 0 and 2 changed
68 /// tracker.bump_all(&changed);
69 /// assert_eq!(tracker.get(0), 1);
70 /// assert_eq!(tracker.get(1), 0);
71 /// assert_eq!(tracker.get(2), 1);
72 /// ```
73 pub fn bump_all(&mut self, changed: &FieldsChanged) {
74 for field_idx in 0..self.versions.len() {
75 if changed.has_field(field_idx) {
76 self.bump(field_idx);
77 }
78 }
79 }
80
81 /// Bump version for a specific field
82 ///
83 /// # Examples
84 ///
85 /// ```ignore
86 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
87 ///
88 /// let mut tracker = FieldVersionTracker::new(3);
89 /// tracker.bump(1);
90 /// assert_eq!(tracker.get(1), 1);
91 /// assert_eq!(tracker.get(0), 0);
92 /// ```
93 pub fn bump(&mut self, field_idx: usize) {
94 self.global_max = self.global_max.saturating_add(1);
95 self.versions[field_idx] = self.global_max;
96 }
97
98 /// Get the current version of a field
99 ///
100 /// # Panics
101 ///
102 /// Panics if `field_idx` is out of bounds.
103 ///
104 /// # Examples
105 ///
106 /// ```ignore
107 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
108 ///
109 /// let mut tracker = FieldVersionTracker::new(3);
110 /// tracker.bump(0);
111 /// assert_eq!(tracker.get(0), 1);
112 /// ```
113 #[must_use]
114 pub fn get(&self, field_idx: usize) -> u64 {
115 self.versions[field_idx]
116 }
117
118 /// Get all field versions as a slice
119 ///
120 /// # Examples
121 ///
122 /// ```ignore
123 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
124 ///
125 /// let tracker = FieldVersionTracker::new(3);
126 /// let versions = tracker.versions();
127 /// assert_eq!(versions, &[0, 0, 0]);
128 /// ```
129 #[must_use]
130 pub fn versions(&self) -> &[u64] {
131 &self.versions
132 }
133
134 /// Get the number of fields being tracked
135 ///
136 /// # Examples
137 ///
138 /// ```ignore
139 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
140 ///
141 /// let tracker = FieldVersionTracker::new(5);
142 /// assert_eq!(tracker.len(), 5);
143 /// ```
144 #[must_use]
145 pub const fn len(&self) -> usize {
146 self.versions.len()
147 }
148
149 /// Check if no fields are being tracked
150 #[must_use]
151 pub const fn is_empty(&self) -> bool {
152 self.versions.is_empty()
153 }
154
155 /// Get all field versions as a slice (alias for `versions()`)
156 ///
157 /// # Examples
158 ///
159 /// ```ignore
160 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
161 ///
162 /// let tracker = FieldVersionTracker::new(3);
163 /// assert_eq!(tracker.as_slice(), &[0, 0, 0]);
164 /// ```
165 #[must_use]
166 pub fn as_slice(&self) -> &[u64] {
167 self.versions()
168 }
169
170 /// Get the global maximum version
171 ///
172 /// # Examples
173 ///
174 /// ```ignore
175 /// use juncture_core::pregel::scheduler::FieldVersionTracker;
176 ///
177 /// let mut tracker = FieldVersionTracker::new(3);
178 /// tracker.bump(0);
179 /// tracker.bump(1);
180 /// assert_eq!(tracker.global_max(), 2);
181 /// ```
182 #[must_use]
183 pub const fn global_max(&self) -> u64 {
184 self.global_max
185 }
186}
187
188/// Version tracking for node activation
189///
190/// Tracks which versions each node has seen to determine when it should
191/// be activated based on its trigger fields.
192#[derive(Clone, Debug)]
193pub struct VersionsSeen {
194 /// Map of node name to the field versions it has seen
195 ///
196 /// Uses `IndexMap` for deterministic iteration order, matching `LangGraph` semantics.
197 seen: IndexMap<String, Vec<u64>>,
198}
199
200impl VersionsSeen {
201 /// Create a new version tracker for the given nodes and fields
202 ///
203 /// # Examples
204 ///
205 /// ```ignore
206 /// use juncture_core::pregel::scheduler::VersionsSeen;
207 ///
208 /// let node_names = vec!["node_a".to_string(), "node_b".to_string()];
209 /// let seen = VersionsSeen::new(&node_names, 3);
210 /// assert_eq!(seen.get_seen("node_a"), &[0, 0, 0]);
211 /// ```
212 #[must_use]
213 pub fn new(node_names: &[String], num_fields: usize) -> Self {
214 let seen = node_names
215 .iter()
216 .map(|name| (name.clone(), vec![0; num_fields]))
217 .collect();
218
219 Self { seen }
220 }
221
222 /// Check if a node should be activated based on its trigger fields
223 ///
224 /// Returns `true` if any of the node's trigger fields have new versions
225 /// that the node hasn't seen yet.
226 ///
227 /// # Examples
228 ///
229 /// ```ignore
230 /// use juncture_core::pregel::scheduler::VersionsSeen;
231 ///
232 /// let node_names = vec!["node_a".to_string()];
233 /// let mut seen = VersionsSeen::new(&node_names, 3);
234 ///
235 /// // Node should activate if field 0 has version > what it has seen
236 /// let trigger_fields = vec![0]; // triggers on field 0
237 /// let current = vec![1, 0, 0]; // field 0 is at version 1
238 /// assert!(seen.should_activate("node_a", &trigger_fields, ¤t));
239 /// ```
240 #[must_use]
241 pub fn should_activate(
242 &self,
243 node_name: &str,
244 trigger_fields: &[usize],
245 current: &[u64],
246 ) -> bool {
247 let Some(seen_versions) = self.seen.get(node_name) else {
248 return true; // Node not yet tracked, should activate
249 };
250
251 for &field_idx in trigger_fields {
252 if current[field_idx] > seen_versions[field_idx] {
253 return true;
254 }
255 }
256
257 false
258 }
259
260 /// Mark that a node has consumed the current field versions
261 ///
262 /// # Examples
263 ///
264 /// ```ignore
265 /// use juncture_core::pregel::scheduler::VersionsSeen;
266 ///
267 /// let node_names = vec!["node_a".to_string()];
268 /// let mut seen = VersionsSeen::new(&node_names, 3);
269 ///
270 /// let current = vec![1, 0, 0];
271 /// seen.mark_consumed("node_a", ¤t);
272 ///
273 /// // Now node shouldn't activate for same versions
274 /// assert!(!seen.should_activate("node_a", &[0], ¤t));
275 /// ```
276 pub fn mark_consumed(&mut self, node_name: &str, current: &[u64]) {
277 if let Some(seen_versions) = self.seen.get_mut(node_name) {
278 seen_versions.copy_from_slice(current);
279 }
280 }
281
282 /// Get the versions a node has seen
283 ///
284 /// Returns an empty slice if the node is not tracked.
285 #[must_use]
286 pub fn get_seen(&self, node_name: &str) -> &[u64] {
287 self.seen.get(node_name).map_or(&[], Vec::as_slice)
288 }
289
290 /// Get the versions a node has seen (alias for `get_seen`)
291 ///
292 /// Returns an empty slice if the node is not tracked.
293 #[must_use]
294 pub fn get_versions(&self, node_name: &str) -> &[u64] {
295 self.get_seen(node_name)
296 }
297
298 /// Compute which fields triggered a node to activate
299 ///
300 /// Compares the node's seen versions with current field versions to determine
301 /// which specific fields had updates that caused the node to be scheduled.
302 ///
303 /// # Arguments
304 ///
305 /// * `node_name` - Name of the node to check
306 /// * `trigger_fields` - Field indices that the node subscribes to
307 /// * `current_versions` - Current field versions
308 ///
309 /// # Returns
310 ///
311 /// Vector of field indices that triggered this node (subset of `trigger_fields`)
312 ///
313 /// # Examples
314 ///
315 /// ```ignore
316 /// use juncture_core::pregel::scheduler::VersionsSeen;
317 ///
318 /// let node_names = vec!["node_a".to_string()];
319 /// let mut seen = VersionsSeen::new(&node_names, 3);
320 ///
321 /// let trigger_fields = vec![0, 2]; // node subscribes to fields 0 and 2
322 /// let current = vec![1, 0, 1]; // fields 0 and 2 have new versions
323 /// let triggered = seen.compute_triggered_fields("node_a", &trigger_fields, ¤t);
324 /// assert_eq!(triggered, vec![0, 2]); // both fields triggered
325 /// ```
326 #[must_use]
327 pub fn compute_triggered_fields(
328 &self,
329 node_name: &str,
330 trigger_fields: &[usize],
331 current_versions: &[u64],
332 ) -> Vec<usize> {
333 let Some(seen_versions) = self.seen.get(node_name) else {
334 // Node not yet tracked, all trigger fields are new
335 return trigger_fields.to_vec();
336 };
337
338 trigger_fields
339 .iter()
340 .filter(|&&field_idx| current_versions[field_idx] > seen_versions[field_idx])
341 .copied()
342 .collect()
343 }
344}
345
346/// Compute the next set of tasks to execute
347///
348/// This function determines which nodes should be activated in the next
349/// superstep based on:
350/// 1. Commands returned by completed tasks (highest priority)
351/// 2. Trigger table edges (Fixed and Conditional)
352///
353/// # Arguments
354///
355/// * `completed_tasks` - Tasks that completed in the previous superstep
356/// * `trigger_table` - Graph's trigger table
357/// * `state` - Current state
358///
359/// # Returns
360///
361/// A vector of pending tasks to execute in the next superstep.
362///
363/// # Errors
364///
365/// Returns an error if:
366/// - A conditional edge router fails to execute
367/// - A conditional edge returns no target
368///
369/// # Examples
370///
371/// ```ignore
372/// use juncture_core::pregel::scheduler::compute_next_tasks;
373/// use juncture_core::pregel::types::{TaskOutput, SuperstepResult};
374/// use std::time::Duration;
375///
376/// # let completed_tasks = vec![];
377/// # let trigger_table = TriggerTable::<MyState>::new();
378/// # let state = MyState;
379/// let next_tasks = compute_next_tasks(&completed_tasks, &trigger_table, &state)?;
380/// ```
381pub async fn compute_next_tasks<S: State>(
382 completed_tasks: &[TaskOutput<S>],
383 trigger_table: &TriggerTable<S>,
384 trigger_to_nodes: &TriggerToNodes,
385 state: &S,
386) -> Result<Vec<PendingTask<S>>, JunctureError> {
387 let mut next_tasks = Vec::new();
388 let mut seen_nodes = HashSet::new();
389
390 // First, check if any task returned a Command with explicit routing
391 for task_output in completed_tasks {
392 let command = &task_output.command;
393
394 match &command.goto {
395 crate::Goto::None => {
396 // No explicit routing, use trigger table with reverse mapping optimization
397 // Use TriggerToNodes to efficiently find which nodes should be triggered
398 let triggered =
399 trigger_to_nodes.triggered_nodes(std::slice::from_ref(&task_output.node_name));
400
401 // Filter outgoing edges to only those leading to triggered nodes
402 if let Some(edges) = trigger_table.outgoing.get(&task_output.node_name) {
403 for edge in edges {
404 // Only process edges that lead to triggered nodes
405 if should_process_edge(edge, state, &triggered).await? {
406 process_edge(
407 edge,
408 state,
409 &mut next_tasks,
410 &mut seen_nodes,
411 &task_output.node_name,
412 )
413 .await?;
414 }
415 }
416 }
417 }
418 crate::Goto::Next(target) => {
419 // Route to single target
420 if !seen_nodes.contains(target) {
421 seen_nodes.insert(target.clone());
422 next_tasks.push(PendingTask::pull(
423 uuid::Uuid::new_v4().to_string(),
424 target.clone(),
425 ));
426 }
427 }
428 crate::Goto::Multiple(targets) => {
429 // Route to multiple targets
430 for target in targets {
431 if !seen_nodes.contains(target) {
432 seen_nodes.insert(target.clone());
433 next_tasks.push(PendingTask::pull(
434 uuid::Uuid::new_v4().to_string(),
435 target.clone(),
436 ));
437 }
438 }
439 }
440 crate::Goto::Send(send_targets) => {
441 // Dynamic fan-out with state overrides.
442 // Each Send target creates a separate task even if multiple targets
443 // share the same node name, because each carries a distinct state override.
444 for (idx, target) in send_targets.iter().enumerate() {
445 next_tasks.push(PendingTask::push(
446 uuid::Uuid::new_v4().to_string(),
447 target.node.clone(),
448 idx,
449 target.state.clone(),
450 ));
451 }
452 }
453 crate::Goto::End => {
454 // Termination, no next tasks
455 }
456 }
457 }
458
459 Ok(next_tasks)
460}
461
462/// Check if an edge should be processed based on triggered nodes
463///
464/// For fixed edges, checks if the target is in the triggered set.
465/// For conditional edges, the router is executed to determine the actual target.
466async fn should_process_edge<S: State>(
467 edge: &CompiledEdge<S>,
468 state: &S,
469 triggered_nodes: &HashSet<String>,
470) -> Result<bool, JunctureError> {
471 match edge {
472 CompiledEdge::Fixed { target } => Ok(triggered_nodes.contains(target)),
473 CompiledEdge::Conditional { router, .. } => {
474 let route_result = router.route(state).await?;
475 Ok(route_result
476 .as_target()
477 .is_some_and(|t| triggered_nodes.contains(t)))
478 }
479 }
480}
481
482/// Process a single edge and add appropriate tasks
483async fn process_edge<S: State>(
484 edge: &CompiledEdge<S>,
485 state: &S,
486 next_tasks: &mut Vec<PendingTask<S>>,
487 seen_nodes: &mut HashSet<String>,
488 from_node: &str,
489) -> Result<(), JunctureError> {
490 match edge {
491 CompiledEdge::Fixed { target } => {
492 if target != crate::edge::END && !seen_nodes.contains(target) {
493 seen_nodes.insert(target.clone());
494 next_tasks.push(PendingTask::pull(
495 uuid::Uuid::new_v4().to_string(),
496 target.clone(),
497 ));
498 }
499 }
500 CompiledEdge::Conditional { router, .. } => {
501 let route_result = router.route(state).await?;
502 let target = route_result.as_target().ok_or_else(|| {
503 JunctureError::execution(format!(
504 "Conditional edge from '{from_node}' returned no target: {route_result:?}"
505 ))
506 })?;
507
508 if target != crate::edge::END && !seen_nodes.contains(target) {
509 seen_nodes.insert(target.to_string());
510 next_tasks.push(PendingTask::pull(
511 uuid::Uuid::new_v4().to_string(),
512 target.to_string(),
513 ));
514 }
515 }
516 }
517
518 Ok(())
519}
520
521/// Apply writes from completed tasks to the state
522///
523/// Takes outputs from a superstep and applies all updates to the state.
524/// Uses path-based sorting (PULL tasks sorted by node name, PUSH tasks
525/// sorted by send index) for deterministic merge order, matching the
526/// `LangGraph` merge semantics.
527///
528/// Returns [`FieldsChanged`] indicating which fields were modified.
529///
530/// # Arguments
531///
532/// * `state` - Mutable state to apply updates to
533/// * `task_outputs` - Outputs from completed tasks in the superstep
534/// * `field_versions` - Version tracker to bump for changed fields
535///
536/// # Errors
537///
538/// Returns `JunctureError` if a reducer constraint is violated, such as
539/// multiple nodes writing to a replace channel in the same superstep.
540///
541/// # Examples
542///
543/// ```ignore
544/// use juncture_core::pregel::scheduler::{apply_writes, FieldVersionTracker};
545///
546/// let mut state = MyState::default();
547/// let mut tracker = FieldVersionTracker::new(3);
548/// let changed = apply_writes(&mut state, &task_outputs, &mut tracker)?;
549/// ```
550pub fn apply_writes<S: State>(
551 state: &mut S,
552 task_outputs: &[crate::pregel::types::TaskOutput<S>],
553 field_versions: &mut FieldVersionTracker,
554) -> Result<FieldsChanged, JunctureError> {
555 // Check for multiple-writer conflicts on replace fields before applying any writes.
556 // This must happen first so that we reject the entire superstep rather than
557 // silently applying partial writes with last-write-wins semantics.
558 check_replace_conflicts_from_state::<S>(task_outputs)?;
559
560 let mut total_changed = FieldsChanged(0);
561
562 // Sort indices by path-based ordering for deterministic merge
563 // PULL tasks: alphabetical by node name
564 // PUSH tasks: by send index
565 let mut sorted_indices: Vec<usize> = (0..task_outputs.len()).collect();
566 sorted_indices.sort_by(|&a, &b| {
567 let task_a = &task_outputs[a];
568 let task_b = &task_outputs[b];
569 match (&task_a.trigger, &task_b.trigger) {
570 (crate::pregel::types::TaskTrigger::Pull, crate::pregel::types::TaskTrigger::Pull) => {
571 task_a.node_name.cmp(&task_b.node_name)
572 }
573 (
574 crate::pregel::types::TaskTrigger::Push { index: idx_a },
575 crate::pregel::types::TaskTrigger::Push { index: idx_b },
576 ) => idx_a.cmp(idx_b),
577 (
578 crate::pregel::types::TaskTrigger::Pull,
579 crate::pregel::types::TaskTrigger::Push { .. },
580 ) => std::cmp::Ordering::Less,
581 (
582 crate::pregel::types::TaskTrigger::Push { .. },
583 crate::pregel::types::TaskTrigger::Pull,
584 ) => std::cmp::Ordering::Greater,
585 }
586 });
587
588 for idx in sorted_indices {
589 let output = &task_outputs[idx];
590 if let Some(ref update) = output.command.update {
591 let changed = state
592 .try_apply(update.clone())
593 .map_err(|e| JunctureError::invalid_update(e.to_string()))?;
594 total_changed.merge(&changed);
595 }
596 }
597
598 // Bump field versions for all changed fields
599 field_versions.bump_all(&total_changed);
600
601 Ok(total_changed)
602}
603
604/// Channel-to-node reverse mapping for efficient scheduling
605///
606/// When a channel (field) is updated, only the subscribed nodes need
607/// to be checked, reducing scheduling from `O(nodes)` to `O(triggered_nodes)`.
608///
609/// # Examples
610///
611/// ```ignore
612/// use juncture_core::pregel::scheduler::TriggerToNodes;
613///
614/// let trigger_to_nodes = TriggerToNodes::from_trigger_table(&trigger_table);
615/// let triggered = trigger_to_nodes.triggered_nodes(&["field_a".to_string()]);
616/// assert!(triggered.contains("node_x"));
617/// ```
618pub struct TriggerToNodes {
619 mapping: HashMap<String, HashSet<String>>,
620}
621
622impl TriggerToNodes {
623 /// Build from the compiled [`TriggerTable`]
624 ///
625 /// Constructs a reverse mapping from trigger source names to the
626 /// set of nodes that subscribe to each source.
627 #[must_use]
628 pub fn from_trigger_table<S: State>(table: &TriggerTable<S>) -> Self {
629 let mut mapping: HashMap<String, HashSet<String>> = HashMap::new();
630 for (node_name, sources) in &table.incoming {
631 for source in sources {
632 match source {
633 TriggerSource::Edge { from } | TriggerSource::Send { from } => {
634 mapping
635 .entry(from.clone())
636 .or_default()
637 .insert(node_name.clone());
638 }
639 }
640 }
641 }
642 Self { mapping }
643 }
644
645 /// Given updated channel names, return the nodes that should be checked
646 ///
647 /// Returns the union of all node sets subscribed to any of the
648 /// given channels.
649 #[must_use]
650 pub fn triggered_nodes(&self, updated_channels: &[String]) -> HashSet<String> {
651 updated_channels
652 .iter()
653 .filter_map(|ch| self.mapping.get(ch))
654 .flatten()
655 .cloned()
656 .collect()
657 }
658}
659
660// Rust guideline compliant 2026-05-19
661
662impl std::fmt::Debug for TriggerToNodes {
663 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664 f.debug_struct("TriggerToNodes")
665 .field("mapping_len", &self.mapping.len())
666 .finish()
667 }
668}
669
670/// Check for replace conflicts in superstep results
671///
672/// For fields using `ReplaceReducer`, only one node is allowed to write
673/// to that field in a single superstep. This function detects violations
674/// of that constraint.
675///
676/// # Arguments
677///
678/// * `superstep_result` - Results from the completed superstep
679/// * `replace_fields` - Field indices that use `ReplaceReducer`
680///
681/// # Returns
682///
683/// - `Ok(())` if no conflicts
684/// - `Err(JunctureError::Execution)` if conflicts exist
685///
686/// # Errors
687///
688/// Returns an error if multiple nodes wrote to the same replace field.
689///
690/// # Examples
691///
692/// ```ignore
693/// use juncture_core::pregel::scheduler::check_replace_conflicts;
694///
695/// let replace_fields = vec![0, 2]; // fields 0 and 2 use ReplaceReducer
696/// check_replace_conflicts(&superstep_result, &replace_fields)?;
697/// ```
698pub fn check_replace_conflicts<S: State>(
699 superstep_result: &SuperstepResult<S>,
700 replace_fields: &[usize],
701) -> Result<(), JunctureError> {
702 for &field_idx in replace_fields {
703 let writers: Vec<&str> = superstep_result
704 .task_outputs
705 .iter()
706 .filter(|o| {
707 o.command
708 .update
709 .as_ref()
710 .is_some_and(|u| S::field_is_set(u, field_idx))
711 })
712 .map(|o| o.node_name.as_str())
713 .collect();
714
715 if writers.len() > 1 {
716 return Err(JunctureError::execution(format!(
717 "Multiple writers for replace field {field_idx}: {writers:?}"
718 )));
719 }
720 }
721 Ok(())
722}
723
724/// Check for replace conflicts using the state's built-in field indices
725///
726/// Uses `S::replace_field_indices()` and `S::field_is_set()` generated by
727/// the proc-macro to detect multiple-writer violations. This is the preferred
728/// entry point for `apply_writes()` since it avoids the caller needing to
729/// track replace field indices separately.
730///
731/// # Errors
732///
733/// Returns an error if multiple nodes wrote to the same replace field.
734fn check_replace_conflicts_from_state<S: State>(
735 task_outputs: &[crate::pregel::types::TaskOutput<S>],
736) -> Result<(), JunctureError> {
737 let replace_fields = S::replace_field_indices();
738 for &field_idx in replace_fields {
739 let writers: Vec<&str> = task_outputs
740 .iter()
741 .filter(|o| {
742 o.command
743 .update
744 .as_ref()
745 .is_some_and(|u| S::field_is_set(u, field_idx))
746 })
747 .map(|o| o.node_name.as_str())
748 .collect();
749
750 if writers.len() > 1 {
751 return Err(JunctureError::multiple_writers(
752 field_idx,
753 writers.into_iter().map(String::from).collect(),
754 ));
755 }
756 }
757 Ok(())
758}
759
760/// Consume triggered channels after `apply_writes`
761///
762/// This function implements the `consume()` step that happens after
763/// `apply_writes` merges all writes but before `reset_ephemeral()`.
764///
765/// For `ephemeral` fields, this marks the channel's consumed flag, indicating
766/// that the value has been read by the framework. The consumed flag is reset
767/// on the next `update()` call. For other field types, `consume_field()` is
768/// a no-op, making it safe to call on any field index.
769///
770/// # Arguments
771///
772/// * `state` - Mutable state to consume channels on
773/// * `triggered_channels` - Field indices of channels that were triggered
774/// (changed) in the current superstep
775///
776/// # Examples
777///
778/// ```ignore
779/// use juncture_core::pregel::scheduler::consume_triggered_channels;
780///
781/// let triggered_channels = vec![0, 2]; // channels 0 and 2 were triggered
782/// consume_triggered_channels(&mut state, &triggered_channels);
783/// ```
784pub fn consume_triggered_channels<S: State>(state: &mut S, triggered_channels: &[usize]) {
785 for &field_idx in triggered_channels {
786 state.consume_field(field_idx);
787 }
788}
789
790/// Schedule error handler tasks for failed nodes
791///
792/// Scans task outputs for failures (indicated by a present `error` field) and
793/// creates recovery [`PendingTask`]s targeting each failed node's registered
794/// error handler. The error handler map is consulted to find the handler node
795/// name for each failed node.
796///
797/// The recovery tasks use [`TaskTrigger::Pull`] and are appended to the next
798/// superstep's pending task list by the caller (`PregelLoop::after_tick`).
799///
800/// # Arguments
801///
802/// * `task_outputs` - All task outputs from the completed superstep
803/// * `nodes` - All nodes in the graph (used to verify handler existence)
804/// * `error_handler_map` - Maps node names to their error handler node names
805///
806/// # Returns
807///
808/// Vector of pending tasks targeting error handler nodes.
809///
810/// # Examples
811///
812/// ```ignore
813/// use juncture_core::pregel::scheduler::schedule_error_handlers;
814///
815/// let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
816/// for task in recovery_tasks {
817/// // Execute error handler task in next superstep
818/// }
819/// ```
820#[expect(
821 clippy::implicit_hasher,
822 reason = "public API accepts std HashMap; callers typically construct from builder metadata"
823)]
824pub fn schedule_error_handlers<S: State>(
825 task_outputs: &[TaskOutput<S>],
826 nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
827 error_handler_map: &std::collections::HashMap<String, String>,
828) -> Vec<PendingTask<S>> {
829 let mut recovery_tasks = Vec::new();
830
831 for output in task_outputs {
832 let Some(ref error) = output.error else {
833 continue;
834 };
835
836 let Some(handler_name) = error_handler_map.get(&output.node_name) else {
837 continue;
838 };
839
840 // Verify the handler node actually exists in the graph
841 if !nodes.contains_key(handler_name) {
842 tracing::warn!(
843 name: "juncture.error_handler.missing_node",
844 node_name = %output.node_name,
845 handler_name = %handler_name,
846 error = %error,
847 "Error handler node not found in graph, skipping recovery"
848 );
849 continue;
850 }
851
852 recovery_tasks.push(PendingTask::pull(
853 uuid::Uuid::new_v4().to_string(),
854 handler_name.clone(),
855 ));
856 }
857
858 recovery_tasks
859}
860
861/// Get the error handler node name for a given node
862///
863/// Looks up the registered error handler for a node from the provided map.
864/// Returns the handler node name if one is registered, `None` otherwise.
865///
866/// # Arguments
867///
868/// * `node_name` - Name of the node that failed
869/// * `error_handler_map` - Maps node names to error handler node names
870///
871/// # Returns
872///
873/// `Some(error_handler_name)` if an error handler is registered, `None` otherwise
874#[must_use]
875#[allow(
876 dead_code,
877 reason = "tested via unit tests; public API awaiting external consumers"
878)]
879pub fn get_error_handler_node(
880 node_name: &str,
881 error_handler_map: &std::collections::HashMap<String, String>,
882) -> Option<String> {
883 error_handler_map.get(node_name).cloned()
884}
885
886#[cfg(test)]
887mod scheduler_tests {
888 use super::*;
889 use crate::node::IntoNode;
890 use crate::state::FieldVersions;
891
892 #[derive(Clone, Debug, Default)]
893 struct TestState;
894
895 impl State for TestState {
896 type Update = TestUpdate;
897 type FieldVersions = FieldVersions;
898
899 fn apply(&mut self, _: Self::Update) -> FieldsChanged {
900 FieldsChanged(0)
901 }
902
903 fn reset_ephemeral(&mut self) {}
904 }
905
906 #[derive(Clone, Debug, Default, serde::Serialize)]
907 struct TestUpdate;
908
909 #[test]
910 fn test_trigger_to_nodes_from_empty_table() {
911 let table: TriggerTable<TestState> = TriggerTable::default();
912 let ttn = TriggerToNodes::from_trigger_table(&table);
913 assert!(ttn.triggered_nodes(&["node_a".to_string()]).is_empty());
914 }
915
916 #[test]
917 fn test_trigger_to_nodes_with_sources() {
918 let mut table: TriggerTable<TestState> = TriggerTable::default();
919 table.add_incoming(
920 "node_b".to_string(),
921 TriggerSource::Edge {
922 from: "node_a".to_string(),
923 },
924 );
925 table.add_incoming(
926 "node_c".to_string(),
927 TriggerSource::Edge {
928 from: "node_a".to_string(),
929 },
930 );
931 table.add_incoming(
932 "node_c".to_string(),
933 TriggerSource::Edge {
934 from: "node_d".to_string(),
935 },
936 );
937
938 let ttn = TriggerToNodes::from_trigger_table(&table);
939 let triggered = ttn.triggered_nodes(&["node_a".to_string()]);
940 assert!(triggered.contains("node_b"));
941 assert!(triggered.contains("node_c"));
942 assert!(!triggered.contains("node_d"));
943
944 let triggered_d = ttn.triggered_nodes(&["node_d".to_string()]);
945 assert!(triggered_d.contains("node_c"));
946 assert!(!triggered_d.contains("node_b"));
947 }
948
949 #[test]
950 fn test_trigger_to_nodes_debug() {
951 let table: TriggerTable<TestState> = TriggerTable::default();
952 let ttn = TriggerToNodes::from_trigger_table(&table);
953 let debug = format!("{ttn:?}");
954 assert!(debug.contains("TriggerToNodes"));
955 }
956
957 #[test]
958 fn test_apply_writes_empty_outputs() {
959 let mut state = TestState;
960 let mut tracker = FieldVersionTracker::new(3);
961 let outputs: Vec<crate::pregel::types::TaskOutput<TestState>> = Vec::new();
962
963 let changed =
964 apply_writes(&mut state, &outputs, &mut tracker).expect("empty outputs should succeed");
965 assert_eq!(changed.0, 0);
966 }
967
968 #[test]
969 fn test_check_replace_conflicts_empty() {
970 let result: SuperstepResult<TestState> = SuperstepResult {
971 task_outputs: Vec::new(),
972 bubble_ups: Vec::new(),
973 };
974 let replace_fields = vec![0, 1];
975 check_replace_conflicts(&result, &replace_fields).unwrap();
976 }
977
978 #[test]
979 fn test_check_replace_conflicts_no_conflicts() {
980 use crate::Command;
981
982 let task_output_a: crate::pregel::types::TaskOutput<TestState> =
983 crate::pregel::types::TaskOutput {
984 triggered_fields: vec![],
985 task_id: "task_1".to_string(),
986 node_name: "node_a".to_string(),
987 trigger: crate::pregel::types::TaskTrigger::Pull,
988 command: Command::end(),
989 duration: std::time::Duration::from_millis(10),
990 error: None,
991 };
992
993 let result: SuperstepResult<TestState> = SuperstepResult {
994 task_outputs: vec![task_output_a],
995 bubble_ups: Vec::new(),
996 };
997 let replace_fields = vec![0, 1];
998 check_replace_conflicts(&result, &replace_fields).unwrap();
999 }
1000
1001 #[test]
1002 fn test_consume_triggered_channels_empty() {
1003 let mut state = TestState;
1004 let triggered_channels = vec![0usize; 0];
1005 consume_triggered_channels(&mut state, &triggered_channels);
1006 }
1007
1008 #[test]
1009 fn test_consume_triggered_channels_some() {
1010 let mut state = TestState;
1011 let triggered_channels = vec![0, 2];
1012 consume_triggered_channels(&mut state, &triggered_channels);
1013 }
1014
1015 #[test]
1016 fn test_schedule_error_handlers_no_failures() {
1017 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1018 indexmap::IndexMap::new();
1019 let task_outputs: Vec<TaskOutput<TestState>> = Vec::new();
1020 let error_handler_map = std::collections::HashMap::new();
1021
1022 let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1023 assert!(recovery_tasks.is_empty());
1024 }
1025
1026 #[test]
1027 fn test_schedule_error_handlers_with_failure() {
1028 use crate::Command;
1029
1030 let mut nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1031 indexmap::IndexMap::new();
1032 nodes.insert(
1033 "error_handler_a".to_string(),
1034 crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1035 .into_node("error_handler_a"),
1036 );
1037
1038 let task_outputs = vec![TaskOutput {
1039 triggered_fields: vec![],
1040 task_id: "task-1".to_string(),
1041 node_name: "failing_node".to_string(),
1042 command: Command::default(),
1043 duration: std::time::Duration::ZERO,
1044 trigger: crate::pregel::types::TaskTrigger::Pull,
1045 error: Some(crate::JunctureError::execution("test failure")),
1046 }];
1047
1048 let mut error_handler_map = std::collections::HashMap::new();
1049 error_handler_map.insert("failing_node".to_string(), "error_handler_a".to_string());
1050
1051 let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1052 assert_eq!(recovery_tasks.len(), 1);
1053 assert_eq!(recovery_tasks[0].node_name, "error_handler_a");
1054 }
1055
1056 #[test]
1057 fn test_schedule_error_handlers_missing_handler_node() {
1058 use crate::Command;
1059
1060 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1061 indexmap::IndexMap::new();
1062
1063 let task_outputs = vec![TaskOutput {
1064 triggered_fields: vec![],
1065 task_id: "task-1".to_string(),
1066 node_name: "failing_node".to_string(),
1067 command: Command::default(),
1068 duration: std::time::Duration::ZERO,
1069 trigger: crate::pregel::types::TaskTrigger::Pull,
1070 error: Some(crate::JunctureError::execution("test failure")),
1071 }];
1072
1073 let mut error_handler_map = std::collections::HashMap::new();
1074 error_handler_map.insert(
1075 "failing_node".to_string(),
1076 "nonexistent_handler".to_string(),
1077 );
1078
1079 let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1080 assert!(
1081 recovery_tasks.is_empty(),
1082 "handler node not in graph, no recovery task"
1083 );
1084 }
1085
1086 #[test]
1087 fn test_schedule_error_handlers_no_handler_registered() {
1088 use crate::Command;
1089
1090 let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1091 indexmap::IndexMap::new();
1092
1093 let task_outputs = vec![TaskOutput {
1094 triggered_fields: vec![],
1095 task_id: "task-1".to_string(),
1096 node_name: "failing_node".to_string(),
1097 command: Command::default(),
1098 duration: std::time::Duration::ZERO,
1099 trigger: crate::pregel::types::TaskTrigger::Pull,
1100 error: Some(crate::JunctureError::execution("test failure")),
1101 }];
1102
1103 let error_handler_map = std::collections::HashMap::new();
1104
1105 let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1106 assert!(recovery_tasks.is_empty());
1107 }
1108
1109 #[test]
1110 fn test_get_error_handler_node_found() {
1111 let mut error_handler_map = std::collections::HashMap::new();
1112 error_handler_map.insert("node_a".to_string(), "handler_a".to_string());
1113
1114 let handler = get_error_handler_node("node_a", &error_handler_map);
1115 assert_eq!(handler, Some("handler_a".to_string()));
1116 }
1117
1118 #[test]
1119 fn test_get_error_handler_node_not_found() {
1120 let error_handler_map = std::collections::HashMap::new();
1121
1122 let handler = get_error_handler_node("node_a", &error_handler_map);
1123 assert!(handler.is_none());
1124 }
1125}
1126
1127// Rust guideline compliant 2026-05-20
1128
1129// Rust guideline compliant 2026-05-19
1130// Rust guideline compliant 2026-05-20