Skip to main content

duraflow_rs/
lib.rs

1//! duraflow-rs — durable, resumable dag tasks built on top of `dagx`
2//!
3//! High level features:
4//! - Durable task decorator that persists task outputs to a `Storage` backend
5//! - `run_result` API for callers to observe persistence errors without changing `Task::run`
6//! - Built-in `MemoryStore` and `FileStore` backends
7//!
8//! Example (simple):
9//!
10//! ```no_run
11//! #[tokio::main]
12//! async fn main() {
13//!     use duraflow_rs::{DurableDag, Context, MemoryStore};
14//!     use dagx::{DagRunner, task, Task};
15//!     use std::sync::{Arc, atomic::AtomicUsize};
16//!
17//!     // define a task
18//!     struct Load(i32);
19//!     #[task]
20//!     impl Load { async fn run(&self) -> i32 { self.0 } }
21//!
22//!     let dag = DagRunner::new();
23//!     let db = Arc::new(MemoryStore::new());
24//!     let ctx = Arc::new(Context { db, completed_count: Arc::new(AtomicUsize::new(0)) });
25//!     let d = DurableDag::new(&dag, ctx.clone());
26//!     let a = d.add("v1", Load(5));
27//!     dag.run(|f| { tokio::spawn(f); }).await.unwrap();
28//!     assert_eq!(dag.get(a).unwrap(), 5);
29//! }
30//! ```
31
32use dagx::{DagRunner, Pending, Task, TaskBuilder};
33use serde::{de::DeserializeOwned, Serialize};
34use std::sync::atomic::AtomicUsize;
35use std::sync::Arc;
36
37mod store;
38pub use store::{FileStore, MemoryStore, Storage};
39
40/// Typed error for duraflow-rs operations
41#[derive(Debug)]
42pub enum DuraflowError {
43    Io(std::io::Error),
44    Serialize(serde_json::Error),
45    Persist { key: String, source: std::io::Error },
46    Other(String),
47}
48
49impl std::fmt::Display for DuraflowError {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            DuraflowError::Io(e) => write!(f, "io error: {}", e),
53            DuraflowError::Serialize(e) => write!(f, "serialize error: {}", e),
54            DuraflowError::Persist { key, source } => {
55                write!(f, "persist failed for {}: {}", key, source)
56            }
57            DuraflowError::Other(s) => write!(f, "{}", s),
58        }
59    }
60}
61
62impl std::error::Error for DuraflowError {}
63
64impl From<std::io::Error> for DuraflowError {
65    fn from(e: std::io::Error) -> Self {
66        DuraflowError::Io(e)
67    }
68}
69
70impl From<serde_json::Error> for DuraflowError {
71    fn from(e: serde_json::Error) -> Self {
72        DuraflowError::Serialize(e)
73    }
74}
75
76/// Shared context for progress and durability
77pub struct Context {
78    pub db: Arc<dyn Storage + Send + Sync>,
79    pub completed_count: Arc<AtomicUsize>,
80}
81
82impl Context {
83    /// Typed helper: deserialize stored JSON into T
84    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
85        self.db
86            .get_raw(key)
87            .and_then(|s| serde_json::from_str(&s).ok())
88    }
89
90    /// Typed helper: serialize value and persist as JSON
91    /// Returns an io::Error if serialization or storage fails.
92    pub fn save<T: Serialize>(&self, key: &str, value: &T) -> std::io::Result<()> {
93        let s = serde_json::to_string(value)
94            .map_err(|e| std::io::Error::other(format!("serialize error: {}", e)))?;
95        self.db.save_raw(key, &s)
96    }
97}
98
99// Internal progress callback shorthand
100type ProgressCb = Arc<dyn Fn(&str, usize) + Send + Sync + 'static>;
101
102// Helper: check cache and mark completion if present
103fn try_cached_and_mark<O: DeserializeOwned>(
104    ctx: &Context,
105    id: &str,
106    cb: &Option<ProgressCb>,
107) -> Option<O> {
108    if let Some(v) = ctx.get::<O>(id) {
109        let completed = ctx
110            .completed_count
111            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
112            + 1;
113        if let Some(cb) = cb {
114            cb(id, completed);
115        }
116        return Some(v);
117    }
118    None
119}
120
121// Helper: persist value and mark completion
122fn persist_and_mark<O: Serialize>(
123    ctx: &Context,
124    id: &str,
125    value: &O,
126    cb: &Option<ProgressCb>,
127) -> Result<(), DuraflowError> {
128    ctx.save(id, value).map_err(|e| DuraflowError::Persist {
129        key: id.to_string(),
130        source: e,
131    })?;
132    let completed = ctx
133        .completed_count
134        .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
135        + 1;
136    if let Some(cb) = cb {
137        cb(id, completed);
138    }
139    Ok(())
140}
141
142// -----------------------------------------------------------------------------
143// Durable decorator
144// -----------------------------------------------------------------------------
145
146/// The "Durable" decorator wraps any Task implementation and persists outputs.
147pub struct Durable<Tk> {
148    id: String,
149    ctx: Arc<Context>,
150    inner: Tk,
151    progress_handler: Option<ProgressCb>,
152}
153
154impl<Tk> Durable<Tk> {
155    pub fn new(
156        id: &str,
157        ctx: Arc<Context>,
158        inner: Tk,
159        progress_handler: Option<ProgressCb>,
160    ) -> Self {
161        Self {
162            id: id.to_string(),
163            ctx,
164            inner,
165            progress_handler,
166        }
167    }
168
169    /// Run the task but return a Result so callers can observe persistence errors.
170    pub async fn run_result(self, input: Tk::Input) -> Result<Tk::Output, DuraflowError>
171    where
172        Tk: Task + Send + 'static,
173        Tk::Input: Send + Clone,
174        Tk::Output: Serialize + DeserializeOwned + Send + Sync + 'static,
175    {
176        // 1. Cached path
177        if let Some(cached) =
178            try_cached_and_mark::<Tk::Output>(&self.ctx, &self.id, &self.progress_handler)
179        {
180            return Ok(cached);
181        }
182
183        // 2. Run inner
184        let result = self.inner.run(input).await;
185
186        // 3. Persist — propagate typed error to caller
187        persist_and_mark(&self.ctx, &self.id, &result, &self.progress_handler)?;
188
189        Ok(result)
190    }
191}
192
193impl<Tk> Task for Durable<Tk>
194where
195    Tk: Task + Send + 'static,
196    Tk::Input: Send + Clone,
197    Tk::Output: Serialize + DeserializeOwned + Send + Sync + 'static,
198{
199    type Input = Tk::Input;
200    type Output = Tk::Output;
201
202    #[allow(clippy::manual_async_fn)]
203    fn run(self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
204        async move {
205            // 1. Check cache (DRY via helper)
206            if let Some(cached) =
207                try_cached_and_mark::<Self::Output>(&self.ctx, &self.id, &self.progress_handler)
208            {
209                return cached;
210            }
211
212            // 2. Run inner task
213            let result = self.inner.run(input).await;
214
215            // 3. Persist + progress — do not panic, log on failure
216            if let Err(e) = persist_and_mark(&self.ctx, &self.id, &result, &self.progress_handler) {
217                eprintln!("persist failed for {}: {}", self.id, e);
218            }
219
220            result
221        }
222    }
223
224    fn extract_and_run(
225        self,
226        receivers: Vec<Box<dyn std::any::Any + Send>>,
227    ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
228        let id = self.id.clone();
229        let ctx = self.ctx.clone();
230        let progress = self.progress_handler.clone();
231        let inner = self.inner;
232
233        async move {
234            // 1. Cached path (DRY via helper)
235            if let Some(cached) = try_cached_and_mark::<Self::Output>(&ctx, &id, &progress) {
236                return Ok(cached);
237            }
238
239            // 2. Delegate to inner extraction
240            let result = inner.extract_and_run(receivers).await?;
241
242            // 3. Persist and update — propagate storage error to caller (map to string)
243            if let Err(e) = persist_and_mark(&ctx, &id, &result, &progress) {
244                return Err(e.to_string());
245            }
246
247            Ok(result)
248        }
249    }
250}
251
252// -----------------------------------------------------------------------------
253// DurableDag builder wrapper
254// -----------------------------------------------------------------------------
255
256pub struct DurableDag<'a> {
257    pub dag: &'a DagRunner,
258    pub ctx: Arc<Context>,
259    pub progress_handler: Option<ProgressCb>,
260}
261
262impl<'a> DurableDag<'a> {
263    pub fn new(dag: &'a DagRunner, ctx: Arc<Context>) -> Self {
264        Self {
265            dag,
266            ctx,
267            progress_handler: None,
268        }
269    }
270
271    /// Attach a progress handler that will be called with (task_id, completed_count)
272    pub fn with_progress<F>(mut self, handler: F) -> Self
273    where
274        F: Fn(&str, usize) + Send + Sync + 'static,
275    {
276        self.progress_handler = Some(Arc::new(handler));
277        self
278    }
279
280    /// Adds a durable task and returns a standard TaskBuilder.
281    /// This allows you to chain .depends_on() just like normal dagx.
282    pub fn add<Tk>(&self, id: &str, task: Tk) -> TaskBuilder<'_, Durable<Tk>, Pending>
283    where
284        Tk: Task + Sync + 'static,
285        Tk::Input: Clone + 'static,
286        Tk::Output: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
287    {
288        let durable_wrapper =
289            Durable::new(id, self.ctx.clone(), task, self.progress_handler.clone());
290        self.dag.add_task(durable_wrapper)
291    }
292}