1use std::sync::Arc;
14use std::time::Instant;
15
16use serde::Serialize;
17use tokio_util::sync::CancellationToken;
18
19use crate::barrier_wait::{
20 BarrierOutcome, apply_barrier_decision_generic, wait_for_barrier_decision,
21};
22use crate::checkpoint::{Checkpoint, NodeId, TraceId};
23use crate::checkpoint_policy::{RetentionPolicy, TriggerPolicy};
24use crate::error::GraphError;
25use crate::event::{BarrierDecision, BarrierDecisionMessage, GraphEvent};
26use crate::execution_engine::{ExecutionEngine, ExecutionSignal, ExecutorState, NextAction};
27use crate::graph::Graph;
28use crate::ids::SpanId;
29use crate::node::{BarrierNode, ConditionNode, FlowNode, LeafNode, NodeKind};
30use crate::state::{ExecutionEntry, GraphResult};
31use crate::trace::{MemoryTraceSink, TraceSink, TraceStep};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34type CheckpointSaveFn<S> = Box<
40 dyn Fn(
41 Checkpoint<S>,
42 TraceId,
43 ) -> std::pin::Pin<
44 Box<
45 dyn std::future::Future<
46 Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47 > + Send,
48 >,
49 > + Send
50 + Sync,
51>;
52
53pub struct CheckpointConfig<S: WorkflowState> {
63 pub trigger: TriggerPolicy,
65 pub retention: RetentionPolicy,
67 save_fn: CheckpointSaveFn<S>,
69 graph_hash: u64,
71 store: Option<std::sync::Arc<dyn crate::store::BlobCheckpointStore>>,
73}
74
75impl<S: WorkflowState> CheckpointConfig<S> {
76 pub fn new(
81 save_fn: impl Fn(
82 Checkpoint<S>,
83 TraceId,
84 ) -> std::pin::Pin<
85 Box<
86 dyn std::future::Future<
87 Output = Result<(), crate::checkpoint::CheckpointStoreError>,
88 > + Send,
89 >,
90 > + Send
91 + Sync
92 + 'static,
93 graph_hash: u64,
94 ) -> Self {
95 Self {
96 save_fn: Box::new(save_fn),
97 trigger: TriggerPolicy::default(),
98 retention: RetentionPolicy::default(),
99 graph_hash,
100 store: None,
101 }
102 }
103
104 pub fn with_trigger(mut self, trigger: TriggerPolicy) -> Self {
106 self.trigger = trigger;
107 self
108 }
109
110 pub fn with_retention(mut self, retention: RetentionPolicy) -> Self {
112 self.retention = retention;
113 self
114 }
115
116 pub fn with_store(
118 mut self,
119 store: std::sync::Arc<dyn crate::store::BlobCheckpointStore>,
120 ) -> Self {
121 self.store = Some(store);
122 self
123 }
124
125 #[allow(deprecated)]
127 pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
128 self.trigger = policy.into();
129 self
130 }
131
132 pub fn should_save(&self, has_mutations: bool, is_barrier: bool) -> bool {
134 match self.trigger {
135 TriggerPolicy::EveryNode => true,
136 TriggerPolicy::BarrierOnly => is_barrier,
137 TriggerPolicy::Manual => false,
138 TriggerPolicy::OnMutation => has_mutations,
139 }
140 }
141
142 pub async fn apply_retention(
144 &self,
145 trace_id: &TraceId,
146 ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
147 if let Some(keep) = self.retention.prune_keep() {
148 if let Some(ref store) = self.store {
149 let pruned = store.prune(trace_id, keep).await?;
150 if pruned > 0 {
151 tracing::debug!(pruned, keep, "checkpoint pruned");
152 }
153 }
154 }
155 Ok(())
156 }
157}
158
159pub(crate) async fn run_execution_loop<S, M>(
175 graph: Arc<Graph<S, M>>,
176 state: S,
177 max_steps: usize,
178 trace_id: TraceId,
179 event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
180 mut decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
181 mut cancel_rx: tokio::sync::mpsc::Receiver<()>,
182 cancel: CancellationToken,
183 checkpoint: Option<CheckpointConfig<S>>,
184 trace_sink: Option<MemoryTraceSink<S::Mutation>>,
185 restore_from: Option<Checkpoint<S>>,
186) where
187 S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
188 S::Mutation: Clone + Send + Sync,
189 M: MergeStrategy<S>,
190{
191 let start_time = Instant::now();
192 let mut execution_log: Vec<ExecutionEntry> = Vec::new();
193
194 let restore_state = restore_from.as_ref().map(|cp| S::restore(cp.state.clone()));
197 let mut engine_state = restore_state.unwrap_or(state);
198 let mut engine = ExecutionEngine::new(&mut engine_state, None, cancel.clone());
199 let mut current = if let Some(ref cp) = restore_from {
200 cp.current_node.0.clone()
201 } else {
202 graph.start_node().to_string()
203 };
204 let mut step: usize = 0;
205 let mut wildcard_cache = std::collections::HashMap::new();
207 let mut trace_sink = trace_sink;
209
210 let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
211
212 let mut skip_first_execution = false;
216 if restore_from.is_some() {
217 if let Some(node) = graph.nodes.get(¤t) {
218 if matches!(node, NodeKind::Barrier(_)) {
219 let span_id = SpanId::new();
220 let barrier_id = crate::event::BarrierId::new(¤t, 0);
221
222 let _ = event_tx
223 .send(GraphEvent::BarrierWaiting {
224 barrier_id: barrier_id.clone(),
225 node_name: current.clone(),
226 span_id,
227 })
228 .await;
229
230 let barrier_timeout = if let NodeKind::Barrier(bn) = node {
232 bn.timeout
233 } else {
234 None
235 };
236
237 let outcome = wait_for_barrier_decision(
238 &mut decision_rx,
239 &mut cancel_rx,
240 &cancel,
241 &barrier_id,
242 barrier_timeout,
243 &mut wildcard_cache,
244 )
245 .await;
246
247 match outcome {
248 BarrierOutcome::Decision(d) => {
249 let _ = event_tx
250 .send(GraphEvent::BarrierResolved {
251 barrier_id,
252 decision: d.clone(),
253 })
254 .await;
255
256 let reroute_target =
257 apply_barrier_decision_generic(engine.state_mut(), ¤t, &d);
258
259 if let Some(target) = reroute_target {
260 current = target;
261 }
262 skip_first_execution = true;
264 }
265 BarrierOutcome::TimedOut => {
266 let _ = event_tx
267 .send(GraphEvent::BarrierResolved {
268 barrier_id,
269 decision: BarrierDecision::Reject {
270 reason: "timeout on restore".into(),
271 },
272 })
273 .await;
274 apply_barrier_decision_generic(
275 engine.state_mut(),
276 ¤t,
277 &BarrierDecision::Reject {
278 reason: "timeout on restore".into(),
279 },
280 );
281 skip_first_execution = true;
282 }
283 BarrierOutcome::Cancelled => {
284 let _ = event_tx
285 .send(GraphEvent::GraphError {
286 error: GraphError::Terminal(
287 crate::error::TerminalError::BarrierCancelled {
288 node: current.clone(),
289 },
290 ),
291 state: engine.state().clone(),
292 })
293 .await;
294 return;
295 }
296 }
297 }
298 }
299 }
300
301 loop {
302 if step == 0 && skip_first_execution {
304 match graph.resolve_next_inline(¤t, engine.state()) {
306 Ok(target) => current = target,
307 Err(_) => {
308 if current == graph.end_node() {
310 send_complete(
311 &event_tx,
312 trace_id,
313 engine.state(),
314 execution_log,
315 start_time,
316 trace_sink.take(),
317 );
318 break;
319 }
320 }
323 }
324 step += 1;
325 continue;
326 }
327 if cancel.is_cancelled() {
328 let _ = event_tx
329 .send(GraphEvent::GraphError {
330 error: GraphError::Terminal(crate::error::TerminalError::BarrierCancelled {
331 node: "execution cancelled".into(),
332 }),
333 state: engine.state().clone(),
334 })
335 .await;
336 break;
337 }
338
339 step += 1;
340 if step > max_steps {
341 let _ = event_tx
342 .send(GraphEvent::GraphError {
343 error: GraphError::Terminal(crate::error::TerminalError::StepsExceeded {
344 limit: max_steps,
345 }),
346 state: engine.state().clone(),
347 })
348 .await;
349 break;
350 }
351
352 let node = match graph.nodes.get(¤t) {
353 Some(n) => n,
354 None => {
355 let _ = event_tx
356 .send(GraphEvent::GraphError {
357 error: GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
358 current.clone(),
359 )),
360 state: engine.state().clone(),
361 })
362 .await;
363 break;
364 }
365 };
366
367 let node_name = current.clone();
368 let span_id = SpanId::new();
369 let node_start = Instant::now();
370
371 let _ = event_tx
372 .send(GraphEvent::NodeStart {
373 node_name: node_name.clone(),
374 trace_id,
375 span_id,
376 step,
377 })
378 .await;
379
380 let is_barrier = matches!(node, NodeKind::Barrier(_));
382 let node_ok = match node {
383 NodeKind::Task(n) => {
384 let mut ctx = engine.build_node_context();
385 n.execute(&mut ctx).await.is_ok()
386 }
387 NodeKind::Condition(n) => {
388 let mut ctx = engine.build_leaf_context();
389 <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx)
390 .await
391 .is_ok()
392 }
393 NodeKind::Barrier(n) => {
394 let mut ctx = engine.build_leaf_context();
395 <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx)
396 .await
397 .is_ok()
398 }
399 NodeKind::External(n) => {
400 let mut ctx = engine.build_node_context();
401 n.execute(&mut ctx).await.is_ok()
402 }
403 NodeKind::ExternalLeaf(n) => {
404 let mut ctx = engine.build_leaf_context();
405 n.execute(&mut ctx).await.is_ok()
406 }
407 NodeKind::Parallel(p) => p.execute(&mut engine).await.is_ok(),
408 NodeKind::Subgraph(subgraph) => {
409 let stream = engine.stream_sink();
411 let cancel = engine.cancel_token().clone();
412 subgraph
413 .execute(engine.state_mut(), stream, cancel)
414 .await
415 .is_ok()
416 }
417 };
418
419 if !node_ok {
420 let _ = event_tx
421 .send(GraphEvent::NodeEnd {
422 node_name: node_name.clone(),
423 trace_id,
424 span_id,
425 success: false,
426 duration: node_start.elapsed(),
427 })
428 .await;
429
430 let _ = event_tx
431 .send(GraphEvent::GraphError {
432 error: GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
433 node: node_name,
434 source: "node execution failed".into(),
435 }),
436 state: engine.state().clone(),
437 })
438 .await;
439 break;
440 }
441
442 let flow_events = engine.take_flow_events();
444 for fe in flow_events {
445 let _ = event_tx
446 .send(GraphEvent::Node {
447 span_id,
448 node_name: node_name.clone(),
449 event: fe,
450 })
451 .await;
452 }
453
454 let commit_batch = engine.take_commit_batch();
457 let has_mutations = !commit_batch.is_empty();
458
459 if let Some(ref mut sink) = trace_sink {
461 if !commit_batch.is_empty() {
462 sink.record_step(TraceStep {
463 step,
464 node_id: NodeId(node_name.clone()),
465 mutations: commit_batch.clone(),
466 });
467 }
468 }
469
470 engine.apply_batch_to_state(commit_batch);
471
472 let node_duration = node_start.elapsed();
473 execution_log.push(ExecutionEntry {
474 step,
475 node_name: node_name.clone(),
476 start_time,
477 end_time: start_time.checked_add(node_duration).unwrap_or(start_time),
478 success: true,
479 error: None,
480 });
481
482 let _ = event_tx
483 .send(GraphEvent::NodeEnd {
484 node_name: node_name.clone(),
485 trace_id,
486 span_id,
487 success: true,
488 duration: node_duration,
489 })
490 .await;
491
492 let (next_action, signal) = engine.take_control();
494
495 let mut next_action = next_action;
497
498 if let Some(ExecutionSignal::Pause {
499 barrier_id,
500 timeout,
501 }) = signal
502 {
503 let _ = event_tx
504 .send(GraphEvent::BarrierWaiting {
505 barrier_id: barrier_id.clone(),
506 node_name: node_name.clone(),
507 span_id,
508 })
509 .await;
510
511 let outcome = wait_for_barrier_decision(
512 &mut decision_rx,
513 &mut cancel_rx,
514 &cancel,
515 &barrier_id,
516 timeout,
517 &mut wildcard_cache,
518 )
519 .await;
520
521 match outcome {
522 BarrierOutcome::Decision(d) => {
523 let _ = event_tx
524 .send(GraphEvent::BarrierResolved {
525 barrier_id,
526 decision: d.clone(),
527 })
528 .await;
529
530 let reroute_target =
531 apply_barrier_decision_generic(engine.state_mut(), &node_name, &d);
532
533 if let Some(target) = reroute_target {
534 current = target;
535 continue;
536 }
537 next_action = NextAction::Next;
538 }
539 BarrierOutcome::TimedOut => {
540 let _ = event_tx
542 .send(GraphEvent::BarrierResolved {
543 barrier_id,
544 decision: BarrierDecision::Reject {
545 reason: "timeout".into(),
546 },
547 })
548 .await;
549 apply_barrier_decision_generic(
550 engine.state_mut(),
551 &node_name,
552 &BarrierDecision::Reject {
553 reason: "timeout".into(),
554 },
555 );
556 next_action = NextAction::Next;
557 }
558 BarrierOutcome::Cancelled => {
559 let _ = event_tx
560 .send(GraphEvent::GraphError {
561 error: GraphError::Terminal(
562 crate::error::TerminalError::BarrierCancelled {
563 node: node_name.clone(),
564 },
565 ),
566 state: engine.state().clone(),
567 })
568 .await;
569 break;
570 }
571 }
572 }
573
574 match next_action {
576 NextAction::End => {
577 send_complete(
578 &event_tx,
579 trace_id,
580 engine.state(),
581 execution_log,
582 start_time,
583 trace_sink.take(),
584 );
585 break;
586 }
587 NextAction::Goto(target) => {
588 current = target;
589 }
590 NextAction::Next => {
591 if current == graph.end_node() {
592 send_complete(
593 &event_tx,
594 trace_id,
595 engine.state(),
596 execution_log,
597 start_time,
598 trace_sink.take(),
599 );
600 break;
601 }
602 match graph.resolve_next_inline(¤t, engine.state()) {
603 Ok(target) => current = target,
604 Err(e) => {
605 let _ = event_tx
606 .send(GraphEvent::GraphError {
607 error: e,
608 state: engine.state().clone(),
609 })
610 .await;
611 break;
612 }
613 }
614 }
615 }
616
617 if let Some(ref cp_config) = checkpoint {
619 if cp_config.should_save(has_mutations, is_barrier) {
620 let cp = Checkpoint::new(¤t, engine.state(), cp_config.graph_hash);
622 match (cp_config.save_fn)(cp, trace_id).await {
623 Ok(()) => {
624 if let Err(e) = cp_config.apply_retention(&trace_id).await {
626 tracing::warn!(error = %e, "checkpoint retention failed");
627 }
628 }
629 Err(e) => {
630 tracing::warn!(error = %e, "checkpoint save failed");
631 }
632 }
633 }
634 }
635 }
636}
637
638pub(crate) fn send_complete<S: WorkflowState>(
640 event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
641 trace_id: TraceId,
642 final_state: &S,
643 execution_log: Vec<ExecutionEntry>,
644 start_time: Instant,
645 trace_sink: Option<MemoryTraceSink<S::Mutation>>,
646) {
647 let duration = start_time.elapsed();
648 let trace = trace_sink.map(|sink| sink.into_trace());
649 let result = GraphResult {
650 trace_id,
651 state: final_state.clone(),
652 execution_log,
653 duration,
654 trace,
655 };
656 let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
657}