use std::sync::Arc;
use std::time::{Duration, Instant};
use serde::Serialize;
use tokio_util::sync::CancellationToken;
use crate::checkpoint::{Checkpoint, CheckpointSink, FrameInfo, TraceId};
use crate::event::{BarrierDecisionMessage, GraphEvent};
use crate::exec::execution_engine::ExecutionEngine;
use crate::graph::{Graph, StepCallback};
use crate::node::barrier_sink::ChannelBarrierSink;
use crate::state::workflow_state::WorkflowState;
use crate::state::{ExecutionEntry, GraphResult};
pub struct CheckpointConfig<S: WorkflowState> {
pub trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
pub retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
graph_hash: u64,
store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
}
impl<S: WorkflowState> CheckpointConfig<S> {
pub fn new(
save_fn: impl Fn(
Checkpoint<S>,
TraceId,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<(), crate::checkpoint::CheckpointStoreError>,
> + Send,
>,
> + Send
+ Sync
+ 'static,
graph_hash: u64,
) -> Self {
Self {
save_fn: Arc::new(Box::new(save_fn)),
trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy::default(),
retention: crate::checkpoint::checkpoint_policy::RetentionPolicy::default(),
graph_hash,
store: None,
}
}
pub fn with_trigger(
mut self,
trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
) -> Self {
self.trigger = trigger;
self
}
pub fn with_retention(
mut self,
retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
) -> Self {
self.retention = retention;
self
}
pub fn with_store(
mut self,
store: Arc<dyn crate::checkpoint::store::BlobCheckpointStore>,
) -> Self {
self.store = Some(store);
self
}
#[allow(deprecated)]
pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
self.trigger = policy.into();
self
}
pub async fn apply_retention(
&self,
trace_id: &TraceId,
) -> Result<(), crate::checkpoint::CheckpointStoreError> {
if let Some(keep) = self.retention.prune_keep() {
if let Some(ref store) = self.store {
let pruned = store.prune(trace_id, keep).await?;
if pruned > 0 {
tracing::debug!(pruned, keep, "checkpoint pruned");
}
}
}
Ok(())
}
}
pub struct CheckpointSaveSink<S: WorkflowState> {
save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
graph_hash: u64,
trace_id: TraceId,
retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
}
impl<S: WorkflowState> CheckpointSaveSink<S> {
pub fn new(config: CheckpointConfig<S>, trace_id: TraceId) -> Self {
Self {
save_fn: config.save_fn,
graph_hash: config.graph_hash,
trace_id,
retention: config.retention,
store: config.store,
}
}
}
impl<S: WorkflowState + 'static> CheckpointSink<S> for CheckpointSaveSink<S> {
fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
let save_fn = self.save_fn.clone();
let graph_hash = self.graph_hash;
let trace_id = self.trace_id;
let retention = self.retention.clone();
let store = self.store.clone();
let cp = Checkpoint::new(frame.node_id.clone(), state, graph_hash);
tokio::spawn(async move {
match save_fn(cp, trace_id).await {
Ok(()) => {
if let Some(keep) = retention.prune_keep() {
if let Some(ref s) = store {
if let Err(e) = s.prune(&trace_id, keep).await {
tracing::warn!(error = %e, "checkpoint retention failed");
}
}
}
}
Err(e) => {
tracing::warn!(error = %e, "checkpoint save failed");
}
}
});
}
}
struct EventStepCallback {
start_time: Instant,
execution_log: Vec<ExecutionEntry>,
}
impl EventStepCallback {
fn new(start_time: Instant) -> Self {
Self {
start_time,
execution_log: Vec::new(),
}
}
fn into_log(self) -> Vec<ExecutionEntry> {
self.execution_log
}
}
impl StepCallback<'_> for EventStepCallback {
fn on_step(&mut self, node_name: &str, step: usize, duration: Duration) {
let node_end = self
.start_time
.checked_add(duration)
.unwrap_or(self.start_time);
self.execution_log.push(ExecutionEntry {
step,
node_name: node_name.to_string(),
start_time: self.start_time,
end_time: node_end,
success: true,
error: None,
});
}
}
pub(crate) async fn run_execution_loop<S, M>(
graph: Arc<Graph<S, M>>,
state: S,
max_steps: usize,
trace_id: TraceId,
event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
cancel_rx: tokio::sync::mpsc::Receiver<()>,
cancel: CancellationToken,
checkpoint: Option<CheckpointConfig<S>>,
_trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
restore_from: Option<Checkpoint<S>>,
) where
S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
S::Mutation: Clone + Send + Sync,
M: crate::state::workflow_state::MergeStrategy<S>,
{
let start_time = Instant::now();
let restore_state = restore_from.as_ref().map(|cp| S::restore(cp.state.clone()));
let mut engine_state = restore_state.unwrap_or(state);
let mut barrier_sink = ChannelBarrierSink::new(decision_rx, cancel_rx, cancel.clone());
let mut cp_sink: Option<CheckpointSaveSink<S>> =
checkpoint.map(|cfg| CheckpointSaveSink::new(cfg, trace_id));
let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
let mut step_cb = EventStepCallback::new(start_time);
let result = {
let mut engine = ExecutionEngine::new(
&mut engine_state,
None,
cancel.clone(),
cp_sink.as_mut().map(|s| s as &mut dyn CheckpointSink<S>),
Some(&mut barrier_sink),
);
graph.run_inline(&mut engine, max_steps, &mut step_cb).await
};
let final_state = engine_state;
let execution_log = step_cb.into_log();
match result {
Ok(()) => {
let duration = start_time.elapsed();
let graph_result = GraphResult {
trace_id,
state: final_state,
execution_log,
duration,
trace: None,
};
let _ = event_tx.try_send(GraphEvent::GraphComplete {
result: graph_result,
});
}
Err(error) => {
let _ = event_tx
.send(GraphEvent::GraphError {
error,
state: final_state,
})
.await;
}
}
}
#[allow(dead_code)]
pub(crate) fn send_complete<S: WorkflowState>(
event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
trace_id: TraceId,
final_state: &S,
execution_log: Vec<ExecutionEntry>,
start_time: Instant,
trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
) {
let duration = start_time.elapsed();
let trace = trace_sink.map(|sink| sink.into_trace());
let result = GraphResult {
trace_id,
state: final_state.clone(),
execution_log,
duration,
trace,
};
let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
}