1use std::collections::HashMap;
10use std::time::Instant;
11
12use tokio::sync::mpsc;
13
14use crate::barrier_node::BarrierDefaultAction;
15use crate::checkpoint::{
16 Checkpoint, CheckpointPolicy, CheckpointScore, CheckpointStore, CheckpointTrigger,
17 ExecutionMetadata, IncrementalSnapshotState,
18};
19use crate::delta::{Reducer, ReducerRegistry, StateDelta};
20use crate::error::{GraphError, ObservedError, TerminalError};
21use crate::event::{
22 BarrierDecision, BarrierDecisionMessage, BarrierId, FlowEvent, GraphEvent, GraphExecution,
23 GraphHandle,
24};
25use crate::graph::Graph;
26use crate::ids::{SpanId, TraceId};
27use crate::node::{FlowNode, NextStep, NodeKind, ParallelErrorStrategy, StreamNodeResult};
28use crate::state::{ExecutionEntry, GraphResult, State};
29
30struct DecisionRegistry {
36 pending: HashMap<BarrierId, BarrierDecision>,
37 wildcards: HashMap<String, BarrierDecision>,
38 occurrence_counter: HashMap<String, u32>,
39}
40
41impl DecisionRegistry {
42 fn new() -> Self {
43 Self {
44 pending: HashMap::new(),
45 wildcards: HashMap::new(),
46 occurrence_counter: HashMap::new(),
47 }
48 }
49
50 fn next_id(&mut self, node_id: &str) -> BarrierId {
51 let occ = self
52 .occurrence_counter
53 .entry(node_id.to_string())
54 .or_insert(0);
55 *occ += 1;
56 BarrierId::new(node_id, *occ)
57 }
58
59 fn take(&mut self, target_id: &BarrierId) -> Option<BarrierDecision> {
60 if let Some(decision) = self.pending.remove(target_id) {
61 return Some(decision);
62 }
63 self.wildcards.get(&target_id.node_id).cloned()
64 }
65
66 fn process_message(
67 &mut self,
68 msg: BarrierDecisionMessage,
69 target_id: &BarrierId,
70 ) -> Option<BarrierDecision> {
71 match msg {
72 BarrierDecisionMessage::Exact {
73 barrier_id,
74 decision,
75 } => {
76 if barrier_id == *target_id {
77 Some(decision)
78 } else {
79 self.pending.insert(barrier_id, decision);
80 None
81 }
82 }
83 BarrierDecisionMessage::Wildcard { node_id, decision } => {
84 self.wildcards.insert(node_id.clone(), decision.clone());
86 if node_id == target_id.node_id {
87 Some(decision)
88 } else {
89 None
90 }
91 }
92 }
93 }
94}
95
96#[derive(Debug)]
100enum StepOutcome {
101 Continue(String),
103 Break,
105 ErrorSent,
107}
108
109pub struct GraphExecutor {
115 pub max_steps: usize,
118 store: Option<std::sync::Arc<dyn CheckpointStore>>,
120 policy: CheckpointPolicy,
122 graph_hash: String,
124 pending_reducers: Vec<(String, Reducer)>,
126 checkpoint_score: CheckpointScore,
128 last_checkpoint_state: Option<State>,
130 pending_deltas: Vec<StateDelta>,
132 delta_compact_threshold: usize,
134}
135
136impl Clone for GraphExecutor {
137 fn clone(&self) -> Self {
138 Self {
139 max_steps: self.max_steps,
140 store: self.store.clone(),
141 policy: self.policy.clone(),
142 graph_hash: self.graph_hash.clone(),
143 pending_reducers: self.pending_reducers.clone(),
144 checkpoint_score: self.checkpoint_score.clone(),
145 last_checkpoint_state: self.last_checkpoint_state.clone(),
146 pending_deltas: self.pending_deltas.clone(),
147 delta_compact_threshold: self.delta_compact_threshold,
148 }
149 }
150}
151
152impl std::fmt::Debug for GraphExecutor {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("GraphExecutor")
155 .field("max_steps", &self.max_steps)
156 .field("has_store", &self.store.is_some())
157 .field("policy", &self.policy)
158 .field("graph_hash", &self.graph_hash)
159 .finish()
160 }
161}
162
163impl Default for GraphExecutor {
164 fn default() -> Self {
165 Self {
166 max_steps: 50,
167 store: None,
168 policy: CheckpointPolicy::default(),
169 graph_hash: String::new(),
170 pending_reducers: Vec::new(),
171 checkpoint_score: CheckpointScore::default(),
172 last_checkpoint_state: None,
173 pending_deltas: Vec::new(),
174 delta_compact_threshold: 20,
175 }
176 }
177}
178
179impl GraphExecutor {
180 pub fn new(max_steps: usize) -> Self {
182 Self {
183 max_steps,
184 store: None,
185 policy: CheckpointPolicy::default(),
186 graph_hash: String::new(),
187 pending_reducers: Vec::new(),
188 checkpoint_score: CheckpointScore::default(),
189 last_checkpoint_state: None,
190 pending_deltas: Vec::new(),
191 delta_compact_threshold: 20,
192 }
193 }
194
195 pub fn with_checkpoint(
197 max_steps: usize,
198 store: std::sync::Arc<dyn CheckpointStore>,
199 policy: CheckpointPolicy,
200 graph: &Graph,
201 ) -> Self {
202 Self {
203 max_steps,
204 store: Some(store),
205 policy,
206 graph_hash: graph.hash(),
207 pending_reducers: Vec::new(),
208 checkpoint_score: CheckpointScore::default(),
209 last_checkpoint_state: None,
210 pending_deltas: Vec::new(),
211 delta_compact_threshold: 20,
212 }
213 }
214
215 pub fn set_store(&mut self, store: std::sync::Arc<dyn CheckpointStore>) {
217 self.store = Some(store);
218 }
219
220 pub fn set_policy(&mut self, policy: CheckpointPolicy) {
222 self.policy = policy;
223 }
224
225 pub fn register_reducer(&mut self, key: &str, reducer: Reducer) {
227 self.pending_reducers.push((key.to_string(), reducer));
230 }
231
232 pub fn set_checkpoint_score(mut self, score: CheckpointScore) -> Self {
234 self.checkpoint_score = score;
235 self
236 }
237
238 pub fn set_graph(&mut self, graph: &Graph) {
240 self.graph_hash = graph.hash();
241 }
242
243 pub async fn execute(
252 &self,
253 graph: std::sync::Arc<Graph>,
254 initial_state: State,
255 ) -> Result<GraphResult, GraphError> {
256 for (name, node) in &graph.nodes {
257 if matches!(node, NodeKind::Barrier(_)) {
258 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
259 "BarrierNode '{}' requires stream mode. Use GraphExecutor::execute_stream() for human-in-the-loop.",
260 name
261 ))));
262 }
263 }
264
265 let GraphExecution { mut stream, handle } = self.execute_stream(graph, initial_state);
266
267 drop(handle);
268
269 let mut result = None;
270
271 while let Some(event) = stream.recv().await {
272 match event {
273 GraphEvent::GraphComplete { result: r } => {
274 result = Some(Ok(r));
275 }
276 GraphEvent::GraphError { error, .. } => {
277 result = Some(Err(error));
278 }
279 _ => {}
280 }
281 }
282
283 result.unwrap_or_else(|| {
284 Err(GraphError::Terminal(TerminalError::InvalidGraph(
285 "stream ended without completion".into(),
286 )))
287 })
288 }
289
290 pub fn execute_stream(
296 &self,
297 graph: std::sync::Arc<Graph>,
298 initial_state: State,
299 ) -> GraphExecution {
300 let executor = self.clone();
301 let (event_tx, event_rx) = mpsc::channel(32);
302 let (decision_tx, decision_rx) = mpsc::channel(16);
303 let (cancel_tx, cancel_rx) = mpsc::channel(1);
304 let (checkpoint_tx, checkpoint_rx) = mpsc::channel(8);
305
306 let handle = GraphHandle::new(decision_tx, cancel_tx, checkpoint_tx);
307
308 tokio::spawn(async move {
309 executor
310 .run_loop(
311 graph,
312 initial_state,
313 event_tx,
314 decision_rx,
315 cancel_rx,
316 checkpoint_rx,
317 )
318 .await;
319 });
320
321 GraphExecution {
322 stream: event_rx,
323 handle,
324 }
325 }
326
327 async fn run_loop(
329 &self,
330 graph: std::sync::Arc<Graph>,
331 initial_state: State,
332 event_tx: mpsc::Sender<GraphEvent>,
333 mut decision_rx: mpsc::Receiver<BarrierDecisionMessage>,
334 mut cancel_rx: mpsc::Receiver<()>,
335 mut checkpoint_rx: mpsc::Receiver<()>,
336 ) {
337 let start_time = Instant::now();
338 let mut state = initial_state;
339 let mut execution_log = Vec::new();
340 let mut decision_registry = DecisionRegistry::new();
341 let mut reducer_registry = ReducerRegistry::new();
342 let mut snapshot_state = IncrementalSnapshotState::new(self.delta_compact_threshold);
343
344 for (key, reducer) in &self.pending_reducers {
346 reducer_registry.register(key, *reducer);
347 }
348
349 let mut current = graph.start_node().to_string();
350 let mut step: usize = 0;
351 let trace_id = TraceId::default();
352
353 if self
355 .send(&event_tx, GraphEvent::GraphStart { trace_id })
356 .await
357 {
358 return;
359 }
360
361 loop {
362 if cancel_rx.try_recv().is_ok() {
364 self.send_graph_error(
365 &event_tx,
366 GraphError::Terminal(TerminalError::BarrierCancelled {
367 node: "execution cancelled by handle".into(),
368 }),
369 &state,
370 &execution_log,
371 start_time,
372 trace_id,
373 )
374 .await;
375 return;
376 }
377
378 if checkpoint_rx.try_recv().is_ok() {
380 self.save_checkpoint_if_needed(
381 &event_tx,
382 &trace_id,
383 ¤t,
384 &state,
385 step,
386 CheckpointTrigger::Explicit,
387 &mut snapshot_state,
388 )
389 .await;
390 }
391
392 step += 1;
393
394 if step > self.max_steps {
396 self.send_graph_error(
397 &event_tx,
398 GraphError::Terminal(TerminalError::StepsExceeded {
399 limit: self.max_steps,
400 }),
401 &state,
402 &execution_log,
403 start_time,
404 trace_id,
405 )
406 .await;
407 return;
408 }
409
410 let node = match graph.nodes.get(¤t) {
412 Some(n) => n,
413 None => {
414 self.send_graph_error(
415 &event_tx,
416 GraphError::Terminal(TerminalError::NodeNotFound(current.clone())),
417 &state,
418 &execution_log,
419 start_time,
420 trace_id,
421 )
422 .await;
423 return;
424 }
425 };
426
427 let node_name = current.clone();
428 let span_id = SpanId::new();
429
430 if self
431 .send(
432 &event_tx,
433 GraphEvent::NodeStart {
434 node_name: node_name.clone(),
435 trace_id,
436 span_id,
437 step,
438 },
439 )
440 .await
441 {
442 return;
443 }
444
445 let node_start = Instant::now();
446 let result = if matches!(node, NodeKind::Parallel(_)) {
447 self.handle_parallel(node, &state, &event_tx, span_id, &node_name)
448 .await
449 } else {
450 node.execute_stream(&state, &event_tx, span_id).await
451 };
452 let node_end = Instant::now();
453 let duration = node_end.duration_since(node_start);
454
455 match result {
456 Ok(StreamNodeResult::Continue {
457 deltas,
458 next,
459 span_id,
460 observed,
461 metadata,
462 }) => {
463 if self.policy.has_adaptive_trigger() {
465 let exec_metadata = ExecutionMetadata {
466 duration_ms: duration.as_millis() as u64,
467 token_cost: metadata.as_ref().map_or(0.0, |m| m.token_cost),
468 has_side_effects: metadata
469 .as_ref()
470 .is_some_and(|m| m.has_side_effects),
471 };
472 self.save_checkpoint_if_needed(
473 &event_tx,
474 &trace_id,
475 ¤t,
476 &state,
477 step,
478 CheckpointTrigger::Adaptive(exec_metadata),
479 &mut snapshot_state,
480 )
481 .await;
482 }
483
484 snapshot_state.record_deltas(deltas.clone());
486
487 if matches!(node, NodeKind::Parallel(_)) {
489 if let Err(e) = reducer_registry.merge_deltas(&mut state, &deltas) {
491 self.handle_error(
493 &event_tx,
494 &mut execution_log,
495 &node_name,
496 node_start,
497 node_end,
498 span_id,
499 step,
500 trace_id,
501 GraphError::Terminal(TerminalError::StateError(format!(
502 "parallel merge conflict: {}",
503 e
504 ))),
505 &state,
506 )
507 .await;
508 return;
509 }
510 for delta in &deltas {
512 let d: StateDelta = delta.clone();
513 let _ = self
514 .send(
515 &event_tx,
516 GraphEvent::Node {
517 span_id: SpanId::new(),
518 node_name: node_name.to_string(),
519 event: FlowEvent::StateChanged {
520 node_id: node_name.to_string(),
521 delta: d,
522 },
523 },
524 )
525 .await;
526 }
527 } else {
528 self.apply_deltas(
529 &event_tx,
530 &mut reducer_registry,
531 &mut state,
532 &node_name,
533 &deltas,
534 )
535 .await;
536 }
537
538 let outcome = self
539 .handle_continue(
540 &event_tx,
541 &graph,
542 ¤t,
543 &mut state,
544 &mut execution_log,
545 next,
546 span_id,
547 observed,
548 step,
549 &node_name,
550 node_start,
551 node_end,
552 duration,
553 trace_id,
554 )
555 .await;
556
557 match outcome {
558 StepOutcome::Continue(target) => {
559 self.save_checkpoint_if_needed(
561 &event_tx,
562 &trace_id,
563 &target,
564 &state,
565 step,
566 CheckpointTrigger::Explicit,
567 &mut snapshot_state,
568 )
569 .await;
570 current = target;
571 }
572 StepOutcome::Break => {
573 self.send_graph_complete(
575 &event_tx,
576 &state,
577 &execution_log,
578 start_time,
579 trace_id,
580 &mut snapshot_state,
581 )
582 .await;
583 return;
584 }
585 StepOutcome::ErrorSent => {
586 return;
587 }
588 }
589 }
590
591 Ok(StreamNodeResult::Pause {
592 deltas: barrier_deltas,
593 node_name: barrier_name,
594 span_id,
595 timeout,
596 default_action,
597 ..
598 }) => {
599 self.apply_deltas(
601 &event_tx,
602 &mut reducer_registry,
603 &mut state,
604 &barrier_name,
605 &barrier_deltas,
606 )
607 .await;
608
609 let outcome = self
610 .handle_barrier(
611 &event_tx,
612 &graph,
613 &mut decision_rx,
614 &mut decision_registry,
615 &mut cancel_rx,
616 &mut reducer_registry,
617 node,
618 ¤t,
619 &mut state,
620 &mut execution_log,
621 &barrier_name,
622 span_id,
623 timeout,
624 default_action,
625 step,
626 node_start,
627 trace_id,
628 )
629 .await;
630
631 match outcome {
632 StepOutcome::Continue(target) => {
633 self.save_checkpoint_if_needed(
635 &event_tx,
636 &trace_id,
637 &target,
638 &state,
639 step,
640 CheckpointTrigger::BarrierResolved,
641 &mut snapshot_state,
642 )
643 .await;
644 current = target;
645 }
646 StepOutcome::Break => {
647 self.send_graph_complete(
649 &event_tx,
650 &state,
651 &execution_log,
652 start_time,
653 trace_id,
654 &mut snapshot_state,
655 )
656 .await;
657 return;
658 }
659 StepOutcome::ErrorSent => {
660 return;
661 }
662 }
663 }
664
665 Ok(StreamNodeResult::Fallback {
666 deltas: fallback_deltas,
667 reason,
668 node_name: fallback_node,
669 }) => {
670 self.apply_deltas(
672 &event_tx,
673 &mut reducer_registry,
674 &mut state,
675 &fallback_node,
676 &fallback_deltas,
677 )
678 .await;
679
680 let outcome = self
681 .handle_fallback(
682 &event_tx,
683 &graph,
684 ¤t,
685 &mut state,
686 &mut execution_log,
687 &fallback_node,
688 &reason,
689 step,
690 node_start,
691 node_end,
692 trace_id,
693 )
694 .await;
695
696 match outcome {
697 StepOutcome::Continue(target) => {
698 current = target;
699 }
700 StepOutcome::ErrorSent => {
701 return;
702 }
703 StepOutcome::Break => {
704 unreachable!("handle_fallback only returns Continue or ErrorSent");
706 }
707 }
708 }
709
710 Err(e) => {
711 self.handle_error(
712 &event_tx,
713 &mut execution_log,
714 &node_name,
715 node_start,
716 node_end,
717 span_id,
718 step,
719 trace_id,
720 e,
721 &state,
722 )
723 .await;
724 return;
725 }
726 }
727 }
728 }
729
730 #[allow(clippy::too_many_arguments)]
734 async fn handle_continue(
735 &self,
736 event_tx: &mpsc::Sender<GraphEvent>,
737 graph: &Graph,
738 current: &str,
739 state: &mut State,
740 execution_log: &mut Vec<ExecutionEntry>,
741 next: NextStep,
742 span_id: SpanId,
743 observed: Option<ObservedError>,
744 step: usize,
745 node_name: &str,
746 node_start: Instant,
747 node_end: Instant,
748 duration: std::time::Duration,
749 trace_id: TraceId,
750 ) -> StepOutcome {
751 execution_log.push(ExecutionEntry {
753 step,
754 node_name: node_name.to_string(),
755 start_time: node_start,
756 end_time: node_end,
757 success: true,
758 });
759
760 if self
762 .send(
763 event_tx,
764 GraphEvent::NodeEnd {
765 node_name: node_name.to_string(),
766 trace_id,
767 span_id,
768 success: true,
769 duration,
770 },
771 )
772 .await
773 {
774 return StepOutcome::Break;
775 }
776
777 if let Some(error) = observed
779 && self
780 .send(
781 event_tx,
782 GraphEvent::ObservedError {
783 error,
784 node_name: node_name.to_string(),
785 },
786 )
787 .await
788 {
789 return StepOutcome::Break;
790 }
791
792 if current == graph.end_node() {
794 return StepOutcome::Break;
795 }
796
797 match self.resolve_next(graph, current, state, next) {
799 Ok(target) => StepOutcome::Continue(target),
800 Err(e) => {
801 self.send_graph_error(event_tx, e, state, execution_log, Instant::now(), trace_id)
802 .await;
803 StepOutcome::ErrorSent
804 }
805 }
806 }
807
808 #[allow(clippy::too_many_arguments)]
812 async fn handle_barrier(
813 &self,
814 event_tx: &mpsc::Sender<GraphEvent>,
815 graph: &Graph,
816 decision_rx: &mut mpsc::Receiver<BarrierDecisionMessage>,
817 decision_registry: &mut DecisionRegistry,
818 cancel_rx: &mut mpsc::Receiver<()>,
819 reducer_registry: &mut ReducerRegistry,
820 node: &NodeKind,
821 current: &str,
822 state: &mut State,
823 execution_log: &mut Vec<ExecutionEntry>,
824 barrier_name: &str,
825 span_id: SpanId,
826 timeout: Option<std::time::Duration>,
827 default_action: BarrierDefaultAction,
828 step: usize,
829 node_start: Instant,
830 trace_id: TraceId,
831 ) -> StepOutcome {
832 let barrier_id = decision_registry.next_id(barrier_name);
833
834 if self
836 .send(
837 event_tx,
838 GraphEvent::BarrierWaiting {
839 barrier_id: barrier_id.clone(),
840 node_name: barrier_name.to_string(),
841 span_id,
842 },
843 )
844 .await
845 {
846 return StepOutcome::Break;
847 }
848
849 let decision = self
851 .wait_barrier_decision(
852 decision_rx,
853 decision_registry,
854 &barrier_id,
855 timeout,
856 &default_action,
857 cancel_rx,
858 )
859 .await;
860
861 if cancel_rx.try_recv().is_ok() {
863 self.send_graph_error(
864 event_tx,
865 GraphError::Terminal(TerminalError::BarrierCancelled {
866 node: barrier_name.to_string(),
867 }),
868 state,
869 execution_log,
870 node_start,
871 trace_id,
872 )
873 .await;
874 return StepOutcome::ErrorSent;
875 }
876
877 if self
879 .send(
880 event_tx,
881 GraphEvent::BarrierResolved {
882 barrier_id: barrier_id.clone(),
883 decision: decision.clone(),
884 },
885 )
886 .await
887 {
888 return StepOutcome::Break;
889 }
890
891 let (next, barrier_deltas) = match node {
893 NodeKind::Barrier(b) => b.apply_decision(decision),
894 _ => {
895 self.send_graph_error(
896 event_tx,
897 GraphError::Terminal(TerminalError::InvalidGraph(
898 "expected BarrierNode but got unexpected node type for BarrierPaused"
899 .to_string(),
900 )),
901 state,
902 execution_log,
903 node_start,
904 trace_id,
905 )
906 .await;
907 return StepOutcome::ErrorSent;
908 }
909 };
910
911 self.apply_deltas(
913 event_tx,
914 reducer_registry,
915 state,
916 barrier_name,
917 &barrier_deltas,
918 )
919 .await;
920
921 let end_time = Instant::now();
923 execution_log.push(ExecutionEntry {
924 step,
925 node_name: barrier_name.to_string(),
926 start_time: node_start,
927 end_time,
928 success: true,
929 });
930
931 if self
933 .send(
934 event_tx,
935 GraphEvent::NodeEnd {
936 node_name: barrier_name.to_string(),
937 trace_id,
938 span_id,
939 success: true,
940 duration: end_time.duration_since(node_start),
941 },
942 )
943 .await
944 {
945 return StepOutcome::Break;
946 }
947
948 if current == graph.end_node() {
950 return StepOutcome::Break;
951 }
952
953 match self.resolve_next(graph, current, state, next) {
955 Ok(target) => StepOutcome::Continue(target),
956 Err(e) => {
957 self.send_graph_error(event_tx, e, state, execution_log, end_time, trace_id)
958 .await;
959 StepOutcome::ErrorSent
960 }
961 }
962 }
963
964 #[allow(clippy::too_many_arguments)]
970 async fn handle_fallback(
971 &self,
972 event_tx: &mpsc::Sender<GraphEvent>,
973 graph: &Graph,
974 current: &str,
975 state: &mut State,
976 execution_log: &mut Vec<ExecutionEntry>,
977 fallback_node: &str,
978 reason: &str,
979 step: usize,
980 node_start: Instant,
981 node_end: Instant,
982 trace_id: TraceId,
983 ) -> StepOutcome {
984 execution_log.push(ExecutionEntry {
986 step,
987 node_name: fallback_node.to_string(),
988 start_time: node_start,
989 end_time: node_end,
990 success: false,
991 });
992
993 if let Some(fallback_target) = graph.find_fallback_edge(current) {
995 if self
997 .send(
998 event_tx,
999 GraphEvent::ObservedError {
1000 error: ObservedError::Degraded {
1001 node: fallback_node.to_string(),
1002 message: format!("fallback to '{}': {}", fallback_target, reason),
1003 },
1004 node_name: fallback_node.to_string(),
1005 },
1006 )
1007 .await
1008 {
1009 return StepOutcome::ErrorSent;
1010 }
1011 StepOutcome::Continue(fallback_target)
1012 } else {
1013 self.send_graph_error(
1015 event_tx,
1016 GraphError::Terminal(TerminalError::NodeExecutionFailed {
1017 node: fallback_node.to_string(),
1018 source: format!("fallback with no fallback edge: {}", reason).into(),
1019 }),
1020 state,
1021 execution_log,
1022 node_end,
1023 trace_id,
1024 )
1025 .await;
1026 StepOutcome::ErrorSent
1027 }
1028 }
1029
1030 #[allow(clippy::too_many_arguments)]
1034 async fn handle_error(
1035 &self,
1036 event_tx: &mpsc::Sender<GraphEvent>,
1037 execution_log: &mut Vec<ExecutionEntry>,
1038 node_name: &str,
1039 node_start: Instant,
1040 node_end: Instant,
1041 span_id: SpanId,
1042 step: usize,
1043 trace_id: TraceId,
1044 error: GraphError,
1045 state: &State,
1046 ) {
1047 let duration = node_end.duration_since(node_start);
1048
1049 execution_log.push(ExecutionEntry {
1051 step,
1052 node_name: node_name.to_string(),
1053 start_time: node_start,
1054 end_time: node_end,
1055 success: false,
1056 });
1057
1058 if self
1060 .send(
1061 event_tx,
1062 GraphEvent::NodeEnd {
1063 node_name: node_name.to_string(),
1064 trace_id,
1065 span_id,
1066 success: false,
1067 duration,
1068 },
1069 )
1070 .await
1071 {
1072 return;
1073 }
1074
1075 self.send_graph_error(event_tx, error, state, execution_log, node_end, trace_id)
1077 .await;
1078 }
1079
1080 async fn handle_parallel(
1086 &self,
1087 node: &NodeKind,
1088 state: &State,
1089 event_tx: &mpsc::Sender<GraphEvent>,
1090 parent_span_id: SpanId,
1091 node_name: &str,
1092 ) -> Result<StreamNodeResult, GraphError> {
1093 let parallel = match node {
1094 NodeKind::Parallel(p) => p,
1095 _ => unreachable!("handle_parallel called on non-Parallel node"),
1096 };
1097
1098 let branch_count = parallel.branch_count();
1099 let error_strategy = parallel.error_strategy();
1100 let display_name = parallel.label().unwrap_or(node_name).to_string();
1101
1102 if self
1104 .send(
1105 event_tx,
1106 GraphEvent::Node {
1107 span_id: parent_span_id,
1108 node_name: node_name.to_string(),
1109 event: FlowEvent::ParallelStarted {
1110 node_id: display_name.clone(),
1111 branch_count,
1112 span_id: parent_span_id,
1113 },
1114 },
1115 )
1116 .await
1117 {
1118 return Err(GraphError::Terminal(TerminalError::InvalidGraph(
1119 "consumer dropped during parallel execution".into(),
1120 )));
1121 }
1122
1123 let parallel_start = Instant::now();
1124
1125 let mut handles = Vec::with_capacity(branch_count);
1127 for (branch_name, branch_node) in parallel.branches_iter() {
1128 let state_copy = state.clone();
1129 let branch_node = branch_node.clone();
1130 let name = branch_name.to_string();
1131
1132 let handle = tokio::spawn(async move {
1133 let branch_start = Instant::now();
1134 let result = branch_node.execute(&state_copy).await;
1137 let branch_end = Instant::now();
1138 (name, result, branch_end.duration_since(branch_start))
1139 });
1140
1141 handles.push(handle);
1142 }
1143
1144 let mut all_deltas: Vec<StateDelta> = Vec::new();
1146 let mut first_error: Option<GraphError> = None;
1147 let mut any_failure = false;
1148
1149 for handle in handles {
1150 let (branch_name, result, branch_duration) = match handle.await {
1151 Ok(res) => res,
1152 Err(join_err) => {
1153 let err = GraphError::Terminal(TerminalError::NodeExecutionFailed {
1154 node: format!("{}/{}", display_name, "<unknown>"),
1155 source: join_err.into(),
1156 });
1157 let _ = self
1159 .send(
1160 event_tx,
1161 GraphEvent::Node {
1162 span_id: parent_span_id,
1163 node_name: node_name.to_string(),
1164 event: FlowEvent::BranchCompleted {
1165 branch_name: "<unknown>".to_string(),
1166 node_id: display_name.clone(),
1167 span_id: SpanId::new(),
1168 success: false,
1169 duration: std::time::Duration::ZERO,
1170 },
1171 },
1172 )
1173 .await;
1174
1175 if matches!(error_strategy, ParallelErrorStrategy::FailFast) {
1176 return Err(err);
1177 }
1178 first_error.get_or_insert(err);
1179 any_failure = true;
1180 continue;
1181 }
1182 };
1183
1184 match result {
1185 Ok(output) => {
1186 all_deltas.extend(output.deltas);
1187
1188 let _ = self
1190 .send(
1191 event_tx,
1192 GraphEvent::Node {
1193 span_id: parent_span_id,
1194 node_name: node_name.to_string(),
1195 event: FlowEvent::BranchCompleted {
1196 branch_name: branch_name.clone(),
1197 node_id: display_name.clone(),
1198 span_id: SpanId::new(),
1199 success: true,
1200 duration: branch_duration,
1201 },
1202 },
1203 )
1204 .await;
1205 }
1206 Err(e) => {
1207 let _ = self
1209 .send(
1210 event_tx,
1211 GraphEvent::Node {
1212 span_id: parent_span_id,
1213 node_name: node_name.to_string(),
1214 event: FlowEvent::BranchCompleted {
1215 branch_name: branch_name.clone(),
1216 node_id: display_name.clone(),
1217 span_id: SpanId::new(),
1218 success: false,
1219 duration: branch_duration,
1220 },
1221 },
1222 )
1223 .await;
1224
1225 if matches!(error_strategy, ParallelErrorStrategy::FailFast) {
1226 return Err(e);
1227 }
1228 first_error.get_or_insert(e);
1229 any_failure = true;
1230 }
1231 }
1232 }
1233
1234 if any_failure {
1236 return Err(first_error.unwrap());
1237 }
1238
1239 let parallel_duration = parallel_start.elapsed();
1245 let _ = self
1246 .send(
1247 event_tx,
1248 GraphEvent::Node {
1249 span_id: parent_span_id,
1250 node_name: node_name.to_string(),
1251 event: FlowEvent::ParallelCompleted {
1252 node_id: display_name,
1253 span_id: parent_span_id,
1254 duration: parallel_duration,
1255 },
1256 },
1257 )
1258 .await;
1259
1260 Ok(StreamNodeResult::Continue {
1261 deltas: all_deltas,
1262 next: NextStep::GoToNext,
1263 span_id: parent_span_id,
1264 observed: None,
1265 metadata: None,
1266 })
1267 }
1268
1269 async fn apply_deltas(
1273 &self,
1274 event_tx: &mpsc::Sender<GraphEvent>,
1275 reducer_registry: &mut ReducerRegistry,
1276 state: &mut State,
1277 node_name: &str,
1278 deltas: &[StateDelta],
1279 ) {
1280 for delta in deltas {
1281 if let Err(e) = reducer_registry.apply_delta(state, delta) {
1283 tracing::warn!(
1284 node = %node_name,
1285 key = %delta.key,
1286 error = %e,
1287 "failed to apply state delta"
1288 );
1289 }
1290
1291 let _ = self
1293 .send(
1294 event_tx,
1295 GraphEvent::Node {
1296 span_id: SpanId::new(), node_name: node_name.to_string(),
1298 event: FlowEvent::StateChanged {
1299 node_id: node_name.to_string(),
1300 delta: delta.clone(),
1301 },
1302 },
1303 )
1304 .await;
1305 }
1306 }
1307
1308 async fn send(&self, event_tx: &mpsc::Sender<GraphEvent>, event: GraphEvent) -> bool {
1310 event_tx.send(event).await.is_err()
1311 }
1312
1313 async fn send_graph_error(
1315 &self,
1316 event_tx: &mpsc::Sender<GraphEvent>,
1317 error: GraphError,
1318 state: &State,
1319 _execution_log: &Vec<ExecutionEntry>,
1320 _start_time: Instant,
1321 _trace_id: TraceId,
1322 ) {
1323 let _ = self
1324 .send(
1325 event_tx,
1326 GraphEvent::GraphError {
1327 error,
1328 state: state.clone(),
1329 },
1330 )
1331 .await;
1332 }
1333
1334 async fn send_graph_complete(
1336 &self,
1337 event_tx: &mpsc::Sender<GraphEvent>,
1338 state: &State,
1339 execution_log: &[ExecutionEntry],
1340 start_time: Instant,
1341 trace_id: TraceId,
1342 snapshot_state: &mut IncrementalSnapshotState,
1343 ) {
1344 if self.policy.should_checkpoint_on_completion()
1346 && let Some(store) = &self.store
1347 {
1348 let (base, deltas, current) = snapshot_state.snapshot(state);
1350 let ck = if let Some(base_state) = base {
1351 Checkpoint::with_snapshot(
1352 trace_id,
1353 &self.graph_hash,
1354 "__complete__",
1355 current,
1356 base_state,
1357 deltas,
1358 )
1359 } else if !deltas.is_empty() {
1360 Checkpoint::with_snapshot(
1361 trace_id,
1362 &self.graph_hash,
1363 "__complete__",
1364 current.clone(),
1365 current,
1366 deltas,
1367 )
1368 } else {
1369 Checkpoint::new(trace_id, &self.graph_hash, "__complete__", state.clone())
1370 };
1371
1372 match store.save(&ck).await {
1373 Ok(()) => {
1374 let _ = self
1375 .send(
1376 event_tx,
1377 GraphEvent::CheckpointSaved {
1378 checkpoint_id: ck.checkpoint_id.clone(),
1379 node_name: "__complete__".to_string(),
1380 step: execution_log.len(),
1381 },
1382 )
1383 .await;
1384 tracing::debug!(
1385 checkpoint = %ck.checkpoint_id,
1386 "final checkpoint saved on completion"
1387 );
1388 }
1389 Err(e) => {
1390 tracing::warn!(error = %e, "final checkpoint save failed");
1391 }
1392 }
1393 }
1394
1395 let _ = self
1396 .send(
1397 event_tx,
1398 GraphEvent::GraphComplete {
1399 result: GraphResult {
1400 trace_id,
1401 state: state.clone(),
1402 execution_log: execution_log.to_vec(),
1403 duration: start_time.elapsed(),
1404 },
1405 },
1406 )
1407 .await;
1408 }
1409
1410 async fn wait_barrier_decision(
1414 &self,
1415 decision_rx: &mut mpsc::Receiver<BarrierDecisionMessage>,
1416 registry: &mut DecisionRegistry,
1417 target_id: &BarrierId,
1418 timeout: Option<std::time::Duration>,
1419 default_action: &BarrierDefaultAction,
1420 cancel_rx: &mut mpsc::Receiver<()>,
1421 ) -> BarrierDecision {
1422 if let Some(decision) = registry.take(target_id) {
1423 return decision;
1424 }
1425
1426 while let Ok(msg) = decision_rx.try_recv() {
1427 if let Some(decision) = registry.process_message(msg, target_id) {
1428 return decision;
1429 }
1430 }
1431
1432 if cancel_rx.try_recv().is_ok() {
1433 return Self::default_decision(default_action);
1434 }
1435
1436 if let Some(timeout) = timeout {
1437 let start = std::time::Instant::now();
1438 loop {
1439 match tokio::time::timeout(std::time::Duration::from_millis(50), decision_rx.recv())
1440 .await
1441 {
1442 Ok(Some(msg)) => {
1443 if let Some(decision) = registry.process_message(msg, target_id) {
1444 return decision;
1445 }
1446 }
1447 Ok(None) => return Self::default_decision(default_action),
1448 Err(_) => {}
1449 }
1450 if cancel_rx.try_recv().is_ok() {
1451 return Self::default_decision(default_action);
1452 }
1453 if start.elapsed() >= timeout {
1454 return Self::default_decision(default_action);
1455 }
1456 }
1457 } else {
1458 loop {
1459 if let Some(msg) = decision_rx.recv().await {
1460 if let Some(decision) = registry.process_message(msg, target_id) {
1461 return decision;
1462 }
1463 } else {
1464 return Self::default_decision(default_action);
1465 }
1466 if cancel_rx.try_recv().is_ok() {
1467 return Self::default_decision(default_action);
1468 }
1469 }
1470 }
1471 }
1472
1473 fn default_decision(action: &BarrierDefaultAction) -> BarrierDecision {
1474 match action {
1475 BarrierDefaultAction::Approve => BarrierDecision::Approve,
1476 BarrierDefaultAction::Reject => BarrierDecision::Reject {
1477 reason: "timeout — no decision received".into(),
1478 },
1479 BarrierDefaultAction::Skip => BarrierDecision::Approve,
1480 }
1481 }
1482
1483 fn resolve_next(
1487 &self,
1488 graph: &Graph,
1489 current: &str,
1490 state: &mut State,
1491 next: NextStep,
1492 ) -> Result<String, GraphError> {
1493 match next {
1494 NextStep::Goto(target) => {
1495 graph.find_edge(current, &target).ok_or_else(|| {
1496 GraphError::Terminal(TerminalError::MissingEdge {
1497 from: current.to_string(),
1498 to: target.clone(),
1499 })
1500 })?;
1501 Ok(target)
1502 }
1503 NextStep::GoToNext => Self::find_next_node(graph, current, state),
1504 NextStep::End => Err(GraphError::Terminal(TerminalError::InvalidGraph(
1505 "unexpected End next step".into(),
1506 ))),
1507 }
1508 }
1509
1510 fn find_next_node(graph: &Graph, current: &str, state: &State) -> Result<String, GraphError> {
1512 let edges = graph.edges_from(current);
1513
1514 if edges.is_empty() {
1515 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
1516 "node '{}' has no outgoing edges and is not the end node",
1517 current
1518 ))));
1519 }
1520
1521 for edge in &edges {
1523 if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
1524 return Ok(edge.to.clone());
1525 }
1526 }
1527
1528 for edge in &edges {
1530 if edge.is_normal() {
1531 return Ok(edge.to.clone());
1532 }
1533 }
1534
1535 for edge in &edges {
1537 if edge.fallback {
1538 return Ok(edge.to.clone());
1539 }
1540 }
1541
1542 let attempted: Vec<crate::error::ConditionEval> = edges
1544 .iter()
1545 .map(|e| crate::error::ConditionEval {
1546 edge: format!("{}→{}", e.from, e.to),
1547 condition: e.condition.as_ref().map(|_| "condition".to_string()),
1548 matched: e.condition.as_ref().is_some_and(|c| c(state)),
1549 })
1550 .collect();
1551
1552 Err(GraphError::Terminal(TerminalError::Unrouted {
1553 node: current.to_string(),
1554 attempted_conditions: attempted,
1555 }))
1556 }
1557
1558 #[allow(clippy::too_many_arguments)]
1568 async fn save_checkpoint_if_needed(
1569 &self,
1570 event_tx: &mpsc::Sender<GraphEvent>,
1571 trace_id: &TraceId,
1572 next_node: &str,
1573 state: &State,
1574 step: usize,
1575 trigger: CheckpointTrigger,
1576 snapshot_state: &mut IncrementalSnapshotState,
1577 ) {
1578 let should_save = match &trigger {
1580 CheckpointTrigger::BarrierResolved => self.policy.should_checkpoint_on_barrier(),
1581 CheckpointTrigger::ExecutionCompleted => self.policy.should_checkpoint_on_completion(),
1582 CheckpointTrigger::HumanDecision => self.policy.should_checkpoint_on_human_decision(),
1583 CheckpointTrigger::Explicit => self.policy.should_checkpoint_on_explicit(),
1584 CheckpointTrigger::Adaptive(metadata) => {
1585 self.checkpoint_score.should_checkpoint(metadata)
1587 }
1588 };
1589
1590 if !should_save {
1591 return;
1592 }
1593
1594 let store = match &self.store {
1595 Some(s) => s,
1596 None => return,
1597 };
1598
1599 let (base, deltas, current) = snapshot_state.snapshot(state);
1601 let ck = if let Some(base_state) = base {
1602 Checkpoint::with_snapshot(
1604 *trace_id,
1605 &self.graph_hash,
1606 next_node,
1607 current,
1608 base_state,
1609 deltas,
1610 )
1611 } else if !deltas.is_empty() {
1612 Checkpoint::with_snapshot(
1614 *trace_id,
1615 &self.graph_hash,
1616 next_node,
1617 current.clone(),
1618 current,
1619 deltas,
1620 )
1621 } else {
1622 Checkpoint::new(*trace_id, &self.graph_hash, next_node, state.clone())
1624 };
1625
1626 match store.save(&ck).await {
1627 Ok(()) => {
1628 snapshot_state.base_state = Some(state.clone());
1630 snapshot_state.clear_pending();
1631
1632 let _ = self
1633 .send(
1634 event_tx,
1635 GraphEvent::CheckpointSaved {
1636 checkpoint_id: ck.checkpoint_id.clone(),
1637 node_name: next_node.to_string(),
1638 step,
1639 },
1640 )
1641 .await;
1642 tracing::debug!(
1643 checkpoint = %ck.checkpoint_id,
1644 node = %next_node,
1645 step,
1646 "checkpoint saved"
1647 );
1648 }
1649 Err(e) => {
1650 tracing::warn!(error = %e, node = %next_node, step, "checkpoint save failed");
1651 }
1652 }
1653 }
1654
1655 pub async fn resume_from(
1670 &self,
1671 store: &dyn CheckpointStore,
1672 trace_id: &TraceId,
1673 graph: &std::sync::Arc<Graph>,
1674 ) -> Result<GraphExecution, GraphError> {
1675 let checkpoint = store
1677 .load_latest(trace_id)
1678 .await
1679 .map_err(|e| {
1680 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1681 "failed to load checkpoint: {}",
1682 e
1683 )))
1684 })?
1685 .ok_or_else(|| {
1686 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1687 "no checkpoint found for trace {}",
1688 trace_id
1689 )))
1690 })?;
1691
1692 let current_hash = graph.hash();
1694 if checkpoint.graph_hash != current_hash {
1695 tracing::warn!(
1696 saved_hash = %checkpoint.graph_hash,
1697 current_hash = %current_hash,
1698 "graph structure has changed since checkpoint — resuming anyway (Force mode)",
1699 );
1700 }
1703
1704 let executor = Self {
1706 max_steps: self.max_steps,
1707 store: self.store.clone(),
1708 policy: self.policy.clone(),
1709 graph_hash: current_hash,
1710 pending_reducers: self.pending_reducers.clone(),
1711 checkpoint_score: self.checkpoint_score.clone(),
1712 last_checkpoint_state: Some(checkpoint.restore_state_simple()),
1716 pending_deltas: Vec::new(),
1717 delta_compact_threshold: self.delta_compact_threshold,
1718 };
1719
1720 let initial_state = checkpoint.state.clone();
1722
1723 let execution = executor.execute_stream(graph.clone(), initial_state);
1727
1728 Ok(execution)
1732 }
1733}