Skip to main content

floxide_longrunning/
lib.rs

1//! Support for long-running processes with checkpoints in the Floxide framework.
2//!
3//! This crate provides the `LongRunningNode` trait and related implementations for
4//! handling long-running processes that need to be suspended and resumed over time.
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use floxide_core::{ActionType, DefaultAction, FloxideError, Node, NodeId, NodeOutcome};
13use serde::{Deserialize, Serialize};
14
15/// Represents the outcome of a long-running process.
16#[derive(Debug)]
17pub enum LongRunningOutcome<T, S> {
18    /// Processing is complete with result
19    Complete(T),
20    /// Processing needs to be suspended with saved state
21    Suspend(S),
22}
23
24/// Trait for nodes that handle long-running processes with checkpoints.
25///
26/// A `LongRunningNode` is capable of processing work incrementally, saving its state
27/// between executions, and resuming from the last checkpoint.
28#[async_trait]
29pub trait LongRunningNode<Context, Action>: Send + Sync
30where
31    Context: Send + Sync + 'static,
32    Action: ActionType + Send + Sync + 'static + Debug,
33    Self::State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
34    Self::Output: Send + 'static,
35{
36    /// Type representing the node's processing state
37    type State;
38
39    /// Type representing the final output
40    type Output;
41
42    /// Process the next step, potentially suspending execution
43    async fn process(
44        &self,
45        state: Option<Self::State>,
46        ctx: &mut Context,
47    ) -> Result<LongRunningOutcome<Self::Output, Self::State>, FloxideError>;
48
49    /// Get the node's unique identifier
50    fn id(&self) -> NodeId;
51}
52
53/// Extension trait for ActionType to support long-running specific actions
54pub trait LongRunningActionExt: ActionType {
55    /// Create a suspend action for long-running nodes
56    fn suspend() -> Self;
57
58    /// Create a resume action for long-running nodes
59    fn resume() -> Self;
60
61    /// Create a complete action for long-running nodes
62    fn complete() -> Self;
63
64    /// Check if this is a suspend action
65    fn is_suspend(&self) -> bool;
66
67    /// Check if this is a resume action
68    fn is_resume(&self) -> bool;
69
70    /// Check if this is a complete action
71    fn is_complete(&self) -> bool;
72}
73
74// Implement the extension trait for DefaultAction
75impl LongRunningActionExt for DefaultAction {
76    fn suspend() -> Self {
77        DefaultAction::Custom("suspend".to_string())
78    }
79
80    fn resume() -> Self {
81        DefaultAction::Custom("resume".to_string())
82    }
83
84    fn complete() -> Self {
85        DefaultAction::Custom("complete".to_string())
86    }
87
88    fn is_suspend(&self) -> bool {
89        matches!(self, DefaultAction::Custom(s) if s == "suspend")
90    }
91
92    fn is_resume(&self) -> bool {
93        matches!(self, DefaultAction::Custom(s) if s == "resume")
94    }
95
96    fn is_complete(&self) -> bool {
97        matches!(self, DefaultAction::Custom(s) if s == "complete")
98    }
99}
100
101/// A simple long-running node that uses a closure for processing.
102pub struct SimpleLongRunningNode<Context, Action, State, Output, F>
103where
104    Context: Send + Sync + 'static,
105    Action: ActionType + Send + Sync + 'static + Debug,
106    State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
107    Output: Send + Sync + 'static,
108    F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
109        + Send
110        + Sync
111        + 'static,
112{
113    id: NodeId,
114    process_fn: Arc<F>,
115    _phantom: std::marker::PhantomData<(Context, Action, State, Output)>,
116}
117
118impl<Context, Action, State, Output, F> SimpleLongRunningNode<Context, Action, State, Output, F>
119where
120    Context: Send + Sync + 'static,
121    Action: ActionType + Send + Sync + 'static + Debug,
122    State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
123    Output: Send + Sync + 'static,
124    F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
125        + Send
126        + Sync
127        + 'static,
128{
129    /// Create a new SimpleLongRunningNode with a unique ID
130    pub fn new(process_fn: F) -> Self {
131        Self {
132            id: NodeId::new(),
133            process_fn: Arc::new(process_fn),
134            _phantom: std::marker::PhantomData,
135        }
136    }
137
138    /// Create a new SimpleLongRunningNode with a specified ID
139    pub fn with_id(id: impl Into<NodeId>, process_fn: F) -> Self {
140        Self {
141            id: id.into(),
142            process_fn: Arc::new(process_fn),
143            _phantom: std::marker::PhantomData,
144        }
145    }
146}
147
148#[async_trait]
149impl<Context, Action, State, Output, F> LongRunningNode<Context, Action>
150    for SimpleLongRunningNode<Context, Action, State, Output, F>
151where
152    Context: Send + Sync + 'static,
153    Action: ActionType + Send + Sync + 'static + Debug,
154    State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
155    Output: Send + Sync + 'static,
156    F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
157        + Send
158        + Sync
159        + 'static,
160{
161    type State = State;
162    type Output = Output;
163
164    async fn process(
165        &self,
166        state: Option<Self::State>,
167        ctx: &mut Context,
168    ) -> Result<LongRunningOutcome<Self::Output, Self::State>, FloxideError> {
169        (self.process_fn)(state, ctx)
170    }
171
172    fn id(&self) -> NodeId {
173        self.id.clone()
174    }
175}
176
177/// An adapter that allows a LongRunningNode to be used as a standard Node.
178/// This adapter handles state management and action conversion.
179pub struct LongRunningNodeAdapter<LRN, Context, Action, S>
180where
181    LRN: LongRunningNode<Context, Action>,
182    Context: Send + Sync + 'static,
183    Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
184    S: StateStore,
185{
186    node: Arc<LRN>,
187    state_store: Arc<S>,
188    suspend_timeout: Option<Duration>,
189    _phantom: std::marker::PhantomData<(Context, Action)>,
190}
191
192impl<LRN, Context, Action, S> LongRunningNodeAdapter<LRN, Context, Action, S>
193where
194    LRN: LongRunningNode<Context, Action>,
195    Context: Send + Sync + 'static,
196    Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
197    S: StateStore,
198{
199    /// Create a new adapter for a long-running node
200    pub fn new(node: LRN, state_store: S) -> Self {
201        Self {
202            node: Arc::new(node),
203            state_store: Arc::new(state_store),
204            suspend_timeout: None,
205            _phantom: std::marker::PhantomData,
206        }
207    }
208
209    /// Set a timeout for suspended nodes
210    pub fn with_suspend_timeout(mut self, timeout: Duration) -> Self {
211        self.suspend_timeout = Some(timeout);
212        self
213    }
214}
215
216#[async_trait]
217impl<LRN, Context, Action, S> Node<Context, Action>
218    for LongRunningNodeAdapter<LRN, Context, Action, S>
219where
220    LRN: LongRunningNode<Context, Action>,
221    Context: Send + Sync + 'static,
222    Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
223    S: StateStore,
224    LRN::Output: Send + Sync + 'static, // Needed for Node trait bounds
225{
226    type Output = Action;
227
228    async fn process(
229        &self,
230        ctx: &mut Context,
231    ) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
232        // Get the node's state if it exists
233        let node_id = self.node.id();
234        let state = match self
235            .state_store
236            .get_state::<LRN::State>(node_id.clone())
237            .await
238        {
239            Ok(state) => state,
240            Err(e) => return Err(e),
241        };
242
243        // Process the next step
244        match self.node.process(state, ctx).await {
245            Ok(LongRunningOutcome::Complete(_output)) => {
246                // Processing is complete, clean up state
247                self.state_store.remove_state(node_id).await?;
248
249                // Return with complete action
250                Ok(NodeOutcome::RouteToAction(Action::complete()))
251            }
252            Ok(LongRunningOutcome::Suspend(state)) => {
253                // Save the state
254                self.state_store.save_state(node_id, &state).await?;
255
256                // Return with suspend action
257                Ok(NodeOutcome::RouteToAction(Action::suspend()))
258            }
259            Err(e) => Err(e),
260        }
261    }
262
263    fn id(&self) -> NodeId {
264        self.node.id()
265    }
266}
267
268/// A workflow for orchestrating long-running nodes.
269/// This manages the suspension and resumption of nodes.
270pub struct LongRunningWorkflow<Context, Action, S>
271where
272    Context: Send + Sync + 'static,
273    Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
274    S: StateStore,
275{
276    state_store: Arc<S>,
277    _phantom: std::marker::PhantomData<(Context, Action)>,
278}
279
280impl<Context, Action, S> LongRunningWorkflow<Context, Action, S>
281where
282    Context: Send + Sync + 'static,
283    Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
284    S: StateStore,
285{
286    /// Create a new long-running workflow
287    pub fn new(state_store: S) -> Self {
288        Self {
289            state_store: Arc::new(state_store),
290            _phantom: std::marker::PhantomData,
291        }
292    }
293
294    /// Execute a single long-running node until completion or suspension
295    pub async fn execute_node<LRN>(
296        &self,
297        node: &LRN,
298        ctx: &mut Context,
299    ) -> Result<NodeOutcome<Action, Action>, FloxideError>
300    where
301        LRN: LongRunningNode<Context, Action>,
302    {
303        let node_id = node.id();
304        let state = self
305            .state_store
306            .get_state::<LRN::State>(node_id.clone())
307            .await?;
308
309        match node.process(state, ctx).await? {
310            LongRunningOutcome::Complete(_output) => {
311                // Clean up state
312                self.state_store.remove_state(node_id).await?;
313                // Store output if needed
314
315                // Return complete action
316                Ok(NodeOutcome::Success(Action::complete()))
317            }
318            LongRunningOutcome::Suspend(state) => {
319                // Save state
320                self.state_store.save_state(node_id, &state).await?;
321
322                // Return suspend action
323                Ok(NodeOutcome::Success(Action::suspend()))
324            }
325        }
326    }
327
328    /// Check if a node has suspended state
329    pub async fn has_suspended_state(&self, node_id: NodeId) -> Result<bool, FloxideError> {
330        self.state_store.has_state(node_id).await
331    }
332
333    /// Get all node IDs with suspended state
334    pub async fn get_suspended_nodes(&self) -> Result<Vec<NodeId>, FloxideError> {
335        self.state_store.get_all_node_ids().await
336    }
337}
338
339/// Trait for storing and retrieving node states.
340#[async_trait]
341pub trait StateStore: Send + Sync + 'static {
342    /// Save a node's state
343    async fn save_state<T: Serialize + Send + Sync + 'static>(
344        &self,
345        node_id: NodeId,
346        state: &T,
347    ) -> Result<(), FloxideError>;
348
349    /// Get a node's state if it exists
350    async fn get_state<T: for<'de> Deserialize<'de> + Send + Sync + 'static>(
351        &self,
352        node_id: NodeId,
353    ) -> Result<Option<T>, FloxideError>;
354
355    /// Check if a node has saved state
356    async fn has_state(&self, node_id: NodeId) -> Result<bool, FloxideError>;
357
358    /// Remove a node's state
359    async fn remove_state(&self, node_id: NodeId) -> Result<(), FloxideError>;
360
361    /// Get all node IDs with saved states
362    async fn get_all_node_ids(&self) -> Result<Vec<NodeId>, FloxideError>;
363}
364
365/// A simple in-memory implementation of StateStore for testing.
366pub struct InMemoryStateStore {
367    states: Arc<Mutex<HashMap<NodeId, Vec<u8>>>>,
368}
369
370impl InMemoryStateStore {
371    /// Create a new in-memory state store
372    pub fn new() -> Self {
373        Self {
374            states: Arc::new(Mutex::new(HashMap::new())),
375        }
376    }
377}
378
379impl Default for InMemoryStateStore {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385#[async_trait]
386impl StateStore for InMemoryStateStore {
387    async fn save_state<T: Serialize + Send + Sync + 'static>(
388        &self,
389        node_id: NodeId,
390        state: &T,
391    ) -> Result<(), FloxideError> {
392        let serialized = serde_json::to_vec(state)
393            .map_err(|e| FloxideError::SerializationError(e.to_string()))?;
394
395        let mut states = self
396            .states
397            .lock()
398            .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
399
400        states.insert(node_id, serialized);
401        Ok(())
402    }
403
404    async fn get_state<T: for<'de> Deserialize<'de> + Send + Sync + 'static>(
405        &self,
406        node_id: NodeId,
407    ) -> Result<Option<T>, FloxideError> {
408        let states = self
409            .states
410            .lock()
411            .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
412
413        match states.get(&node_id) {
414            Some(serialized) => {
415                let state = serde_json::from_slice(serialized)
416                    .map_err(|e| FloxideError::DeserializationError(e.to_string()))?;
417                Ok(Some(state))
418            }
419            None => Ok(None),
420        }
421    }
422
423    async fn has_state(&self, node_id: NodeId) -> Result<bool, FloxideError> {
424        let states = self
425            .states
426            .lock()
427            .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
428
429        Ok(states.contains_key(&node_id))
430    }
431
432    async fn remove_state(&self, node_id: NodeId) -> Result<(), FloxideError> {
433        let mut states = self
434            .states
435            .lock()
436            .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
437
438        states.remove(&node_id);
439        Ok(())
440    }
441
442    async fn get_all_node_ids(&self) -> Result<Vec<NodeId>, FloxideError> {
443        let states = self
444            .states
445            .lock()
446            .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
447
448        Ok(states.keys().cloned().collect())
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[tokio::test]
457    async fn test_simple_long_running_node() {
458        // Create a node that completes after 3 steps
459        let node = SimpleLongRunningNode::<_, DefaultAction, _, _, _>::new(
460            |state: Option<u32>, _ctx: &mut ()| {
461                let current_step = state.unwrap_or(0);
462                let next_step = current_step + 1;
463
464                if next_step >= 3 {
465                    // Complete after 3 steps
466                    Ok(LongRunningOutcome::Complete("done"))
467                } else {
468                    // Otherwise suspend with updated state
469                    Ok(LongRunningOutcome::Suspend(next_step))
470                }
471            },
472        );
473
474        let state_store = InMemoryStateStore::new();
475        let workflow = LongRunningWorkflow::new(state_store);
476
477        let mut ctx = ();
478
479        // First execution - should suspend with state 1
480        let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
481        assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_suspend()));
482
483        // Verify state was saved
484        assert!(workflow.has_suspended_state(node.id()).await.unwrap());
485
486        // Second execution - should suspend with state 2
487        let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
488        assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_suspend()));
489
490        // Third execution - should complete
491        let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
492        assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_complete()));
493
494        // Verify state was removed
495        assert!(!workflow.has_suspended_state(node.id()).await.unwrap());
496    }
497}