Skip to main content

juncture_core/pregel/
scheduler.rs

1//! Version tracking and task computation for Pregel engine
2//!
3//! This module provides field version tracking, task scheduling logic,
4//! state write application, and trigger-to-node mapping for the Pregel
5//! execution engine.
6
7use crate::{
8    JunctureError, State,
9    edge::{CompiledEdge, TriggerSource, TriggerTable},
10    pregel::types::{PendingTask, SuperstepResult, TaskOutput},
11    state::FieldsChanged,
12};
13use indexmap::IndexMap;
14use std::{collections::HashMap, collections::HashSet};
15
16/// Field version tracker for Pregel execution
17///
18/// Tracks version numbers for each field in the state to determine
19/// when nodes should be activated based on their trigger fields.
20#[derive(Clone, Debug)]
21pub struct FieldVersionTracker {
22    /// Version number for each field (index = field position)
23    versions: Vec<u64>,
24
25    /// Global maximum version across all fields
26    global_max: u64,
27}
28
29impl FieldVersionTracker {
30    /// Create a new version tracker for the given number of fields
31    ///
32    /// # Panics
33    ///
34    /// Panics if `num_fields` is greater than 64 (the maximum number of
35    /// fields that can be tracked in a `FieldsChanged` bitmask).
36    ///
37    /// # Examples
38    ///
39    /// ```ignore
40    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
41    ///
42    /// let tracker = FieldVersionTracker::new(5);
43    /// assert_eq!(tracker.versions().len(), 5);
44    /// ```
45    #[must_use]
46    pub fn new(num_fields: usize) -> Self {
47        assert!(
48            num_fields <= 64,
49            "Cannot track more than 64 fields (got {num_fields})"
50        );
51
52        Self {
53            versions: vec![0; num_fields],
54            global_max: 0,
55        }
56    }
57
58    /// Bump all field versions (used when state changes globally)
59    ///
60    /// # Examples
61    ///
62    /// ```ignore
63    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
64    /// use juncture_core::state::FieldsChanged;
65    ///
66    /// let mut tracker = FieldVersionTracker::new(3);
67    /// let changed = FieldsChanged(0b101); // fields 0 and 2 changed
68    /// tracker.bump_all(&changed);
69    /// assert_eq!(tracker.get(0), 1);
70    /// assert_eq!(tracker.get(1), 0);
71    /// assert_eq!(tracker.get(2), 1);
72    /// ```
73    pub fn bump_all(&mut self, changed: &FieldsChanged) {
74        for field_idx in 0..self.versions.len() {
75            if changed.has_field(field_idx) {
76                self.bump(field_idx);
77            }
78        }
79    }
80
81    /// Bump version for a specific field
82    ///
83    /// # Examples
84    ///
85    /// ```ignore
86    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
87    ///
88    /// let mut tracker = FieldVersionTracker::new(3);
89    /// tracker.bump(1);
90    /// assert_eq!(tracker.get(1), 1);
91    /// assert_eq!(tracker.get(0), 0);
92    /// ```
93    pub fn bump(&mut self, field_idx: usize) {
94        self.global_max = self.global_max.saturating_add(1);
95        self.versions[field_idx] = self.global_max;
96    }
97
98    /// Get the current version of a field
99    ///
100    /// # Panics
101    ///
102    /// Panics if `field_idx` is out of bounds.
103    ///
104    /// # Examples
105    ///
106    /// ```ignore
107    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
108    ///
109    /// let mut tracker = FieldVersionTracker::new(3);
110    /// tracker.bump(0);
111    /// assert_eq!(tracker.get(0), 1);
112    /// ```
113    #[must_use]
114    pub fn get(&self, field_idx: usize) -> u64 {
115        self.versions[field_idx]
116    }
117
118    /// Get all field versions as a slice
119    ///
120    /// # Examples
121    ///
122    /// ```ignore
123    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
124    ///
125    /// let tracker = FieldVersionTracker::new(3);
126    /// let versions = tracker.versions();
127    /// assert_eq!(versions, &[0, 0, 0]);
128    /// ```
129    #[must_use]
130    pub fn versions(&self) -> &[u64] {
131        &self.versions
132    }
133
134    /// Get the number of fields being tracked
135    ///
136    /// # Examples
137    ///
138    /// ```ignore
139    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
140    ///
141    /// let tracker = FieldVersionTracker::new(5);
142    /// assert_eq!(tracker.len(), 5);
143    /// ```
144    #[must_use]
145    pub const fn len(&self) -> usize {
146        self.versions.len()
147    }
148
149    /// Check if no fields are being tracked
150    #[must_use]
151    pub const fn is_empty(&self) -> bool {
152        self.versions.is_empty()
153    }
154
155    /// Get all field versions as a slice (alias for `versions()`)
156    ///
157    /// # Examples
158    ///
159    /// ```ignore
160    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
161    ///
162    /// let tracker = FieldVersionTracker::new(3);
163    /// assert_eq!(tracker.as_slice(), &[0, 0, 0]);
164    /// ```
165    #[must_use]
166    pub fn as_slice(&self) -> &[u64] {
167        self.versions()
168    }
169
170    /// Get the global maximum version
171    ///
172    /// # Examples
173    ///
174    /// ```ignore
175    /// use juncture_core::pregel::scheduler::FieldVersionTracker;
176    ///
177    /// let mut tracker = FieldVersionTracker::new(3);
178    /// tracker.bump(0);
179    /// tracker.bump(1);
180    /// assert_eq!(tracker.global_max(), 2);
181    /// ```
182    #[must_use]
183    pub const fn global_max(&self) -> u64 {
184        self.global_max
185    }
186}
187
188/// Version tracking for node activation
189///
190/// Tracks which versions each node has seen to determine when it should
191/// be activated based on its trigger fields.
192#[derive(Clone, Debug)]
193pub struct VersionsSeen {
194    /// Map of node name to the field versions it has seen
195    ///
196    /// Uses `IndexMap` for deterministic iteration order, matching `LangGraph` semantics.
197    seen: IndexMap<String, Vec<u64>>,
198}
199
200impl VersionsSeen {
201    /// Create a new version tracker for the given nodes and fields
202    ///
203    /// # Examples
204    ///
205    /// ```ignore
206    /// use juncture_core::pregel::scheduler::VersionsSeen;
207    ///
208    /// let node_names = vec!["node_a".to_string(), "node_b".to_string()];
209    /// let seen = VersionsSeen::new(&node_names, 3);
210    /// assert_eq!(seen.get_seen("node_a"), &[0, 0, 0]);
211    /// ```
212    #[must_use]
213    pub fn new(node_names: &[String], num_fields: usize) -> Self {
214        let seen = node_names
215            .iter()
216            .map(|name| (name.clone(), vec![0; num_fields]))
217            .collect();
218
219        Self { seen }
220    }
221
222    /// Check if a node should be activated based on its trigger fields
223    ///
224    /// Returns `true` if any of the node's trigger fields have new versions
225    /// that the node hasn't seen yet.
226    ///
227    /// # Examples
228    ///
229    /// ```ignore
230    /// use juncture_core::pregel::scheduler::VersionsSeen;
231    ///
232    /// let node_names = vec!["node_a".to_string()];
233    /// let mut seen = VersionsSeen::new(&node_names, 3);
234    ///
235    /// // Node should activate if field 0 has version > what it has seen
236    /// let trigger_fields = vec![0]; // triggers on field 0
237    /// let current = vec![1, 0, 0]; // field 0 is at version 1
238    /// assert!(seen.should_activate("node_a", &trigger_fields, &current));
239    /// ```
240    #[must_use]
241    pub fn should_activate(
242        &self,
243        node_name: &str,
244        trigger_fields: &[usize],
245        current: &[u64],
246    ) -> bool {
247        let Some(seen_versions) = self.seen.get(node_name) else {
248            return true; // Node not yet tracked, should activate
249        };
250
251        for &field_idx in trigger_fields {
252            if current[field_idx] > seen_versions[field_idx] {
253                return true;
254            }
255        }
256
257        false
258    }
259
260    /// Mark that a node has consumed the current field versions
261    ///
262    /// # Examples
263    ///
264    /// ```ignore
265    /// use juncture_core::pregel::scheduler::VersionsSeen;
266    ///
267    /// let node_names = vec!["node_a".to_string()];
268    /// let mut seen = VersionsSeen::new(&node_names, 3);
269    ///
270    /// let current = vec![1, 0, 0];
271    /// seen.mark_consumed("node_a", &current);
272    ///
273    /// // Now node shouldn't activate for same versions
274    /// assert!(!seen.should_activate("node_a", &[0], &current));
275    /// ```
276    pub fn mark_consumed(&mut self, node_name: &str, current: &[u64]) {
277        if let Some(seen_versions) = self.seen.get_mut(node_name) {
278            seen_versions.copy_from_slice(current);
279        }
280    }
281
282    /// Get the versions a node has seen
283    ///
284    /// Returns an empty slice if the node is not tracked.
285    #[must_use]
286    pub fn get_seen(&self, node_name: &str) -> &[u64] {
287        self.seen.get(node_name).map_or(&[], Vec::as_slice)
288    }
289
290    /// Get the versions a node has seen (alias for `get_seen`)
291    ///
292    /// Returns an empty slice if the node is not tracked.
293    #[must_use]
294    pub fn get_versions(&self, node_name: &str) -> &[u64] {
295        self.get_seen(node_name)
296    }
297
298    /// Compute which fields triggered a node to activate
299    ///
300    /// Compares the node's seen versions with current field versions to determine
301    /// which specific fields had updates that caused the node to be scheduled.
302    ///
303    /// # Arguments
304    ///
305    /// * `node_name` - Name of the node to check
306    /// * `trigger_fields` - Field indices that the node subscribes to
307    /// * `current_versions` - Current field versions
308    ///
309    /// # Returns
310    ///
311    /// Vector of field indices that triggered this node (subset of `trigger_fields`)
312    ///
313    /// # Examples
314    ///
315    /// ```ignore
316    /// use juncture_core::pregel::scheduler::VersionsSeen;
317    ///
318    /// let node_names = vec!["node_a".to_string()];
319    /// let mut seen = VersionsSeen::new(&node_names, 3);
320    ///
321    /// let trigger_fields = vec![0, 2]; // node subscribes to fields 0 and 2
322    /// let current = vec![1, 0, 1]; // fields 0 and 2 have new versions
323    /// let triggered = seen.compute_triggered_fields("node_a", &trigger_fields, &current);
324    /// assert_eq!(triggered, vec![0, 2]); // both fields triggered
325    /// ```
326    #[must_use]
327    pub fn compute_triggered_fields(
328        &self,
329        node_name: &str,
330        trigger_fields: &[usize],
331        current_versions: &[u64],
332    ) -> Vec<usize> {
333        let Some(seen_versions) = self.seen.get(node_name) else {
334            // Node not yet tracked, all trigger fields are new
335            return trigger_fields.to_vec();
336        };
337
338        trigger_fields
339            .iter()
340            .filter(|&&field_idx| current_versions[field_idx] > seen_versions[field_idx])
341            .copied()
342            .collect()
343    }
344}
345
346/// Compute the next set of tasks to execute
347///
348/// This function determines which nodes should be activated in the next
349/// superstep based on:
350/// 1. Commands returned by completed tasks (highest priority)
351/// 2. Trigger table edges (Fixed and Conditional)
352///
353/// # Arguments
354///
355/// * `completed_tasks` - Tasks that completed in the previous superstep
356/// * `trigger_table` - Graph's trigger table
357/// * `state` - Current state
358///
359/// # Returns
360///
361/// A vector of pending tasks to execute in the next superstep.
362///
363/// # Errors
364///
365/// Returns an error if:
366/// - A conditional edge router fails to execute
367/// - A conditional edge returns no target
368///
369/// # Examples
370///
371/// ```ignore
372/// use juncture_core::pregel::scheduler::compute_next_tasks;
373/// use juncture_core::pregel::types::{TaskOutput, SuperstepResult};
374/// use std::time::Duration;
375///
376/// # let completed_tasks = vec![];
377/// # let trigger_table = TriggerTable::<MyState>::new();
378/// # let state = MyState;
379/// let next_tasks = compute_next_tasks(&completed_tasks, &trigger_table, &state)?;
380/// ```
381pub async fn compute_next_tasks<S: State>(
382    completed_tasks: &[TaskOutput<S>],
383    trigger_table: &TriggerTable<S>,
384    trigger_to_nodes: &TriggerToNodes,
385    state: &S,
386) -> Result<Vec<PendingTask<S>>, JunctureError> {
387    let mut next_tasks = Vec::new();
388    let mut seen_nodes = HashSet::new();
389
390    // First, check if any task returned a Command with explicit routing
391    for task_output in completed_tasks {
392        let command = &task_output.command;
393
394        match &command.goto {
395            crate::Goto::None => {
396                // No explicit routing, use trigger table with reverse mapping optimization
397                // Use TriggerToNodes to efficiently find which nodes should be triggered
398                let triggered =
399                    trigger_to_nodes.triggered_nodes(std::slice::from_ref(&task_output.node_name));
400
401                // Filter outgoing edges to only those leading to triggered nodes
402                if let Some(edges) = trigger_table.outgoing.get(&task_output.node_name) {
403                    for edge in edges {
404                        // Only process edges that lead to triggered nodes
405                        if should_process_edge(edge, state, &triggered).await? {
406                            process_edge(
407                                edge,
408                                state,
409                                &mut next_tasks,
410                                &mut seen_nodes,
411                                &task_output.node_name,
412                            )
413                            .await?;
414                        }
415                    }
416                }
417            }
418            crate::Goto::Next(target) => {
419                // Route to single target
420                if !seen_nodes.contains(target) {
421                    seen_nodes.insert(target.clone());
422                    next_tasks.push(PendingTask::pull(
423                        uuid::Uuid::new_v4().to_string(),
424                        target.clone(),
425                    ));
426                }
427            }
428            crate::Goto::Multiple(targets) => {
429                // Route to multiple targets
430                for target in targets {
431                    if !seen_nodes.contains(target) {
432                        seen_nodes.insert(target.clone());
433                        next_tasks.push(PendingTask::pull(
434                            uuid::Uuid::new_v4().to_string(),
435                            target.clone(),
436                        ));
437                    }
438                }
439            }
440            crate::Goto::Send(send_targets) => {
441                // Dynamic fan-out with state overrides.
442                // Each Send target creates a separate task even if multiple targets
443                // share the same node name, because each carries a distinct state override.
444                for (idx, target) in send_targets.iter().enumerate() {
445                    next_tasks.push(PendingTask::push(
446                        uuid::Uuid::new_v4().to_string(),
447                        target.node.clone(),
448                        idx,
449                        target.state.clone(),
450                    ));
451                }
452            }
453            crate::Goto::End => {
454                // Termination, no next tasks
455            }
456        }
457    }
458
459    Ok(next_tasks)
460}
461
462/// Check if an edge should be processed based on triggered nodes
463///
464/// For fixed edges, checks if the target is in the triggered set.
465/// For conditional edges, the router is executed to determine the actual target.
466async fn should_process_edge<S: State>(
467    edge: &CompiledEdge<S>,
468    state: &S,
469    triggered_nodes: &HashSet<String>,
470) -> Result<bool, JunctureError> {
471    match edge {
472        CompiledEdge::Fixed { target } => Ok(triggered_nodes.contains(target)),
473        CompiledEdge::Conditional { router, .. } => {
474            let route_result = router.route(state).await?;
475            Ok(route_result
476                .as_target()
477                .is_some_and(|t| triggered_nodes.contains(t)))
478        }
479    }
480}
481
482/// Process a single edge and add appropriate tasks
483async fn process_edge<S: State>(
484    edge: &CompiledEdge<S>,
485    state: &S,
486    next_tasks: &mut Vec<PendingTask<S>>,
487    seen_nodes: &mut HashSet<String>,
488    from_node: &str,
489) -> Result<(), JunctureError> {
490    match edge {
491        CompiledEdge::Fixed { target } => {
492            if target != crate::edge::END && !seen_nodes.contains(target) {
493                seen_nodes.insert(target.clone());
494                next_tasks.push(PendingTask::pull(
495                    uuid::Uuid::new_v4().to_string(),
496                    target.clone(),
497                ));
498            }
499        }
500        CompiledEdge::Conditional { router, .. } => {
501            let route_result = router.route(state).await?;
502            let target = route_result.as_target().ok_or_else(|| {
503                JunctureError::execution(format!(
504                    "Conditional edge from '{from_node}' returned no target: {route_result:?}"
505                ))
506            })?;
507
508            if target != crate::edge::END && !seen_nodes.contains(target) {
509                seen_nodes.insert(target.to_string());
510                next_tasks.push(PendingTask::pull(
511                    uuid::Uuid::new_v4().to_string(),
512                    target.to_string(),
513                ));
514            }
515        }
516    }
517
518    Ok(())
519}
520
521/// Apply writes from completed tasks to the state
522///
523/// Takes outputs from a superstep and applies all updates to the state.
524/// Uses path-based sorting (PULL tasks sorted by node name, PUSH tasks
525/// sorted by send index) for deterministic merge order, matching the
526/// `LangGraph` merge semantics.
527///
528/// Returns [`FieldsChanged`] indicating which fields were modified.
529///
530/// # Arguments
531///
532/// * `state` - Mutable state to apply updates to
533/// * `task_outputs` - Outputs from completed tasks in the superstep
534/// * `field_versions` - Version tracker to bump for changed fields
535///
536/// # Errors
537///
538/// Returns `JunctureError` if a reducer constraint is violated, such as
539/// multiple nodes writing to a replace channel in the same superstep.
540///
541/// # Examples
542///
543/// ```ignore
544/// use juncture_core::pregel::scheduler::{apply_writes, FieldVersionTracker};
545///
546/// let mut state = MyState::default();
547/// let mut tracker = FieldVersionTracker::new(3);
548/// let changed = apply_writes(&mut state, &task_outputs, &mut tracker)?;
549/// ```
550pub fn apply_writes<S: State>(
551    state: &mut S,
552    task_outputs: &[crate::pregel::types::TaskOutput<S>],
553    field_versions: &mut FieldVersionTracker,
554) -> Result<FieldsChanged, JunctureError> {
555    // Check for multiple-writer conflicts on replace fields before applying any writes.
556    // This must happen first so that we reject the entire superstep rather than
557    // silently applying partial writes with last-write-wins semantics.
558    check_replace_conflicts_from_state::<S>(task_outputs)?;
559
560    let mut total_changed = FieldsChanged(0);
561
562    // Sort indices by path-based ordering for deterministic merge
563    // PULL tasks: alphabetical by node name
564    // PUSH tasks: by send index
565    let mut sorted_indices: Vec<usize> = (0..task_outputs.len()).collect();
566    sorted_indices.sort_by(|&a, &b| {
567        let task_a = &task_outputs[a];
568        let task_b = &task_outputs[b];
569        match (&task_a.trigger, &task_b.trigger) {
570            (crate::pregel::types::TaskTrigger::Pull, crate::pregel::types::TaskTrigger::Pull) => {
571                task_a.node_name.cmp(&task_b.node_name)
572            }
573            (
574                crate::pregel::types::TaskTrigger::Push { index: idx_a },
575                crate::pregel::types::TaskTrigger::Push { index: idx_b },
576            ) => idx_a.cmp(idx_b),
577            (
578                crate::pregel::types::TaskTrigger::Pull,
579                crate::pregel::types::TaskTrigger::Push { .. },
580            ) => std::cmp::Ordering::Less,
581            (
582                crate::pregel::types::TaskTrigger::Push { .. },
583                crate::pregel::types::TaskTrigger::Pull,
584            ) => std::cmp::Ordering::Greater,
585        }
586    });
587
588    for idx in sorted_indices {
589        let output = &task_outputs[idx];
590        if let Some(ref update) = output.command.update {
591            let changed = state
592                .try_apply(update.clone())
593                .map_err(|e| JunctureError::invalid_update(e.to_string()))?;
594            total_changed.merge(&changed);
595        }
596    }
597
598    // Bump field versions for all changed fields
599    field_versions.bump_all(&total_changed);
600
601    Ok(total_changed)
602}
603
604/// Channel-to-node reverse mapping for efficient scheduling
605///
606/// When a channel (field) is updated, only the subscribed nodes need
607/// to be checked, reducing scheduling from `O(nodes)` to `O(triggered_nodes)`.
608///
609/// # Examples
610///
611/// ```ignore
612/// use juncture_core::pregel::scheduler::TriggerToNodes;
613///
614/// let trigger_to_nodes = TriggerToNodes::from_trigger_table(&trigger_table);
615/// let triggered = trigger_to_nodes.triggered_nodes(&["field_a".to_string()]);
616/// assert!(triggered.contains("node_x"));
617/// ```
618pub struct TriggerToNodes {
619    mapping: HashMap<String, HashSet<String>>,
620}
621
622impl TriggerToNodes {
623    /// Build from the compiled [`TriggerTable`]
624    ///
625    /// Constructs a reverse mapping from trigger source names to the
626    /// set of nodes that subscribe to each source.
627    #[must_use]
628    pub fn from_trigger_table<S: State>(table: &TriggerTable<S>) -> Self {
629        let mut mapping: HashMap<String, HashSet<String>> = HashMap::new();
630        for (node_name, sources) in &table.incoming {
631            for source in sources {
632                match source {
633                    TriggerSource::Edge { from } | TriggerSource::Send { from } => {
634                        mapping
635                            .entry(from.clone())
636                            .or_default()
637                            .insert(node_name.clone());
638                    }
639                }
640            }
641        }
642        Self { mapping }
643    }
644
645    /// Given updated channel names, return the nodes that should be checked
646    ///
647    /// Returns the union of all node sets subscribed to any of the
648    /// given channels.
649    #[must_use]
650    pub fn triggered_nodes(&self, updated_channels: &[String]) -> HashSet<String> {
651        updated_channels
652            .iter()
653            .filter_map(|ch| self.mapping.get(ch))
654            .flatten()
655            .cloned()
656            .collect()
657    }
658}
659
660// Rust guideline compliant 2026-05-19
661
662impl std::fmt::Debug for TriggerToNodes {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        f.debug_struct("TriggerToNodes")
665            .field("mapping_len", &self.mapping.len())
666            .finish()
667    }
668}
669
670/// Check for replace conflicts in superstep results
671///
672/// For fields using `ReplaceReducer`, only one node is allowed to write
673/// to that field in a single superstep. This function detects violations
674/// of that constraint.
675///
676/// # Arguments
677///
678/// * `superstep_result` - Results from the completed superstep
679/// * `replace_fields` - Field indices that use `ReplaceReducer`
680///
681/// # Returns
682///
683/// - `Ok(())` if no conflicts
684/// - `Err(JunctureError::Execution)` if conflicts exist
685///
686/// # Errors
687///
688/// Returns an error if multiple nodes wrote to the same replace field.
689///
690/// # Examples
691///
692/// ```ignore
693/// use juncture_core::pregel::scheduler::check_replace_conflicts;
694///
695/// let replace_fields = vec![0, 2]; // fields 0 and 2 use ReplaceReducer
696/// check_replace_conflicts(&superstep_result, &replace_fields)?;
697/// ```
698pub fn check_replace_conflicts<S: State>(
699    superstep_result: &SuperstepResult<S>,
700    replace_fields: &[usize],
701) -> Result<(), JunctureError> {
702    for &field_idx in replace_fields {
703        let writers: Vec<&str> = superstep_result
704            .task_outputs
705            .iter()
706            .filter(|o| {
707                o.command
708                    .update
709                    .as_ref()
710                    .is_some_and(|u| S::field_is_set(u, field_idx))
711            })
712            .map(|o| o.node_name.as_str())
713            .collect();
714
715        if writers.len() > 1 {
716            return Err(JunctureError::execution(format!(
717                "Multiple writers for replace field {field_idx}: {writers:?}"
718            )));
719        }
720    }
721    Ok(())
722}
723
724/// Check for replace conflicts using the state's built-in field indices
725///
726/// Uses `S::replace_field_indices()` and `S::field_is_set()` generated by
727/// the proc-macro to detect multiple-writer violations. This is the preferred
728/// entry point for `apply_writes()` since it avoids the caller needing to
729/// track replace field indices separately.
730///
731/// # Errors
732///
733/// Returns an error if multiple nodes wrote to the same replace field.
734fn check_replace_conflicts_from_state<S: State>(
735    task_outputs: &[crate::pregel::types::TaskOutput<S>],
736) -> Result<(), JunctureError> {
737    let replace_fields = S::replace_field_indices();
738    for &field_idx in replace_fields {
739        let writers: Vec<&str> = task_outputs
740            .iter()
741            .filter(|o| {
742                o.command
743                    .update
744                    .as_ref()
745                    .is_some_and(|u| S::field_is_set(u, field_idx))
746            })
747            .map(|o| o.node_name.as_str())
748            .collect();
749
750        if writers.len() > 1 {
751            return Err(JunctureError::multiple_writers(
752                field_idx,
753                writers.into_iter().map(String::from).collect(),
754            ));
755        }
756    }
757    Ok(())
758}
759
760/// Consume triggered channels after `apply_writes`
761///
762/// This function implements the `consume()` step that happens after
763/// `apply_writes` merges all writes but before `reset_ephemeral()`.
764///
765/// For `ephemeral` fields, this marks the channel's consumed flag, indicating
766/// that the value has been read by the framework. The consumed flag is reset
767/// on the next `update()` call. For other field types, `consume_field()` is
768/// a no-op, making it safe to call on any field index.
769///
770/// # Arguments
771///
772/// * `state` - Mutable state to consume channels on
773/// * `triggered_channels` - Field indices of channels that were triggered
774///   (changed) in the current superstep
775///
776/// # Examples
777///
778/// ```ignore
779/// use juncture_core::pregel::scheduler::consume_triggered_channels;
780///
781/// let triggered_channels = vec![0, 2]; // channels 0 and 2 were triggered
782/// consume_triggered_channels(&mut state, &triggered_channels);
783/// ```
784pub fn consume_triggered_channels<S: State>(state: &mut S, triggered_channels: &[usize]) {
785    for &field_idx in triggered_channels {
786        state.consume_field(field_idx);
787    }
788}
789
790/// Schedule error handler tasks for failed nodes
791///
792/// Scans task outputs for failures (indicated by a present `error` field) and
793/// creates recovery [`PendingTask`]s targeting each failed node's registered
794/// error handler. The error handler map is consulted to find the handler node
795/// name for each failed node.
796///
797/// The recovery tasks use [`TaskTrigger::Pull`] and are appended to the next
798/// superstep's pending task list by the caller (`PregelLoop::after_tick`).
799///
800/// # Arguments
801///
802/// * `task_outputs` - All task outputs from the completed superstep
803/// * `nodes` - All nodes in the graph (used to verify handler existence)
804/// * `error_handler_map` - Maps node names to their error handler node names
805///
806/// # Returns
807///
808/// Vector of pending tasks targeting error handler nodes.
809///
810/// # Examples
811///
812/// ```ignore
813/// use juncture_core::pregel::scheduler::schedule_error_handlers;
814///
815/// let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
816/// for task in recovery_tasks {
817///     // Execute error handler task in next superstep
818/// }
819/// ```
820#[expect(
821    clippy::implicit_hasher,
822    reason = "public API accepts std HashMap; callers typically construct from builder metadata"
823)]
824pub fn schedule_error_handlers<S: State>(
825    task_outputs: &[TaskOutput<S>],
826    nodes: &indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<S>>>,
827    error_handler_map: &std::collections::HashMap<String, String>,
828) -> Vec<PendingTask<S>> {
829    let mut recovery_tasks = Vec::new();
830
831    for output in task_outputs {
832        let Some(ref error) = output.error else {
833            continue;
834        };
835
836        let Some(handler_name) = error_handler_map.get(&output.node_name) else {
837            continue;
838        };
839
840        // Verify the handler node actually exists in the graph
841        if !nodes.contains_key(handler_name) {
842            tracing::warn!(
843                name: "juncture.error_handler.missing_node",
844                node_name = %output.node_name,
845                handler_name = %handler_name,
846                error = %error,
847                "Error handler node not found in graph, skipping recovery"
848            );
849            continue;
850        }
851
852        recovery_tasks.push(PendingTask::pull(
853            uuid::Uuid::new_v4().to_string(),
854            handler_name.clone(),
855        ));
856    }
857
858    recovery_tasks
859}
860
861/// Get the error handler node name for a given node
862///
863/// Looks up the registered error handler for a node from the provided map.
864/// Returns the handler node name if one is registered, `None` otherwise.
865///
866/// # Arguments
867///
868/// * `node_name` - Name of the node that failed
869/// * `error_handler_map` - Maps node names to error handler node names
870///
871/// # Returns
872///
873/// `Some(error_handler_name)` if an error handler is registered, `None` otherwise
874#[must_use]
875#[allow(
876    dead_code,
877    reason = "tested via unit tests; public API awaiting external consumers"
878)]
879pub fn get_error_handler_node(
880    node_name: &str,
881    error_handler_map: &std::collections::HashMap<String, String>,
882) -> Option<String> {
883    error_handler_map.get(node_name).cloned()
884}
885
886#[cfg(test)]
887mod scheduler_tests {
888    use super::*;
889    use crate::node::IntoNode;
890    use crate::state::FieldVersions;
891
892    #[derive(Clone, Debug, Default)]
893    struct TestState;
894
895    impl State for TestState {
896        type Update = TestUpdate;
897        type FieldVersions = FieldVersions;
898
899        fn apply(&mut self, _: Self::Update) -> FieldsChanged {
900            FieldsChanged(0)
901        }
902
903        fn reset_ephemeral(&mut self) {}
904    }
905
906    #[derive(Clone, Debug, Default, serde::Serialize)]
907    struct TestUpdate;
908
909    #[test]
910    fn test_trigger_to_nodes_from_empty_table() {
911        let table: TriggerTable<TestState> = TriggerTable::default();
912        let ttn = TriggerToNodes::from_trigger_table(&table);
913        assert!(ttn.triggered_nodes(&["node_a".to_string()]).is_empty());
914    }
915
916    #[test]
917    fn test_trigger_to_nodes_with_sources() {
918        let mut table: TriggerTable<TestState> = TriggerTable::default();
919        table.add_incoming(
920            "node_b".to_string(),
921            TriggerSource::Edge {
922                from: "node_a".to_string(),
923            },
924        );
925        table.add_incoming(
926            "node_c".to_string(),
927            TriggerSource::Edge {
928                from: "node_a".to_string(),
929            },
930        );
931        table.add_incoming(
932            "node_c".to_string(),
933            TriggerSource::Edge {
934                from: "node_d".to_string(),
935            },
936        );
937
938        let ttn = TriggerToNodes::from_trigger_table(&table);
939        let triggered = ttn.triggered_nodes(&["node_a".to_string()]);
940        assert!(triggered.contains("node_b"));
941        assert!(triggered.contains("node_c"));
942        assert!(!triggered.contains("node_d"));
943
944        let triggered_d = ttn.triggered_nodes(&["node_d".to_string()]);
945        assert!(triggered_d.contains("node_c"));
946        assert!(!triggered_d.contains("node_b"));
947    }
948
949    #[test]
950    fn test_trigger_to_nodes_debug() {
951        let table: TriggerTable<TestState> = TriggerTable::default();
952        let ttn = TriggerToNodes::from_trigger_table(&table);
953        let debug = format!("{ttn:?}");
954        assert!(debug.contains("TriggerToNodes"));
955    }
956
957    #[test]
958    fn test_apply_writes_empty_outputs() {
959        let mut state = TestState;
960        let mut tracker = FieldVersionTracker::new(3);
961        let outputs: Vec<crate::pregel::types::TaskOutput<TestState>> = Vec::new();
962
963        let changed =
964            apply_writes(&mut state, &outputs, &mut tracker).expect("empty outputs should succeed");
965        assert_eq!(changed.0, 0);
966    }
967
968    #[test]
969    fn test_check_replace_conflicts_empty() {
970        let result: SuperstepResult<TestState> = SuperstepResult {
971            task_outputs: Vec::new(),
972            bubble_ups: Vec::new(),
973        };
974        let replace_fields = vec![0, 1];
975        check_replace_conflicts(&result, &replace_fields).unwrap();
976    }
977
978    #[test]
979    fn test_check_replace_conflicts_no_conflicts() {
980        use crate::Command;
981
982        let task_output_a: crate::pregel::types::TaskOutput<TestState> =
983            crate::pregel::types::TaskOutput {
984                triggered_fields: vec![],
985                task_id: "task_1".to_string(),
986                node_name: "node_a".to_string(),
987                trigger: crate::pregel::types::TaskTrigger::Pull,
988                command: Command::end(),
989                duration: std::time::Duration::from_millis(10),
990                error: None,
991            };
992
993        let result: SuperstepResult<TestState> = SuperstepResult {
994            task_outputs: vec![task_output_a],
995            bubble_ups: Vec::new(),
996        };
997        let replace_fields = vec![0, 1];
998        check_replace_conflicts(&result, &replace_fields).unwrap();
999    }
1000
1001    #[test]
1002    fn test_consume_triggered_channels_empty() {
1003        let mut state = TestState;
1004        let triggered_channels = vec![0usize; 0];
1005        consume_triggered_channels(&mut state, &triggered_channels);
1006    }
1007
1008    #[test]
1009    fn test_consume_triggered_channels_some() {
1010        let mut state = TestState;
1011        let triggered_channels = vec![0, 2];
1012        consume_triggered_channels(&mut state, &triggered_channels);
1013    }
1014
1015    #[test]
1016    fn test_schedule_error_handlers_no_failures() {
1017        let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1018            indexmap::IndexMap::new();
1019        let task_outputs: Vec<TaskOutput<TestState>> = Vec::new();
1020        let error_handler_map = std::collections::HashMap::new();
1021
1022        let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1023        assert!(recovery_tasks.is_empty());
1024    }
1025
1026    #[test]
1027    fn test_schedule_error_handlers_with_failure() {
1028        use crate::Command;
1029
1030        let mut nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1031            indexmap::IndexMap::new();
1032        nodes.insert(
1033            "error_handler_a".to_string(),
1034            crate::node::NodeFnCommand(|_s: &TestState| async move { Ok(Command::end()) })
1035                .into_node("error_handler_a"),
1036        );
1037
1038        let task_outputs = vec![TaskOutput {
1039            triggered_fields: vec![],
1040            task_id: "task-1".to_string(),
1041            node_name: "failing_node".to_string(),
1042            command: Command::default(),
1043            duration: std::time::Duration::ZERO,
1044            trigger: crate::pregel::types::TaskTrigger::Pull,
1045            error: Some(crate::JunctureError::execution("test failure")),
1046        }];
1047
1048        let mut error_handler_map = std::collections::HashMap::new();
1049        error_handler_map.insert("failing_node".to_string(), "error_handler_a".to_string());
1050
1051        let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1052        assert_eq!(recovery_tasks.len(), 1);
1053        assert_eq!(recovery_tasks[0].node_name, "error_handler_a");
1054    }
1055
1056    #[test]
1057    fn test_schedule_error_handlers_missing_handler_node() {
1058        use crate::Command;
1059
1060        let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1061            indexmap::IndexMap::new();
1062
1063        let task_outputs = vec![TaskOutput {
1064            triggered_fields: vec![],
1065            task_id: "task-1".to_string(),
1066            node_name: "failing_node".to_string(),
1067            command: Command::default(),
1068            duration: std::time::Duration::ZERO,
1069            trigger: crate::pregel::types::TaskTrigger::Pull,
1070            error: Some(crate::JunctureError::execution("test failure")),
1071        }];
1072
1073        let mut error_handler_map = std::collections::HashMap::new();
1074        error_handler_map.insert(
1075            "failing_node".to_string(),
1076            "nonexistent_handler".to_string(),
1077        );
1078
1079        let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1080        assert!(
1081            recovery_tasks.is_empty(),
1082            "handler node not in graph, no recovery task"
1083        );
1084    }
1085
1086    #[test]
1087    fn test_schedule_error_handlers_no_handler_registered() {
1088        use crate::Command;
1089
1090        let nodes: indexmap::IndexMap<String, std::sync::Arc<dyn crate::Node<TestState>>> =
1091            indexmap::IndexMap::new();
1092
1093        let task_outputs = vec![TaskOutput {
1094            triggered_fields: vec![],
1095            task_id: "task-1".to_string(),
1096            node_name: "failing_node".to_string(),
1097            command: Command::default(),
1098            duration: std::time::Duration::ZERO,
1099            trigger: crate::pregel::types::TaskTrigger::Pull,
1100            error: Some(crate::JunctureError::execution("test failure")),
1101        }];
1102
1103        let error_handler_map = std::collections::HashMap::new();
1104
1105        let recovery_tasks = schedule_error_handlers(&task_outputs, &nodes, &error_handler_map);
1106        assert!(recovery_tasks.is_empty());
1107    }
1108
1109    #[test]
1110    fn test_get_error_handler_node_found() {
1111        let mut error_handler_map = std::collections::HashMap::new();
1112        error_handler_map.insert("node_a".to_string(), "handler_a".to_string());
1113
1114        let handler = get_error_handler_node("node_a", &error_handler_map);
1115        assert_eq!(handler, Some("handler_a".to_string()));
1116    }
1117
1118    #[test]
1119    fn test_get_error_handler_node_not_found() {
1120        let error_handler_map = std::collections::HashMap::new();
1121
1122        let handler = get_error_handler_node("node_a", &error_handler_map);
1123        assert!(handler.is_none());
1124    }
1125}
1126
1127// Rust guideline compliant 2026-05-20
1128
1129// Rust guideline compliant 2026-05-19
1130// Rust guideline compliant 2026-05-20