Skip to main content

dag_executor/tasks/
stateful.rs

1//! A task that persists checkpointed state across attempts and runs.
2
3use crate::context::Context;
4use crate::error::TaskError;
5use crate::storage::Storage;
6use crate::tasks::r#trait::{Task, TaskOutput};
7use async_trait::async_trait;
8use futures::future::BoxFuture;
9use futures::FutureExt;
10use std::future::Future;
11use std::sync::Arc;
12
13type StepFn = Box<
14    dyn Fn(Arc<Context>, serde_json::Value) -> BoxFuture<'static, Result<StepResult, TaskError>>
15        + Send
16        + Sync,
17>;
18
19/// The outcome of one [`StatefulTask`] step: a new checkpoint plus the output.
20#[derive(Debug, Clone)]
21pub struct StepResult {
22    /// State to persist for the next attempt/run.
23    pub state: serde_json::Value,
24    /// Output to report as the task result.
25    pub output: TaskOutput,
26}
27
28impl StepResult {
29    /// Build a step result.
30    pub fn new(state: serde_json::Value, output: TaskOutput) -> Self {
31        StepResult { state, output }
32    }
33}
34
35/// A task whose checkpoint is loaded before, and saved after, each execution.
36///
37/// The checkpoint is keyed by `state:{id}` in the provided [`Storage`], so a
38/// task that crashed partway through can resume from its last saved state
39/// instead of starting over. The step closure receives the previously saved
40/// state (or `null` on first run).
41pub struct StatefulTask {
42    id: String,
43    deps: Vec<String>,
44    priority: u8,
45    storage: Arc<dyn Storage>,
46    step: StepFn,
47}
48
49impl StatefulTask {
50    /// Create a stateful task backed by `storage`.
51    pub fn new<F, Fut>(id: impl Into<String>, storage: Arc<dyn Storage>, step: F) -> Self
52    where
53        F: Fn(Arc<Context>, serde_json::Value) -> Fut + Send + Sync + 'static,
54        Fut: Future<Output = Result<StepResult, TaskError>> + Send + 'static,
55    {
56        StatefulTask {
57            id: id.into(),
58            deps: Vec::new(),
59            priority: 0,
60            storage,
61            step: Box::new(move |ctx, state| step(ctx, state).boxed()),
62        }
63    }
64
65    /// Declare dependencies.
66    pub fn with_deps<I, S>(mut self, deps: I) -> Self
67    where
68        I: IntoIterator<Item = S>,
69        S: Into<String>,
70    {
71        self.deps = deps.into_iter().map(Into::into).collect();
72        self
73    }
74
75    /// Set scheduling priority.
76    pub fn with_priority(mut self, priority: u8) -> Self {
77        self.priority = priority;
78        self
79    }
80
81    fn state_key(&self) -> String {
82        format!("state:{}", self.id)
83    }
84}
85
86#[async_trait]
87impl Task for StatefulTask {
88    fn id(&self) -> &str {
89        &self.id
90    }
91
92    fn dependencies(&self) -> Vec<String> {
93        self.deps.clone()
94    }
95
96    fn priority(&self) -> u8 {
97        self.priority
98    }
99
100    async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
101        let key = self.state_key();
102        let prev = self
103            .storage
104            .load(&key)
105            .await
106            .map_err(|e| TaskError::execution(format!("loading checkpoint: {e}")))?
107            .unwrap_or(serde_json::Value::Null);
108
109        let result = (self.step)(ctx, prev).await?;
110
111        self.storage
112            .save(&key, &result.state)
113            .await
114            .map_err(|e| TaskError::execution(format!("saving checkpoint: {e}")))?;
115
116        Ok(result.output)
117    }
118}