1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use floxide_core::{ActionType, DefaultAction, FloxideError, Node, NodeId, NodeOutcome};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug)]
17pub enum LongRunningOutcome<T, S> {
18 Complete(T),
20 Suspend(S),
22}
23
24#[async_trait]
29pub trait LongRunningNode<Context, Action>: Send + Sync
30where
31 Context: Send + Sync + 'static,
32 Action: ActionType + Send + Sync + 'static + Debug,
33 Self::State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
34 Self::Output: Send + 'static,
35{
36 type State;
38
39 type Output;
41
42 async fn process(
44 &self,
45 state: Option<Self::State>,
46 ctx: &mut Context,
47 ) -> Result<LongRunningOutcome<Self::Output, Self::State>, FloxideError>;
48
49 fn id(&self) -> NodeId;
51}
52
53pub trait LongRunningActionExt: ActionType {
55 fn suspend() -> Self;
57
58 fn resume() -> Self;
60
61 fn complete() -> Self;
63
64 fn is_suspend(&self) -> bool;
66
67 fn is_resume(&self) -> bool;
69
70 fn is_complete(&self) -> bool;
72}
73
74impl LongRunningActionExt for DefaultAction {
76 fn suspend() -> Self {
77 DefaultAction::Custom("suspend".to_string())
78 }
79
80 fn resume() -> Self {
81 DefaultAction::Custom("resume".to_string())
82 }
83
84 fn complete() -> Self {
85 DefaultAction::Custom("complete".to_string())
86 }
87
88 fn is_suspend(&self) -> bool {
89 matches!(self, DefaultAction::Custom(s) if s == "suspend")
90 }
91
92 fn is_resume(&self) -> bool {
93 matches!(self, DefaultAction::Custom(s) if s == "resume")
94 }
95
96 fn is_complete(&self) -> bool {
97 matches!(self, DefaultAction::Custom(s) if s == "complete")
98 }
99}
100
101pub struct SimpleLongRunningNode<Context, Action, State, Output, F>
103where
104 Context: Send + Sync + 'static,
105 Action: ActionType + Send + Sync + 'static + Debug,
106 State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
107 Output: Send + Sync + 'static,
108 F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
109 + Send
110 + Sync
111 + 'static,
112{
113 id: NodeId,
114 process_fn: Arc<F>,
115 _phantom: std::marker::PhantomData<(Context, Action, State, Output)>,
116}
117
118impl<Context, Action, State, Output, F> SimpleLongRunningNode<Context, Action, State, Output, F>
119where
120 Context: Send + Sync + 'static,
121 Action: ActionType + Send + Sync + 'static + Debug,
122 State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
123 Output: Send + Sync + 'static,
124 F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
125 + Send
126 + Sync
127 + 'static,
128{
129 pub fn new(process_fn: F) -> Self {
131 Self {
132 id: NodeId::new(),
133 process_fn: Arc::new(process_fn),
134 _phantom: std::marker::PhantomData,
135 }
136 }
137
138 pub fn with_id(id: impl Into<NodeId>, process_fn: F) -> Self {
140 Self {
141 id: id.into(),
142 process_fn: Arc::new(process_fn),
143 _phantom: std::marker::PhantomData,
144 }
145 }
146}
147
148#[async_trait]
149impl<Context, Action, State, Output, F> LongRunningNode<Context, Action>
150 for SimpleLongRunningNode<Context, Action, State, Output, F>
151where
152 Context: Send + Sync + 'static,
153 Action: ActionType + Send + Sync + 'static + Debug,
154 State: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
155 Output: Send + Sync + 'static,
156 F: Fn(Option<State>, &mut Context) -> Result<LongRunningOutcome<Output, State>, FloxideError>
157 + Send
158 + Sync
159 + 'static,
160{
161 type State = State;
162 type Output = Output;
163
164 async fn process(
165 &self,
166 state: Option<Self::State>,
167 ctx: &mut Context,
168 ) -> Result<LongRunningOutcome<Self::Output, Self::State>, FloxideError> {
169 (self.process_fn)(state, ctx)
170 }
171
172 fn id(&self) -> NodeId {
173 self.id.clone()
174 }
175}
176
177pub struct LongRunningNodeAdapter<LRN, Context, Action, S>
180where
181 LRN: LongRunningNode<Context, Action>,
182 Context: Send + Sync + 'static,
183 Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
184 S: StateStore,
185{
186 node: Arc<LRN>,
187 state_store: Arc<S>,
188 suspend_timeout: Option<Duration>,
189 _phantom: std::marker::PhantomData<(Context, Action)>,
190}
191
192impl<LRN, Context, Action, S> LongRunningNodeAdapter<LRN, Context, Action, S>
193where
194 LRN: LongRunningNode<Context, Action>,
195 Context: Send + Sync + 'static,
196 Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
197 S: StateStore,
198{
199 pub fn new(node: LRN, state_store: S) -> Self {
201 Self {
202 node: Arc::new(node),
203 state_store: Arc::new(state_store),
204 suspend_timeout: None,
205 _phantom: std::marker::PhantomData,
206 }
207 }
208
209 pub fn with_suspend_timeout(mut self, timeout: Duration) -> Self {
211 self.suspend_timeout = Some(timeout);
212 self
213 }
214}
215
216#[async_trait]
217impl<LRN, Context, Action, S> Node<Context, Action>
218 for LongRunningNodeAdapter<LRN, Context, Action, S>
219where
220 LRN: LongRunningNode<Context, Action>,
221 Context: Send + Sync + 'static,
222 Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
223 S: StateStore,
224 LRN::Output: Send + Sync + 'static, {
226 type Output = Action;
227
228 async fn process(
229 &self,
230 ctx: &mut Context,
231 ) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
232 let node_id = self.node.id();
234 let state = match self
235 .state_store
236 .get_state::<LRN::State>(node_id.clone())
237 .await
238 {
239 Ok(state) => state,
240 Err(e) => return Err(e),
241 };
242
243 match self.node.process(state, ctx).await {
245 Ok(LongRunningOutcome::Complete(_output)) => {
246 self.state_store.remove_state(node_id).await?;
248
249 Ok(NodeOutcome::RouteToAction(Action::complete()))
251 }
252 Ok(LongRunningOutcome::Suspend(state)) => {
253 self.state_store.save_state(node_id, &state).await?;
255
256 Ok(NodeOutcome::RouteToAction(Action::suspend()))
258 }
259 Err(e) => Err(e),
260 }
261 }
262
263 fn id(&self) -> NodeId {
264 self.node.id()
265 }
266}
267
268pub struct LongRunningWorkflow<Context, Action, S>
271where
272 Context: Send + Sync + 'static,
273 Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
274 S: StateStore,
275{
276 state_store: Arc<S>,
277 _phantom: std::marker::PhantomData<(Context, Action)>,
278}
279
280impl<Context, Action, S> LongRunningWorkflow<Context, Action, S>
281where
282 Context: Send + Sync + 'static,
283 Action: ActionType + LongRunningActionExt + Send + Sync + 'static + Debug,
284 S: StateStore,
285{
286 pub fn new(state_store: S) -> Self {
288 Self {
289 state_store: Arc::new(state_store),
290 _phantom: std::marker::PhantomData,
291 }
292 }
293
294 pub async fn execute_node<LRN>(
296 &self,
297 node: &LRN,
298 ctx: &mut Context,
299 ) -> Result<NodeOutcome<Action, Action>, FloxideError>
300 where
301 LRN: LongRunningNode<Context, Action>,
302 {
303 let node_id = node.id();
304 let state = self
305 .state_store
306 .get_state::<LRN::State>(node_id.clone())
307 .await?;
308
309 match node.process(state, ctx).await? {
310 LongRunningOutcome::Complete(_output) => {
311 self.state_store.remove_state(node_id).await?;
313 Ok(NodeOutcome::Success(Action::complete()))
317 }
318 LongRunningOutcome::Suspend(state) => {
319 self.state_store.save_state(node_id, &state).await?;
321
322 Ok(NodeOutcome::Success(Action::suspend()))
324 }
325 }
326 }
327
328 pub async fn has_suspended_state(&self, node_id: NodeId) -> Result<bool, FloxideError> {
330 self.state_store.has_state(node_id).await
331 }
332
333 pub async fn get_suspended_nodes(&self) -> Result<Vec<NodeId>, FloxideError> {
335 self.state_store.get_all_node_ids().await
336 }
337}
338
339#[async_trait]
341pub trait StateStore: Send + Sync + 'static {
342 async fn save_state<T: Serialize + Send + Sync + 'static>(
344 &self,
345 node_id: NodeId,
346 state: &T,
347 ) -> Result<(), FloxideError>;
348
349 async fn get_state<T: for<'de> Deserialize<'de> + Send + Sync + 'static>(
351 &self,
352 node_id: NodeId,
353 ) -> Result<Option<T>, FloxideError>;
354
355 async fn has_state(&self, node_id: NodeId) -> Result<bool, FloxideError>;
357
358 async fn remove_state(&self, node_id: NodeId) -> Result<(), FloxideError>;
360
361 async fn get_all_node_ids(&self) -> Result<Vec<NodeId>, FloxideError>;
363}
364
365pub struct InMemoryStateStore {
367 states: Arc<Mutex<HashMap<NodeId, Vec<u8>>>>,
368}
369
370impl InMemoryStateStore {
371 pub fn new() -> Self {
373 Self {
374 states: Arc::new(Mutex::new(HashMap::new())),
375 }
376 }
377}
378
379impl Default for InMemoryStateStore {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[async_trait]
386impl StateStore for InMemoryStateStore {
387 async fn save_state<T: Serialize + Send + Sync + 'static>(
388 &self,
389 node_id: NodeId,
390 state: &T,
391 ) -> Result<(), FloxideError> {
392 let serialized = serde_json::to_vec(state)
393 .map_err(|e| FloxideError::SerializationError(e.to_string()))?;
394
395 let mut states = self
396 .states
397 .lock()
398 .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
399
400 states.insert(node_id, serialized);
401 Ok(())
402 }
403
404 async fn get_state<T: for<'de> Deserialize<'de> + Send + Sync + 'static>(
405 &self,
406 node_id: NodeId,
407 ) -> Result<Option<T>, FloxideError> {
408 let states = self
409 .states
410 .lock()
411 .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
412
413 match states.get(&node_id) {
414 Some(serialized) => {
415 let state = serde_json::from_slice(serialized)
416 .map_err(|e| FloxideError::DeserializationError(e.to_string()))?;
417 Ok(Some(state))
418 }
419 None => Ok(None),
420 }
421 }
422
423 async fn has_state(&self, node_id: NodeId) -> Result<bool, FloxideError> {
424 let states = self
425 .states
426 .lock()
427 .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
428
429 Ok(states.contains_key(&node_id))
430 }
431
432 async fn remove_state(&self, node_id: NodeId) -> Result<(), FloxideError> {
433 let mut states = self
434 .states
435 .lock()
436 .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
437
438 states.remove(&node_id);
439 Ok(())
440 }
441
442 async fn get_all_node_ids(&self) -> Result<Vec<NodeId>, FloxideError> {
443 let states = self
444 .states
445 .lock()
446 .map_err(|e| FloxideError::Other(format!("Failed to acquire mutex lock: {}", e)))?;
447
448 Ok(states.keys().cloned().collect())
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[tokio::test]
457 async fn test_simple_long_running_node() {
458 let node = SimpleLongRunningNode::<_, DefaultAction, _, _, _>::new(
460 |state: Option<u32>, _ctx: &mut ()| {
461 let current_step = state.unwrap_or(0);
462 let next_step = current_step + 1;
463
464 if next_step >= 3 {
465 Ok(LongRunningOutcome::Complete("done"))
467 } else {
468 Ok(LongRunningOutcome::Suspend(next_step))
470 }
471 },
472 );
473
474 let state_store = InMemoryStateStore::new();
475 let workflow = LongRunningWorkflow::new(state_store);
476
477 let mut ctx = ();
478
479 let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
481 assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_suspend()));
482
483 assert!(workflow.has_suspended_state(node.id()).await.unwrap());
485
486 let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
488 assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_suspend()));
489
490 let outcome = workflow.execute_node(&node, &mut ctx).await.unwrap();
492 assert!(matches!(outcome, NodeOutcome::Success(action) if action.is_complete()));
493
494 assert!(!workflow.has_suspended_state(node.id()).await.unwrap());
496 }
497}