1use std::sync::Arc;
14use std::time::Instant;
15
16use serde::Serialize;
17use tokio_util::sync::CancellationToken;
18
19use crate::barrier_wait::{
20 BarrierOutcome, apply_barrier_decision_generic, wait_for_barrier_decision,
21};
22use crate::checkpoint::{Checkpoint, NodeId, TraceId};
23use crate::checkpoint_policy::{RetentionPolicy, TriggerPolicy};
24use crate::error::GraphError;
25use crate::event::{BarrierDecision, BarrierDecisionMessage, GraphEvent};
26use crate::execution_engine::{ExecutionEngine, ExecutionSignal, ExecutorState, NextAction};
27use crate::graph::Graph;
28use crate::ids::SpanId;
29use crate::node::{BarrierNode, ConditionNode, ExecutorOperation, FlowNode, LeafNode, NodeKind};
30use crate::state::{ExecutionEntry, GraphResult};
31use crate::trace::{MemoryTraceSink, TraceSink, TraceStep};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34type CheckpointSaveFn<S> = Box<
40 dyn Fn(
41 Checkpoint<S>,
42 TraceId,
43 ) -> std::pin::Pin<
44 Box<
45 dyn std::future::Future<
46 Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47 > + Send,
48 >,
49 > + Send
50 + Sync,
51>;
52
53pub struct CheckpointConfig<S: WorkflowState> {
63 pub trigger: TriggerPolicy,
65 pub retention: RetentionPolicy,
67 save_fn: CheckpointSaveFn<S>,
69 graph_hash: u64,
71 store: Option<std::sync::Arc<dyn crate::store::BlobCheckpointStore>>,
73}
74
75impl<S: WorkflowState> CheckpointConfig<S> {
76 pub fn new(
81 save_fn: impl Fn(
82 Checkpoint<S>,
83 TraceId,
84 ) -> std::pin::Pin<
85 Box<
86 dyn std::future::Future<
87 Output = Result<(), crate::checkpoint::CheckpointStoreError>,
88 > + Send,
89 >,
90 > + Send
91 + Sync
92 + 'static,
93 graph_hash: u64,
94 ) -> Self {
95 Self {
96 save_fn: Box::new(save_fn),
97 trigger: TriggerPolicy::default(),
98 retention: RetentionPolicy::default(),
99 graph_hash,
100 store: None,
101 }
102 }
103
104 pub fn with_trigger(mut self, trigger: TriggerPolicy) -> Self {
106 self.trigger = trigger;
107 self
108 }
109
110 pub fn with_retention(mut self, retention: RetentionPolicy) -> Self {
112 self.retention = retention;
113 self
114 }
115
116 pub fn with_store(
118 mut self,
119 store: std::sync::Arc<dyn crate::store::BlobCheckpointStore>,
120 ) -> Self {
121 self.store = Some(store);
122 self
123 }
124
125 #[allow(deprecated)]
127 pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
128 self.trigger = policy.into();
129 self
130 }
131
132 pub fn should_save(&self, has_mutations: bool, is_barrier: bool) -> bool {
134 match self.trigger {
135 TriggerPolicy::EveryNode => true,
136 TriggerPolicy::BarrierOnly => is_barrier,
137 TriggerPolicy::Manual => false,
138 TriggerPolicy::OnMutation => has_mutations,
139 }
140 }
141
142 pub async fn apply_retention(
144 &self,
145 trace_id: &TraceId,
146 ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
147 if let Some(keep) = self.retention.prune_keep() {
148 if let Some(ref store) = self.store {
149 let pruned = store.prune(trace_id, keep).await?;
150 if pruned > 0 {
151 tracing::debug!(pruned, keep, "checkpoint pruned");
152 }
153 }
154 }
155 Ok(())
156 }
157}
158
159pub(crate) async fn run_execution_loop<S, M>(
175 graph: Arc<Graph<S, M>>,
176 state: S,
177 max_steps: usize,
178 trace_id: TraceId,
179 event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
180 mut decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
181 mut cancel_rx: tokio::sync::mpsc::Receiver<()>,
182 cancel: CancellationToken,
183 checkpoint: Option<CheckpointConfig<S>>,
184 trace_sink: Option<MemoryTraceSink<S::Mutation>>,
185 restore_from: Option<Checkpoint<S>>,
186) where
187 S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
188 S::Mutation: Clone + Send + Sync,
189 M: MergeStrategy<S>,
190{
191 let start_time = Instant::now();
192 let mut execution_log: Vec<ExecutionEntry> = Vec::new();
193
194 let restore_state = restore_from.as_ref().map(|cp| cp.state.clone());
196 let engine_state = restore_state.unwrap_or(state);
197 let mut engine = ExecutionEngine::new(engine_state, None, cancel.clone());
198 let mut current = if let Some(ref cp) = restore_from {
199 cp.current_node.0.clone()
200 } else {
201 graph.start_node().to_string()
202 };
203 let mut step: usize = 0;
204 let mut wildcard_cache = std::collections::HashMap::new();
206 let mut trace_sink = trace_sink;
208
209 let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
210
211 let mut skip_first_execution = false;
215 if restore_from.is_some() {
216 if let Some(node) = graph.nodes.get(¤t) {
217 if matches!(node, NodeKind::Barrier(_)) {
218 let span_id = SpanId::new();
219 let barrier_id = crate::event::BarrierId::new(¤t, 0);
220
221 let _ = event_tx
222 .send(GraphEvent::BarrierWaiting {
223 barrier_id: barrier_id.clone(),
224 node_name: current.clone(),
225 span_id,
226 })
227 .await;
228
229 let barrier_timeout = if let NodeKind::Barrier(bn) = node {
231 bn.timeout
232 } else {
233 None
234 };
235
236 let outcome = wait_for_barrier_decision(
237 &mut decision_rx,
238 &mut cancel_rx,
239 &cancel,
240 &barrier_id,
241 barrier_timeout,
242 &mut wildcard_cache,
243 )
244 .await;
245
246 match outcome {
247 BarrierOutcome::Decision(d) => {
248 let _ = event_tx
249 .send(GraphEvent::BarrierResolved {
250 barrier_id,
251 decision: d.clone(),
252 })
253 .await;
254
255 let reroute_target =
256 apply_barrier_decision_generic(engine.state_mut(), ¤t, &d);
257
258 if let Some(target) = reroute_target {
259 current = target;
260 }
261 skip_first_execution = true;
263 }
264 BarrierOutcome::TimedOut => {
265 let _ = event_tx
266 .send(GraphEvent::BarrierResolved {
267 barrier_id,
268 decision: BarrierDecision::Reject {
269 reason: "timeout on restore".into(),
270 },
271 })
272 .await;
273 apply_barrier_decision_generic(
274 engine.state_mut(),
275 ¤t,
276 &BarrierDecision::Reject {
277 reason: "timeout on restore".into(),
278 },
279 );
280 skip_first_execution = true;
281 }
282 BarrierOutcome::Cancelled => {
283 let _ = event_tx
284 .send(GraphEvent::GraphError {
285 error: GraphError::Terminal(
286 crate::error::TerminalError::BarrierCancelled {
287 node: current.clone(),
288 },
289 ),
290 state: engine.state().clone(),
291 })
292 .await;
293 return;
294 }
295 }
296 }
297 }
298 }
299
300 loop {
301 if step == 0 && skip_first_execution {
303 match graph.resolve_next_inline(¤t, engine.state()) {
305 Ok(target) => current = target,
306 Err(_) => {
307 if current == graph.end_node() {
309 send_complete(
310 &event_tx,
311 trace_id,
312 engine,
313 execution_log,
314 start_time,
315 trace_sink.take(),
316 );
317 break;
318 }
319 }
322 }
323 step += 1;
324 continue;
325 }
326 if cancel.is_cancelled() {
327 let _ = event_tx
328 .send(GraphEvent::GraphError {
329 error: GraphError::Terminal(crate::error::TerminalError::BarrierCancelled {
330 node: "execution cancelled".into(),
331 }),
332 state: engine.state().clone(),
333 })
334 .await;
335 break;
336 }
337
338 step += 1;
339 if step > max_steps {
340 let _ = event_tx
341 .send(GraphEvent::GraphError {
342 error: GraphError::Terminal(crate::error::TerminalError::StepsExceeded {
343 limit: max_steps,
344 }),
345 state: engine.state().clone(),
346 })
347 .await;
348 break;
349 }
350
351 let node = match graph.nodes.get(¤t) {
352 Some(n) => n,
353 None => {
354 let _ = event_tx
355 .send(GraphEvent::GraphError {
356 error: GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
357 current.clone(),
358 )),
359 state: engine.state().clone(),
360 })
361 .await;
362 break;
363 }
364 };
365
366 let node_name = current.clone();
367 let span_id = SpanId::new();
368 let node_start = Instant::now();
369
370 let _ = event_tx
371 .send(GraphEvent::NodeStart {
372 node_name: node_name.clone(),
373 trace_id,
374 span_id,
375 step,
376 })
377 .await;
378
379 let is_barrier = matches!(node, NodeKind::Barrier(_));
381 let node_ok = match node {
382 NodeKind::Task(n) => {
383 let mut ctx = engine.build_node_context();
384 n.execute(&mut ctx).await.is_ok()
385 }
386 NodeKind::Condition(n) => {
387 let mut ctx = engine.build_leaf_context();
388 <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx)
389 .await
390 .is_ok()
391 }
392 NodeKind::Barrier(n) => {
393 let mut ctx = engine.build_leaf_context();
394 <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx)
395 .await
396 .is_ok()
397 }
398 NodeKind::External(n) => {
399 let mut ctx = engine.build_node_context();
400 n.execute(&mut ctx).await.is_ok()
401 }
402 NodeKind::ExternalLeaf(n) => {
403 let mut ctx = engine.build_leaf_context();
404 n.execute(&mut ctx).await.is_ok()
405 }
406 NodeKind::Parallel(p) => p.execute(&mut engine).await.is_ok(),
407 };
408
409 if !node_ok {
410 let _ = event_tx
411 .send(GraphEvent::NodeEnd {
412 node_name: node_name.clone(),
413 trace_id,
414 span_id,
415 success: false,
416 duration: node_start.elapsed(),
417 })
418 .await;
419
420 let _ = event_tx
421 .send(GraphEvent::GraphError {
422 error: GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
423 node: node_name,
424 source: "node execution failed".into(),
425 }),
426 state: engine.state().clone(),
427 })
428 .await;
429 break;
430 }
431
432 let flow_events = engine.take_flow_events();
434 for fe in flow_events {
435 let _ = event_tx
436 .send(GraphEvent::Node {
437 span_id,
438 node_name: node_name.clone(),
439 event: fe,
440 })
441 .await;
442 }
443
444 let commit_batch = engine.take_commit_batch();
447 let has_mutations = !commit_batch.is_empty();
448
449 if let Some(ref mut sink) = trace_sink {
451 if !commit_batch.is_empty() {
452 sink.record_step(TraceStep {
453 step,
454 node_id: NodeId(node_name.clone()),
455 mutations: commit_batch.clone(),
456 });
457 }
458 }
459
460 engine.apply_batch_to_state(commit_batch);
461
462 let node_duration = node_start.elapsed();
463 execution_log.push(ExecutionEntry {
464 step,
465 node_name: node_name.clone(),
466 start_time,
467 end_time: start_time.checked_add(node_duration).unwrap_or(start_time),
468 success: true,
469 error: None,
470 });
471
472 let _ = event_tx
473 .send(GraphEvent::NodeEnd {
474 node_name: node_name.clone(),
475 trace_id,
476 span_id,
477 success: true,
478 duration: node_duration,
479 })
480 .await;
481
482 let (next_action, signal) = engine.take_control();
484
485 let mut next_action = next_action;
487
488 if let Some(ExecutionSignal::Pause {
489 barrier_id,
490 timeout,
491 }) = signal
492 {
493 let _ = event_tx
494 .send(GraphEvent::BarrierWaiting {
495 barrier_id: barrier_id.clone(),
496 node_name: node_name.clone(),
497 span_id,
498 })
499 .await;
500
501 let outcome = wait_for_barrier_decision(
502 &mut decision_rx,
503 &mut cancel_rx,
504 &cancel,
505 &barrier_id,
506 timeout,
507 &mut wildcard_cache,
508 )
509 .await;
510
511 match outcome {
512 BarrierOutcome::Decision(d) => {
513 let _ = event_tx
514 .send(GraphEvent::BarrierResolved {
515 barrier_id,
516 decision: d.clone(),
517 })
518 .await;
519
520 let reroute_target =
521 apply_barrier_decision_generic(engine.state_mut(), &node_name, &d);
522
523 if let Some(target) = reroute_target {
524 current = target;
525 continue;
526 }
527 next_action = NextAction::Next;
528 }
529 BarrierOutcome::TimedOut => {
530 let _ = event_tx
532 .send(GraphEvent::BarrierResolved {
533 barrier_id,
534 decision: BarrierDecision::Reject {
535 reason: "timeout".into(),
536 },
537 })
538 .await;
539 apply_barrier_decision_generic(
540 engine.state_mut(),
541 &node_name,
542 &BarrierDecision::Reject {
543 reason: "timeout".into(),
544 },
545 );
546 next_action = NextAction::Next;
547 }
548 BarrierOutcome::Cancelled => {
549 let _ = event_tx
550 .send(GraphEvent::GraphError {
551 error: GraphError::Terminal(
552 crate::error::TerminalError::BarrierCancelled {
553 node: node_name.clone(),
554 },
555 ),
556 state: engine.state().clone(),
557 })
558 .await;
559 break;
560 }
561 }
562 }
563
564 match next_action {
566 NextAction::End => {
567 send_complete(
568 &event_tx,
569 trace_id,
570 engine,
571 execution_log,
572 start_time,
573 trace_sink.take(),
574 );
575 break;
576 }
577 NextAction::Goto(target) => {
578 current = target;
579 }
580 NextAction::Next => {
581 if current == graph.end_node() {
582 send_complete(
583 &event_tx,
584 trace_id,
585 engine,
586 execution_log,
587 start_time,
588 trace_sink.take(),
589 );
590 break;
591 }
592 match graph.resolve_next_inline(¤t, engine.state()) {
593 Ok(target) => current = target,
594 Err(e) => {
595 let _ = event_tx
596 .send(GraphEvent::GraphError {
597 error: e,
598 state: engine.state().clone(),
599 })
600 .await;
601 break;
602 }
603 }
604 }
605 }
606
607 if let Some(ref cp_config) = checkpoint {
609 if cp_config.should_save(has_mutations, is_barrier) {
610 let cp = Checkpoint::new(¤t, engine.state().clone(), cp_config.graph_hash);
611 match (cp_config.save_fn)(cp, trace_id).await {
612 Ok(()) => {
613 if let Err(e) = cp_config.apply_retention(&trace_id).await {
615 tracing::warn!(error = %e, "checkpoint retention failed");
616 }
617 }
618 Err(e) => {
619 tracing::warn!(error = %e, "checkpoint save failed");
620 }
621 }
622 }
623 }
624 }
625}
626
627pub(crate) fn send_complete<S: WorkflowState>(
629 event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
630 trace_id: TraceId,
631 engine: ExecutionEngine<S>,
632 execution_log: Vec<ExecutionEntry>,
633 start_time: Instant,
634 trace_sink: Option<MemoryTraceSink<S::Mutation>>,
635) {
636 let duration = start_time.elapsed();
637 let final_state = engine.into_state();
638 let trace = trace_sink.map(|sink| sink.into_trace());
639 let result = GraphResult {
640 trace_id,
641 state: final_state,
642 execution_log,
643 duration,
644 trace,
645 };
646 let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
647}