1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio::sync::mpsc;
11
12use crate::barrier_node::BarrierDefaultAction;
13use crate::branch_state::BranchState;
14use crate::checkpoint::{Checkpoint, CheckpointPolicy, CheckpointStore, TraceId};
15use crate::delta::ReducerRegistry;
16use crate::error::{GraphError, TerminalError};
17use crate::event::{
18 BarrierDecision, BarrierDecisionMessage, BarrierId, FlowEvent, GraphEvent, GraphExecution,
19 GraphHandle,
20};
21use crate::graph::Graph;
22use crate::ids::SpanId;
23use crate::node::{FlowNode, NodeKind};
24use crate::node_context::{ExecutionSignal, NextAction, NodeContext, NodeMetadata};
25use crate::runtime_event::RuntimeEvent;
26use crate::state::{ExecutionEntry, GraphResult, State};
27use crate::stream_emitter::NoopSink;
28use crate::workflow_state::WorkflowState;
29use tokio_util::sync::CancellationToken;
30
31struct RunLoopContext {
35 graph: Arc<Graph>,
36 initial_state: State,
37 event_tx: mpsc::Sender<GraphEvent>,
38 decision_rx: mpsc::Receiver<BarrierDecisionMessage>,
39 cancel_rx: mpsc::Receiver<()>,
40 start_node: Option<String>,
41 trace_id: Option<TraceId>,
42}
43
44struct BarrierHandlerContext<'a> {
48 event_tx: &'a mpsc::Sender<GraphEvent>,
49 graph: &'a Graph,
50 decision_rx: &'a mut mpsc::Receiver<BarrierDecisionMessage>,
51 decision_registry: &'a mut DecisionRegistry,
52 cancel_rx: &'a mut mpsc::Receiver<()>,
53 node: &'a NodeKind,
54 current: &'a str,
55 state: &'a mut State,
56 execution_log: &'a mut Vec<ExecutionEntry>,
57 barrier_name: &'a str,
58 barrier_id: BarrierId,
59 timeout: Option<std::time::Duration>,
60 step: usize,
61 node_start: Instant,
62 trace_id: TraceId,
63}
64
65#[allow(dead_code)]
68struct DecisionRegistry {
69 pending: HashMap<BarrierId, BarrierDecision>,
70 wildcards: HashMap<String, BarrierDecision>,
71 occurrence_counter: HashMap<String, u32>,
72}
73
74impl DecisionRegistry {
75 fn new() -> Self {
76 Self {
77 pending: HashMap::new(),
78 wildcards: HashMap::new(),
79 occurrence_counter: HashMap::new(),
80 }
81 }
82
83 #[allow(dead_code)]
84 fn next_id(&mut self, node_id: &str) -> BarrierId {
85 let occ = self
86 .occurrence_counter
87 .entry(node_id.to_string())
88 .or_insert(0);
89 *occ += 1;
90 BarrierId::new(node_id, *occ)
91 }
92
93 fn take(&mut self, target_id: &BarrierId) -> Option<BarrierDecision> {
94 if let Some(decision) = self.pending.remove(target_id) {
95 return Some(decision);
96 }
97 self.wildcards.get(&target_id.node_id).cloned()
98 }
99
100 fn process_message(
101 &mut self,
102 msg: BarrierDecisionMessage,
103 target_id: &BarrierId,
104 ) -> Option<BarrierDecision> {
105 match msg {
106 BarrierDecisionMessage::Exact {
107 barrier_id,
108 decision,
109 } => {
110 if barrier_id == *target_id {
111 Some(decision)
112 } else {
113 self.pending.insert(barrier_id, decision);
114 None
115 }
116 }
117 BarrierDecisionMessage::Wildcard { node_id, decision } => {
118 self.wildcards.insert(node_id.clone(), decision.clone());
119 if node_id == target_id.node_id {
120 Some(decision)
121 } else {
122 None
123 }
124 }
125 }
126 }
127}
128
129#[derive(Debug)]
132enum StepOutcome {
133 Continue(String),
134 Break,
135 ErrorSent,
136}
137
138pub struct GraphExecutor {
142 pub max_steps: usize,
143 store: Option<Arc<dyn CheckpointStore>>,
144 policy: CheckpointPolicy,
145 graph_hash: String,
146 pending_reducers: Vec<(String, crate::delta::Reducer)>,
147}
148
149impl Clone for GraphExecutor {
150 fn clone(&self) -> Self {
151 Self {
152 max_steps: self.max_steps,
153 store: self.store.clone(),
154 policy: self.policy,
155 graph_hash: self.graph_hash.clone(),
156 pending_reducers: self.pending_reducers.clone(),
157 }
158 }
159}
160
161impl std::fmt::Debug for GraphExecutor {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("GraphExecutor")
164 .field("max_steps", &self.max_steps)
165 .field("has_store", &self.store.is_some())
166 .field("policy", &self.policy)
167 .finish()
168 }
169}
170
171impl Default for GraphExecutor {
172 fn default() -> Self {
173 Self {
174 max_steps: 50,
175 store: None,
176 policy: CheckpointPolicy::default(),
177 graph_hash: String::new(),
178 pending_reducers: Vec::new(),
179 }
180 }
181}
182
183impl GraphExecutor {
184 pub fn new(max_steps: usize) -> Self {
185 Self {
186 max_steps,
187 ..Default::default()
188 }
189 }
190
191 pub fn with_checkpoint(
192 max_steps: usize,
193 store: Arc<dyn CheckpointStore>,
194 policy: CheckpointPolicy,
195 graph: &Graph,
196 ) -> Self {
197 Self {
198 max_steps,
199 store: Some(store),
200 policy,
201 graph_hash: graph.hash(),
202 ..Default::default()
203 }
204 }
205
206 pub fn register_reducer(&mut self, key: &str, reducer: crate::delta::Reducer) {
207 self.pending_reducers.push((key.to_string(), reducer));
208 }
209
210 pub fn set_graph(&mut self, graph: &Graph) {
211 self.graph_hash = graph.hash();
212 }
213
214 pub async fn execute(
217 &self,
218 graph: Arc<Graph>,
219 initial_state: State,
220 ) -> Result<GraphResult, GraphError> {
221 let GraphExecution { mut stream, handle } = self.execute_stream(graph, initial_state);
222 drop(handle);
223
224 let mut result = None;
225 while let Some(event) = stream.recv().await {
226 match event {
227 GraphEvent::GraphComplete { result: r } => result = Some(Ok(r)),
228 GraphEvent::GraphError { error, .. } => result = Some(Err(error)),
229 _ => {}
230 }
231 }
232 result.unwrap_or_else(|| {
233 Err(GraphError::Terminal(TerminalError::InvalidGraph(
234 "stream ended without completion".into(),
235 )))
236 })
237 }
238
239 pub fn execute_stream(&self, graph: Arc<Graph>, initial_state: State) -> GraphExecution {
242 let executor = self.clone();
243 let (event_tx, event_rx) = mpsc::channel(32);
244 let (decision_tx, decision_rx) = mpsc::channel(16);
245 let (cancel_tx, cancel_rx) = mpsc::channel(1);
246
247 let handle = GraphHandle::new(decision_tx, cancel_tx);
248
249 tokio::spawn(async move {
250 executor
251 .run_loop(RunLoopContext {
252 graph,
253 initial_state,
254 event_tx,
255 decision_rx,
256 cancel_rx,
257 start_node: None,
258 trace_id: None,
259 })
260 .await;
261 });
262
263 GraphExecution {
264 stream: event_rx,
265 handle,
266 }
267 }
268
269 fn execute_stream_with(
271 &self,
272 graph: Arc<Graph>,
273 initial_state: State,
274 start_node: Option<String>,
275 trace_id: Option<TraceId>,
276 ) -> GraphExecution {
277 let executor = self.clone();
278 let (event_tx, event_rx) = mpsc::channel(32);
279 let (decision_tx, decision_rx) = mpsc::channel(16);
280 let (cancel_tx, cancel_rx) = mpsc::channel(1);
281
282 let handle = GraphHandle::new(decision_tx, cancel_tx);
283
284 tokio::spawn(async move {
285 executor
286 .run_loop(RunLoopContext {
287 graph,
288 initial_state,
289 event_tx,
290 decision_rx,
291 cancel_rx,
292 start_node,
293 trace_id,
294 })
295 .await;
296 });
297
298 GraphExecution {
299 stream: event_rx,
300 handle,
301 }
302 }
303
304 async fn run_loop(&self, ctx: RunLoopContext) {
307 let RunLoopContext {
308 graph,
309 initial_state,
310 event_tx,
311 mut decision_rx,
312 mut cancel_rx,
313 start_node,
314 trace_id,
315 } = ctx;
316 let start_time = Instant::now();
317 let mut state = initial_state;
318 let mut execution_log = Vec::new();
319 let mut decision_registry = DecisionRegistry::new();
320 let mut _reducer_registry = ReducerRegistry::new();
321
322 for (key, reducer) in &self.pending_reducers {
323 _reducer_registry.register(key, *reducer);
324 }
325
326 let mut current = start_node.unwrap_or_else(|| graph.start_node().to_string());
327 let mut step: usize = 0;
328 let trace_id = trace_id.unwrap_or_default();
329
330 self.emit_runtime(
332 &event_tx,
333 RuntimeEvent::ExecutionStarted {
334 trace_id,
335 graph_name: graph.name().to_string(),
336 },
337 )
338 .await;
339 if self
340 .send(&event_tx, GraphEvent::GraphStart { trace_id })
341 .await
342 {
343 return;
344 }
345
346 loop {
347 if cancel_rx.try_recv().is_ok() {
349 self.send_graph_error(
350 &event_tx,
351 GraphError::Terminal(TerminalError::BarrierCancelled {
352 node: "execution cancelled by handle".into(),
353 }),
354 &state,
355 start_time,
356 trace_id,
357 )
358 .await;
359 return;
360 }
361
362 step += 1;
363
364 if step > self.max_steps {
366 self.send_graph_error(
367 &event_tx,
368 GraphError::Terminal(TerminalError::StepsExceeded {
369 limit: self.max_steps,
370 }),
371 &state,
372 start_time,
373 trace_id,
374 )
375 .await;
376 return;
377 }
378
379 let node = match graph.nodes.get(¤t) {
381 Some(n) => n,
382 None => {
383 self.send_graph_error(
384 &event_tx,
385 GraphError::Terminal(TerminalError::NodeNotFound(current.clone())),
386 &state,
387 start_time,
388 trace_id,
389 )
390 .await;
391 return;
392 }
393 };
394
395 let node_name = current.clone();
396 let span_id = SpanId::new();
397
398 self.emit_runtime(
400 &event_tx,
401 RuntimeEvent::NodeStarted {
402 node_name: node_name.clone(),
403 trace_id,
404 span_id,
405 step,
406 },
407 )
408 .await;
409 if self
410 .send(
411 &event_tx,
412 GraphEvent::NodeStart {
413 node_name: node_name.clone(),
414 trace_id,
415 span_id,
416 step,
417 },
418 )
419 .await
420 {
421 return;
422 }
423
424 let node_start = Instant::now();
425
426 let exec_result = self
428 .execute_node(node, &mut state, &node_name, span_id)
429 .await;
430 let node_end = Instant::now();
431 let duration = node_end.duration_since(node_start);
432
433 match exec_result {
434 Ok((next_action, signal, metadata, flow_events)) => {
435 execution_log.push(ExecutionEntry {
437 step,
438 node_name: node_name.clone(),
439 start_time: node_start,
440 end_time: node_end,
441 success: true,
442 error: None,
443 });
444
445 for flow_event in flow_events {
447 if self
448 .send(
449 &event_tx,
450 GraphEvent::Node {
451 span_id,
452 node_name: node_name.clone(),
453 event: flow_event,
454 },
455 )
456 .await
457 {
458 return;
459 }
460 }
461
462 if self.should_checkpoint(CheckpointPolicyTrigger::NodeExecuted {
464 has_side_effects: metadata.has_side_effects,
465 }) {
466 self.save_checkpoint(&event_tx, &trace_id, ¤t, &state, step)
467 .await;
468 }
469
470 self.emit_runtime(
472 &event_tx,
473 RuntimeEvent::NodeCompleted {
474 node_name: node_name.clone(),
475 trace_id,
476 span_id,
477 duration,
478 },
479 )
480 .await;
481 if self
482 .send(
483 &event_tx,
484 GraphEvent::NodeEnd {
485 node_name: node_name.clone(),
486 trace_id,
487 span_id,
488 success: true,
489 duration,
490 },
491 )
492 .await
493 {
494 return;
495 }
496
497 if let Some(signal) = signal {
499 match signal {
500 ExecutionSignal::Pause {
501 barrier_id,
502 timeout,
503 } => {
504 let outcome = self
505 .handle_barrier_signal(BarrierHandlerContext {
506 event_tx: &event_tx,
507 graph: &graph,
508 decision_rx: &mut decision_rx,
509 decision_registry: &mut decision_registry,
510 cancel_rx: &mut cancel_rx,
511 node,
512 current: ¤t,
513 state: &mut state,
514 execution_log: &mut execution_log,
515 barrier_name: &node_name,
516 barrier_id,
517 timeout,
518 step,
519 node_start,
520 trace_id,
521 })
522 .await;
523 match outcome {
524 StepOutcome::Continue(target) => {
525 if self.should_checkpoint(
527 CheckpointPolicyTrigger::BarrierResolved,
528 ) {
529 self.save_checkpoint(
530 &event_tx, &trace_id, &target, &state, step,
531 )
532 .await;
533 }
534 current = target;
535 }
536 StepOutcome::Break => {
537 self.send_graph_complete(
538 &event_tx,
539 &state,
540 &execution_log,
541 start_time,
542 trace_id,
543 )
544 .await;
545 return;
546 }
547 StepOutcome::ErrorSent => return,
548 }
549 continue;
550 }
551 }
552 }
553
554 let outcome = match next_action {
556 NextAction::End => StepOutcome::Break,
557 NextAction::Goto(target) => StepOutcome::Continue(target),
558 NextAction::Next => {
559 if current == graph.end_node() {
561 StepOutcome::Break
562 } else {
563 match self.resolve_next(&graph, ¤t, &state) {
564 Ok(target) => StepOutcome::Continue(target),
565 Err(e) => {
566 self.send_graph_error(
567 &event_tx, e, &state, start_time, trace_id,
568 )
569 .await;
570 StepOutcome::ErrorSent
571 }
572 }
573 }
574 }
575 };
576
577 match outcome {
578 StepOutcome::Continue(target) => {
579 current = target;
580 }
581 StepOutcome::Break => {
582 if self.should_checkpoint(CheckpointPolicyTrigger::GraphComplete) {
584 self.save_checkpoint(
585 &event_tx,
586 &trace_id,
587 "__complete__",
588 &state,
589 step,
590 )
591 .await;
592 }
593
594 self.send_graph_complete(
595 &event_tx,
596 &state,
597 &execution_log,
598 start_time,
599 trace_id,
600 )
601 .await;
602 return;
603 }
604 StepOutcome::ErrorSent => return,
605 }
606 }
607 Err(e) => {
608 let error_str = e.to_string();
610 execution_log.push(ExecutionEntry {
611 step,
612 node_name: node_name.clone(),
613 start_time: node_start,
614 end_time: node_end,
615 success: false,
616 error: Some(error_str.clone()),
617 });
618
619 self.emit_runtime(
620 &event_tx,
621 RuntimeEvent::NodeFailed {
622 node_name: node_name.clone(),
623 trace_id,
624 span_id,
625 error: e.to_string(),
626 },
627 )
628 .await;
629 if self
630 .send(
631 &event_tx,
632 GraphEvent::NodeEnd {
633 node_name: node_name.clone(),
634 trace_id,
635 span_id,
636 success: false,
637 duration,
638 },
639 )
640 .await
641 {
642 return;
643 }
644
645 self.send_graph_error(&event_tx, e, &state, start_time, trace_id)
646 .await;
647 return;
648 }
649 }
650 }
651 }
652
653 async fn execute_node(
657 &self,
658 node: &NodeKind,
659 state: &mut State,
660 _node_name: &str,
661 _span_id: SpanId,
662 ) -> Result<
663 (
664 NextAction,
665 Option<ExecutionSignal>,
666 NodeMetadata,
667 Vec<FlowEvent>,
668 ),
669 GraphError,
670 > {
671 let mut branch = BranchState::from_state(state.clone());
672 let cancel = CancellationToken::new();
673 let noop_sink = NoopSink;
674 let mut ctx = NodeContext::new(state, &mut branch, Some(&noop_sink), cancel);
675
676 node.execute(&mut ctx).await?;
677
678 let effects = ctx.consume_effects();
679 let (next_action, signal) = ctx.take_control();
680 let metadata = ctx.take_metadata();
681 let flow_events = ctx.take_flow_events();
682
683 state.apply_batch(effects);
684
685 Ok((next_action, signal, metadata, flow_events))
686 }
687
688 async fn handle_barrier_signal(&self, ctx: BarrierHandlerContext<'_>) -> StepOutcome {
691 let BarrierHandlerContext {
692 event_tx,
693 graph,
694 decision_rx,
695 decision_registry,
696 cancel_rx,
697 node,
698 current,
699 state,
700 execution_log,
701 barrier_name,
702 barrier_id,
703 timeout,
704 step,
705 node_start,
706 trace_id,
707 } = ctx;
708 self.emit_runtime(
709 event_tx,
710 RuntimeEvent::BarrierWaiting {
711 barrier_id: barrier_id.clone(),
712 node_name: barrier_name.to_string(),
713 span_id: SpanId::new(),
714 },
715 )
716 .await;
717 if self
718 .send(
719 event_tx,
720 GraphEvent::BarrierWaiting {
721 barrier_id: barrier_id.clone(),
722 node_name: barrier_name.to_string(),
723 span_id: SpanId::new(),
724 },
725 )
726 .await
727 {
728 return StepOutcome::Break;
729 }
730
731 let decision = self
732 .wait_barrier_decision(
733 decision_rx,
734 decision_registry,
735 &barrier_id,
736 timeout,
737 cancel_rx,
738 )
739 .await;
740
741 if cancel_rx.try_recv().is_ok() {
742 self.send_graph_error(
743 event_tx,
744 GraphError::Terminal(TerminalError::BarrierCancelled {
745 node: barrier_name.to_string(),
746 }),
747 state,
748 node_start,
749 trace_id,
750 )
751 .await;
752 return StepOutcome::ErrorSent;
753 }
754
755 self.emit_runtime(
756 event_tx,
757 RuntimeEvent::BarrierResolved {
758 barrier_id: barrier_id.clone(),
759 },
760 )
761 .await;
762 if self
763 .send(
764 event_tx,
765 GraphEvent::BarrierResolved {
766 barrier_id: barrier_id.clone(),
767 decision: decision.clone(),
768 },
769 )
770 .await
771 {
772 return StepOutcome::Break;
773 }
774
775 match node {
776 NodeKind::Barrier(b) => {
777 let mut branch = BranchState::from_state(state.clone());
778 let cancel = CancellationToken::new();
779 let mut ctx = NodeContext::new(state, &mut branch, None, cancel);
780 b.apply_decision_to_ctx(&mut ctx, decision);
781 let (next, _signal) = ctx.take_control();
782
783 let effects = ctx.consume_effects();
784 state.apply_batch(effects);
785
786 let end_time = Instant::now();
787 execution_log.push(ExecutionEntry {
788 step,
789 node_name: barrier_name.to_string(),
790 start_time: node_start,
791 end_time,
792 success: true,
793 error: None,
794 });
795
796 if self
797 .send(
798 event_tx,
799 GraphEvent::NodeEnd {
800 node_name: barrier_name.to_string(),
801 trace_id,
802 span_id: SpanId::new(),
803 success: true,
804 duration: end_time.duration_since(node_start),
805 },
806 )
807 .await
808 {
809 return StepOutcome::Break;
810 }
811
812 if current == graph.end_node() {
813 return StepOutcome::Break;
814 }
815
816 match next {
817 NextAction::End => StepOutcome::Break,
818 NextAction::Goto(target) => StepOutcome::Continue(target),
819 NextAction::Next => match self.resolve_next(graph, current, state) {
820 Ok(target) => StepOutcome::Continue(target),
821 Err(e) => {
822 self.send_graph_error(event_tx, e, state, end_time, trace_id)
823 .await;
824 StepOutcome::ErrorSent
825 }
826 },
827 }
828 }
829 _ => {
830 self.send_graph_error(
831 event_tx,
832 GraphError::Terminal(TerminalError::InvalidGraph(
833 "expected BarrierNode for pause signal".into(),
834 )),
835 state,
836 node_start,
837 trace_id,
838 )
839 .await;
840 StepOutcome::ErrorSent
841 }
842 }
843 }
844
845 async fn emit_runtime(
848 &self,
849 _event_tx: &mpsc::Sender<GraphEvent>,
850 runtime_event: RuntimeEvent,
851 ) {
852 tracing::debug!(?runtime_event, "runtime event");
853 }
854
855 async fn send(&self, event_tx: &mpsc::Sender<GraphEvent>, event: GraphEvent) -> bool {
856 event_tx.send(event).await.is_err()
857 }
858
859 async fn send_graph_error(
860 &self,
861 event_tx: &mpsc::Sender<GraphEvent>,
862 error: GraphError,
863 state: &State,
864 _start_time: Instant,
865 _trace_id: TraceId,
866 ) {
867 let _ = self
868 .send(
869 event_tx,
870 GraphEvent::GraphError {
871 error,
872 state: state.clone(),
873 },
874 )
875 .await;
876 }
877
878 async fn send_graph_complete(
879 &self,
880 event_tx: &mpsc::Sender<GraphEvent>,
881 state: &State,
882 execution_log: &[ExecutionEntry],
883 start_time: Instant,
884 trace_id: TraceId,
885 ) {
886 self.emit_runtime(
887 event_tx,
888 RuntimeEvent::ExecutionCompleted {
889 trace_id,
890 duration: start_time.elapsed(),
891 },
892 )
893 .await;
894
895 let _ = self
896 .send(
897 event_tx,
898 GraphEvent::GraphComplete {
899 result: GraphResult {
900 trace_id,
901 state: state.clone(),
902 execution_log: execution_log.to_vec(),
903 duration: start_time.elapsed(),
904 },
905 },
906 )
907 .await;
908 }
909
910 async fn wait_barrier_decision(
913 &self,
914 decision_rx: &mut mpsc::Receiver<BarrierDecisionMessage>,
915 registry: &mut DecisionRegistry,
916 target_id: &BarrierId,
917 timeout: Option<std::time::Duration>,
918 cancel_rx: &mut mpsc::Receiver<()>,
919 ) -> BarrierDecision {
920 if let Some(decision) = registry.take(target_id) {
921 return decision;
922 }
923
924 while let Ok(msg) = decision_rx.try_recv() {
925 if let Some(decision) = registry.process_message(msg, target_id) {
926 return decision;
927 }
928 }
929
930 if cancel_rx.try_recv().is_ok() {
931 return BarrierDecision::Reject {
932 reason: "cancelled".into(),
933 };
934 }
935
936 let default_action = BarrierDefaultAction::Reject;
937
938 if let Some(timeout) = timeout {
939 let start = Instant::now();
940 loop {
941 match tokio::time::timeout(std::time::Duration::from_millis(50), decision_rx.recv())
942 .await
943 {
944 Ok(Some(msg)) => {
945 if let Some(decision) = registry.process_message(msg, target_id) {
946 return decision;
947 }
948 }
949 Ok(None) => return Self::default_decision(&default_action),
950 Err(_) => {}
951 }
952 if cancel_rx.try_recv().is_ok() {
953 return Self::default_decision(&default_action);
954 }
955 if start.elapsed() >= timeout {
956 return Self::default_decision(&default_action);
957 }
958 }
959 } else {
960 loop {
961 if let Some(msg) = decision_rx.recv().await {
962 if let Some(decision) = registry.process_message(msg, target_id) {
963 return decision;
964 }
965 } else {
966 return Self::default_decision(&default_action);
967 }
968 if cancel_rx.try_recv().is_ok() {
969 return Self::default_decision(&default_action);
970 }
971 }
972 }
973 }
974
975 fn default_decision(action: &BarrierDefaultAction) -> BarrierDecision {
976 match action {
977 BarrierDefaultAction::Approve => BarrierDecision::Approve,
978 BarrierDefaultAction::Reject => BarrierDecision::Reject {
979 reason: "timeout — no decision received".into(),
980 },
981 BarrierDefaultAction::Skip => BarrierDecision::Approve,
982 }
983 }
984
985 fn resolve_next(
988 &self,
989 graph: &Graph,
990 current: &str,
991 state: &State,
992 ) -> Result<String, GraphError> {
993 Self::find_next_node(graph, current, state)
994 }
995
996 fn find_next_node(graph: &Graph, current: &str, state: &State) -> Result<String, GraphError> {
997 let edges = graph.edges_from(current);
998
999 if edges.is_empty() {
1000 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
1001 "node '{}' has no outgoing edges and is not the end node",
1002 current
1003 ))));
1004 }
1005
1006 for edge in &edges {
1007 if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
1008 return Ok(edge.to.clone());
1009 }
1010 }
1011
1012 for edge in &edges {
1013 if edge.is_normal() {
1014 return Ok(edge.to.clone());
1015 }
1016 }
1017
1018 for edge in &edges {
1019 if edge.fallback {
1020 return Ok(edge.to.clone());
1021 }
1022 }
1023
1024 let attempted: Vec<crate::error::ConditionEval> = edges
1025 .iter()
1026 .map(|e| crate::error::ConditionEval {
1027 edge: format!("{}→{}", e.from, e.to),
1028 condition: e.condition.as_ref().map(|_| "condition".to_string()),
1029 matched: e.condition.as_ref().is_some_and(|c| c(state)),
1030 })
1031 .collect();
1032
1033 Err(GraphError::Terminal(TerminalError::Unrouted {
1034 node: current.to_string(),
1035 attempted_conditions: attempted,
1036 }))
1037 }
1038
1039 fn should_checkpoint(&self, trigger: CheckpointPolicyTrigger) -> bool {
1043 match self.policy {
1044 CheckpointPolicy::EveryNode => true,
1045 CheckpointPolicy::BarrierOnly => matches!(
1046 trigger,
1047 CheckpointPolicyTrigger::BarrierResolved | CheckpointPolicyTrigger::GraphComplete
1048 ),
1049 CheckpointPolicy::Manual => false,
1050 }
1051 }
1052
1053 async fn save_checkpoint(
1055 &self,
1056 event_tx: &mpsc::Sender<GraphEvent>,
1057 trace_id: &TraceId,
1058 node_name: &str,
1059 state: &State,
1060 step: usize,
1061 ) {
1062 let store = match &self.store {
1063 Some(s) => s,
1064 None => return,
1065 };
1066
1067 let ck = Checkpoint::new(node_name, state.clone());
1068
1069 match store.save_with_trace(trace_id, &ck).await {
1070 Ok(()) => {
1071 let _ = self
1072 .send(
1073 event_tx,
1074 GraphEvent::CheckpointSaved {
1075 checkpoint_id: ck.checkpoint_id.clone(),
1076 node_name: node_name.to_string(),
1077 step,
1078 },
1079 )
1080 .await;
1081 }
1082 Err(e) => tracing::warn!(error = %e, "checkpoint save failed"),
1083 }
1084 }
1085
1086 pub async fn resume_from(
1089 &self,
1090 store: &dyn CheckpointStore,
1091 trace_id: &TraceId,
1092 graph: &Arc<Graph>,
1093 ) -> Result<GraphExecution, GraphError> {
1094 let checkpoint = store
1095 .load_latest(trace_id)
1096 .await
1097 .map_err(|e| {
1098 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1099 "failed to load checkpoint: {}",
1100 e
1101 )))
1102 })?
1103 .ok_or_else(|| {
1104 GraphError::Terminal(TerminalError::InvalidGraph(format!(
1105 "no checkpoint found for trace {}",
1106 trace_id
1107 )))
1108 })?;
1109
1110 let initial_state = checkpoint.state.clone();
1111
1112 let resume_node = {
1113 let cn = checkpoint.current_node.0.as_str();
1114 if cn == "__complete__" || cn == graph.end_node() {
1115 tracing::warn!(
1116 trace_id = %trace_id,
1117 current_node = %cn,
1118 "checkpoint indicates graph was already complete; \
1119 resuming from start node. \
1120 Consider using an intermediate checkpoint for true recovery."
1121 );
1122 None
1123 } else if graph.nodes.contains_key(cn) {
1124 tracing::info!(
1125 trace_id = %trace_id,
1126 resume_node = %cn,
1127 "resuming from checkpoint node"
1128 );
1129 Some(cn.to_string())
1130 } else {
1131 tracing::warn!(
1132 trace_id = %trace_id,
1133 current_node = %cn,
1134 "checkpoint node not found in graph; resuming from start node"
1135 );
1136 None
1137 }
1138 };
1139
1140 let execution =
1141 self.execute_stream_with(graph.clone(), initial_state, resume_node, Some(*trace_id));
1142 Ok(execution)
1143 }
1144}
1145
1146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1150enum CheckpointPolicyTrigger {
1151 NodeExecuted { has_side_effects: bool },
1153 BarrierResolved,
1155 GraphComplete,
1157}