1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio::sync::mpsc;
11
12use crate::barrier_node::BarrierDefaultAction;
13use crate::branch_state::BranchState;
14use crate::checkpoint::{Checkpoint, CheckpointPolicy, CheckpointStore, TraceId};
15use crate::delta::ReducerRegistry;
16use crate::error::{GraphError, TerminalError};
17use crate::event::{
18 BarrierDecision, BarrierDecisionMessage, BarrierId, FlowEvent, GraphEvent, GraphExecution,
19 GraphHandle,
20};
21use crate::graph::Graph;
22use crate::ids::SpanId;
23use crate::node::{FlowNode, NodeKind};
24use crate::node_context::{ExecutionSignal, NextAction, NodeContext, NodeMetadata};
25use crate::runtime_event::RuntimeEvent;
26use crate::state::{ExecutionEntry, GraphResult, State};
27use crate::stream_emitter::StreamEmitter;
28use crate::workflow_state::WorkflowState;
29
30struct RunLoopContext {
34 graph: Arc<Graph>,
35 initial_state: State,
36 event_tx: mpsc::Sender<GraphEvent>,
37 decision_rx: mpsc::Receiver<BarrierDecisionMessage>,
38 cancel_rx: mpsc::Receiver<()>,
39 start_node: Option<String>,
40 trace_id: Option<TraceId>,
41}
42
43struct BarrierHandlerContext<'a> {
47 event_tx: &'a mpsc::Sender<GraphEvent>,
48 graph: &'a Graph,
49 decision_rx: &'a mut mpsc::Receiver<BarrierDecisionMessage>,
50 decision_registry: &'a mut DecisionRegistry,
51 cancel_rx: &'a mut mpsc::Receiver<()>,
52 node: &'a NodeKind,
53 current: &'a str,
54 state: &'a mut State,
55 execution_log: &'a mut Vec<ExecutionEntry>,
56 barrier_name: &'a str,
57 barrier_id: BarrierId,
58 timeout: Option<std::time::Duration>,
59 step: usize,
60 node_start: Instant,
61 trace_id: TraceId,
62}
63
64#[allow(dead_code)]
67struct DecisionRegistry {
68 pending: HashMap<BarrierId, BarrierDecision>,
69 wildcards: HashMap<String, BarrierDecision>,
70 occurrence_counter: HashMap<String, u32>,
71}
72
73impl DecisionRegistry {
74 fn new() -> Self {
75 Self {
76 pending: HashMap::new(),
77 wildcards: HashMap::new(),
78 occurrence_counter: HashMap::new(),
79 }
80 }
81
82 #[allow(dead_code)]
83 fn next_id(&mut self, node_id: &str) -> BarrierId {
84 let occ = self
85 .occurrence_counter
86 .entry(node_id.to_string())
87 .or_insert(0);
88 *occ += 1;
89 BarrierId::new(node_id, *occ)
90 }
91
92 fn take(&mut self, target_id: &BarrierId) -> Option<BarrierDecision> {
93 if let Some(decision) = self.pending.remove(target_id) {
94 return Some(decision);
95 }
96 self.wildcards.get(&target_id.node_id).cloned()
97 }
98
99 fn process_message(
100 &mut self,
101 msg: BarrierDecisionMessage,
102 target_id: &BarrierId,
103 ) -> Option<BarrierDecision> {
104 match msg {
105 BarrierDecisionMessage::Exact {
106 barrier_id,
107 decision,
108 } => {
109 if barrier_id == *target_id {
110 Some(decision)
111 } else {
112 self.pending.insert(barrier_id, decision);
113 None
114 }
115 }
116 BarrierDecisionMessage::Wildcard { node_id, decision } => {
117 self.wildcards.insert(node_id.clone(), decision.clone());
118 if node_id == target_id.node_id {
119 Some(decision)
120 } else {
121 None
122 }
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
131enum StepOutcome {
132 Continue(String),
133 Break,
134 ErrorSent,
135}
136
137pub struct GraphExecutor {
141 pub max_steps: usize,
142 store: Option<Arc<dyn CheckpointStore>>,
143 policy: CheckpointPolicy,
144 graph_hash: String,
145 pending_reducers: Vec<(String, crate::delta::Reducer)>,
146}
147
148impl Clone for GraphExecutor {
149 fn clone(&self) -> Self {
150 Self {
151 max_steps: self.max_steps,
152 store: self.store.clone(),
153 policy: self.policy,
154 graph_hash: self.graph_hash.clone(),
155 pending_reducers: self.pending_reducers.clone(),
156 }
157 }
158}
159
160impl std::fmt::Debug for GraphExecutor {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct("GraphExecutor")
163 .field("max_steps", &self.max_steps)
164 .field("has_store", &self.store.is_some())
165 .field("policy", &self.policy)
166 .finish()
167 }
168}
169
170impl Default for GraphExecutor {
171 fn default() -> Self {
172 Self {
173 max_steps: 50,
174 store: None,
175 policy: CheckpointPolicy::default(),
176 graph_hash: String::new(),
177 pending_reducers: Vec::new(),
178 }
179 }
180}
181
182impl GraphExecutor {
183 pub fn new(max_steps: usize) -> Self {
184 Self {
185 max_steps,
186 ..Default::default()
187 }
188 }
189
190 pub fn with_checkpoint(
191 max_steps: usize,
192 store: Arc<dyn CheckpointStore>,
193 policy: CheckpointPolicy,
194 graph: &Graph,
195 ) -> Self {
196 Self {
197 max_steps,
198 store: Some(store),
199 policy,
200 graph_hash: graph.hash(),
201 ..Default::default()
202 }
203 }
204
205 pub fn register_reducer(&mut self, key: &str, reducer: crate::delta::Reducer) {
206 self.pending_reducers.push((key.to_string(), reducer));
207 }
208
209 pub fn set_graph(&mut self, graph: &Graph) {
210 self.graph_hash = graph.hash();
211 }
212
213 pub async fn execute(
216 &self,
217 graph: Arc<Graph>,
218 initial_state: State,
219 ) -> Result<GraphResult, GraphError> {
220 let GraphExecution { mut stream, handle } = self.execute_stream(graph, initial_state);
221 drop(handle);
222
223 let mut result = None;
224 while let Some(event) = stream.recv().await {
225 match event {
226 GraphEvent::GraphComplete { result: r } => result = Some(Ok(r)),
227 GraphEvent::GraphError { error, .. } => result = Some(Err(error)),
228 _ => {}
229 }
230 }
231 result.unwrap_or_else(|| {
232 Err(GraphError::Terminal(TerminalError::InvalidGraph(
233 "stream ended without completion".into(),
234 )))
235 })
236 }
237
238 pub fn execute_stream(&self, graph: Arc<Graph>, initial_state: State) -> GraphExecution {
241 let executor = self.clone();
242 let (event_tx, event_rx) = mpsc::channel(32);
243 let (decision_tx, decision_rx) = mpsc::channel(16);
244 let (cancel_tx, cancel_rx) = mpsc::channel(1);
245
246 let handle = GraphHandle::new(decision_tx, cancel_tx);
247
248 tokio::spawn(async move {
249 executor
250 .run_loop(RunLoopContext {
251 graph,
252 initial_state,
253 event_tx,
254 decision_rx,
255 cancel_rx,
256 start_node: None,
257 trace_id: None,
258 })
259 .await;
260 });
261
262 GraphExecution {
263 stream: event_rx,
264 handle,
265 }
266 }
267
268 fn execute_stream_with(
270 &self,
271 graph: Arc<Graph>,
272 initial_state: State,
273 start_node: Option<String>,
274 trace_id: Option<TraceId>,
275 ) -> GraphExecution {
276 let executor = self.clone();
277 let (event_tx, event_rx) = mpsc::channel(32);
278 let (decision_tx, decision_rx) = mpsc::channel(16);
279 let (cancel_tx, cancel_rx) = mpsc::channel(1);
280
281 let handle = GraphHandle::new(decision_tx, cancel_tx);
282
283 tokio::spawn(async move {
284 executor
285 .run_loop(RunLoopContext {
286 graph,
287 initial_state,
288 event_tx,
289 decision_rx,
290 cancel_rx,
291 start_node,
292 trace_id,
293 })
294 .await;
295 });
296
297 GraphExecution {
298 stream: event_rx,
299 handle,
300 }
301 }
302
303 async fn run_loop(&self, ctx: RunLoopContext) {
306 let RunLoopContext {
307 graph,
308 initial_state,
309 event_tx,
310 mut decision_rx,
311 mut cancel_rx,
312 start_node,
313 trace_id,
314 } = ctx;
315 let start_time = Instant::now();
316 let mut state = initial_state;
317 let mut execution_log = Vec::new();
318 let mut decision_registry = DecisionRegistry::new();
319 let mut _reducer_registry = ReducerRegistry::new();
320
321 for (key, reducer) in &self.pending_reducers {
322 _reducer_registry.register(key, *reducer);
323 }
324
325 let mut current = start_node.unwrap_or_else(|| graph.start_node().to_string());
326 let mut step: usize = 0;
327 let trace_id = trace_id.unwrap_or_default();
328
329 self.emit_runtime(
331 &event_tx,
332 RuntimeEvent::ExecutionStarted {
333 trace_id,
334 graph_name: graph.name().to_string(),
335 },
336 )
337 .await;
338 if self
339 .send(&event_tx, GraphEvent::GraphStart { trace_id })
340 .await
341 {
342 return;
343 }
344
345 loop {
346 if cancel_rx.try_recv().is_ok() {
348 self.send_graph_error(
349 &event_tx,
350 GraphError::Terminal(TerminalError::BarrierCancelled {
351 node: "execution cancelled by handle".into(),
352 }),
353 &state,
354 start_time,
355 trace_id,
356 )
357 .await;
358 return;
359 }
360
361 step += 1;
362
363 if step > self.max_steps {
365 self.send_graph_error(
366 &event_tx,
367 GraphError::Terminal(TerminalError::StepsExceeded {
368 limit: self.max_steps,
369 }),
370 &state,
371 start_time,
372 trace_id,
373 )
374 .await;
375 return;
376 }
377
378 let node = match graph.nodes.get(¤t) {
380 Some(n) => n,
381 None => {
382 self.send_graph_error(
383 &event_tx,
384 GraphError::Terminal(TerminalError::NodeNotFound(current.clone())),
385 &state,
386 start_time,
387 trace_id,
388 )
389 .await;
390 return;
391 }
392 };
393
394 let node_name = current.clone();
395 let span_id = SpanId::new();
396
397 self.emit_runtime(
399 &event_tx,
400 RuntimeEvent::NodeStarted {
401 node_name: node_name.clone(),
402 trace_id,
403 span_id,
404 step,
405 },
406 )
407 .await;
408 if self
409 .send(
410 &event_tx,
411 GraphEvent::NodeStart {
412 node_name: node_name.clone(),
413 trace_id,
414 span_id,
415 step,
416 },
417 )
418 .await
419 {
420 return;
421 }
422
423 let node_start = Instant::now();
424
425 let exec_result = self
427 .execute_node(node, &mut state, &node_name, span_id)
428 .await;
429 let node_end = Instant::now();
430 let duration = node_end.duration_since(node_start);
431
432 match exec_result {
433 Ok((next_action, signal, metadata, flow_events)) => {
434 execution_log.push(ExecutionEntry {
436 step,
437 node_name: node_name.clone(),
438 start_time: node_start,
439 end_time: node_end,
440 success: true,
441 error: None,
442 });
443
444 for flow_event in flow_events {
446 if self
447 .send(
448 &event_tx,
449 GraphEvent::Node {
450 span_id,
451 node_name: node_name.clone(),
452 event: flow_event,
453 },
454 )
455 .await
456 {
457 return;
458 }
459 }
460
461 if self.should_checkpoint(CheckpointPolicyTrigger::NodeExecuted {
463 has_side_effects: metadata.has_side_effects,
464 }) {
465 self.save_checkpoint(&event_tx, &trace_id, ¤t, &state, step)
466 .await;
467 }
468
469 self.emit_runtime(
471 &event_tx,
472 RuntimeEvent::NodeCompleted {
473 node_name: node_name.clone(),
474 trace_id,
475 span_id,
476 duration,
477 },
478 )
479 .await;
480 if self
481 .send(
482 &event_tx,
483 GraphEvent::NodeEnd {
484 node_name: node_name.clone(),
485 trace_id,
486 span_id,
487 success: true,
488 duration,
489 },
490 )
491 .await
492 {
493 return;
494 }
495
496 if let Some(signal) = signal {
498 match signal {
499 ExecutionSignal::Pause {
500 barrier_id,
501 timeout,
502 } => {
503 let outcome = self
504 .handle_barrier_signal(BarrierHandlerContext {
505 event_tx: &event_tx,
506 graph: &graph,
507 decision_rx: &mut decision_rx,
508 decision_registry: &mut decision_registry,
509 cancel_rx: &mut cancel_rx,
510 node,
511 current: ¤t,
512 state: &mut state,
513 execution_log: &mut execution_log,
514 barrier_name: &node_name,
515 barrier_id,
516 timeout,
517 step,
518 node_start,
519 trace_id,
520 })
521 .await;
522 match outcome {
523 StepOutcome::Continue(target) => {
524 if self.should_checkpoint(
526 CheckpointPolicyTrigger::BarrierResolved,
527 ) {
528 self.save_checkpoint(
529 &event_tx, &trace_id, &target, &state, step,
530 )
531 .await;
532 }
533 current = target;
534 }
535 StepOutcome::Break => {
536 self.send_graph_complete(
537 &event_tx,
538 &state,
539 &execution_log,
540 start_time,
541 trace_id,
542 )
543 .await;
544 return;
545 }
546 StepOutcome::ErrorSent => return,
547 }
548 continue;
549 }
550 }
551 }
552
553 let outcome = match next_action {
555 NextAction::End => StepOutcome::Break,
556 NextAction::Goto(target) => StepOutcome::Continue(target),
557 NextAction::Next => {
558 if current == graph.end_node() {
560 StepOutcome::Break
561 } else {
562 match self.resolve_next(&graph, ¤t, &state) {
563 Ok(target) => StepOutcome::Continue(target),
564 Err(e) => {
565 self.send_graph_error(
566 &event_tx, e, &state, start_time, trace_id,
567 )
568 .await;
569 StepOutcome::ErrorSent
570 }
571 }
572 }
573 }
574 };
575
576 match outcome {
577 StepOutcome::Continue(target) => {
578 current = target;
579 }
580 StepOutcome::Break => {
581 if self.should_checkpoint(CheckpointPolicyTrigger::GraphComplete) {
583 self.save_checkpoint(
584 &event_tx,
585 &trace_id,
586 "__complete__",
587 &state,
588 step,
589 )
590 .await;
591 }
592
593 self.send_graph_complete(
594 &event_tx,
595 &state,
596 &execution_log,
597 start_time,
598 trace_id,
599 )
600 .await;
601 return;
602 }
603 StepOutcome::ErrorSent => return,
604 }
605 }
606 Err(e) => {
607 let error_str = e.to_string();
609 execution_log.push(ExecutionEntry {
610 step,
611 node_name: node_name.clone(),
612 start_time: node_start,
613 end_time: node_end,
614 success: false,
615 error: Some(error_str.clone()),
616 });
617
618 self.emit_runtime(
619 &event_tx,
620 RuntimeEvent::NodeFailed {
621 node_name: node_name.clone(),
622 trace_id,
623 span_id,
624 error: e.to_string(),
625 },
626 )
627 .await;
628 if self
629 .send(
630 &event_tx,
631 GraphEvent::NodeEnd {
632 node_name: node_name.clone(),
633 trace_id,
634 span_id,
635 success: false,
636 duration,
637 },
638 )
639 .await
640 {
641 return;
642 }
643
644 self.send_graph_error(&event_tx, e, &state, start_time, trace_id)
645 .await;
646 return;
647 }
648 }
649 }
650 }
651
652 async fn execute_node(
656 &self,
657 node: &NodeKind,
658 state: &mut State,
659 _node_name: &str,
660 _span_id: SpanId,
661 ) -> Result<
662 (
663 NextAction,
664 Option<ExecutionSignal>,
665 NodeMetadata,
666 Vec<FlowEvent>,
667 ),
668 GraphError,
669 > {
670 let mut branch = BranchState::from_state(state.clone());
671 let (tx, _rx) = mpsc::channel(64);
672 let emitter = StreamEmitter::new(tx);
673 let mut ctx = NodeContext::new(state, &mut branch, Some(&emitter));
674
675 node.execute(&mut ctx).await?;
676
677 let effects = ctx.consume_effects();
678 let (next_action, signal) = ctx.take_control();
679 let metadata = ctx.take_metadata();
680 let flow_events = ctx.take_flow_events();
681
682 state.apply_batch(effects);
683
684 Ok((next_action, signal, metadata, flow_events))
685 }
686
687 async fn handle_barrier_signal(&self, ctx: BarrierHandlerContext<'_>) -> StepOutcome {
690 let BarrierHandlerContext {
691 event_tx,
692 graph,
693 decision_rx,
694 decision_registry,
695 cancel_rx,
696 node,
697 current,
698 state,
699 execution_log,
700 barrier_name,
701 barrier_id,
702 timeout,
703 step,
704 node_start,
705 trace_id,
706 } = ctx;
707 self.emit_runtime(
708 event_tx,
709 RuntimeEvent::BarrierWaiting {
710 barrier_id: barrier_id.clone(),
711 node_name: barrier_name.to_string(),
712 span_id: SpanId::new(),
713 },
714 )
715 .await;
716 if self
717 .send(
718 event_tx,
719 GraphEvent::BarrierWaiting {
720 barrier_id: barrier_id.clone(),
721 node_name: barrier_name.to_string(),
722 span_id: SpanId::new(),
723 },
724 )
725 .await
726 {
727 return StepOutcome::Break;
728 }
729
730 let decision = self
731 .wait_barrier_decision(
732 decision_rx,
733 decision_registry,
734 &barrier_id,
735 timeout,
736 cancel_rx,
737 )
738 .await;
739
740 if cancel_rx.try_recv().is_ok() {
741 self.send_graph_error(
742 event_tx,
743 GraphError::Terminal(TerminalError::BarrierCancelled {
744 node: barrier_name.to_string(),
745 }),
746 state,
747 node_start,
748 trace_id,
749 )
750 .await;
751 return StepOutcome::ErrorSent;
752 }
753
754 self.emit_runtime(
755 event_tx,
756 RuntimeEvent::BarrierResolved {
757 barrier_id: barrier_id.clone(),
758 },
759 )
760 .await;
761 if self
762 .send(
763 event_tx,
764 GraphEvent::BarrierResolved {
765 barrier_id: barrier_id.clone(),
766 decision: decision.clone(),
767 },
768 )
769 .await
770 {
771 return StepOutcome::Break;
772 }
773
774 match node {
775 NodeKind::Barrier(b) => {
776 let mut branch = BranchState::from_state(state.clone());
777 let mut ctx = NodeContext::new(state, &mut branch, None);
778 b.apply_decision_to_ctx(&mut ctx, decision);
779 let (next, _signal) = ctx.take_control();
780
781 let effects = ctx.consume_effects();
782 state.apply_batch(effects);
783
784 let end_time = Instant::now();
785 execution_log.push(ExecutionEntry {
786 step,
787 node_name: barrier_name.to_string(),
788 start_time: node_start,
789 end_time,
790 success: true,
791 error: None,
792 });
793
794 if self
795 .send(
796 event_tx,
797 GraphEvent::NodeEnd {
798 node_name: barrier_name.to_string(),
799 trace_id,
800 span_id: SpanId::new(),
801 success: true,
802 duration: end_time.duration_since(node_start),
803 },
804 )
805 .await
806 {
807 return StepOutcome::Break;
808 }
809
810 if current == graph.end_node() {
811 return StepOutcome::Break;
812 }
813
814 match next {
815 NextAction::End => StepOutcome::Break,
816 NextAction::Goto(target) => StepOutcome::Continue(target),
817 NextAction::Next => match self.resolve_next(graph, current, state) {
818 Ok(target) => StepOutcome::Continue(target),
819 Err(e) => {
820 self.send_graph_error(event_tx, e, state, end_time, trace_id)
821 .await;
822 StepOutcome::ErrorSent
823 }
824 },
825 }
826 }
827 _ => {
828 self.send_graph_error(
829 event_tx,
830 GraphError::Terminal(TerminalError::InvalidGraph(
831 "expected BarrierNode for pause signal".into(),
832 )),
833 state,
834 node_start,
835 trace_id,
836 )
837 .await;
838 StepOutcome::ErrorSent
839 }
840 }
841 }
842
843 async fn emit_runtime(
846 &self,
847 _event_tx: &mpsc::Sender<GraphEvent>,
848 runtime_event: RuntimeEvent,
849 ) {
850 tracing::debug!(?runtime_event, "runtime event");
851 }
852
853 async fn send(&self, event_tx: &mpsc::Sender<GraphEvent>, event: GraphEvent) -> bool {
854 event_tx.send(event).await.is_err()
855 }
856
857 async fn send_graph_error(
858 &self,
859 event_tx: &mpsc::Sender<GraphEvent>,
860 error: GraphError,
861 state: &State,
862 _start_time: Instant,
863 _trace_id: TraceId,
864 ) {
865 let _ = self
866 .send(
867 event_tx,
868 GraphEvent::GraphError {
869 error,
870 state: state.clone(),
871 },
872 )
873 .await;
874 }
875
876 async fn send_graph_complete(
877 &self,
878 event_tx: &mpsc::Sender<GraphEvent>,
879 state: &State,
880 execution_log: &[ExecutionEntry],
881 start_time: Instant,
882 trace_id: TraceId,
883 ) {
884 self.emit_runtime(
885 event_tx,
886 RuntimeEvent::ExecutionCompleted {
887 trace_id,
888 duration: start_time.elapsed(),
889 },
890 )
891 .await;
892
893 let _ = self
894 .send(
895 event_tx,
896 GraphEvent::GraphComplete {
897 result: GraphResult {
898 trace_id,
899 state: state.clone(),
900 execution_log: execution_log.to_vec(),
901 duration: start_time.elapsed(),
902 },
903 },
904 )
905 .await;
906 }
907
908 async fn wait_barrier_decision(
911 &self,
912 decision_rx: &mut mpsc::Receiver<BarrierDecisionMessage>,
913 registry: &mut DecisionRegistry,
914 target_id: &BarrierId,
915 timeout: Option<std::time::Duration>,
916 cancel_rx: &mut mpsc::Receiver<()>,
917 ) -> BarrierDecision {
918 if let Some(decision) = registry.take(target_id) {
919 return decision;
920 }
921
922 while let Ok(msg) = decision_rx.try_recv() {
923 if let Some(decision) = registry.process_message(msg, target_id) {
924 return decision;
925 }
926 }
927
928 if cancel_rx.try_recv().is_ok() {
929 return BarrierDecision::Reject {
930 reason: "cancelled".into(),
931 };
932 }
933
934 let default_action = BarrierDefaultAction::Reject;
935
936 if let Some(timeout) = timeout {
937 let start = Instant::now();
938 loop {
939 match tokio::time::timeout(std::time::Duration::from_millis(50), decision_rx.recv())
940 .await
941 {
942 Ok(Some(msg)) => {
943 if let Some(decision) = registry.process_message(msg, target_id) {
944 return decision;
945 }
946 }
947 Ok(None) => return Self::default_decision(&default_action),
948 Err(_) => {}
949 }
950 if cancel_rx.try_recv().is_ok() {
951 return Self::default_decision(&default_action);
952 }
953 if start.elapsed() >= timeout {
954 return Self::default_decision(&default_action);
955 }
956 }
957 } else {
958 loop {
959 if let Some(msg) = decision_rx.recv().await {
960 if let Some(decision) = registry.process_message(msg, target_id) {
961 return decision;
962 }
963 } else {
964 return Self::default_decision(&default_action);
965 }
966 if cancel_rx.try_recv().is_ok() {
967 return Self::default_decision(&default_action);
968 }
969 }
970 }
971 }
972
973 fn default_decision(action: &BarrierDefaultAction) -> BarrierDecision {
974 match action {
975 BarrierDefaultAction::Approve => BarrierDecision::Approve,
976 BarrierDefaultAction::Reject => BarrierDecision::Reject {
977 reason: "timeout — no decision received".into(),
978 },
979 BarrierDefaultAction::Skip => BarrierDecision::Approve,
980 }
981 }
982
983 fn resolve_next(
986 &self,
987 graph: &Graph,
988 current: &str,
989 state: &State,
990 ) -> Result<String, GraphError> {
991 Self::find_next_node(graph, current, state)
992 }
993
994 fn find_next_node(graph: &Graph, current: &str, state: &State) -> Result<String, GraphError> {
995 let edges = graph.edges_from(current);
996
997 if edges.is_empty() {
998 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
999 "node '{}' has no outgoing edges and is not the end node",
1000 current
1001 ))));
1002 }
1003
1004 for edge in &edges {
1005 if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
1006 return Ok(edge.to.clone());
1007 }
1008 }
1009
1010 for edge in &edges {
1011 if edge.is_normal() {
1012 return Ok(edge.to.clone());
1013 }
1014 }
1015
1016 for edge in &edges {
1017 if edge.fallback {
1018 return Ok(edge.to.clone());
1019 }
1020 }
1021
1022 let attempted: Vec<crate::error::ConditionEval> = edges
1023 .iter()
1024 .map(|e| crate::error::ConditionEval {
1025 edge: format!("{}→{}", e.from, e.to),
1026 condition: e.condition.as_ref().map(|_| "condition".to_string()),
1027 matched: e.condition.as_ref().is_some_and(|c| c(state)),
1028 })
1029 .collect();
1030
1031 Err(GraphError::Terminal(TerminalError::Unrouted {
1032 node: current.to_string(),
1033 attempted_conditions: attempted,
1034 }))
1035 }
1036
1037 fn should_checkpoint(&self, trigger: CheckpointPolicyTrigger) -> bool {
1041 match self.policy {
1042 CheckpointPolicy::EveryNode => true,
1043 CheckpointPolicy::BarrierOnly => matches!(
1044 trigger,
1045 CheckpointPolicyTrigger::BarrierResolved | CheckpointPolicyTrigger::GraphComplete
1046 ),
1047 CheckpointPolicy::Manual => false,
1048 }
1049 }
1050
1051 async fn save_checkpoint(
1053 &self,
1054 event_tx: &mpsc::Sender<GraphEvent>,
1055 trace_id: &TraceId,
1056 node_name: &str,
1057 state: &State,
1058 step: usize,
1059 ) {
1060 let store = match &self.store {
1061 Some(s) => s,
1062 None => return,
1063 };
1064
1065 let ck = Checkpoint::new(node_name, state.clone());
1066
1067 match store.save_with_trace(trace_id, &ck).await {
1068 Ok(()) => {
1069 let _ = self
1070 .send(
1071 event_tx,
1072 GraphEvent::CheckpointSaved {
1073 checkpoint_id: ck.checkpoint_id.clone(),
1074 node_name: node_name.to_string(),
1075 step,
1076 },
1077 )
1078 .await;
1079 }
1080 Err(e) => tracing::warn!(error = %e, "checkpoint save failed"),
1081 }
1082 }
1083
1084 pub async fn resume_from(
1087 &self,
1088 store: &dyn CheckpointStore,
1089 trace_id: &TraceId,
1090 graph: &Arc<Graph>,
1091 ) -> Result<GraphExecution, GraphError> {
1092 let checkpoint = store
1093 .load_latest(trace_id)
1094 .await
1095 .map_err(|e| {
1096 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1097 "failed to load checkpoint: {}",
1098 e
1099 )))
1100 })?
1101 .ok_or_else(|| {
1102 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1103 "no checkpoint found for trace {}",
1104 trace_id
1105 )))
1106 })?;
1107
1108 let initial_state = checkpoint.state.clone();
1109
1110 let resume_node = {
1111 let cn = checkpoint.current_node.0.as_str();
1112 if cn == "__complete__" || cn == graph.end_node() {
1113 tracing::warn!(
1114 trace_id = %trace_id,
1115 current_node = %cn,
1116 "checkpoint indicates graph was already complete; \
1117 resuming from start node. \
1118 Consider using an intermediate checkpoint for true recovery."
1119 );
1120 None
1121 } else if graph.nodes.contains_key(cn) {
1122 tracing::info!(
1123 trace_id = %trace_id,
1124 resume_node = %cn,
1125 "resuming from checkpoint node"
1126 );
1127 Some(cn.to_string())
1128 } else {
1129 tracing::warn!(
1130 trace_id = %trace_id,
1131 current_node = %cn,
1132 "checkpoint node not found in graph; resuming from start node"
1133 );
1134 None
1135 }
1136 };
1137
1138 let execution =
1139 self.execute_stream_with(graph.clone(), initial_state, resume_node, Some(*trace_id));
1140 Ok(execution)
1141 }
1142}
1143
1144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1148enum CheckpointPolicyTrigger {
1149 NodeExecuted { has_side_effects: bool },
1151 BarrierResolved,
1153 GraphComplete,
1155}