1use dagx::{DagRunner, Pending, Task, TaskBuilder};
33use serde::{de::DeserializeOwned, Serialize};
34use std::sync::atomic::AtomicUsize;
35use std::sync::Arc;
36
37mod store;
38pub use store::{FileStore, MemoryStore, Storage};
39
40#[derive(Debug)]
42pub enum DuraflowError {
43 Io(std::io::Error),
44 Serialize(serde_json::Error),
45 Persist { key: String, source: std::io::Error },
46 Other(String),
47}
48
49impl std::fmt::Display for DuraflowError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 DuraflowError::Io(e) => write!(f, "io error: {}", e),
53 DuraflowError::Serialize(e) => write!(f, "serialize error: {}", e),
54 DuraflowError::Persist { key, source } => {
55 write!(f, "persist failed for {}: {}", key, source)
56 }
57 DuraflowError::Other(s) => write!(f, "{}", s),
58 }
59 }
60}
61
62impl std::error::Error for DuraflowError {}
63
64impl From<std::io::Error> for DuraflowError {
65 fn from(e: std::io::Error) -> Self {
66 DuraflowError::Io(e)
67 }
68}
69
70impl From<serde_json::Error> for DuraflowError {
71 fn from(e: serde_json::Error) -> Self {
72 DuraflowError::Serialize(e)
73 }
74}
75
76pub struct Context {
78 pub db: Arc<dyn Storage + Send + Sync>,
79 pub completed_count: Arc<AtomicUsize>,
80}
81
82impl Context {
83 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
85 self.db
86 .get_raw(key)
87 .and_then(|s| serde_json::from_str(&s).ok())
88 }
89
90 pub fn save<T: Serialize>(&self, key: &str, value: &T) -> std::io::Result<()> {
93 let s = serde_json::to_string(value)
94 .map_err(|e| std::io::Error::other(format!("serialize error: {}", e)))?;
95 self.db.save_raw(key, &s)
96 }
97}
98
99type ProgressCb = Arc<dyn Fn(&str, usize) + Send + Sync + 'static>;
101
102fn try_cached_and_mark<O: DeserializeOwned>(
104 ctx: &Context,
105 id: &str,
106 cb: &Option<ProgressCb>,
107) -> Option<O> {
108 if let Some(v) = ctx.get::<O>(id) {
109 let completed = ctx
110 .completed_count
111 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
112 + 1;
113 if let Some(cb) = cb {
114 cb(id, completed);
115 }
116 return Some(v);
117 }
118 None
119}
120
121fn persist_and_mark<O: Serialize>(
123 ctx: &Context,
124 id: &str,
125 value: &O,
126 cb: &Option<ProgressCb>,
127) -> Result<(), DuraflowError> {
128 ctx.save(id, value).map_err(|e| DuraflowError::Persist {
129 key: id.to_string(),
130 source: e,
131 })?;
132 let completed = ctx
133 .completed_count
134 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
135 + 1;
136 if let Some(cb) = cb {
137 cb(id, completed);
138 }
139 Ok(())
140}
141
142pub struct Durable<Tk> {
148 id: String,
149 ctx: Arc<Context>,
150 inner: Tk,
151 progress_handler: Option<ProgressCb>,
152}
153
154impl<Tk> Durable<Tk> {
155 pub fn new(
156 id: &str,
157 ctx: Arc<Context>,
158 inner: Tk,
159 progress_handler: Option<ProgressCb>,
160 ) -> Self {
161 Self {
162 id: id.to_string(),
163 ctx,
164 inner,
165 progress_handler,
166 }
167 }
168
169 pub async fn run_result(self, input: Tk::Input) -> Result<Tk::Output, DuraflowError>
171 where
172 Tk: Task + Send + 'static,
173 Tk::Input: Send + Clone,
174 Tk::Output: Serialize + DeserializeOwned + Send + Sync + 'static,
175 {
176 if let Some(cached) =
178 try_cached_and_mark::<Tk::Output>(&self.ctx, &self.id, &self.progress_handler)
179 {
180 return Ok(cached);
181 }
182
183 let result = self.inner.run(input).await;
185
186 persist_and_mark(&self.ctx, &self.id, &result, &self.progress_handler)?;
188
189 Ok(result)
190 }
191}
192
193impl<Tk> Task for Durable<Tk>
194where
195 Tk: Task + Send + 'static,
196 Tk::Input: Send + Clone,
197 Tk::Output: Serialize + DeserializeOwned + Send + Sync + 'static,
198{
199 type Input = Tk::Input;
200 type Output = Tk::Output;
201
202 #[allow(clippy::manual_async_fn)]
203 fn run(self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
204 async move {
205 if let Some(cached) =
207 try_cached_and_mark::<Self::Output>(&self.ctx, &self.id, &self.progress_handler)
208 {
209 return cached;
210 }
211
212 let result = self.inner.run(input).await;
214
215 if let Err(e) = persist_and_mark(&self.ctx, &self.id, &result, &self.progress_handler) {
217 eprintln!("persist failed for {}: {}", self.id, e);
218 }
219
220 result
221 }
222 }
223
224 fn extract_and_run(
225 self,
226 receivers: Vec<Box<dyn std::any::Any + Send>>,
227 ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
228 let id = self.id.clone();
229 let ctx = self.ctx.clone();
230 let progress = self.progress_handler.clone();
231 let inner = self.inner;
232
233 async move {
234 if let Some(cached) = try_cached_and_mark::<Self::Output>(&ctx, &id, &progress) {
236 return Ok(cached);
237 }
238
239 let result = inner.extract_and_run(receivers).await?;
241
242 if let Err(e) = persist_and_mark(&ctx, &id, &result, &progress) {
244 return Err(e.to_string());
245 }
246
247 Ok(result)
248 }
249 }
250}
251
252pub struct DurableDag<'a> {
257 pub dag: &'a DagRunner,
258 pub ctx: Arc<Context>,
259 pub progress_handler: Option<ProgressCb>,
260}
261
262impl<'a> DurableDag<'a> {
263 pub fn new(dag: &'a DagRunner, ctx: Arc<Context>) -> Self {
264 Self {
265 dag,
266 ctx,
267 progress_handler: None,
268 }
269 }
270
271 pub fn with_progress<F>(mut self, handler: F) -> Self
273 where
274 F: Fn(&str, usize) + Send + Sync + 'static,
275 {
276 self.progress_handler = Some(Arc::new(handler));
277 self
278 }
279
280 pub fn add<Tk>(&self, id: &str, task: Tk) -> TaskBuilder<'_, Durable<Tk>, Pending>
283 where
284 Tk: Task + Sync + 'static,
285 Tk::Input: Clone + 'static,
286 Tk::Output: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
287 {
288 let durable_wrapper =
289 Durable::new(id, self.ctx.clone(), task, self.progress_handler.clone());
290 self.dag.add_task(durable_wrapper)
291 }
292}