dag_executor/tasks/
stateful.rs1use crate::context::Context;
4use crate::error::TaskError;
5use crate::storage::Storage;
6use crate::tasks::r#trait::{Task, TaskOutput};
7use async_trait::async_trait;
8use futures::future::BoxFuture;
9use futures::FutureExt;
10use std::future::Future;
11use std::sync::Arc;
12
13type StepFn = Box<
14 dyn Fn(Arc<Context>, serde_json::Value) -> BoxFuture<'static, Result<StepResult, TaskError>>
15 + Send
16 + Sync,
17>;
18
19#[derive(Debug, Clone)]
21pub struct StepResult {
22 pub state: serde_json::Value,
24 pub output: TaskOutput,
26}
27
28impl StepResult {
29 pub fn new(state: serde_json::Value, output: TaskOutput) -> Self {
31 StepResult { state, output }
32 }
33}
34
35pub struct StatefulTask {
42 id: String,
43 deps: Vec<String>,
44 priority: u8,
45 storage: Arc<dyn Storage>,
46 step: StepFn,
47}
48
49impl StatefulTask {
50 pub fn new<F, Fut>(id: impl Into<String>, storage: Arc<dyn Storage>, step: F) -> Self
52 where
53 F: Fn(Arc<Context>, serde_json::Value) -> Fut + Send + Sync + 'static,
54 Fut: Future<Output = Result<StepResult, TaskError>> + Send + 'static,
55 {
56 StatefulTask {
57 id: id.into(),
58 deps: Vec::new(),
59 priority: 0,
60 storage,
61 step: Box::new(move |ctx, state| step(ctx, state).boxed()),
62 }
63 }
64
65 pub fn with_deps<I, S>(mut self, deps: I) -> Self
67 where
68 I: IntoIterator<Item = S>,
69 S: Into<String>,
70 {
71 self.deps = deps.into_iter().map(Into::into).collect();
72 self
73 }
74
75 pub fn with_priority(mut self, priority: u8) -> Self {
77 self.priority = priority;
78 self
79 }
80
81 fn state_key(&self) -> String {
82 format!("state:{}", self.id)
83 }
84}
85
86#[async_trait]
87impl Task for StatefulTask {
88 fn id(&self) -> &str {
89 &self.id
90 }
91
92 fn dependencies(&self) -> Vec<String> {
93 self.deps.clone()
94 }
95
96 fn priority(&self) -> u8 {
97 self.priority
98 }
99
100 async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
101 let key = self.state_key();
102 let prev = self
103 .storage
104 .load(&key)
105 .await
106 .map_err(|e| TaskError::execution(format!("loading checkpoint: {e}")))?
107 .unwrap_or(serde_json::Value::Null);
108
109 let result = (self.step)(ctx, prev).await?;
110
111 self.storage
112 .save(&key, &result.state)
113 .await
114 .map_err(|e| TaskError::execution(format!("saving checkpoint: {e}")))?;
115
116 Ok(result.output)
117 }
118}