1use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use serde::Serialize;
12use tokio_util::sync::CancellationToken;
13
14use crate::checkpoint::{Checkpoint, CheckpointSink, FrameInfo, TraceId};
15use crate::event::{BarrierDecisionMessage, GraphEvent};
16use crate::exec::execution_engine::ExecutionEngine;
17use crate::graph::{Graph, StepCallback};
18use crate::node::barrier_sink::ChannelBarrierSink;
19use crate::state::workflow_state::WorkflowState;
20use crate::state::{ExecutionEntry, GraphResult};
21
22pub struct CheckpointConfig<S: WorkflowState> {
26 pub trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
28 pub retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
30 save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
32 graph_hash: u64,
34 store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
36}
37
38impl<S: WorkflowState> CheckpointConfig<S> {
39 pub fn new(
40 save_fn: impl Fn(
41 Checkpoint<S>,
42 TraceId,
43 ) -> std::pin::Pin<
44 Box<
45 dyn std::future::Future<
46 Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47 > + Send,
48 >,
49 > + Send
50 + Sync
51 + 'static,
52 graph_hash: u64,
53 ) -> Self {
54 Self {
55 save_fn: Arc::new(Box::new(save_fn)),
56 trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy::default(),
57 retention: crate::checkpoint::checkpoint_policy::RetentionPolicy::default(),
58 graph_hash,
59 store: None,
60 }
61 }
62
63 pub fn with_trigger(
64 mut self,
65 trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
66 ) -> Self {
67 self.trigger = trigger;
68 self
69 }
70
71 pub fn with_retention(
72 mut self,
73 retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
74 ) -> Self {
75 self.retention = retention;
76 self
77 }
78
79 pub fn with_store(
80 mut self,
81 store: Arc<dyn crate::checkpoint::store::BlobCheckpointStore>,
82 ) -> Self {
83 self.store = Some(store);
84 self
85 }
86
87 #[allow(deprecated)]
88 pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
89 self.trigger = policy.into();
90 self
91 }
92
93 pub async fn apply_retention(
94 &self,
95 trace_id: &TraceId,
96 ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
97 if let Some(keep) = self.retention.prune_keep() {
98 if let Some(ref store) = self.store {
99 let pruned = store.prune(trace_id, keep).await?;
100 if pruned > 0 {
101 tracing::debug!(pruned, keep, "checkpoint pruned");
102 }
103 }
104 }
105 Ok(())
106 }
107}
108
109pub struct CheckpointSaveSink<S: WorkflowState> {
113 save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
114 graph_hash: u64,
115 trace_id: TraceId,
116 retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
117 store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
118}
119
120impl<S: WorkflowState> CheckpointSaveSink<S> {
121 pub fn new(config: CheckpointConfig<S>, trace_id: TraceId) -> Self {
122 Self {
123 save_fn: config.save_fn,
124 graph_hash: config.graph_hash,
125 trace_id,
126 retention: config.retention,
127 store: config.store,
128 }
129 }
130}
131
132impl<S: WorkflowState + 'static> CheckpointSink<S> for CheckpointSaveSink<S> {
133 fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
134 let save_fn = self.save_fn.clone();
135 let graph_hash = self.graph_hash;
136 let trace_id = self.trace_id;
137 let retention = self.retention.clone();
138 let store = self.store.clone();
139 let cp = Checkpoint::new(frame.node_id.clone(), state, graph_hash);
140
141 tokio::spawn(async move {
142 match save_fn(cp, trace_id).await {
143 Ok(()) => {
144 if let Some(keep) = retention.prune_keep() {
145 if let Some(ref s) = store {
146 if let Err(e) = s.prune(&trace_id, keep).await {
147 tracing::warn!(error = %e, "checkpoint retention failed");
148 }
149 }
150 }
151 }
152 Err(e) => {
153 tracing::warn!(error = %e, "checkpoint save failed");
154 }
155 }
156 });
157 }
158}
159
160struct EventStepCallback {
164 start_time: Instant,
165 execution_log: Vec<ExecutionEntry>,
166}
167
168impl EventStepCallback {
169 fn new(start_time: Instant) -> Self {
170 Self {
171 start_time,
172 execution_log: Vec::new(),
173 }
174 }
175
176 fn into_log(self) -> Vec<ExecutionEntry> {
177 self.execution_log
178 }
179}
180
181impl StepCallback<'_> for EventStepCallback {
182 fn on_step(&mut self, node_name: &str, step: usize, duration: Duration) {
183 let node_end = self
184 .start_time
185 .checked_add(duration)
186 .unwrap_or(self.start_time);
187 self.execution_log.push(ExecutionEntry {
188 step,
189 node_name: node_name.to_string(),
190 start_time: self.start_time,
191 end_time: node_end,
192 success: true,
193 error: None,
194 });
195 }
196}
197
198pub(crate) async fn run_execution_loop<S, M>(
213 graph: Arc<Graph<S, M>>,
214 state: S,
215 max_steps: usize,
216 trace_id: TraceId,
217 event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
218 decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
219 cancel_rx: tokio::sync::mpsc::Receiver<()>,
220 cancel: CancellationToken,
221 checkpoint: Option<CheckpointConfig<S>>,
222 _trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
223 restore_from: Option<Checkpoint<S>>,
224) where
225 S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
226 S::Mutation: Clone + Send + Sync,
227 M: crate::state::workflow_state::MergeStrategy<S>,
228{
229 let start_time = Instant::now();
230
231 let restore_state = restore_from.as_ref().map(|cp| S::restore(cp.state.clone()));
233 let mut engine_state = restore_state.unwrap_or(state);
234
235 let mut barrier_sink = ChannelBarrierSink::new(decision_rx, cancel_rx, cancel.clone());
237
238 let mut cp_sink: Option<CheckpointSaveSink<S>> =
240 checkpoint.map(|cfg| CheckpointSaveSink::new(cfg, trace_id));
241
242 let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
244
245 let mut step_cb = EventStepCallback::new(start_time);
247
248 let result = {
250 let mut engine = ExecutionEngine::new(
251 &mut engine_state,
252 None,
253 cancel.clone(),
254 cp_sink.as_mut().map(|s| s as &mut dyn CheckpointSink<S>),
255 Some(&mut barrier_sink),
256 );
257 graph.run_inline(&mut engine, max_steps, &mut step_cb).await
258 };
259
260 let final_state = engine_state;
262 let execution_log = step_cb.into_log();
263
264 match result {
265 Ok(()) => {
266 let duration = start_time.elapsed();
267 let graph_result = GraphResult {
268 trace_id,
269 state: final_state,
270 execution_log,
271 duration,
272 trace: None,
273 };
274 let _ = event_tx.try_send(GraphEvent::GraphComplete {
275 result: graph_result,
276 });
277 }
278 Err(error) => {
279 let _ = event_tx
280 .send(GraphEvent::GraphError {
281 error,
282 state: final_state,
283 })
284 .await;
285 }
286 }
287}
288
289#[allow(dead_code)]
295pub(crate) fn send_complete<S: WorkflowState>(
296 event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
297 trace_id: TraceId,
298 final_state: &S,
299 execution_log: Vec<ExecutionEntry>,
300 start_time: Instant,
301 trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
302) {
303 let duration = start_time.elapsed();
304 let trace = trace_sink.map(|sink| sink.into_trace());
305 let result = GraphResult {
306 trace_id,
307 state: final_state.clone(),
308 execution_log,
309 duration,
310 trace,
311 };
312 let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
313}