1use crate::time::Instant;
7use crate::{
8 JunctureError, Node, State,
9 checkpoint::{
10 Checkpoint, CheckpointMetadata, CheckpointSource, DeltaCounters, generate_checkpoint_id,
11 },
12 edge::TriggerTable,
13 interrupt::should_interrupt,
14 pregel::{
15 budget::BudgetTracker,
16 context::ExecutionContext,
17 durability::Durability,
18 runner::execute_superstep,
19 scheduler::{
20 FieldVersionTracker, VersionsSeen, apply_writes, compute_next_tasks,
21 schedule_error_handlers_filtered, schedule_fallback_tasks,
22 },
23 types::{BubbleUp, LoopStatus, PendingTask, SuperstepResult},
24 },
25 state::FieldsChanged,
26 stream::{DebugEvent, StreamEvent},
27};
28use indexmap::IndexMap;
29use std::collections::{HashMap, HashSet};
30use std::sync::Arc;
31use std::sync::atomic::{AtomicBool, Ordering};
32use tokio::sync::mpsc;
33use tokio_util::sync::CancellationToken;
34
35#[derive(Clone, Debug)]
60pub struct RunControl {
61 drain_requested: Arc<AtomicBool>,
62}
63
64impl RunControl {
65 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 drain_requested: Arc::new(AtomicBool::new(false)),
79 }
80 }
81
82 pub fn request_drain(&self) {
96 self.drain_requested.store(true, Ordering::Release);
97 }
98
99 #[must_use]
110 pub fn is_drain_requested(&self) -> bool {
111 self.drain_requested.load(Ordering::Acquire)
112 }
113}
114
115impl Default for RunControl {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121pub struct PregelLoop<S: State> {
126 pub state: S,
128
129 pub nodes: IndexMap<String, Arc<dyn Node<S>>>,
131
132 pub trigger_table: TriggerTable<S>,
134
135 pub field_versions: FieldVersionTracker,
137
138 pub versions_seen: VersionsSeen,
140
141 pub runnable_config: crate::config::RunnableConfig,
143
144 pub cancellation_token: CancellationToken,
146
147 pub stream_tx: Option<mpsc::UnboundedSender<StreamEvent<S>>>,
149
150 pub checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
152
153 pub step: usize,
155
156 pub status: LoopStatus,
158
159 pub pending_tasks: Vec<PendingTask<S>>,
161
162 previous_superstep_changed_fields: FieldsChanged,
168
169 budget_tracker: Option<Arc<BudgetTracker>>,
171
172 run_control: RunControl,
174
175 run_id: String,
177
178 interrupt_rx: Option<mpsc::UnboundedReceiver<crate::interrupt::InterruptSignal>>,
181
182 pending_interrupts: Vec<crate::interrupt::InterruptSignal>,
184
185 scratchpad: crate::interrupt::Scratchpad,
187
188 interrupt_versions_seen: HashMap<String, u64>,
192
193 superstep_start: Option<Instant>,
197
198 error_handler_map: HashMap<String, String>,
204
205 trigger_to_nodes: crate::pregel::scheduler::TriggerToNodes,
211
212 retry_policies: HashMap<String, crate::graph::RetryPolicy>,
218
219 timeout_policies: HashMap<String, crate::pregel::context::TimeoutPolicy>,
226
227 circuit_breaker_configs: HashMap<String, crate::graph::CircuitBreakerConfig>,
233
234 circuit_breaker_states: HashMap<String, crate::graph::CircuitBreakerState>,
239
240 fallback_map: HashMap<String, String>,
246
247 delta_counters: HashMap<String, DeltaCounters>,
253
254 channels_finished: bool,
264}
265
266impl<S: State> std::fmt::Debug for PregelLoop<S> {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 f.debug_struct("PregelLoop")
269 .field("state", &"<state>")
270 .field("nodes", &self.nodes.len())
271 .field("trigger_table", &self.trigger_table)
272 .field("field_versions", &self.field_versions)
273 .field("versions_seen", &self.versions_seen)
274 .field("runnable_config", &self.runnable_config)
275 .field("cancellation_token", &self.cancellation_token)
276 .field("stream_tx", &self.stream_tx.is_some())
277 .field("checkpointer", &self.checkpointer.is_some())
278 .field("step", &self.step)
279 .field("status", &self.status)
280 .field("pending_tasks", &self.pending_tasks)
281 .field(
282 "previous_superstep_changed_fields",
283 &self.previous_superstep_changed_fields,
284 )
285 .field("budget_tracker", &self.budget_tracker.is_some())
286 .field("run_control", &self.run_control)
287 .field("run_id", &self.run_id)
288 .field("interrupt_rx", &self.interrupt_rx.is_some())
289 .field("pending_interrupts", &self.pending_interrupts.len())
290 .field("scratchpad", &self.scratchpad)
291 .field("interrupt_versions_seen", &self.interrupt_versions_seen)
292 .field("superstep_start", &self.superstep_start.is_some())
293 .field("error_handler_map", &self.error_handler_map.len())
294 .field("trigger_to_nodes", &"<cached>")
295 .field(
296 "retry_policies",
297 &self.retry_policies.keys().collect::<Vec<_>>(),
298 )
299 .field(
300 "timeout_policies",
301 &self.timeout_policies.keys().collect::<Vec<_>>(),
302 )
303 .field(
304 "circuit_breaker_configs",
305 &self.circuit_breaker_configs.keys().collect::<Vec<_>>(),
306 )
307 .field(
308 "circuit_breaker_states",
309 &self.circuit_breaker_states.keys().collect::<Vec<_>>(),
310 )
311 .field("fallback_map", &self.fallback_map.len())
312 .field("delta_counters", &self.delta_counters.len())
313 .field("channels_finished", &self.channels_finished)
314 .finish()
315 }
316}
317
318impl<S: State> PregelLoop<S> {
319 pub fn new(
348 state: S,
349 nodes: IndexMap<String, Arc<dyn Node<S>>>,
350 trigger_table: TriggerTable<S>,
351 config: crate::config::RunnableConfig,
352 num_fields: usize,
353 ) -> Result<Self, JunctureError> {
354 Self::with_error_handlers(
355 state,
356 nodes,
357 trigger_table,
358 config,
359 num_fields,
360 HashMap::new(),
361 )
362 }
363
364 pub fn with_error_handlers(
384 state: S,
385 nodes: IndexMap<String, Arc<dyn Node<S>>>,
386 trigger_table: TriggerTable<S>,
387 config: crate::config::RunnableConfig,
388 num_fields: usize,
389 error_handler_map: HashMap<String, String>,
390 ) -> Result<Self, JunctureError> {
391 let node_names: Vec<String> = nodes.keys().cloned().collect();
392 let field_versions = FieldVersionTracker::new(num_fields);
393 let versions_seen = VersionsSeen::new(&node_names, num_fields);
394 let cancellation_token = CancellationToken::new();
395
396 let pending_tasks = Self::compute_initial_tasks(&trigger_table);
398
399 let trigger_to_nodes =
401 crate::pregel::scheduler::TriggerToNodes::from_trigger_table(&trigger_table);
402
403 let run_id = uuid::Uuid::new_v4().to_string();
405
406 Ok(Self {
407 state,
408 nodes,
409 trigger_table,
410 field_versions,
411 versions_seen,
412 runnable_config: config,
413 cancellation_token,
414 stream_tx: None,
415 checkpointer: None,
416 step: 0,
417 status: LoopStatus::Running,
418 pending_tasks,
419 previous_superstep_changed_fields: FieldsChanged(0),
420 budget_tracker: None,
421 run_control: RunControl::new(),
422 run_id,
423 interrupt_rx: None,
424 pending_interrupts: Vec::new(),
425 scratchpad: crate::interrupt::Scratchpad::new(),
426 interrupt_versions_seen: HashMap::new(),
427 superstep_start: None,
428 error_handler_map,
429 trigger_to_nodes,
430 retry_policies: HashMap::new(),
431 timeout_policies: HashMap::new(),
432 circuit_breaker_configs: HashMap::new(),
433 circuit_breaker_states: HashMap::new(),
434 fallback_map: HashMap::new(),
435 delta_counters: HashMap::new(),
436 channels_finished: false,
437 })
438 }
439
440 pub fn set_stream_sender(&mut self, tx: mpsc::UnboundedSender<StreamEvent<S>>) {
451 self.stream_tx = Some(tx);
452 }
453
454 pub fn set_checkpointer(&mut self, saver: Arc<dyn crate::checkpoint::CheckpointSaver>) {
456 self.checkpointer = Some(saver);
457 }
458
459 pub fn set_budget_tracker(&mut self, tracker: BudgetTracker) {
475 let shared = Arc::new(tracker);
476 self.runnable_config.budget_tracker = Some(Arc::clone(&shared));
477 self.budget_tracker = Some(shared);
478 }
479
480 pub fn set_retry_policies(&mut self, policies: HashMap<String, crate::graph::RetryPolicy>) {
487 self.retry_policies = policies;
488 }
489
490 pub fn set_timeout_policies(
497 &mut self,
498 policies: HashMap<String, crate::pregel::context::TimeoutPolicy>,
499 ) {
500 self.timeout_policies = policies;
501 }
502
503 pub fn set_circuit_breaker_policies(
510 &mut self,
511 configs: HashMap<String, crate::graph::CircuitBreakerConfig>,
512 ) {
513 let states = configs
515 .keys()
516 .map(|name| (name.clone(), crate::graph::CircuitBreakerState::new()))
517 .collect();
518 self.circuit_breaker_configs = configs;
519 self.circuit_breaker_states = states;
520 }
521
522 pub fn set_fallback_map(&mut self, map: HashMap<String, String>) {
528 self.fallback_map = map;
529 }
530
531 fn check_circuit_breakers(&mut self) -> Vec<crate::pregel::types::TaskOutput<S>> {
537 if self.circuit_breaker_configs.is_empty() {
538 return Vec::new();
539 }
540
541 let mut blocked_outputs = Vec::new();
542 let mut blocked_ids = std::collections::HashSet::new();
543
544 for task in &self.pending_tasks {
545 if let Some(config) = self.circuit_breaker_configs.get(&task.node_name) {
546 if let Some(state) = self.circuit_breaker_states.get_mut(&task.node_name) {
547 if !state.should_allow(config) {
548 let failures = state.consecutive_failures();
549
550 tracing::warn!(
551 name: "juncture.circuit_breaker.open",
552 node_name = %task.node_name,
553 consecutive_failures = failures,
554 "Circuit breaker is open, skipping node execution"
555 );
556
557 blocked_outputs.push(crate::pregel::types::TaskOutput {
558 task_id: task.id.clone(),
559 node_name: task.node_name.clone(),
560 command: crate::Command::default(),
561 duration: std::time::Duration::ZERO,
562 trigger: crate::pregel::types::TaskTrigger::Pull,
563 triggered_fields: Vec::new(),
564 error: Some(crate::JunctureError::execution(format!(
565 "Circuit breaker open for node '{}': {failures} consecutive failures",
566 task.node_name,
567 ))),
568 circuit_blocked: true,
569 });
570 blocked_ids.insert(task.id.clone());
571 }
572 }
573 }
574 }
575
576 self.pending_tasks
578 .retain(|task| !blocked_ids.contains(&task.id));
579
580 if !blocked_outputs.is_empty() {
582 let count = u64::try_from(blocked_outputs.len()).unwrap_or(u64::MAX);
583 self.emit_counter("juncture.circuit_breaker.blocked", count);
584 }
585
586 blocked_outputs
587 }
588
589 fn record_circuit_breaker_results(&mut self, outputs: &[crate::pregel::types::TaskOutput<S>]) {
593 if self.circuit_breaker_configs.is_empty() {
594 return;
595 }
596
597 for output in outputs {
598 if let Some(config) = self.circuit_breaker_configs.get(&output.node_name) {
599 if let Some(state) = self.circuit_breaker_states.get_mut(&output.node_name) {
600 state.mark_half_open_attempt();
604
605 if output.error.is_some() {
606 state.record_failure(config);
607
608 tracing::debug!(
609 name: "juncture.circuit_breaker.failure_recorded",
610 node_name = %output.node_name,
611 consecutive_failures = state.consecutive_failures(),
612 circuit_state = ?state.state(),
613 "Circuit breaker recorded failure"
614 );
615 } else {
616 state.record_success();
617 }
618 }
619 }
620 }
621 }
622
623 fn check_state_size_limit(&self, changed: &crate::FieldsChanged) -> Result<(), JunctureError>
632 where
633 S: serde::Serialize,
634 {
635 let Some(ref limits) = self.runnable_config.resource_limits else {
636 return Ok(());
637 };
638
639 let Some(max_bytes) = limits.max_state_size_bytes else {
640 return Ok(());
641 };
642
643 if changed.is_empty() {
645 return Ok(());
646 }
647
648 let serialized = serde_json::to_vec(&self.state).map_err(|e| {
649 JunctureError::execution(format!("State serialization failed for size check: {e}"))
650 })?;
651
652 if serialized.len() > max_bytes {
653 tracing::warn!(
654 name: "juncture.resource.state_size_exceeded",
655 actual_bytes = serialized.len(),
656 max_bytes = max_bytes,
657 "State size exceeds configured limit"
658 );
659
660 self.emit_counter("juncture.resource.state_size_exceeded", 1);
662
663 return Err(JunctureError::execution(format!(
664 "State size limit exceeded: {} bytes > {} bytes limit",
665 serialized.len(),
666 max_bytes,
667 )));
668 }
669
670 Ok(())
671 }
672
673 #[must_use]
679 pub fn health(&self) -> crate::pregel::types::HealthStatus {
680 use crate::graph::CircuitState;
681 use crate::pregel::types::{NodeHealth, NodeHealthState};
682
683 let mut nodes = std::collections::HashMap::new();
684 let mut open_circuit_breakers = 0;
685
686 for node_name in self.circuit_breaker_configs.keys() {
688 if let Some(state) = self.circuit_breaker_states.get(node_name) {
689 let (status, circuit_state_str) = match state.state() {
690 CircuitState::Closed => {
691 if state.consecutive_failures() > 0 {
692 (NodeHealthState::Degraded, "closed".to_string())
693 } else {
694 (NodeHealthState::Healthy, "closed".to_string())
695 }
696 }
697 CircuitState::HalfOpen => (NodeHealthState::Degraded, "half_open".to_string()),
698 CircuitState::Open => {
699 open_circuit_breakers += 1;
700 (NodeHealthState::Unhealthy, "open".to_string())
701 }
702 };
703
704 nodes.insert(
705 node_name.clone(),
706 NodeHealth {
707 status,
708 consecutive_failures: state.consecutive_failures(),
709 circuit_state: Some(circuit_state_str),
710 },
711 );
712 }
713 }
714
715 for node_name in self.nodes.keys() {
717 if !nodes.contains_key(node_name) {
718 nodes.insert(
719 node_name.clone(),
720 NodeHealth {
721 status: NodeHealthState::Healthy,
722 consecutive_failures: 0,
723 circuit_state: None,
724 },
725 );
726 }
727 }
728
729 let has_degraded_or_unhealthy =
731 nodes.values().any(|n| n.status != NodeHealthState::Healthy);
732 let healthy = !has_degraded_or_unhealthy;
733
734 crate::pregel::types::HealthStatus {
735 healthy,
736 nodes,
737 open_circuit_breakers,
738 }
739 }
740
741 fn compute_initial_tasks(trigger_table: &TriggerTable<S>) -> Vec<PendingTask<S>> {
743 let mut initial_tasks = Vec::new();
745
746 for (node_name, sources) in &trigger_table.incoming {
747 for source in sources {
748 if let crate::edge::TriggerSource::Edge { from } = source
749 && from == crate::edge::START
750 {
751 initial_tasks.push(PendingTask::pull(
752 uuid::Uuid::new_v4().to_string(),
753 node_name.clone(),
754 ));
755 break;
756 }
757 }
758 }
759
760 initial_tasks
761 }
762
763 #[allow(
783 clippy::too_many_lines,
784 reason = "Function contains multiple termination path checks (recursion limit, cancellation, budget, drain, interrupt) with finish_all_channels() calls on each path. Refactoring would reduce clarity by splitting related checks."
785 )]
786 pub fn tick(&mut self) -> Result<bool, JunctureError> {
787 let span = tracing::info_span!(
789 "juncture.graph.invoke",
790 "juncture.thread.id" = ?std::thread::current().id(),
791 "juncture.step" = self.step,
792 "juncture.recursion.limit" = self.runnable_config.recursion_limit,
793 "juncture.graph.name" = ?self.runnable_config.graph_name,
794 "juncture.run.id" = %self.run_id,
795 );
796 let _enter = span.enter();
797
798 if self.step >= self.runnable_config.recursion_limit {
800 self.status = LoopStatus::OutOfSteps;
801 self.emit_counter("juncture.graph.errors", 1);
802 let result: Result<(), JunctureError> = Err(JunctureError::recursion_limit(
803 self.step,
804 self.runnable_config.recursion_limit,
805 ));
806 self.on_graph_end(&result);
807 self.finish_all_channels();
809 let Err(err) = result else {
812 unreachable!("result was constructed as Err");
813 };
814 return Err(err);
815 }
816
817 if self.cancellation_token.is_cancelled() {
819 self.status = LoopStatus::Cancelled;
820 self.on_graph_end(&Ok(()));
821 self.finish_all_channels();
823 return Ok(false);
824 }
825
826 if let Some(tracker) = &self.budget_tracker
828 && let Some(reason) = tracker.check()
829 {
830 self.status = LoopStatus::BudgetExceeded;
831 self.emit_counter("juncture.graph.errors", 1);
832 let result: Result<(), JunctureError> = Err(JunctureError::execution(format!(
833 "Budget exceeded: {reason}"
834 )));
835 self.on_graph_end(&result);
836 self.finish_all_channels();
838 let Err(err) = result else {
839 unreachable!("result was constructed as Err");
840 };
841 return Err(err);
842 }
843
844 if let Some(ref tracker) = self.budget_tracker {
846 let usage = tracker.current_usage();
847 if let Some(ref budget) = self.runnable_config.budget {
848 if let Some(max_tokens) = budget.max_tokens {
849 self.emit_gauge(
850 "juncture.budget.remaining_tokens",
851 max_tokens.saturating_sub(usage.tokens_used),
852 );
853 }
854 if let Some(max_cost) = budget.max_cost_usd {
855 #[allow(
856 clippy::cast_possible_truncation,
857 clippy::cast_sign_loss,
858 reason = "Gauge values are u64; cost is converted to micro-units (6 decimal places) for precision. Truncation is acceptable for gauge display."
859 )]
860 let remaining_micro_usd =
861 ((max_cost - usage.cost_usd).max(0.0) * 1_000_000.0) as u64;
862 self.emit_gauge("juncture.budget.remaining_cost_usd", remaining_micro_usd);
863 }
864 }
865 }
866
867 if self.pending_tasks.is_empty() {
869 if self.run_control.is_drain_requested() {
871 self.finish_all_channels();
874 self.status = LoopStatus::Done;
875 self.on_graph_end(&Ok(()));
876 return Ok(false);
877 }
878
879 self.finish_all_channels();
884 self.status = LoopStatus::Done;
885 self.on_graph_end(&Ok(()));
886 return Ok(false);
887 }
888
889 if let Some(ref interrupt_before_nodes) = self.runnable_config.interrupt_before {
891 let interrupt_before_set: HashSet<String> =
892 interrupt_before_nodes.iter().cloned().collect();
893
894 let channel_versions: HashMap<String, u64> = self
896 .field_versions
897 .versions()
898 .iter()
899 .enumerate()
900 .map(|(idx, ver)| (format!("field_{idx}"), *ver))
901 .collect();
902
903 if let Some(signals) = should_interrupt(
904 &self.pending_tasks,
905 &interrupt_before_set,
906 &HashSet::new(), &channel_versions,
908 &self.interrupt_versions_seen,
909 ) {
910 self.interrupt_versions_seen = channel_versions;
911 self.pending_interrupts.clone_from(&signals);
912 self.status = LoopStatus::InterruptBefore(signals);
913 self.finish_all_channels();
915 return Ok(false);
916 }
917 }
918
919 Ok(true)
920 }
921
922 fn resolve_state_json(tasks: &mut [PendingTask<S>]) -> Result<(), JunctureError>
927 where
928 S: serde::de::DeserializeOwned,
929 {
930 for task in tasks {
931 if task.state_override.is_none()
932 && let Some(ref json) = task.state_json
933 {
934 let deserialized = serde_json::from_value::<S>(json.clone()).map_err(|e| {
935 JunctureError::execution(format!(
936 "failed to deserialize state_json for task '{}': {e}",
937 task.node_name
938 ))
939 })?;
940 task.state_override = Some(deserialized);
941 }
942 }
943 Ok(())
944 }
945
946 pub async fn execute_superstep(&mut self) -> Result<SuperstepResult<S>, JunctureError>
963 where
964 S: serde::de::DeserializeOwned,
965 S::Update: serde::Serialize,
966 {
967 Self::resolve_state_json(&mut self.pending_tasks)?;
971
972 let circuit_blocked = self.check_circuit_breakers();
976
977 let arc_state: Arc<S> = Arc::new(std::mem::take(&mut self.state));
983 let node_names: Vec<_> = self
984 .pending_tasks
985 .iter()
986 .map(|t| t.node_name.as_str())
987 .collect();
988 let span = tracing::info_span!(
989 "juncture.superstep",
990 step = self.step,
991 num_tasks = self.pending_tasks.len(),
992 "juncture.step.nodes" = ?node_names,
993 "juncture.step.duration_ms" = tracing::field::Empty,
994 );
995 let _enter = span.enter();
996
997 let start = Instant::now();
999 self.superstep_start = Some(start);
1000
1001 if let Some(ref tx) = self.stream_tx {
1003 let _ = tx.send(StreamEvent::Debug(DebugEvent::SuperstepStart {
1004 step: self.step,
1005 pending_nodes: node_names
1006 .iter()
1007 .copied()
1008 .map(std::string::ToString::to_string)
1009 .collect(),
1010 }));
1011 }
1012
1013 let num_tasks = u64::try_from(self.pending_tasks.len()).unwrap_or(u64::MAX);
1015 self.emit_counter("juncture.superstep.tasks", num_tasks);
1016
1017 let (result, interrupt_rx) = execute_superstep(
1018 &self.pending_tasks,
1019 &arc_state,
1020 &self.nodes,
1021 &self.runnable_config,
1022 &self.cancellation_token,
1023 self.checkpointer.as_ref(),
1024 &self.pending_interrupts,
1025 &self.scratchpad,
1026 &self.error_handler_map,
1027 &self.retry_policies,
1028 &self.timeout_policies,
1029 &self.fallback_map,
1030 self.step,
1031 )
1032 .await?;
1033
1034 self.state = match Arc::try_unwrap(arc_state) {
1037 Ok(state) => state,
1038 Err(arc) => {
1039 tracing::warn!(
1040 name: "juncture.state.arc_leak",
1041 step = self.step,
1042 "Arc refcount > 1 after superstep, falling back to clone"
1043 );
1044 S::clone(&*arc)
1045 }
1046 };
1047
1048 for signal in &self.pending_interrupts {
1055 if let Some(ref id) = signal.id {
1056 self.scratchpad.mark_interrupt_processed(id);
1057 }
1058 }
1059
1060 let duration = start.elapsed().as_millis();
1061 tracing::Span::current().record("juncture.step.duration_ms", duration);
1062
1063 let duration_ms = u64::try_from(duration).unwrap_or(u64::MAX);
1067 #[allow(
1068 clippy::cast_precision_loss,
1069 reason = "millisecond durations fit well within f64 precision for histogram recording"
1070 )]
1071 let duration_f64 = duration_ms as f64;
1072 self.emit_histogram("juncture.superstep.duration_ms", duration_f64);
1073
1074 tracing::debug!(
1076 name: "juncture.superstep.duration_ms",
1077 step = self.step,
1078 duration_ms = duration,
1079 );
1080
1081 self.record_circuit_breaker_results(&result.task_outputs);
1083
1084 let mut merged_result = result;
1086 if !circuit_blocked.is_empty() {
1087 merged_result.task_outputs.extend(circuit_blocked);
1088 }
1089
1090 self.interrupt_rx = Some(interrupt_rx);
1093
1094 Ok(merged_result)
1095 }
1096
1097 #[expect(
1114 clippy::too_many_lines,
1115 reason = "after_tick orchestrates multiple sequential phases: apply writes, bump versions, consume channels, emit events including stream_data, compute tasks, drain interrupts, check interrupts, finish channels, increment step"
1116 )]
1117 #[allow(
1118 clippy::cognitive_complexity,
1119 reason = "after_tick orchestrates multiple sequential phases: apply writes, bump versions, consume channels, emit events including stream_data, compute tasks, drain interrupts, check interrupts, finish channels, increment step"
1120 )]
1121 pub async fn after_tick(&mut self, result: SuperstepResult<S>) -> Result<(), JunctureError>
1122 where
1123 S: Clone + serde::Serialize,
1124 {
1125 let versions_before_apply = self.field_versions.versions().to_vec();
1131
1132 let needs_snapshot = self
1135 .runnable_config
1136 .resource_limits
1137 .as_ref()
1138 .is_some_and(|r| r.max_state_size_bytes.is_some());
1139 let state_snapshot = needs_snapshot.then(|| self.state.clone());
1140 let field_versions_snapshot = needs_snapshot.then(|| self.field_versions.clone());
1141 let circuit_breaker_snapshot = needs_snapshot.then(|| self.circuit_breaker_states.clone());
1142
1143 let executed_task_outputs: Vec<_> = result
1150 .task_outputs
1151 .iter()
1152 .filter(|o| !o.circuit_blocked)
1153 .cloned()
1154 .collect();
1155 let total_changed = apply_writes(
1156 &mut self.state,
1157 &executed_task_outputs,
1158 &mut self.field_versions,
1159 )?;
1160
1161 if let Err(e) = self.check_state_size_limit(&total_changed) {
1164 if let (Some(snapshot), Some(versions_snapshot), Some(cb_snapshot)) = (
1165 state_snapshot,
1166 field_versions_snapshot,
1167 circuit_breaker_snapshot,
1168 ) {
1169 self.state = snapshot;
1170 self.field_versions = versions_snapshot;
1171 self.circuit_breaker_states = cb_snapshot;
1172 tracing::warn!(
1173 name: "juncture.resource.state_size_rollback",
1174 step = self.step,
1175 "State size limit exceeded, rolled back to pre-superstep snapshot"
1176 );
1177 }
1178 return Err(e);
1179 }
1180
1181 self.update_delta_counters(&total_changed);
1186
1187 let fields_to_consume = self.previous_superstep_changed_fields.clone();
1198 self.consume_triggered_channels(&fields_to_consume);
1199
1200 for task_output in &result.task_outputs {
1206 if task_output.circuit_blocked {
1207 continue;
1208 }
1209 self.versions_seen
1210 .mark_consumed(&task_output.node_name, &versions_before_apply);
1211 }
1212
1213 self.state.reset_ephemeral();
1215
1216 if let Some(ref tx) = self.stream_tx {
1218 for task_output in &result.task_outputs {
1219 if task_output.circuit_blocked {
1220 continue;
1221 }
1222 let start_event = StreamEvent::TaskStart {
1224 node: task_output.node_name.clone(),
1225 task_id: task_output.task_id.clone(),
1226 step: self.step,
1227 };
1228 let _ = tx.send(start_event);
1229
1230 let end_event = StreamEvent::TaskEnd {
1232 node: task_output.node_name.clone(),
1233 task_id: task_output.task_id.clone(),
1234 step: self.step,
1235 duration_ms: u64::try_from(task_output.duration.as_millis())
1236 .expect("duration should fit in u64"),
1237 };
1238 let _ = tx.send(end_event);
1239
1240 for data in &task_output.command.stream_data {
1244 let custom_event = StreamEvent::Custom {
1245 node: task_output.node_name.clone(),
1246 data: data.clone(),
1247 ns: Vec::new(),
1248 };
1249 let _ = tx.send(custom_event);
1250 }
1251
1252 if let Some(ref update) = task_output.command.update {
1254 let updates_event = StreamEvent::Updates {
1255 node: task_output.node_name.clone(),
1256 update: update.clone(),
1257 step: self.step,
1258 };
1259 let _ = tx.send(updates_event);
1260 }
1261 }
1262
1263 let values_event = StreamEvent::Values {
1265 state: self.state.clone(),
1266 step: self.step,
1267 };
1268 let _ = tx.send(values_event);
1269
1270 if let Some(superstep_start) = self.superstep_start {
1272 let duration_ms =
1273 u64::try_from(superstep_start.elapsed().as_millis()).unwrap_or(u64::MAX);
1274 let end_event = StreamEvent::Debug(DebugEvent::SuperstepEnd {
1275 step: self.step,
1276 duration_ms,
1277 });
1278 let _ = tx.send(end_event);
1279 }
1280 }
1281
1282 let executed_outputs: Vec<_> = result
1286 .task_outputs
1287 .iter()
1288 .filter(|o| !o.circuit_blocked)
1289 .cloned()
1290 .collect();
1291 self.pending_tasks = compute_next_tasks(
1292 &executed_outputs,
1293 &self.trigger_table,
1294 &self.trigger_to_nodes,
1295 &self.state,
1296 )
1297 .await?;
1298
1299 let (fallback_tasks, fallback_handled) =
1307 schedule_fallback_tasks(&executed_outputs, &self.nodes, &self.fallback_map);
1308 if !fallback_tasks.is_empty() {
1309 let existing_nodes: std::collections::HashSet<&str> = self
1310 .pending_tasks
1311 .iter()
1312 .map(|t| t.node_name.as_str())
1313 .collect();
1314 let deduplicated_fallbacks: Vec<_> = fallback_tasks
1315 .into_iter()
1316 .filter(|t| !existing_nodes.contains(t.node_name.as_str()))
1317 .collect();
1318 tracing::debug!(
1319 name: "juncture.fallback.recovery_tasks",
1320 step = self.step,
1321 count = deduplicated_fallbacks.len(),
1322 "Scheduling fallback recovery tasks"
1323 );
1324 self.pending_tasks.extend(deduplicated_fallbacks);
1325 }
1326
1327 let recovery_tasks = schedule_error_handlers_filtered(
1334 &executed_outputs,
1335 &self.nodes,
1336 &self.error_handler_map,
1337 &fallback_handled,
1338 );
1339 if !recovery_tasks.is_empty() {
1340 tracing::debug!(
1341 name: "juncture.error_handler.recovery_tasks",
1342 step = self.step,
1343 count = recovery_tasks.len(),
1344 "Scheduling error handler recovery tasks"
1345 );
1346 self.pending_tasks.extend(recovery_tasks);
1347 }
1348
1349 if let Some(ref tx) = self.stream_tx {
1351 let next_node_names: Vec<String> = self
1352 .pending_tasks
1353 .iter()
1354 .map(|t| t.node_name.clone())
1355 .collect();
1356 for node_name in &next_node_names {
1358 let edge_event = StreamEvent::Debug(DebugEvent::EdgeTraversed {
1359 from: "superstep".to_string(),
1360 to: node_name.clone(),
1361 edge_type: "conditional".to_string(),
1362 });
1363 let _ = tx.send(edge_event);
1364 }
1365 }
1366
1367 self.previous_superstep_changed_fields = total_changed;
1371
1372 self.save_superstep_checkpoint().await;
1377
1378 let mut node_interrupts = Vec::new();
1381 if let Some(mut rx) = self.interrupt_rx.take() {
1382 while let Ok(signal) = rx.try_recv() {
1384 node_interrupts.push(signal);
1385 }
1386 }
1387
1388 if !node_interrupts.is_empty() {
1390 self.pending_interrupts.clone_from(&node_interrupts);
1391 self.status = LoopStatus::InterruptAfter(node_interrupts.clone());
1392
1393 self.emit_interrupt_events(&node_interrupts);
1395
1396 let node = self.interrupt_node_name().to_string();
1398 self.save_interrupt_checkpoint(&node).await;
1399
1400 self.finish_all_channels();
1402 return Ok(());
1403 }
1404
1405 if result.has_bubble_ups() && self.handle_bubble_ups(&result.bubble_ups) {
1407 if self.status.is_interrupted() {
1410 let node = self.interrupt_node_name().to_string();
1411 self.save_interrupt_checkpoint(&node).await;
1412 }
1413 self.finish_all_channels();
1415 return Ok(());
1416 }
1417
1418 if let Some(ref interrupt_after_nodes) = self.runnable_config.interrupt_after {
1420 let interrupt_after_set: HashSet<String> =
1421 interrupt_after_nodes.iter().cloned().collect();
1422
1423 let channel_versions: HashMap<String, u64> = self
1425 .field_versions
1426 .versions()
1427 .iter()
1428 .enumerate()
1429 .map(|(idx, ver)| (format!("field_{idx}"), *ver))
1430 .collect();
1431
1432 if let Some(signals) = should_interrupt(
1433 &self.pending_tasks,
1434 &HashSet::new(), &interrupt_after_set,
1436 &channel_versions,
1437 &self.interrupt_versions_seen,
1438 ) {
1439 self.interrupt_versions_seen = channel_versions;
1440 self.pending_interrupts.clone_from(&signals);
1441 self.status = LoopStatus::InterruptAfter(signals.clone());
1442
1443 self.emit_interrupt_events(&signals);
1445
1446 let node = self.interrupt_node_name().to_string();
1448 self.save_interrupt_checkpoint(&node).await;
1449
1450 self.finish_all_channels();
1452 return Ok(());
1453 }
1454 }
1455
1456 if self.pending_tasks.is_empty() {
1460 self.finish_all_channels();
1461 if self.effective_durability() == Durability::Exit {
1464 self.save_exit_checkpoint().await;
1465 }
1466 }
1467
1468 self.step += 1;
1470
1471 if let Some(ref tracker) = self.budget_tracker {
1473 tracker.report_step();
1474 }
1475
1476 Ok(())
1477 }
1478
1479 fn handle_bubble_ups(&mut self, bubble_ups: &[BubbleUp<S>]) -> bool {
1487 let mut should_stop = false;
1488
1489 for bubble_up in bubble_ups {
1490 match bubble_up {
1491 BubbleUp::Interrupt(graph_interrupt) => {
1492 self.handle_bubble_up_interrupt(graph_interrupt);
1493 should_stop = true;
1494 }
1495 BubbleUp::Drained(drained) => {
1496 self.handle_bubble_up_drained(drained);
1497 should_stop = true;
1498 }
1499 BubbleUp::ParentCommand(cmd) => {
1500 self.handle_bubble_up_parent_command(cmd);
1501 }
1502 }
1503 }
1504
1505 should_stop
1506 }
1507
1508 fn handle_bubble_up_interrupt(
1510 &mut self,
1511 graph_interrupt: &crate::pregel::types::GraphInterrupt,
1512 ) {
1513 tracing::debug!(
1514 step = self.step,
1515 num_signals = graph_interrupt.interrupts.len(),
1516 interrupt_step = graph_interrupt.step,
1517 namespace = ?graph_interrupt.namespace,
1518 "Subgraph interrupt bubbling up to parent"
1519 );
1520
1521 self.pending_interrupts
1522 .clone_from(&graph_interrupt.interrupts);
1523 self.status = LoopStatus::InterruptAfter(graph_interrupt.interrupts.clone());
1524
1525 self.emit_interrupt_events_with_namespace(
1529 &graph_interrupt.interrupts,
1530 &graph_interrupt.namespace,
1531 );
1532 }
1533
1534 fn handle_bubble_up_drained(&mut self, drained: &crate::pregel::types::GraphDrained) {
1536 tracing::debug!(
1537 step = self.step,
1538 reason = %drained.reason,
1539 "Subgraph drained bubbling up to parent"
1540 );
1541
1542 self.status = LoopStatus::Drained;
1543 }
1544
1545 fn handle_bubble_up_parent_command(&mut self, parent_cmd: &crate::command::ParentCommand<S>) {
1547 tracing::debug!(
1548 step = self.step,
1549 source_node = %parent_cmd.source_node,
1550 namespace = %parent_cmd.namespace,
1551 goto = ?parent_cmd.command.goto,
1552 "Subgraph parent command bubbling up"
1553 );
1554
1555 if let Some(ref update) = parent_cmd.command.update {
1556 let changed = self.state.try_apply(update.clone());
1557 match changed {
1558 Ok(changed) => self.field_versions.bump_all(&changed),
1559 Err(err) => {
1560 tracing::warn!(
1561 name: "juncture.subgraph.parent_command.apply_failed",
1562 step = self.step,
1563 source_node = %parent_cmd.source_node,
1564 namespace = %parent_cmd.namespace,
1565 error = %err,
1566 "Failed to apply parent command from subgraph"
1567 );
1568 }
1569 }
1570 }
1571 }
1572
1573 #[must_use]
1581 pub fn into_state(self) -> S {
1582 self.state
1583 }
1584
1585 #[must_use]
1587 pub const fn step(&self) -> usize {
1588 self.step
1589 }
1590
1591 #[must_use]
1593 pub fn run_id(&self) -> &str {
1594 &self.run_id
1595 }
1596
1597 #[must_use]
1599 pub const fn status(&self) -> &LoopStatus {
1600 &self.status
1601 }
1602
1603 #[must_use]
1605 pub fn pending_interrupts(&self) -> &[crate::interrupt::InterruptSignal] {
1606 &self.pending_interrupts
1607 }
1608
1609 #[must_use]
1611 pub const fn scratchpad(&self) -> &crate::interrupt::Scratchpad {
1612 &self.scratchpad
1613 }
1614
1615 pub const fn scratchpad_mut(&mut self) -> &mut crate::interrupt::Scratchpad {
1617 &mut self.scratchpad
1618 }
1619
1620 #[must_use]
1622 pub const fn is_running(&self) -> bool {
1623 matches!(self.status, LoopStatus::Running)
1624 }
1625
1626 #[must_use]
1631 pub fn snapshot_state(&self) -> S
1632 where
1633 S: Clone,
1634 {
1635 self.state.clone()
1636 }
1637
1638 #[must_use]
1657 pub const fn run_control(&self) -> &RunControl {
1658 &self.run_control
1659 }
1660
1661 #[must_use]
1681 #[allow(
1682 clippy::clone_on_copy,
1683 reason = "ExecutionContext requires owned state, not reference"
1684 )]
1685 pub fn as_context(&self) -> ExecutionContext<S>
1686 where
1687 S: Clone,
1688 {
1689 ExecutionContext {
1690 state: self.state.clone(),
1691 field_versions: self.field_versions.clone(),
1692 versions_seen: self.versions_seen.clone(),
1693 pending_writes: vec![],
1694 }
1695 }
1696
1697 #[must_use]
1717 pub fn as_config(&self) -> crate::pregel::context::ExecutionConfig {
1718 crate::pregel::context::ExecutionConfig {
1719 recursion_limit: self.runnable_config.recursion_limit,
1720 interrupt_before: self
1721 .runnable_config
1722 .interrupt_before
1723 .as_ref()
1724 .map_or_else(HashSet::new, |v| v.iter().cloned().collect()),
1725 interrupt_after: self
1726 .runnable_config
1727 .interrupt_after
1728 .as_ref()
1729 .map_or_else(HashSet::new, |v| v.iter().cloned().collect()),
1730 budget: self.runnable_config.budget.clone(),
1731 durability: self.runnable_config.durability.clone().unwrap_or_default(),
1732 retry_policies: std::collections::HashMap::new(),
1733 timeout_policies: std::collections::HashMap::new(),
1734 }
1735 }
1736
1737 #[allow(
1749 clippy::cognitive_complexity,
1750 clippy::too_many_lines,
1751 reason = "durability match arms and checkpoint construction logic are necessarily complex for handling Sync/Async/Exit modes"
1752 )]
1753 async fn save_interrupt_checkpoint(&mut self, node: &str)
1754 where
1755 S: serde::Serialize,
1756 {
1757 let Some(ref checkpointer) = self.checkpointer else {
1758 return;
1759 };
1760
1761 let channel_values = match serde_json::to_value(&self.state) {
1762 Ok(v) => v,
1763 Err(err) => {
1764 tracing::warn!(
1765 name: "juncture.checkpoint.interrupt.serialize_failed",
1766 node = node,
1767 error = %err,
1768 "Failed to serialize state for interrupt checkpoint"
1769 );
1770 return;
1771 }
1772 };
1773
1774 let (channel_versions, new_versions, versions_seen) = self.build_checkpoint_versions();
1775
1776 let checkpoint_id = generate_checkpoint_id();
1777 let cp_id_for_event = checkpoint_id.clone();
1778 let created_at = chrono::Utc::now().to_rfc3339();
1779
1780 let checkpoint = Checkpoint {
1781 id: checkpoint_id,
1782 channel_values,
1783 channel_versions,
1784 versions_seen,
1785 pending_tasks: Vec::new(),
1786 pending_sends: Vec::new(),
1787 pending_interrupts: self.pending_interrupts.clone(),
1788 schema_version: S::schema_version(),
1789 created_at,
1790 v: 1,
1791 new_versions,
1792 counters_since_delta_snapshot: self.build_checkpoint_delta_counters(),
1793 };
1794
1795 let metadata = CheckpointMetadata {
1796 source: CheckpointSource::Interrupt {
1797 node: node.to_string(),
1798 },
1799 step: i64::try_from(self.step).unwrap_or(i64::MAX),
1800 writes: HashMap::new(),
1801 parents: HashMap::new(),
1802 run_id: self.run_id.clone(),
1803 };
1804
1805 let cp_config = self.runnable_config.clone();
1806 let stream_tx_clone = self.stream_tx.clone();
1807 match self.effective_durability() {
1808 Durability::Async => {
1809 let step = self.step;
1810 let node_label = node.to_string();
1811 let checkpointer_arc = Arc::clone(checkpointer);
1812 let metadata_for_event = metadata.clone();
1813 tokio::spawn(async move {
1814 match checkpointer_arc.put(&cp_config, checkpoint, metadata).await {
1815 Ok(_updated_config) => {
1816 tracing::info!(
1817 name: "juncture.checkpoint.put",
1818 checkpoint_step = step,
1819 checkpoint_source = "Interrupt",
1820 "Interrupt checkpoint persisted (async)"
1821 );
1822 if let Some(ref collector) = cp_config.metrics_collector {
1824 collector.inc_counter("juncture.checkpoint.writes", 1);
1825 }
1826 Self::emit_checkpoint_saved_event(
1828 stream_tx_clone.as_ref(),
1829 cp_id_for_event,
1830 metadata_for_event,
1831 step,
1832 );
1833 }
1834 Err(err) => {
1835 tracing::warn!(
1836 name: "juncture.checkpoint.interrupt.save_failed",
1837 node = node_label,
1838 error = %err,
1839 "Failed to save interrupt checkpoint (async)"
1840 );
1841 if let Some(ref collector) = cp_config.metrics_collector {
1843 collector.inc_counter("juncture.checkpoint.errors", 1);
1844 }
1845 }
1846 }
1847 });
1848 self.reset_delta_counters();
1849 }
1850 Durability::Sync | Durability::Exit => {
1851 let metadata_for_event = metadata.clone();
1852 match checkpointer
1853 .put(&self.runnable_config, checkpoint, metadata)
1854 .await
1855 {
1856 Ok(updated_config) => {
1857 self.runnable_config.checkpoint_id = updated_config.checkpoint_id;
1858 self.reset_delta_counters();
1859 self.emit_counter("juncture.checkpoint.writes", 1);
1861 tracing::info!(
1862 name: "juncture.checkpoint.put",
1863 checkpoint_id = %self.runnable_config.checkpoint_id.as_deref().unwrap_or("unknown"),
1864 checkpoint_step = self.step,
1865 checkpoint_source = "Interrupt",
1866 "Interrupt checkpoint persisted"
1867 );
1868 if let Some(ref cp_id) = self.runnable_config.checkpoint_id {
1869 self.on_checkpoint_saved(cp_id, self.step);
1870 Self::emit_checkpoint_saved_event(
1872 self.stream_tx.as_ref(),
1873 cp_id.clone(),
1874 metadata_for_event,
1875 self.step,
1876 );
1877 }
1878 }
1879 Err(err) => {
1880 tracing::warn!(
1881 name: "juncture.checkpoint.interrupt.save_failed",
1882 node = node,
1883 error = %err,
1884 "Failed to save interrupt checkpoint"
1885 );
1886 self.emit_counter("juncture.checkpoint.errors", 1);
1888 }
1889 }
1890 }
1891 }
1892 }
1893
1894 #[allow(
1909 clippy::cognitive_complexity,
1910 clippy::too_many_lines,
1911 reason = "durability match arms and checkpoint construction logic are necessarily complex for handling Sync/Async/Exit modes"
1912 )]
1913 async fn save_superstep_checkpoint(&mut self)
1914 where
1915 S: serde::Serialize,
1916 {
1917 let Some(ref checkpointer) = self.checkpointer else {
1918 return;
1919 };
1920
1921 if self.effective_durability() == Durability::Exit {
1925 return;
1926 }
1927
1928 let needs_full_snapshot = self.should_take_full_snapshot();
1933 tracing::debug!(
1934 name = "juncture.checkpoint.superstep.delta_decision",
1935 step = self.step,
1936 needs_full_snapshot = needs_full_snapshot,
1937 "Delta-channel snapshot frequency evaluation"
1938 );
1939
1940 if !needs_full_snapshot {
1943 tracing::debug!(
1944 name = "juncture.checkpoint.superstep.skipped",
1945 step = self.step,
1946 "Skipped full checkpoint - delta optimization active"
1947 );
1948 return;
1949 }
1950
1951 let channel_values = match serde_json::to_value(&self.state) {
1952 Ok(v) => v,
1953 Err(err) => {
1954 tracing::warn!(
1955 name: "juncture.checkpoint.superstep.serialize_failed",
1956 step = self.step,
1957 error = %err,
1958 "Failed to serialize state for superstep checkpoint"
1959 );
1960 return;
1961 }
1962 };
1963
1964 let (channel_versions, new_versions, versions_seen) = self.build_checkpoint_versions();
1965
1966 let pending_tasks: Vec<crate::checkpoint::CheckpointPendingTask> = self
1969 .pending_tasks
1970 .iter()
1971 .map(|task| crate::checkpoint::CheckpointPendingTask {
1972 id: task.id.clone(),
1973 node: task.node_name.clone(),
1974 triggers: Vec::new(),
1975 state_override: None,
1976 })
1977 .collect();
1978
1979 let checkpoint_id = generate_checkpoint_id();
1980 let cp_id_for_event = checkpoint_id.clone();
1981 let created_at = chrono::Utc::now().to_rfc3339();
1982
1983 let checkpoint = Checkpoint {
1984 id: checkpoint_id,
1985 channel_values,
1986 channel_versions,
1987 versions_seen,
1988 pending_tasks,
1989 pending_sends: Vec::new(),
1990 pending_interrupts: Vec::new(),
1991 schema_version: S::schema_version(),
1992 created_at,
1993 v: 1,
1994 new_versions,
1995 counters_since_delta_snapshot: self.build_checkpoint_delta_counters(),
1996 };
1997
1998 let metadata = CheckpointMetadata {
1999 source: CheckpointSource::Loop,
2000 step: i64::try_from(self.step).unwrap_or(i64::MAX),
2001 writes: HashMap::new(),
2002 parents: HashMap::new(),
2003 run_id: self.run_id.clone(),
2004 };
2005
2006 let cp_config = self.runnable_config.clone();
2007 let stream_tx_clone = self.stream_tx.clone();
2008 match self.effective_durability() {
2009 Durability::Async => {
2010 let step = self.step;
2011 let checkpointer_arc = Arc::clone(checkpointer);
2012 let metadata_for_event = metadata.clone();
2013 tokio::spawn(async move {
2014 match checkpointer_arc.put(&cp_config, checkpoint, metadata).await {
2015 Ok(_updated_config) => {
2016 tracing::info!(
2017 name: "juncture.checkpoint.put",
2018 checkpoint_step = step,
2019 checkpoint_source = "Loop",
2020 "Superstep checkpoint persisted (async)"
2021 );
2022 if let Some(ref collector) = cp_config.metrics_collector {
2024 collector.inc_counter("juncture.checkpoint.writes", 1);
2025 }
2026 Self::emit_checkpoint_saved_event(
2028 stream_tx_clone.as_ref(),
2029 cp_id_for_event,
2030 metadata_for_event,
2031 step,
2032 );
2033 }
2034 Err(err) => {
2035 tracing::warn!(
2036 name: "juncture.checkpoint.superstep.save_failed",
2037 step = step,
2038 error = %err,
2039 "Failed to save superstep checkpoint (async)"
2040 );
2041 if let Some(ref collector) = cp_config.metrics_collector {
2043 collector.inc_counter("juncture.checkpoint.errors", 1);
2044 }
2045 }
2046 }
2047 });
2048 self.reset_delta_counters();
2049 }
2050 Durability::Sync | Durability::Exit => {
2051 let metadata_for_event = metadata.clone();
2052 match checkpointer
2053 .put(&self.runnable_config, checkpoint, metadata)
2054 .await
2055 {
2056 Ok(updated_config) => {
2057 self.runnable_config.checkpoint_id = updated_config.checkpoint_id;
2058 self.reset_delta_counters();
2062 self.emit_counter("juncture.checkpoint.writes", 1);
2064 tracing::info!(
2065 name: "juncture.checkpoint.put",
2066 checkpoint_id = %self.runnable_config.checkpoint_id.as_deref().unwrap_or("unknown"),
2067 checkpoint_step = self.step,
2068 checkpoint_source = "Loop",
2069 "Superstep checkpoint persisted"
2070 );
2071 if let Some(ref cp_id) = self.runnable_config.checkpoint_id {
2072 self.on_checkpoint_saved(cp_id, self.step);
2073 Self::emit_checkpoint_saved_event(
2075 self.stream_tx.as_ref(),
2076 cp_id.clone(),
2077 metadata_for_event,
2078 self.step,
2079 );
2080 }
2081 }
2082 Err(err) => {
2083 tracing::warn!(
2084 name: "juncture.checkpoint.superstep.save_failed",
2085 step = self.step,
2086 error = %err,
2087 "Failed to save superstep checkpoint"
2088 );
2089 self.emit_counter("juncture.checkpoint.errors", 1);
2091 }
2092 }
2093 }
2094 }
2095 }
2096
2097 pub async fn save_pending_interrupt_checkpoint(&mut self)
2116 where
2117 S: serde::Serialize,
2118 {
2119 if !self.status.is_interrupted() || self.checkpointer.is_none() {
2120 return;
2121 }
2122 let node = self.interrupt_node_name().to_string();
2123 self.save_interrupt_checkpoint(&node).await;
2124 }
2125
2126 fn interrupt_node_name(&self) -> &str {
2131 static UNKNOWN: &str = "unknown";
2132 self.pending_interrupts
2133 .first()
2134 .and_then(|s| s.payload.get("node"))
2135 .and_then(|v| v.as_str())
2136 .unwrap_or(UNKNOWN)
2137 }
2138
2139 fn current_ns(&self) -> Vec<String> {
2146 self.runnable_config
2147 .checkpoint_ns
2148 .as_ref()
2149 .map(|ns| {
2150 ns.segments
2151 .iter()
2152 .map(|seg| seg.node_name.clone())
2153 .collect()
2154 })
2155 .unwrap_or_default()
2156 }
2157
2158 fn emit_interrupt_events(&self, signals: &[crate::interrupt::InterruptSignal]) {
2164 self.emit_interrupt_events_with_namespace(signals, &self.current_ns());
2165 }
2166
2167 fn emit_interrupt_events_with_namespace(
2178 &self,
2179 signals: &[crate::interrupt::InterruptSignal],
2180 namespace: &[String],
2181 ) {
2182 let Some(ref tx) = self.stream_tx else {
2183 return;
2184 };
2185
2186 for signal in signals {
2187 let node = signal
2188 .payload
2189 .get("node")
2190 .and_then(|v| v.as_str())
2191 .unwrap_or("unknown");
2192
2193 let tags: &[String] = &[];
2196 if crate::interrupt::is_hidden_node(node, tags) {
2197 continue;
2198 }
2199
2200 let event = StreamEvent::Interrupt {
2201 node: node.to_string(),
2202 payload: signal.payload.clone(),
2203 resumable: true,
2204 ns: namespace.to_vec(),
2205 };
2206 let _ = tx.send(event);
2207 }
2208 }
2209
2210 fn finish_all_channels(&mut self) {
2233 if self.channels_finished {
2238 return;
2239 }
2240
2241 for &field_idx in S::replace_after_finish_field_indices() {
2242 self.state.finish_field(field_idx);
2243 }
2244
2245 self.channels_finished = true;
2246 }
2247
2248 fn consume_triggered_channels(&mut self, changed: &crate::FieldsChanged) {
2260 for field_idx in 0..S::field_count() {
2261 if changed.has_field(field_idx) {
2262 self.state.consume_field(field_idx);
2263 }
2264 }
2265 }
2266
2267 #[must_use]
2273 fn effective_durability(&self) -> Durability {
2274 self.runnable_config
2275 .durability
2276 .clone()
2277 .unwrap_or(Durability::Sync)
2278 }
2279
2280 #[must_use]
2288 #[allow(
2289 clippy::type_complexity,
2290 reason = "return type is a direct mapping of the three version maps required by Checkpoint struct; factoring into a named type adds indirection without benefit"
2291 )]
2292 fn build_checkpoint_versions(
2293 &self,
2294 ) -> (
2295 HashMap<String, u64>,
2296 HashMap<String, u64>,
2297 HashMap<String, HashMap<String, u64>>,
2298 ) {
2299 let channel_versions: HashMap<String, u64> = self
2300 .field_versions
2301 .versions()
2302 .iter()
2303 .enumerate()
2304 .map(|(idx, ver)| (format!("field_{idx}"), *ver))
2305 .collect();
2306
2307 let new_versions = channel_versions.clone();
2308
2309 let versions_seen: HashMap<String, HashMap<String, u64>> = self
2310 .nodes
2311 .keys()
2312 .map(|node_name| {
2313 let versions = self.versions_seen.get_versions(node_name);
2314 let map: HashMap<String, u64> = versions
2315 .iter()
2316 .enumerate()
2317 .map(|(idx, ver)| (format!("field_{idx}"), *ver))
2318 .collect();
2319 (node_name.clone(), map)
2320 })
2321 .collect();
2322
2323 (channel_versions, new_versions, versions_seen)
2324 }
2325
2326 async fn save_exit_checkpoint(&mut self)
2335 where
2336 S: serde::Serialize,
2337 {
2338 let Some(ref checkpointer) = self.checkpointer else {
2339 return;
2340 };
2341
2342 let channel_values = match serde_json::to_value(&self.state) {
2343 Ok(v) => v,
2344 Err(err) => {
2345 tracing::warn!(
2346 name: "juncture.checkpoint.exit.serialize_failed",
2347 step = self.step,
2348 error = %err,
2349 "Failed to serialize state for exit checkpoint"
2350 );
2351 return;
2352 }
2353 };
2354
2355 let (channel_versions, new_versions, versions_seen) = self.build_checkpoint_versions();
2356
2357 let pending_tasks: Vec<crate::checkpoint::CheckpointPendingTask> = self
2358 .pending_tasks
2359 .iter()
2360 .map(|task| crate::checkpoint::CheckpointPendingTask {
2361 id: task.id.clone(),
2362 node: task.node_name.clone(),
2363 triggers: Vec::new(),
2364 state_override: None,
2365 })
2366 .collect();
2367
2368 let checkpoint_id = generate_checkpoint_id();
2369 let created_at = chrono::Utc::now().to_rfc3339();
2370
2371 let checkpoint = Checkpoint {
2372 id: checkpoint_id,
2373 channel_values,
2374 channel_versions,
2375 versions_seen,
2376 pending_tasks,
2377 pending_sends: Vec::new(),
2378 pending_interrupts: Vec::new(),
2379 schema_version: S::schema_version(),
2380 created_at,
2381 v: 1,
2382 new_versions,
2383 counters_since_delta_snapshot: HashMap::new(),
2384 };
2385
2386 let metadata = CheckpointMetadata {
2387 source: CheckpointSource::Loop,
2388 step: i64::try_from(self.step).unwrap_or(i64::MAX),
2389 writes: HashMap::new(),
2390 parents: HashMap::new(),
2391 run_id: self.run_id.clone(),
2392 };
2393
2394 let metadata_for_event = metadata.clone();
2395 match checkpointer
2396 .put(&self.runnable_config, checkpoint, metadata)
2397 .await
2398 {
2399 Ok(updated_config) => {
2400 self.runnable_config.checkpoint_id = updated_config.checkpoint_id;
2401 self.emit_counter("juncture.checkpoint.writes", 1);
2403 tracing::info!(
2404 name: "juncture.checkpoint.put",
2405 checkpoint_id = %self.runnable_config.checkpoint_id.as_deref().unwrap_or("unknown"),
2406 checkpoint_step = self.step,
2407 checkpoint_source = "Loop",
2408 "Exit checkpoint persisted"
2409 );
2410 if let Some(ref cp_id) = self.runnable_config.checkpoint_id {
2411 self.on_checkpoint_saved(cp_id, self.step);
2412 Self::emit_checkpoint_saved_event(
2414 self.stream_tx.as_ref(),
2415 cp_id.clone(),
2416 metadata_for_event,
2417 self.step,
2418 );
2419 }
2420 }
2421 Err(err) => {
2422 tracing::warn!(
2423 name: "juncture.checkpoint.exit.save_failed",
2424 step = self.step,
2425 error = %err,
2426 "Failed to save exit checkpoint"
2427 );
2428 self.emit_counter("juncture.checkpoint.errors", 1);
2430 }
2431 }
2432 }
2433
2434 fn update_delta_counters(&mut self, changed: &crate::FieldsChanged) {
2445 let field_names = S::field_names();
2446 let num_fields = field_names.len().min(self.field_versions.len());
2447
2448 for field_idx in 0..num_fields {
2449 let channel_name = format!("field_{field_idx}");
2450 let entry = self.delta_counters.entry(channel_name).or_default();
2451
2452 entry.supersteps = entry.supersteps.saturating_add(1);
2454
2455 if changed.has_field(field_idx) {
2457 entry.updates = entry.updates.saturating_add(1);
2458 }
2459 }
2460 }
2461
2462 fn build_checkpoint_delta_counters(&self) -> HashMap<String, DeltaCounters> {
2467 self.delta_counters.clone()
2468 }
2469
2470 fn should_take_full_snapshot(&self) -> bool {
2477 let specs = S::delta_channel_specs();
2478 if specs.is_empty() {
2479 return true;
2482 }
2483
2484 for &(field_idx, frequency) in specs {
2485 let channel_name = format!("field_{field_idx}");
2486 if let Some(counters) = self.delta_counters.get(&channel_name)
2487 && counters.exceeds_frequency(frequency)
2488 {
2489 return true;
2490 }
2491 }
2492
2493 false
2494 }
2495
2496 fn reset_delta_counters(&mut self) {
2498 self.delta_counters.clear();
2499 }
2500
2501 #[inline]
2507 fn emit_counter(&self, name: &str, value: u64) {
2508 if let Some(ref collector) = self.runnable_config.metrics_collector {
2509 collector.inc_counter(name, value);
2510 }
2511 }
2512
2513 #[inline]
2515 fn emit_histogram(&self, name: &str, value: f64) {
2516 if let Some(ref collector) = self.runnable_config.metrics_collector {
2517 collector.record_histogram(name, value);
2518 }
2519 }
2520
2521 #[inline]
2523 fn emit_gauge(&self, name: &str, value: u64) {
2524 if let Some(ref collector) = self.runnable_config.metrics_collector {
2525 collector.set_gauge(name, value);
2526 }
2527 }
2528
2529 #[inline]
2536 fn on_graph_end(&self, result: &Result<(), JunctureError>) {
2537 let (total_tokens, cost_usd) = self.budget_tracker.as_ref().map_or((0, 0.0), |tracker| {
2539 let usage = tracker.current_usage();
2540 (usage.tokens_used, usage.cost_usd)
2541 });
2542
2543 let success = result.is_ok();
2544 let span = tracing::info_span!(
2545 "juncture.graph.complete",
2546 total_steps = self.step,
2547 total_tokens = total_tokens,
2548 cost_usd = cost_usd,
2549 success = success,
2550 );
2551 let _enter = span.enter();
2552
2553 tracing::info!("Graph execution completed");
2554
2555 if let Some(ref handler) = self.runnable_config.callback_handler {
2556 handler.on_graph_end(result);
2557 }
2558 }
2559
2560 #[inline]
2562 fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
2563 if let Some(ref handler) = self.runnable_config.callback_handler {
2564 handler.on_checkpoint_saved(checkpoint_id, step);
2565 }
2566 }
2567
2568 #[inline]
2570 fn emit_checkpoint_saved_event(
2571 stream_tx: Option<&mpsc::UnboundedSender<StreamEvent<S>>>,
2572 checkpoint_id: String,
2573 metadata: CheckpointMetadata,
2574 step: usize,
2575 ) {
2576 if let Some(tx) = stream_tx {
2577 let _ = tx.send(StreamEvent::CheckpointSaved {
2578 checkpoint_id,
2579 metadata,
2580 step,
2581 });
2582 }
2583 }
2584}
2585
2586#[cfg(test)]
2587mod tests {
2588 use super::*;
2589 use crate::state::FieldVersions;
2590 use crate::{
2591 Command,
2592 node::IntoNode,
2593 node::NodeFnCommand,
2594 pregel::types::{TaskOutput, TaskTrigger},
2595 };
2596 use chrono::Utc;
2597
2598 #[test]
2599 fn test_pregel_loop_creation() {
2600 let state = TestState;
2601 let mut nodes = IndexMap::new();
2602 nodes.insert(
2603 "test_node".to_string(),
2604 NodeFnCommand(
2605 |_s: &TestState| -> std::pin::Pin<
2606 Box<
2607 dyn std::future::Future<
2608 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2609 > + Send,
2610 >,
2611 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2612 )
2613 .into_node("test_node"),
2614 );
2615
2616 let trigger_table = TriggerTable::new();
2617 let config = crate::config::RunnableConfig::new();
2618
2619 let result = PregelLoop::new(state, nodes, trigger_table, config, 0);
2620 result.unwrap();
2621 }
2622
2623 #[test]
2624 fn test_field_version_tracker() {
2625 let mut tracker = FieldVersionTracker::new(5);
2626
2627 assert_eq!(tracker.get(0), 0);
2628 assert_eq!(tracker.global_max(), 0);
2629
2630 tracker.bump(0);
2631 assert_eq!(tracker.get(0), 1);
2632 assert_eq!(tracker.global_max(), 1);
2633
2634 tracker.bump(2);
2635 assert_eq!(tracker.get(2), 2);
2636 assert_eq!(tracker.global_max(), 2);
2637 }
2638
2639 #[test]
2640 fn test_versions_seen() {
2641 let node_names = vec!["node_a".to_string(), "node_b".to_string()];
2642 let mut seen = VersionsSeen::new(&node_names, 3);
2643
2644 assert!(!seen.should_activate("node_a", &[0], &[0, 0, 0]));
2645
2646 let current = vec![1, 0, 0];
2647 assert!(seen.should_activate("node_a", &[0], ¤t));
2648
2649 seen.mark_consumed("node_a", ¤t);
2650 assert!(!seen.should_activate("node_a", &[0], ¤t));
2651 }
2652
2653 #[test]
2654 fn test_run_control() {
2655 let rc = RunControl::new();
2656 assert!(!rc.is_drain_requested());
2657
2658 rc.request_drain();
2659 assert!(rc.is_drain_requested());
2660 }
2661
2662 #[test]
2663 fn test_run_control_default() {
2664 let rc = RunControl::default();
2665 assert!(!rc.is_drain_requested());
2666 }
2667
2668 #[test]
2669 fn test_handle_bubble_up_interrupt_sets_status() {
2670 let state = TestState;
2671 let mut nodes = IndexMap::new();
2672 nodes.insert(
2673 "test_node".to_string(),
2674 NodeFnCommand(
2675 |_s: &TestState| -> std::pin::Pin<
2676 Box<
2677 dyn std::future::Future<
2678 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2679 > + Send,
2680 >,
2681 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2682 )
2683 .into_node("test_node"),
2684 );
2685
2686 let trigger_table = TriggerTable::new();
2687 let config = crate::config::RunnableConfig::new();
2688
2689 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
2690
2691 let signals = vec![crate::interrupt::InterruptSignal {
2692 index: 0,
2693 id: Some("sub-int-0".to_string()),
2694 payload: serde_json::json!({"node": "subgraph_node"}),
2695 timestamp: Utc::now(),
2696 }];
2697 let bubble_ups = vec![BubbleUp::Interrupt(crate::pregel::types::GraphInterrupt {
2698 interrupts: signals,
2699 step: 2,
2700 namespace: vec![],
2701 })];
2702
2703 let should_stop = loop_.handle_bubble_ups(&bubble_ups);
2704
2705 assert!(should_stop);
2706 assert!(loop_.status.is_interrupted());
2707 assert_eq!(loop_.pending_interrupts.len(), 1);
2708 assert_eq!(loop_.pending_interrupts[0].id.as_deref(), Some("sub-int-0"));
2709 }
2710
2711 #[test]
2712 fn test_handle_bubble_up_drained_sets_status() {
2713 let state = TestState;
2714 let mut nodes = IndexMap::new();
2715 nodes.insert(
2716 "test_node".to_string(),
2717 NodeFnCommand(
2718 |_s: &TestState| -> std::pin::Pin<
2719 Box<
2720 dyn std::future::Future<
2721 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2722 > + Send,
2723 >,
2724 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2725 )
2726 .into_node("test_node"),
2727 );
2728
2729 let trigger_table = TriggerTable::new();
2730 let config = crate::config::RunnableConfig::new();
2731
2732 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
2733
2734 let bubble_ups = vec![BubbleUp::Drained(crate::pregel::types::GraphDrained {
2735 reason: "subgraph completed".to_string(),
2736 })];
2737
2738 let should_stop = loop_.handle_bubble_ups(&bubble_ups);
2739
2740 assert!(should_stop);
2741 assert!(loop_.status.is_terminal());
2742 assert!(matches!(loop_.status, LoopStatus::Drained));
2743 }
2744
2745 #[test]
2746 fn test_handle_bubble_up_parent_command_does_not_stop() {
2747 let state = TestState;
2748 let mut nodes = IndexMap::new();
2749 nodes.insert(
2750 "test_node".to_string(),
2751 NodeFnCommand(
2752 |_s: &TestState| -> std::pin::Pin<
2753 Box<
2754 dyn std::future::Future<
2755 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2756 > + Send,
2757 >,
2758 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2759 )
2760 .into_node("test_node"),
2761 );
2762
2763 let trigger_table = TriggerTable::new();
2764 let config = crate::config::RunnableConfig::new();
2765
2766 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
2767
2768 let parent_cmd = crate::command::ParentCommand::from_subgraph(
2769 Command::end(),
2770 "test_subgraph_node",
2771 "test_namespace",
2772 );
2773 let bubble_ups = vec![BubbleUp::ParentCommand(parent_cmd)];
2774
2775 let should_stop = loop_.handle_bubble_ups(&bubble_ups);
2776
2777 assert!(!should_stop);
2778 assert!(loop_.status.is_running());
2779 }
2780
2781 #[test]
2782 fn test_handle_bubble_up_empty_does_nothing() {
2783 let state = TestState;
2784 let mut nodes = IndexMap::new();
2785 nodes.insert(
2786 "test_node".to_string(),
2787 NodeFnCommand(
2788 |_s: &TestState| -> std::pin::Pin<
2789 Box<
2790 dyn std::future::Future<
2791 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2792 > + Send,
2793 >,
2794 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2795 )
2796 .into_node("test_node"),
2797 );
2798
2799 let trigger_table = TriggerTable::new();
2800 let config = crate::config::RunnableConfig::new();
2801
2802 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
2803
2804 let should_stop = loop_.handle_bubble_ups(&[]);
2805
2806 assert!(!should_stop);
2807 assert!(loop_.status.is_running());
2808 }
2809
2810 #[test]
2811 fn test_handle_bubble_up_interrupt_takes_priority_over_drain() {
2812 let state = TestState;
2813 let mut nodes = IndexMap::new();
2814 nodes.insert(
2815 "test_node".to_string(),
2816 NodeFnCommand(
2817 |_s: &TestState| -> std::pin::Pin<
2818 Box<
2819 dyn std::future::Future<
2820 Output = Result<crate::Command<TestState>, crate::JunctureError>,
2821 > + Send,
2822 >,
2823 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2824 )
2825 .into_node("test_node"),
2826 );
2827
2828 let trigger_table = TriggerTable::new();
2829 let config = crate::config::RunnableConfig::new();
2830
2831 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
2832
2833 let bubble_ups = vec![
2834 BubbleUp::Drained(crate::pregel::types::GraphDrained {
2835 reason: "drained".to_string(),
2836 }),
2837 BubbleUp::Interrupt(crate::pregel::types::GraphInterrupt {
2838 interrupts: vec![crate::interrupt::InterruptSignal {
2839 index: 0,
2840 id: None,
2841 payload: serde_json::Value::Null,
2842 timestamp: Utc::now(),
2843 }],
2844 step: 1,
2845 namespace: vec![],
2846 }),
2847 ];
2848
2849 let should_stop = loop_.handle_bubble_ups(&bubble_ups);
2850
2851 assert!(should_stop);
2852 assert!(loop_.status.is_interrupted());
2854 }
2855
2856 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2857 struct TestState;
2858
2859 impl State for TestState {
2860 type Update = TestUpdate;
2861 type FieldVersions = FieldVersions;
2862
2863 fn apply(&mut self, _: Self::Update) -> crate::FieldsChanged {
2864 crate::FieldsChanged(0)
2865 }
2866
2867 fn reset_ephemeral(&mut self) {}
2868 }
2869
2870 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2871 struct TestUpdate;
2872
2873 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2877 struct DeltaTestState {
2878 value: i32,
2879 messages: Vec<String>,
2880 }
2881
2882 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2883 struct DeltaTestUpdate {
2884 value: Option<i32>,
2885 messages: Option<Vec<String>>,
2886 }
2887
2888 impl State for DeltaTestState {
2889 type Update = DeltaTestUpdate;
2890 type FieldVersions = FieldVersions;
2891
2892 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2893 let mut changed = crate::FieldsChanged(0);
2894 if let Some(v) = update.value {
2895 self.value = v;
2896 changed.set_field(0);
2897 }
2898 if let Some(msgs) = update.messages {
2899 self.messages.extend(msgs);
2900 changed.set_field(1);
2901 }
2902 changed
2903 }
2904
2905 fn reset_ephemeral(&mut self) {}
2906
2907 fn field_names() -> &'static [&'static str] {
2908 &["value", "messages"]
2909 }
2910
2911 fn field_count() -> usize {
2912 2
2913 }
2914
2915 fn delta_channel_specs() -> &'static [(usize, usize)] {
2917 &[(1, 3)]
2918 }
2919 }
2920
2921 struct CapturingCheckpointer {
2923 captured: Arc<std::sync::Mutex<Option<crate::checkpoint::Checkpoint>>>,
2924 }
2925
2926 #[async_trait::async_trait]
2927 impl crate::checkpoint::CheckpointSaver for CapturingCheckpointer {
2928 async fn get_tuple(
2929 &self,
2930 _: &crate::config::RunnableConfig,
2931 ) -> Result<Option<crate::checkpoint::CheckpointTuple>, crate::checkpoint::CheckpointError>
2932 {
2933 Ok(None)
2934 }
2935
2936 async fn list(
2937 &self,
2938 _: &crate::config::RunnableConfig,
2939 _: Option<crate::checkpoint::CheckpointFilter>,
2940 ) -> Result<Vec<crate::checkpoint::CheckpointTuple>, crate::checkpoint::CheckpointError>
2941 {
2942 Ok(Vec::new())
2943 }
2944
2945 async fn put(
2946 &self,
2947 _: &crate::config::RunnableConfig,
2948 checkpoint: crate::checkpoint::Checkpoint,
2949 _metadata: crate::checkpoint::CheckpointMetadata,
2950 ) -> Result<crate::config::RunnableConfig, crate::checkpoint::CheckpointError> {
2951 *self
2952 .captured
2953 .lock()
2954 .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(checkpoint);
2955 let mut cfg = crate::config::RunnableConfig::new();
2956 cfg.checkpoint_id = Some("cp-capture".to_string());
2957 Ok(cfg)
2958 }
2959
2960 async fn put_writes(
2961 &self,
2962 _: &crate::config::RunnableConfig,
2963 _: Vec<crate::checkpoint::PendingWrite>,
2964 _: &str,
2965 ) -> Result<(), crate::checkpoint::CheckpointError> {
2966 Ok(())
2967 }
2968 }
2969
2970 #[tokio::test]
2972 async fn test_delta_counters_increment_on_field_change() {
2973 let state = DeltaTestState {
2974 value: 0,
2975 messages: vec![],
2976 };
2977 let mut nodes = IndexMap::new();
2978 nodes.insert(
2979 "test_node".to_string(),
2980 NodeFnCommand(
2981 |_s: &DeltaTestState| -> std::pin::Pin<
2982 Box<
2983 dyn std::future::Future<
2984 Output = Result<
2985 crate::Command<DeltaTestState>,
2986 crate::JunctureError,
2987 >,
2988 > + Send,
2989 >,
2990 > { Box::pin(async move { Ok(crate::Command::end()) }) },
2991 )
2992 .into_node("test_node"),
2993 );
2994 let trigger_table = TriggerTable::new();
2995 let config = crate::config::RunnableConfig::new();
2996
2997 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 2).unwrap();
2998
2999 let changed = crate::FieldsChanged(0b11); loop_.update_delta_counters(&changed);
3002
3003 assert_eq!(loop_.delta_counters.len(), 2, "should track both fields");
3004
3005 let field_0 = loop_
3006 .delta_counters
3007 .get("field_0")
3008 .expect("field_0 should exist");
3009 assert_eq!(field_0.updates, 1, "field_0 should have 1 update");
3010 assert_eq!(field_0.supersteps, 1, "field_0 should have 1 superstep");
3011
3012 let field_1 = loop_
3013 .delta_counters
3014 .get("field_1")
3015 .expect("field_1 should exist");
3016 assert_eq!(field_1.updates, 1, "field_1 should have 1 update");
3017 assert_eq!(field_1.supersteps, 1, "field_1 should have 1 superstep");
3018 }
3019
3020 #[tokio::test]
3022 async fn test_delta_counters_increment_unchanged_fields_get_superstep_only() {
3023 let state = DeltaTestState {
3024 value: 0,
3025 messages: vec![],
3026 };
3027 let mut nodes = IndexMap::new();
3028 nodes.insert(
3029 "test_node".to_string(),
3030 NodeFnCommand(
3031 |_s: &DeltaTestState| -> std::pin::Pin<
3032 Box<
3033 dyn std::future::Future<
3034 Output = Result<
3035 crate::Command<DeltaTestState>,
3036 crate::JunctureError,
3037 >,
3038 > + Send,
3039 >,
3040 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3041 )
3042 .into_node("test_node"),
3043 );
3044 let trigger_table = TriggerTable::new();
3045 let config = crate::config::RunnableConfig::new();
3046
3047 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 2).unwrap();
3048
3049 let changed = crate::FieldsChanged(0b01);
3051 loop_.update_delta_counters(&changed);
3052
3053 let field_0 = loop_
3054 .delta_counters
3055 .get("field_0")
3056 .expect("field_0 should exist");
3057 assert_eq!(field_0.updates, 1, "field_0 should have 1 update");
3058
3059 let field_1 = loop_
3060 .delta_counters
3061 .get("field_1")
3062 .expect("field_1 should exist");
3063 assert_eq!(
3064 field_1.updates, 0,
3065 "field_1 should have 0 updates (not changed)"
3066 );
3067 assert_eq!(
3068 field_1.supersteps, 1,
3069 "field_1 should still have 1 superstep"
3070 );
3071 }
3072
3073 #[tokio::test]
3075 async fn test_delta_counters_accumulate_across_supersteps() {
3076 let state = DeltaTestState {
3077 value: 0,
3078 messages: vec![],
3079 };
3080 let mut nodes = IndexMap::new();
3081 nodes.insert(
3082 "test_node".to_string(),
3083 NodeFnCommand(
3084 |_s: &DeltaTestState| -> std::pin::Pin<
3085 Box<
3086 dyn std::future::Future<
3087 Output = Result<
3088 crate::Command<DeltaTestState>,
3089 crate::JunctureError,
3090 >,
3091 > + Send,
3092 >,
3093 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3094 )
3095 .into_node("test_node"),
3096 );
3097 let trigger_table = TriggerTable::new();
3098 let config = crate::config::RunnableConfig::new();
3099
3100 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 2).unwrap();
3101
3102 loop_.update_delta_counters(&crate::FieldsChanged(0b01));
3104 loop_.update_delta_counters(&crate::FieldsChanged(0b11));
3106
3107 let field_0 = loop_
3108 .delta_counters
3109 .get("field_0")
3110 .expect("field_0 should exist");
3111 assert_eq!(field_0.updates, 2, "field_0 updated in both supersteps");
3112 assert_eq!(field_0.supersteps, 2, "field_0 has 2 supersteps");
3113
3114 let field_1 = loop_
3115 .delta_counters
3116 .get("field_1")
3117 .expect("field_1 should exist");
3118 assert_eq!(
3119 field_1.updates, 1,
3120 "field_1 updated in only second superstep"
3121 );
3122 assert_eq!(field_1.supersteps, 2, "field_1 has 2 supersteps");
3123 }
3124
3125 #[tokio::test]
3127 async fn test_delta_counters_populated_in_checkpoint_and_reset() {
3128 let state = DeltaTestState {
3129 value: 0,
3130 messages: vec![],
3131 };
3132 let mut nodes = IndexMap::new();
3133 nodes.insert(
3134 "test_node".to_string(),
3135 NodeFnCommand(
3136 |_s: &DeltaTestState| -> std::pin::Pin<
3137 Box<
3138 dyn std::future::Future<
3139 Output = Result<
3140 crate::Command<DeltaTestState>,
3141 crate::JunctureError,
3142 >,
3143 > + Send,
3144 >,
3145 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3146 )
3147 .into_node("test_node"),
3148 );
3149 let trigger_table = TriggerTable::new();
3150 let mut config = crate::config::RunnableConfig::new();
3151 config.thread_id = Some("test-thread".to_string());
3152
3153 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 2).unwrap();
3154
3155 let captured: Arc<std::sync::Mutex<Option<crate::checkpoint::Checkpoint>>> =
3156 Arc::new(std::sync::Mutex::new(None));
3157 let checkpointer = CapturingCheckpointer {
3158 captured: Arc::clone(&captured),
3159 };
3160 loop_.set_checkpointer(Arc::new(checkpointer));
3161
3162 loop_.delta_counters.insert(
3167 "field_0".to_string(),
3168 DeltaCounters {
3169 updates: 1,
3170 supersteps: 1,
3171 },
3172 );
3173 loop_.delta_counters.insert(
3176 "field_1".to_string(),
3177 DeltaCounters {
3178 updates: 3,
3179 supersteps: 1,
3180 },
3181 );
3182
3183 loop_.pending_tasks = vec![PendingTask::pull(
3186 uuid::Uuid::new_v4().to_string(),
3187 "test_node".to_string(),
3188 )];
3189 let _ = loop_.execute_superstep().await;
3190 let _ = loop_.after_tick(SuperstepResult::empty()).await;
3191
3192 let checkpoint = captured
3194 .lock()
3195 .unwrap_or_else(std::sync::PoisonError::into_inner)
3196 .take()
3197 .expect("checkpoint should have been saved");
3198 assert!(
3199 !checkpoint.counters_since_delta_snapshot.is_empty(),
3200 "counters_since_delta_snapshot should be populated"
3201 );
3202 let field_0 = checkpoint
3203 .counters_since_delta_snapshot
3204 .get("field_0")
3205 .expect("field_0 should be in delta counters");
3206 assert_eq!(
3209 field_0.updates, 1,
3210 "field_0 should have 1 update in checkpoint"
3211 );
3212 assert_eq!(
3213 field_0.supersteps, 2,
3214 "field_0 should have 2 supersteps in checkpoint"
3215 );
3216
3217 let field_1 = checkpoint
3218 .counters_since_delta_snapshot
3219 .get("field_1")
3220 .expect("field_1 should be in delta counters");
3221 assert_eq!(
3222 field_1.updates, 3,
3223 "field_1 should have 3 updates in checkpoint"
3224 );
3225
3226 assert!(
3228 loop_.delta_counters.is_empty(),
3229 "delta counters should be reset after checkpoint save"
3230 );
3231 }
3232
3233 #[test]
3235 fn test_should_take_full_snapshot_no_delta_channels() {
3236 let state = TestState;
3238 let mut nodes = IndexMap::new();
3239 nodes.insert(
3240 "test_node".to_string(),
3241 NodeFnCommand(
3242 |_s: &TestState| -> std::pin::Pin<
3243 Box<
3244 dyn std::future::Future<
3245 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3246 > + Send,
3247 >,
3248 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3249 )
3250 .into_node("test_node"),
3251 );
3252 let trigger_table = TriggerTable::new();
3253 let config = crate::config::RunnableConfig::new();
3254
3255 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3256
3257 assert!(
3259 loop_.should_take_full_snapshot(),
3260 "should always take full snapshot with no delta channels"
3261 );
3262
3263 loop_.delta_counters.insert(
3265 "field_0".to_string(),
3266 DeltaCounters {
3267 updates: 100,
3268 supersteps: 50,
3269 },
3270 );
3271 assert!(
3272 loop_.should_take_full_snapshot(),
3273 "still full snapshot when specs are empty (no delta optimization)"
3274 );
3275 }
3276
3277 #[test]
3279 fn test_should_take_full_snapshot_respects_frequency() {
3280 let state = DeltaTestState {
3281 value: 0,
3282 messages: vec![],
3283 };
3284 let mut nodes = IndexMap::new();
3285 nodes.insert(
3286 "test_node".to_string(),
3287 NodeFnCommand(
3288 |_s: &DeltaTestState| -> std::pin::Pin<
3289 Box<
3290 dyn std::future::Future<
3291 Output = Result<
3292 crate::Command<DeltaTestState>,
3293 crate::JunctureError,
3294 >,
3295 > + Send,
3296 >,
3297 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3298 )
3299 .into_node("test_node"),
3300 );
3301 let trigger_table = TriggerTable::new();
3302 let config = crate::config::RunnableConfig::new();
3303
3304 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 2).unwrap();
3305
3306 loop_.delta_counters.insert(
3308 "field_1".to_string(),
3309 DeltaCounters {
3310 updates: 2,
3311 supersteps: 2,
3312 },
3313 );
3314 assert!(
3315 !loop_.should_take_full_snapshot(),
3316 "should not take full snapshot below frequency threshold"
3317 );
3318
3319 loop_.delta_counters.insert(
3321 "field_1".to_string(),
3322 DeltaCounters {
3323 updates: 3,
3324 supersteps: 3,
3325 },
3326 );
3327 assert!(
3328 loop_.should_take_full_snapshot(),
3329 "should take full snapshot at frequency threshold"
3330 );
3331
3332 loop_.delta_counters.insert(
3334 "field_1".to_string(),
3335 DeltaCounters {
3336 updates: 10,
3337 supersteps: 5,
3338 },
3339 );
3340 assert!(
3341 loop_.should_take_full_snapshot(),
3342 "should take full snapshot above frequency threshold"
3343 );
3344 }
3345
3346 #[test]
3348 fn test_delta_counters_exceeds_frequency() {
3349 let counters = DeltaCounters::new();
3350 assert_eq!(counters.updates, 0);
3351 assert_eq!(counters.supersteps, 0);
3352
3353 assert!(
3355 counters.exceeds_frequency(0),
3356 "frequency 0 always snapshots"
3357 );
3358
3359 let counters = DeltaCounters {
3361 updates: 2,
3362 supersteps: 1,
3363 };
3364 assert!(!counters.exceeds_frequency(3), "2 < 3, not exceeded");
3365
3366 let counters = DeltaCounters {
3368 updates: 3,
3369 supersteps: 1,
3370 };
3371 assert!(counters.exceeds_frequency(3), "3 >= 3, exceeded");
3372
3373 let counters = DeltaCounters {
3375 updates: 10,
3376 supersteps: 1,
3377 };
3378 assert!(counters.exceeds_frequency(3), "10 >= 3, exceeded");
3379 }
3380
3381 #[tokio::test]
3388 async fn test_scratchpad_populated_after_execute_superstep() {
3389 let state = TestState;
3390
3391 let mut nodes = IndexMap::new();
3392 nodes.insert(
3393 "test_node".to_string(),
3394 NodeFnCommand(
3395 |_s: &TestState| -> std::pin::Pin<
3396 Box<
3397 dyn std::future::Future<
3398 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3399 > + Send,
3400 >,
3401 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3402 )
3403 .into_node("test_node"),
3404 );
3405
3406 let trigger_table = TriggerTable::new();
3407 let config = crate::config::RunnableConfig::new();
3408
3409 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3410
3411 loop_.pending_tasks = vec![PendingTask::pull(
3413 uuid::Uuid::new_v4().to_string(),
3414 "test_node".to_string(),
3415 )];
3416 loop_.pending_interrupts = vec![
3417 crate::interrupt::InterruptSignal {
3418 index: 0,
3419 id: Some("int-alpha".to_string()),
3420 payload: serde_json::Value::Null,
3421 timestamp: Utc::now(),
3422 },
3423 crate::interrupt::InterruptSignal {
3424 index: 1,
3425 id: Some("int-beta".to_string()),
3426 payload: serde_json::Value::Null,
3427 timestamp: Utc::now(),
3428 },
3429 ];
3430
3431 assert!(
3433 !loop_.scratchpad.is_interrupt_processed("int-alpha"),
3434 "scratchpad should be empty before superstep"
3435 );
3436 assert!(
3437 !loop_.scratchpad.is_interrupt_processed("int-beta"),
3438 "scratchpad should be empty before superstep"
3439 );
3440
3441 let result = loop_.execute_superstep().await;
3442 assert!(result.is_ok(), "execute_superstep should succeed");
3443
3444 assert!(
3446 loop_.scratchpad.is_interrupt_processed("int-alpha"),
3447 "int-alpha should be marked as processed after superstep"
3448 );
3449 assert!(
3450 loop_.scratchpad.is_interrupt_processed("int-beta"),
3451 "int-beta should be marked as processed after superstep"
3452 );
3453 assert!(
3454 !loop_.scratchpad.is_interrupt_processed("int-gamma"),
3455 "unrelated interrupt should not be marked as processed"
3456 );
3457 }
3458
3459 #[tokio::test]
3462 async fn test_scratchpad_accumulates_across_supersteps() {
3463 let state = TestState;
3464
3465 let mut nodes = IndexMap::new();
3466 nodes.insert(
3467 "test_node".to_string(),
3468 NodeFnCommand(
3469 |_s: &TestState| -> std::pin::Pin<
3470 Box<
3471 dyn std::future::Future<
3472 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3473 > + Send,
3474 >,
3475 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3476 )
3477 .into_node("test_node"),
3478 );
3479
3480 let trigger_table = TriggerTable::new();
3481 let config = crate::config::RunnableConfig::new();
3482
3483 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3484
3485 loop_.pending_tasks = vec![PendingTask::pull(
3487 uuid::Uuid::new_v4().to_string(),
3488 "test_node".to_string(),
3489 )];
3490 loop_.pending_interrupts = vec![crate::interrupt::InterruptSignal {
3491 index: 0,
3492 id: Some("int-1".to_string()),
3493 payload: serde_json::Value::Null,
3494 timestamp: Utc::now(),
3495 }];
3496
3497 let _ = loop_.execute_superstep().await;
3498 let _ = loop_.after_tick(SuperstepResult::empty()).await;
3499
3500 loop_.pending_tasks = vec![PendingTask::pull(
3502 uuid::Uuid::new_v4().to_string(),
3503 "test_node".to_string(),
3504 )];
3505 loop_.pending_interrupts = vec![crate::interrupt::InterruptSignal {
3506 index: 0,
3507 id: Some("int-2".to_string()),
3508 payload: serde_json::Value::Null,
3509 timestamp: Utc::now(),
3510 }];
3511
3512 let _ = loop_.execute_superstep().await;
3513
3514 assert!(
3516 loop_.scratchpad.is_interrupt_processed("int-1"),
3517 "int-1 from first superstep should still be tracked"
3518 );
3519 assert!(
3520 loop_.scratchpad.is_interrupt_processed("int-2"),
3521 "int-2 from second superstep should be tracked"
3522 );
3523 }
3524
3525 #[derive(Clone, Debug, PartialEq, Eq)]
3529 enum ObservedCall {
3530 Put {
3531 source: crate::checkpoint::CheckpointSource,
3532 step: i64,
3533 },
3534 }
3535
3536 struct TrackingCheckpointer {
3538 observed: Arc<std::sync::Mutex<Vec<ObservedCall>>>,
3539 }
3540
3541 #[async_trait::async_trait]
3542 impl crate::checkpoint::CheckpointSaver for TrackingCheckpointer {
3543 async fn get_tuple(
3544 &self,
3545 _: &crate::config::RunnableConfig,
3546 ) -> Result<Option<crate::checkpoint::CheckpointTuple>, crate::checkpoint::CheckpointError>
3547 {
3548 Ok(None)
3549 }
3550
3551 async fn list(
3552 &self,
3553 _: &crate::config::RunnableConfig,
3554 _: Option<crate::checkpoint::CheckpointFilter>,
3555 ) -> Result<Vec<crate::checkpoint::CheckpointTuple>, crate::checkpoint::CheckpointError>
3556 {
3557 Ok(Vec::new())
3558 }
3559
3560 async fn put(
3561 &self,
3562 _: &crate::config::RunnableConfig,
3563 _checkpoint: crate::checkpoint::Checkpoint,
3564 metadata: crate::checkpoint::CheckpointMetadata,
3565 ) -> Result<crate::config::RunnableConfig, crate::checkpoint::CheckpointError> {
3566 self.observed
3567 .lock()
3568 .unwrap_or_else(std::sync::PoisonError::into_inner)
3569 .push(ObservedCall::Put {
3570 source: metadata.source,
3571 step: metadata.step,
3572 });
3573 let mut cfg = crate::config::RunnableConfig::new();
3574 cfg.checkpoint_id = Some("cp-test".to_string());
3575 Ok(cfg)
3576 }
3577
3578 async fn put_writes(
3579 &self,
3580 _: &crate::config::RunnableConfig,
3581 _: Vec<crate::checkpoint::PendingWrite>,
3582 _: &str,
3583 ) -> Result<(), crate::checkpoint::CheckpointError> {
3584 Ok(())
3585 }
3586 }
3587
3588 #[tokio::test]
3591 async fn test_superstep_checkpoint_saved_on_normal_completion() {
3592 let state = TestState;
3593
3594 let mut nodes = IndexMap::new();
3595 nodes.insert(
3596 "test_node".to_string(),
3597 NodeFnCommand(
3598 |_s: &TestState| -> std::pin::Pin<
3599 Box<
3600 dyn std::future::Future<
3601 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3602 > + Send,
3603 >,
3604 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3605 )
3606 .into_node("test_node"),
3607 );
3608
3609 let trigger_table = TriggerTable::new();
3610 let mut config = crate::config::RunnableConfig::new();
3611 config.thread_id = Some("test-thread".to_string());
3612
3613 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3614
3615 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
3616 let checkpointer = TrackingCheckpointer {
3617 observed: Arc::clone(&observed),
3618 };
3619 loop_.set_checkpointer(Arc::new(checkpointer));
3620
3621 loop_.pending_tasks = vec![PendingTask::pull(
3623 uuid::Uuid::new_v4().to_string(),
3624 "test_node".to_string(),
3625 )];
3626
3627 let _ = loop_.execute_superstep().await;
3628 let _ = loop_.after_tick(SuperstepResult::empty()).await;
3629
3630 let has_loop_checkpoint = {
3632 let calls = observed
3633 .lock()
3634 .unwrap_or_else(std::sync::PoisonError::into_inner);
3635 calls.iter().any(|c| {
3636 matches!(
3637 c,
3638 ObservedCall::Put {
3639 source: crate::checkpoint::CheckpointSource::Loop,
3640 step: 0,
3641 }
3642 )
3643 })
3644 };
3645 assert!(has_loop_checkpoint, "expected a Loop checkpoint at step 0");
3646 }
3647
3648 #[tokio::test]
3651 async fn test_superstep_checkpoint_step_increments() {
3652 let state = TestState;
3653
3654 let mut nodes = IndexMap::new();
3655 nodes.insert(
3656 "test_node".to_string(),
3657 NodeFnCommand(
3658 |_s: &TestState| -> std::pin::Pin<
3659 Box<
3660 dyn std::future::Future<
3661 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3662 > + Send,
3663 >,
3664 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3665 )
3666 .into_node("test_node"),
3667 );
3668
3669 let trigger_table = TriggerTable::new();
3670 let mut config = crate::config::RunnableConfig::new();
3671 config.thread_id = Some("test-thread".to_string());
3672
3673 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3674
3675 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
3676 let checkpointer = TrackingCheckpointer {
3677 observed: Arc::clone(&observed),
3678 };
3679 loop_.set_checkpointer(Arc::new(checkpointer));
3680
3681 loop_.pending_tasks = vec![PendingTask::pull(
3683 uuid::Uuid::new_v4().to_string(),
3684 "test_node".to_string(),
3685 )];
3686 let _ = loop_.execute_superstep().await;
3687 let _ = loop_.after_tick(SuperstepResult::empty()).await;
3688
3689 loop_.pending_tasks = vec![PendingTask::pull(
3691 uuid::Uuid::new_v4().to_string(),
3692 "test_node".to_string(),
3693 )];
3694 let _ = loop_.execute_superstep().await;
3695 let _ = loop_.after_tick(SuperstepResult::empty()).await;
3696
3697 let loop_steps: Vec<i64> = {
3698 let calls = observed
3699 .lock()
3700 .unwrap_or_else(std::sync::PoisonError::into_inner);
3701 calls
3702 .iter()
3703 .filter_map(|c| match c {
3704 ObservedCall::Put {
3705 source: crate::checkpoint::CheckpointSource::Loop,
3706 step,
3707 } => Some(*step),
3708 ObservedCall::Put { .. } => None,
3709 })
3710 .collect()
3711 };
3712
3713 assert_eq!(
3714 loop_steps,
3715 vec![0, 1],
3716 "expected Loop checkpoints at steps 0 and 1, got: {loop_steps:?}"
3717 );
3718 }
3719
3720 #[tokio::test]
3723 async fn test_superstep_checkpoint_noop_without_checkpointer() {
3724 let state = TestState;
3725
3726 let mut nodes = IndexMap::new();
3727 nodes.insert(
3728 "test_node".to_string(),
3729 NodeFnCommand(
3730 |_s: &TestState| -> std::pin::Pin<
3731 Box<
3732 dyn std::future::Future<
3733 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3734 > + Send,
3735 >,
3736 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3737 )
3738 .into_node("test_node"),
3739 );
3740
3741 let trigger_table = TriggerTable::new();
3742 let config = crate::config::RunnableConfig::new();
3743
3744 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3745 assert!(
3746 loop_.checkpointer.is_none(),
3747 "no checkpointer should be configured by default"
3748 );
3749
3750 loop_.pending_tasks = vec![PendingTask::pull(
3752 uuid::Uuid::new_v4().to_string(),
3753 "test_node".to_string(),
3754 )];
3755
3756 let result = loop_.execute_superstep().await;
3757 assert!(result.is_ok(), "execute_superstep should succeed");
3758
3759 let after_result = loop_.after_tick(SuperstepResult::empty()).await;
3760 assert!(
3761 after_result.is_ok(),
3762 "after_tick should succeed without checkpointer"
3763 );
3764 }
3765
3766 #[test]
3769 fn test_current_ns_empty_when_no_checkpoint_ns() {
3770 let state = TestState;
3771 let mut nodes = IndexMap::new();
3772 nodes.insert(
3773 "test_node".to_string(),
3774 NodeFnCommand(
3775 |_s: &TestState| -> std::pin::Pin<
3776 Box<
3777 dyn std::future::Future<
3778 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3779 > + Send,
3780 >,
3781 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3782 )
3783 .into_node("test_node"),
3784 );
3785 let trigger_table = TriggerTable::new();
3786 let config = crate::config::RunnableConfig::new();
3787
3788 let loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3789 assert!(
3790 loop_.current_ns().is_empty(),
3791 "root-level graph should have empty ns"
3792 );
3793 }
3794
3795 #[test]
3796 fn test_current_ns_extracts_node_names_from_checkpoint_ns() {
3797 let state = TestState;
3798 let mut nodes = IndexMap::new();
3799 nodes.insert(
3800 "test_node".to_string(),
3801 NodeFnCommand(
3802 |_s: &TestState| -> std::pin::Pin<
3803 Box<
3804 dyn std::future::Future<
3805 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3806 > + Send,
3807 >,
3808 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3809 )
3810 .into_node("test_node"),
3811 );
3812 let trigger_table = TriggerTable::new();
3813 let config = crate::config::RunnableConfig::new().with_checkpoint_ns(
3814 crate::checkpoint::CheckpointNamespace::new(vec![
3815 crate::checkpoint::NamespaceSegment::new(
3816 "review".to_string(),
3817 "uuid-1".to_string(),
3818 ),
3819 crate::checkpoint::NamespaceSegment::new(
3820 "detail".to_string(),
3821 "uuid-2".to_string(),
3822 ),
3823 ]),
3824 );
3825
3826 let loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3827 let ns = loop_.current_ns();
3828 assert_eq!(ns, vec!["review", "detail"]);
3829 }
3830
3831 #[test]
3832 fn test_current_ns_single_segment() {
3833 let state = TestState;
3834 let mut nodes = IndexMap::new();
3835 nodes.insert(
3836 "test_node".to_string(),
3837 NodeFnCommand(
3838 |_s: &TestState| -> std::pin::Pin<
3839 Box<
3840 dyn std::future::Future<
3841 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3842 > + Send,
3843 >,
3844 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3845 )
3846 .into_node("test_node"),
3847 );
3848 let trigger_table = TriggerTable::new();
3849 let config = crate::config::RunnableConfig::new().with_checkpoint_ns(
3850 crate::checkpoint::CheckpointNamespace::new(vec![
3851 crate::checkpoint::NamespaceSegment::new(
3852 "agent".to_string(),
3853 "uuid-single".to_string(),
3854 ),
3855 ]),
3856 );
3857
3858 let loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3859 let ns = loop_.current_ns();
3860 assert_eq!(ns, vec!["agent"]);
3861 }
3862
3863 #[test]
3866 fn test_bubble_up_interrupt_emits_ns_from_checkpoint_ns() {
3867 let state = TestState;
3868 let mut nodes = IndexMap::new();
3869 nodes.insert(
3870 "test_node".to_string(),
3871 NodeFnCommand(
3872 |_s: &TestState| -> std::pin::Pin<
3873 Box<
3874 dyn std::future::Future<
3875 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3876 > + Send,
3877 >,
3878 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3879 )
3880 .into_node("test_node"),
3881 );
3882
3883 let trigger_table = TriggerTable::new();
3884 let checkpoint_ns = crate::checkpoint::CheckpointNamespace::new(vec![
3885 crate::checkpoint::NamespaceSegment::new(
3886 "review".to_string(),
3887 "uuid-parent".to_string(),
3888 ),
3889 ]);
3890 let config = crate::config::RunnableConfig::new().with_checkpoint_ns(checkpoint_ns);
3891
3892 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3893
3894 let (tx, mut rx) = mpsc::unbounded_channel();
3896 loop_.stream_tx = Some(tx);
3897
3898 let signals = vec![crate::interrupt::InterruptSignal {
3899 index: 0,
3900 id: Some("int-ns-0".to_string()),
3901 payload: serde_json::json!({"node": "child_node"}),
3902 timestamp: Utc::now(),
3903 }];
3904 let bubble_ups = vec![BubbleUp::Interrupt(crate::pregel::types::GraphInterrupt {
3905 interrupts: signals,
3906 step: 1,
3907 namespace: vec!["review".to_string()],
3908 })];
3909
3910 let _ = loop_.handle_bubble_ups(&bubble_ups);
3911
3912 let event = rx
3914 .try_recv()
3915 .expect("should have received an interrupt event");
3916 match event {
3917 StreamEvent::Interrupt { ns, .. } => {
3918 assert_eq!(ns, vec!["review"]);
3919 }
3920 other => panic!("expected Interrupt event, got {other:?}"),
3921 }
3922 }
3923
3924 #[test]
3929 fn test_hidden_node_filtered_from_bubble_up_interrupt_stream() {
3930 let state = TestState;
3931 let mut nodes = IndexMap::new();
3932 nodes.insert(
3933 "test_node".to_string(),
3934 NodeFnCommand(
3935 |_s: &TestState| -> std::pin::Pin<
3936 Box<
3937 dyn std::future::Future<
3938 Output = Result<crate::Command<TestState>, crate::JunctureError>,
3939 > + Send,
3940 >,
3941 > { Box::pin(async move { Ok(crate::Command::end()) }) },
3942 )
3943 .into_node("test_node"),
3944 );
3945
3946 let trigger_table = TriggerTable::new();
3947 let config = crate::config::RunnableConfig::new();
3948
3949 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
3950
3951 let (tx, mut rx) = mpsc::unbounded_channel();
3952 loop_.stream_tx = Some(tx);
3953
3954 let signals = vec![
3956 crate::interrupt::InterruptSignal {
3957 index: 0,
3958 id: Some("int-visible".to_string()),
3959 payload: serde_json::json!({"node": "agent"}),
3960 timestamp: Utc::now(),
3961 },
3962 crate::interrupt::InterruptSignal {
3963 index: 1,
3964 id: Some("int-hidden".to_string()),
3965 payload: serde_json::json!({"node": "__route__"}),
3966 timestamp: Utc::now(),
3967 },
3968 crate::interrupt::InterruptSignal {
3969 index: 2,
3970 id: Some("int-also-visible".to_string()),
3971 payload: serde_json::json!({"node": "review"}),
3972 timestamp: Utc::now(),
3973 },
3974 ];
3975 let bubble_ups = vec![BubbleUp::Interrupt(crate::pregel::types::GraphInterrupt {
3976 interrupts: signals,
3977 step: 1,
3978 namespace: vec![],
3979 })];
3980
3981 let _ = loop_.handle_bubble_ups(&bubble_ups);
3982
3983 let mut received_nodes = Vec::new();
3985 while let Ok(event) = rx.try_recv() {
3986 match event {
3987 StreamEvent::Interrupt { node, .. } => received_nodes.push(node),
3988 other => panic!("unexpected event: {other:?}"),
3989 }
3990 }
3991 assert_eq!(
3992 received_nodes,
3993 vec!["agent", "review"],
3994 "hidden node __route__ should be filtered from stream"
3995 );
3996 }
3997
3998 #[test]
4000 fn test_all_hidden_nodes_produce_no_stream_events() {
4001 let state = TestState;
4002 let mut nodes = IndexMap::new();
4003 nodes.insert(
4004 "test_node".to_string(),
4005 NodeFnCommand(
4006 |_s: &TestState| -> std::pin::Pin<
4007 Box<
4008 dyn std::future::Future<
4009 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4010 > + Send,
4011 >,
4012 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4013 )
4014 .into_node("test_node"),
4015 );
4016
4017 let trigger_table = TriggerTable::new();
4018 let config = crate::config::RunnableConfig::new();
4019
4020 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4021
4022 let (tx, mut rx) = mpsc::unbounded_channel();
4023 loop_.stream_tx = Some(tx);
4024
4025 let signals = vec![
4026 crate::interrupt::InterruptSignal {
4027 index: 0,
4028 id: Some("int-h1".to_string()),
4029 payload: serde_json::json!({"node": "__route__"}),
4030 timestamp: Utc::now(),
4031 },
4032 crate::interrupt::InterruptSignal {
4033 index: 1,
4034 id: Some("int-h2".to_string()),
4035 payload: serde_json::json!({"node": "__handler__"}),
4036 timestamp: Utc::now(),
4037 },
4038 ];
4039 let bubble_ups = vec![BubbleUp::Interrupt(crate::pregel::types::GraphInterrupt {
4040 interrupts: signals,
4041 step: 1,
4042 namespace: vec![],
4043 })];
4044
4045 let _ = loop_.handle_bubble_ups(&bubble_ups);
4046
4047 assert!(
4049 rx.try_recv().is_err(),
4050 "all-hidden signals should produce no stream events"
4051 );
4052 assert_eq!(loop_.pending_interrupts.len(), 2);
4054 }
4055
4056 #[test]
4061 fn test_effective_durability_defaults_to_sync() {
4062 let state = TestState;
4063 let mut nodes = IndexMap::new();
4064 nodes.insert(
4065 "test_node".to_string(),
4066 NodeFnCommand(
4067 |_s: &TestState| -> std::pin::Pin<
4068 Box<
4069 dyn std::future::Future<
4070 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4071 > + Send,
4072 >,
4073 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4074 )
4075 .into_node("test_node"),
4076 );
4077 let trigger_table = TriggerTable::new();
4078 let config = crate::config::RunnableConfig::new();
4079
4080 let loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4081 assert_eq!(
4082 loop_.effective_durability(),
4083 Durability::Sync,
4084 "default durability should be Sync"
4085 );
4086 }
4087
4088 #[tokio::test]
4091 async fn test_durability_exit_skips_superstep_saves_final() {
4092 let state = TestState;
4093
4094 let mut nodes = IndexMap::new();
4095 nodes.insert(
4096 "test_node".to_string(),
4097 NodeFnCommand(
4098 |_s: &TestState| -> std::pin::Pin<
4099 Box<
4100 dyn std::future::Future<
4101 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4102 > + Send,
4103 >,
4104 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4105 )
4106 .into_node("test_node"),
4107 );
4108
4109 let trigger_table = TriggerTable::new();
4110 let mut config = crate::config::RunnableConfig::new();
4111 config.thread_id = Some("test-thread".to_string());
4112 config.durability = Some(Durability::Exit);
4113
4114 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4115
4116 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
4117 let checkpointer = TrackingCheckpointer {
4118 observed: Arc::clone(&observed),
4119 };
4120 loop_.set_checkpointer(Arc::new(checkpointer));
4121
4122 loop_.pending_tasks = vec![PendingTask::pull(
4126 uuid::Uuid::new_v4().to_string(),
4127 "test_node".to_string(),
4128 )];
4129 let _ = loop_.execute_superstep().await;
4130 let _ = loop_.after_tick(SuperstepResult::empty()).await;
4131
4132 let calls = observed
4135 .lock()
4136 .unwrap_or_else(std::sync::PoisonError::into_inner)
4137 .clone();
4138 assert_eq!(
4139 calls.len(),
4140 1,
4141 "Exit mode should save exactly one final checkpoint"
4142 );
4143 assert!(
4144 matches!(
4145 &calls[0],
4146 ObservedCall::Put {
4147 source: crate::checkpoint::CheckpointSource::Loop,
4148 step: 0
4149 }
4150 ),
4151 "Final exit checkpoint should have Loop source at step 0"
4152 );
4153 }
4154
4155 #[tokio::test]
4157 async fn test_durability_sync_saves_superstep_checkpoint() {
4158 let state = TestState;
4159
4160 let mut nodes = IndexMap::new();
4161 nodes.insert(
4162 "test_node".to_string(),
4163 NodeFnCommand(
4164 |_s: &TestState| -> std::pin::Pin<
4165 Box<
4166 dyn std::future::Future<
4167 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4168 > + Send,
4169 >,
4170 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4171 )
4172 .into_node("test_node"),
4173 );
4174
4175 let trigger_table = TriggerTable::new();
4176 let mut config = crate::config::RunnableConfig::new();
4177 config.thread_id = Some("test-thread".to_string());
4178 config.durability = Some(Durability::Sync);
4179
4180 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4181
4182 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
4183 let checkpointer = TrackingCheckpointer {
4184 observed: Arc::clone(&observed),
4185 };
4186 loop_.set_checkpointer(Arc::new(checkpointer));
4187
4188 loop_.pending_tasks = vec![PendingTask::pull(
4190 uuid::Uuid::new_v4().to_string(),
4191 "test_node".to_string(),
4192 )];
4193 let _ = loop_.execute_superstep().await;
4194 let _ = loop_.after_tick(SuperstepResult::empty()).await;
4195
4196 let has_loop_checkpoint = {
4197 let calls = observed
4198 .lock()
4199 .unwrap_or_else(std::sync::PoisonError::into_inner);
4200 calls.iter().any(|c| {
4201 matches!(
4202 c,
4203 ObservedCall::Put {
4204 source: crate::checkpoint::CheckpointSource::Loop,
4205 step: 0,
4206 }
4207 )
4208 })
4209 };
4210 assert!(
4211 has_loop_checkpoint,
4212 "Sync mode should save a Loop checkpoint at step 0"
4213 );
4214 }
4215
4216 #[tokio::test]
4218 async fn test_durability_exit_saves_interrupt_checkpoint() {
4219 let state = TestState;
4220
4221 let mut nodes = IndexMap::new();
4222 nodes.insert(
4223 "test_node".to_string(),
4224 NodeFnCommand(
4225 |_s: &TestState| -> std::pin::Pin<
4226 Box<
4227 dyn std::future::Future<
4228 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4229 > + Send,
4230 >,
4231 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4232 )
4233 .into_node("test_node"),
4234 );
4235
4236 let trigger_table = TriggerTable::new();
4237 let mut config = crate::config::RunnableConfig::new();
4238 config.thread_id = Some("test-thread".to_string());
4239 config.durability = Some(Durability::Exit);
4240
4241 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4242
4243 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
4244 let checkpointer = TrackingCheckpointer {
4245 observed: Arc::clone(&observed),
4246 };
4247 loop_.set_checkpointer(Arc::new(checkpointer));
4248
4249 loop_.pending_interrupts = vec![crate::interrupt::InterruptSignal {
4251 index: 0,
4252 id: Some("int-exit-test".to_string()),
4253 payload: serde_json::json!({"node": "test_node"}),
4254 timestamp: Utc::now(),
4255 }];
4256 loop_.save_interrupt_checkpoint("test_node").await;
4257
4258 let has_interrupt_checkpoint = {
4259 let calls = observed
4260 .lock()
4261 .unwrap_or_else(std::sync::PoisonError::into_inner);
4262 calls.iter().any(|c| {
4263 matches!(
4264 c,
4265 ObservedCall::Put {
4266 source: crate::checkpoint::CheckpointSource::Interrupt { .. },
4267 step: 0,
4268 }
4269 )
4270 })
4271 };
4272 assert!(
4273 has_interrupt_checkpoint,
4274 "Exit mode should still save interrupt checkpoints"
4275 );
4276 }
4277
4278 #[tokio::test]
4284 async fn test_budget_tracker_arc_sharing() {
4285 let state = TestState;
4286 let mut nodes = IndexMap::new();
4287 nodes.insert(
4288 "test_node".to_string(),
4289 NodeFnCommand(
4290 |_s: &TestState| -> std::pin::Pin<
4291 Box<
4292 dyn std::future::Future<
4293 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4294 > + Send,
4295 >,
4296 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4297 )
4298 .into_node("test_node"),
4299 );
4300
4301 let trigger_table = TriggerTable::new();
4302 let budget = crate::pregel::budget::BudgetConfig::new().with_max_tokens(100);
4303 let config = crate::config::RunnableConfig::new().with_budget(budget);
4304
4305 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4306
4307 let tracker_config = loop_.runnable_config.budget.clone().unwrap();
4309 loop_.set_budget_tracker(BudgetTracker::new(tracker_config));
4310
4311 assert!(loop_.budget_tracker.as_ref().unwrap().check().is_none());
4313
4314 if let Some(ref tracker) = loop_.runnable_config.budget_tracker {
4316 tracker.report_model_call(30, 20); }
4318
4319 let usage = loop_.budget_tracker.as_ref().unwrap().current_usage();
4321 assert_eq!(usage.tokens_used, 50);
4322
4323 assert!(loop_.budget_tracker.as_ref().unwrap().check().is_none());
4325
4326 if let Some(ref tracker) = loop_.runnable_config.budget_tracker {
4328 tracker.report_model_call(40, 30); }
4330
4331 assert!(loop_.budget_tracker.as_ref().unwrap().check().is_some());
4333 assert_eq!(
4334 loop_
4335 .budget_tracker
4336 .as_ref()
4337 .unwrap()
4338 .current_usage()
4339 .tokens_used,
4340 120
4341 );
4342
4343 let _ = loop_.tick().unwrap_err();
4345 assert!(loop_.status.is_terminal());
4346 }
4347
4348 #[tokio::test]
4352 async fn test_budget_tracker_cost_via_config() {
4353 let state = TestState;
4354 let mut nodes = IndexMap::new();
4355 nodes.insert(
4356 "test_node".to_string(),
4357 NodeFnCommand(
4358 |_s: &TestState| -> std::pin::Pin<
4359 Box<
4360 dyn std::future::Future<
4361 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4362 > + Send,
4363 >,
4364 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4365 )
4366 .into_node("test_node"),
4367 );
4368
4369 let trigger_table = TriggerTable::new();
4370 let budget = crate::pregel::budget::BudgetConfig::new().with_max_cost_usd(0.01);
4371 let config = crate::config::RunnableConfig::new().with_budget(budget);
4372
4373 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4374
4375 let tracker_config = loop_.runnable_config.budget.clone().unwrap();
4376 loop_.set_budget_tracker(BudgetTracker::new(tracker_config));
4377
4378 if let Some(ref tracker) = loop_.runnable_config.budget_tracker {
4380 tracker.report_cost(0.003);
4381 tracker.report_cost(0.004);
4382 }
4383
4384 let usage = loop_.budget_tracker.as_ref().unwrap().current_usage();
4386 assert!((usage.cost_usd - 0.007).abs() < 0.0001);
4387 assert!(loop_.budget_tracker.as_ref().unwrap().check().is_none());
4388
4389 if let Some(ref tracker) = loop_.runnable_config.budget_tracker {
4391 tracker.report_cost(0.004); }
4393
4394 assert!(loop_.budget_tracker.as_ref().unwrap().check().is_some());
4395
4396 let _ = loop_.tick().unwrap_err();
4398 assert!(loop_.status.is_terminal());
4399 }
4400
4401 #[tokio::test]
4403 async fn test_durability_async_does_not_block() {
4404 let state = TestState;
4405
4406 let mut nodes = IndexMap::new();
4407 nodes.insert(
4408 "test_node".to_string(),
4409 NodeFnCommand(
4410 |_s: &TestState| -> std::pin::Pin<
4411 Box<
4412 dyn std::future::Future<
4413 Output = Result<crate::Command<TestState>, crate::JunctureError>,
4414 > + Send,
4415 >,
4416 > { Box::pin(async move { Ok(crate::Command::end()) }) },
4417 )
4418 .into_node("test_node"),
4419 );
4420
4421 let trigger_table = TriggerTable::new();
4422 let mut config = crate::config::RunnableConfig::new();
4423 config.thread_id = Some("test-thread".to_string());
4424 config.durability = Some(Durability::Async);
4425
4426 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4427
4428 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
4429 let checkpointer = TrackingCheckpointer {
4430 observed: Arc::clone(&observed),
4431 };
4432 loop_.set_checkpointer(Arc::new(checkpointer));
4433
4434 loop_.pending_tasks = vec![PendingTask::pull(
4436 uuid::Uuid::new_v4().to_string(),
4437 "test_node".to_string(),
4438 )];
4439 let _ = loop_.execute_superstep().await;
4440 let _ = loop_.after_tick(SuperstepResult::empty()).await;
4441
4442 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
4445
4446 let has_checkpoint = {
4448 let calls = observed
4449 .lock()
4450 .unwrap_or_else(std::sync::PoisonError::into_inner);
4451 calls.iter().any(|c| {
4452 matches!(
4453 c,
4454 ObservedCall::Put {
4455 source: crate::checkpoint::CheckpointSource::Loop,
4456 step: 0,
4457 }
4458 )
4459 })
4460 };
4461 assert!(
4462 has_checkpoint,
4463 "Async mode should eventually persist the checkpoint via spawned task"
4464 );
4465 }
4466
4467 #[tokio::test]
4472 async fn test_stream_data_emits_custom_events() {
4473 let state = TestState;
4474 let nodes = IndexMap::new();
4475 let trigger_table = TriggerTable::new();
4476 let config = crate::config::RunnableConfig::new();
4477
4478 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4479
4480 let (tx, mut rx) = mpsc::unbounded_channel();
4481 loop_.stream_tx = Some(tx);
4482
4483 let result = SuperstepResult {
4485 task_outputs: vec![TaskOutput {
4486 triggered_fields: vec![],
4487 task_id: "task-1".to_string(),
4488 node_name: "test_node".to_string(),
4489 command: Command::end()
4490 .with_stream_data(serde_json::json!({"event": "first"}))
4491 .with_stream_data(serde_json::json!({"event": "second"})),
4492 duration: std::time::Duration::from_millis(1),
4493 trigger: TaskTrigger::Pull,
4494 error: None,
4495 circuit_blocked: false,
4496 }],
4497 bubble_ups: Vec::new(),
4498 };
4499
4500 let () = loop_.after_tick(result).await.unwrap();
4501
4502 let mut custom_data = Vec::new();
4504 while let Ok(event) = rx.try_recv() {
4505 if let StreamEvent::Custom { node, data, ns } = event {
4506 assert_eq!(node, "test_node");
4507 assert!(ns.is_empty());
4508 custom_data.push(data);
4509 }
4510 }
4511
4512 assert_eq!(custom_data.len(), 2, "should emit two custom events");
4513 assert_eq!(custom_data[0], serde_json::json!({"event": "first"}));
4514 assert_eq!(custom_data[1], serde_json::json!({"event": "second"}));
4515 }
4516
4517 #[tokio::test]
4519 async fn test_stream_data_empty_produces_no_custom_events() {
4520 let state = TestState;
4521 let nodes = IndexMap::new();
4522 let trigger_table = TriggerTable::new();
4523 let config = crate::config::RunnableConfig::new();
4524
4525 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4526
4527 let (tx, mut rx) = mpsc::unbounded_channel();
4528 loop_.stream_tx = Some(tx);
4529
4530 let result = SuperstepResult {
4532 task_outputs: vec![TaskOutput {
4533 triggered_fields: vec![],
4534 task_id: "task-1".to_string(),
4535 node_name: "test_node".to_string(),
4536 command: Command::end(),
4537 duration: std::time::Duration::from_millis(1),
4538 trigger: TaskTrigger::Pull,
4539 error: None,
4540 circuit_blocked: false,
4541 }],
4542 bubble_ups: Vec::new(),
4543 };
4544
4545 let () = loop_.after_tick(result).await.unwrap();
4546
4547 while let Ok(event) = rx.try_recv() {
4549 assert!(
4550 !matches!(event, StreamEvent::Custom { .. }),
4551 "no Custom events expected for empty stream_data"
4552 );
4553 }
4554 }
4555
4556 #[tokio::test]
4558 async fn test_stream_data_multiple_tasks() {
4559 let state = TestState;
4560 let nodes = IndexMap::new();
4561 let trigger_table = TriggerTable::new();
4562 let config = crate::config::RunnableConfig::new();
4563
4564 let mut loop_ = PregelLoop::new(state, nodes, trigger_table, config, 0).unwrap();
4565
4566 let (tx, mut rx) = mpsc::unbounded_channel();
4567 loop_.stream_tx = Some(tx);
4568
4569 let result = SuperstepResult {
4571 task_outputs: vec![
4572 TaskOutput {
4573 triggered_fields: vec![],
4574 task_id: "task-1".to_string(),
4575 node_name: "node_a".to_string(),
4576 command: Command::end().with_stream_data(serde_json::json!("from_a")),
4577 duration: std::time::Duration::from_millis(1),
4578 trigger: TaskTrigger::Pull,
4579 error: None,
4580 circuit_blocked: false,
4581 },
4582 TaskOutput {
4583 triggered_fields: vec![],
4584 task_id: "task-2".to_string(),
4585 node_name: "node_b".to_string(),
4586 command: Command::end(),
4587 duration: std::time::Duration::from_millis(2),
4588 trigger: TaskTrigger::Pull,
4589 error: None,
4590 circuit_blocked: false,
4591 },
4592 ],
4593 bubble_ups: Vec::new(),
4594 };
4595
4596 let () = loop_.after_tick(result).await.unwrap();
4597
4598 let mut custom_events = Vec::new();
4600 while let Ok(event) = rx.try_recv() {
4601 if let StreamEvent::Custom { node, data, .. } = event {
4602 custom_events.push((node, data));
4603 }
4604 }
4605
4606 assert_eq!(
4607 custom_events.len(),
4608 1,
4609 "only node_a should emit a custom event"
4610 );
4611 assert_eq!(custom_events[0].0, "node_a");
4612 assert_eq!(custom_events[0].1, serde_json::json!("from_a"));
4613 }
4614}
4615
4616