pocketflow_core/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum FlowError {
8    #[error("Execution error: {0}")]
9    Execution(String),
10    #[error("Node not found")]
11    NodeNotFound,
12    #[error("Parameter error: {0}")]
13    ParameterError(String),
14}
15
16pub type Result<T> = std::result::Result<T, FlowError>;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Params {
20    inner: HashMap<String, serde_json::Value>,
21}
22
23impl Params {
24    pub fn new() -> Self {
25        Self {
26            inner: HashMap::new(),
27        }
28    }
29
30    pub fn insert<S: Into<String>, V: Into<serde_json::Value>>(&mut self, key: S, value: V) {
31        self.inner.insert(key.into(), value.into());
32    }
33
34    pub fn get<S: AsRef<str>>(&self, key: S) -> Option<&serde_json::Value> {
35        self.inner.get(key.as_ref())
36    }
37
38    pub fn merge(&mut self, other: &Params) {
39        self.inner.extend(other.inner.clone());
40    }
41
42    pub fn remove<S: AsRef<str>>(&mut self, key: S) -> Option<serde_json::Value> {
43        self.inner.remove(key.as_ref())
44    }
45}
46
47impl Default for Params {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53pub trait SharedData: Clone + Send + Sync {}
54
55type NodeFunc = Arc<dyn Fn(&mut (dyn std::any::Any + Send), &Params) -> Result<Option<String>> + Send + Sync>;
56
57#[derive(Clone)]
58pub struct Node {
59    name: String,
60    params: Params,
61    successors: HashMap<String, Node>,
62    func: NodeFunc,
63}
64
65impl Node {
66    pub fn new<F>(name: impl Into<String>, func: F) -> Self
67    where
68        F: Fn(&mut (dyn std::any::Any + Send), &Params) -> Result<Option<String>> + Send + Sync + 'static,
69    {
70        Self {
71            name: name.into(),
72            params: Params::new(),
73            successors: HashMap::new(),
74            func: Arc::new(func),
75        }
76    }
77
78    pub fn add_successor(&mut self, action: impl Into<String>, node: Node) -> &mut Self {
79        self.successors.insert(action.into(), node);
80        self
81    }
82
83    pub fn next(&mut self, node: Node) -> &mut Self {
84        self.add_successor("default", node)
85    }
86
87    pub fn set_params(&mut self, params: Params) {
88        self.params = params;
89    }
90
91    pub fn get_params(&self) -> &Params {
92        &self.params
93    }
94
95    pub fn get_successor(&self, action: &str) -> Option<&Node> {
96        self.successors.get(action)
97    }
98
99    pub fn has_successors(&self) -> bool {
100        !self.successors.is_empty()
101    }
102
103    pub fn run(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
104        if self.has_successors() {
105            eprintln!("Warning: Node won't run successors. Use Flow.");
106        }
107        
108        (self.func)(shared, &self.params)?;
109        Ok(())
110    }
111
112    pub fn run_recursive(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
113        let action = (self.func)(shared, &self.params)?;
114        
115        if let Some(next_node) = action
116            .as_ref()
117            .and_then(|a| self.successors.get(a))
118            .or_else(|| self.successors.get("default")) {
119            next_node.run_recursive(shared)?;
120        }
121        
122        Ok(())
123    }
124}
125
126pub struct Flow {
127    start_node: Option<Node>,
128    params: Params,
129}
130
131impl Flow {
132    pub fn new() -> Self {
133        Self {
134            start_node: None,
135            params: Params::new(),
136        }
137    }
138
139    pub fn start(mut self, node: Node) -> Self {
140        self.start_node = Some(node);
141        self
142    }
143
144    pub fn set_params(&mut self, params: Params) {
145        self.params = params;
146    }
147
148    pub fn run(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
149        if let Some(ref node) = self.start_node {
150            let mut node = node.clone();
151            node.set_params(self.params.clone());
152            node.run_recursive(shared)?;
153        }
154        Ok(())
155    }
156
157    pub fn run_with_params(&self, shared: &mut (dyn std::any::Any + Send), params: Params) -> Result<()> {
158        if let Some(ref node) = self.start_node {
159            let mut node = node.clone();
160            let mut merged_params = self.params.clone();
161            merged_params.merge(&params);
162            node.set_params(merged_params);
163            node.run_recursive(shared)?;
164        }
165        Ok(())
166    }
167}
168
169impl Default for Flow {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175pub struct BatchFlow {
176    start_node: Option<Node>,
177    params: Params,
178}
179
180impl BatchFlow {
181    pub fn new() -> Self {
182        Self {
183            start_node: None,
184            params: Params::new(),
185        }
186    }
187
188    pub fn start(mut self, node: Node) -> Self {
189        self.start_node = Some(node);
190        self
191    }
192
193    pub fn set_params(&mut self, params: Params) {
194        self.params = params;
195    }
196
197    pub fn run_batch(&self, shared: &mut (dyn std::any::Any + Send), batch_params: Vec<Params>) -> Result<Vec<()>> {
198        let mut results = Vec::with_capacity(batch_params.len());
199        for params in batch_params {
200            let flow = Flow {
201                start_node: self.start_node.clone(),
202                params: self.params.clone(),
203            };
204            flow.run_with_params(shared, params)?;
205            results.push(());
206        }
207        Ok(results)
208    }
209}
210
211impl Default for BatchFlow {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use std::sync::{Arc, Mutex};
221
222    #[derive(Default, Clone)]
223    struct TestShared {
224        pub counter: Arc<Mutex<i32>>,
225    }
226
227    #[test]
228    fn test_basic_node() {
229        let mut shared = TestShared::default();
230        let node = Node::new("test", |shared, _params| {
231            if let Some(shared) = shared.downcast_mut::<TestShared>() {
232                let mut counter = shared.counter.lock().unwrap();
233                *counter += 1;
234            }
235            Ok(None)
236        });
237        
238        node.run(&mut shared).unwrap();
239        
240        let counter = shared.counter.lock().unwrap();
241        assert_eq!(*counter, 1);
242    }
243
244    #[test]
245    fn test_flow_execution() {
246        let mut shared = TestShared::default();
247        
248        let node = Node::new("test", |shared, _params| {
249            if let Some(shared) = shared.downcast_mut::<TestShared>() {
250                let mut counter = shared.counter.lock().unwrap();
251                *counter += 1;
252            }
253            Ok(None)
254        });
255        
256        let flow = Flow::new().start(node);
257        flow.run(&mut shared).unwrap();
258        
259        let counter = shared.counter.lock().unwrap();
260        assert_eq!(*counter, 1);
261    }
262
263    #[test]
264    fn test_chained_flow() {
265        let mut shared = TestShared::default();
266        
267        let mut node1 = Node::new("node1", |shared, _params| {
268            if let Some(shared) = shared.downcast_mut::<TestShared>() {
269                let mut counter = shared.counter.lock().unwrap();
270                *counter += 1;
271            }
272            Ok(None)
273        });
274        
275        let node2 = Node::new("node2", |shared, _params| {
276            if let Some(shared) = shared.downcast_mut::<TestShared>() {
277                let mut counter = shared.counter.lock().unwrap();
278                *counter += 10;
279            }
280            Ok(None)
281        });
282        
283        node1.next(node2);
284        let flow = Flow::new().start(node1);
285        flow.run(&mut shared).unwrap();
286        
287        let counter = shared.counter.lock().unwrap();
288        assert_eq!(*counter, 11);
289    }
290
291    #[test]
292    fn test_batch_flow() {
293        let mut shared = TestShared::default();
294        
295        let node = Node::new("batch", |shared, params| {
296            if let Some(shared) = shared.downcast_mut::<TestShared>() {
297                let mut counter = shared.counter.lock().unwrap();
298                if let Some(value) = params.get("value") {
299                    if let Some(num) = value.as_i64() {
300                        *counter += num as i32;
301                    }
302                }
303            }
304            Ok(None)
305        });
306        
307        let flow = BatchFlow::new().start(node);
308        
309        let batch_params = vec![
310            {
311                let mut params = Params::new();
312                params.insert("value", 1);
313                params
314            },
315            {
316                let mut params = Params::new();
317                params.insert("value", 2);
318                params
319            },
320        ];
321        
322        flow.run_batch(&mut shared, batch_params
323        ).unwrap();
324        
325        let counter = shared.counter.lock().unwrap();
326        assert_eq!(*counter, 3);
327    }
328
329    #[test]
330    fn test_conditional_flow() {
331        let mut shared = TestShared::default();
332        
333        let mut node1 = Node::new("node1", |_shared, params| {
334            let should_continue = params.get("continue").and_then(|v| v.as_bool()).unwrap_or(false);
335            if should_continue {
336                Ok(Some("continue".to_string()))
337            } else {
338                Ok(None)
339            }
340        });
341        
342        let node2 = Node::new("node2", |shared, _params| {
343            if let Some(shared) = shared.downcast_mut::<TestShared>() {
344                let mut counter = shared.counter.lock().unwrap();
345                *counter += 100;
346            }
347            Ok(None)
348        });
349        
350        node1.add_successor("continue", node2);
351        
352        let mut params = Params::new();
353        params.insert("continue", true);
354        
355        let flow = Flow::new().start(node1);
356        flow.run_with_params(&mut shared, params
357        ).unwrap();
358        
359        let counter = shared.counter.lock().unwrap();
360        assert_eq!(*counter, 100);
361    }
362}
363
364// Re-export common types
365pub mod prelude {
366    pub use super::{Node, Flow, BatchFlow, Params, Result, FlowError};
367}
368
369// Helper macros for creating nodes more easily
370#[macro_export]
371macro_rules! node {
372    ($name:expr, $func:expr) => {
373        Node::new($name, $func)
374    };
375}
376
377#[macro_export]
378macro_rules! flow {
379    ($($node:expr),+ $(,)?) => {{
380        let mut current = None;
381        $(current = Some($node);)*
382        Flow::new().start(current.unwrap())
383    }};
384}
385
386#[macro_export]
387macro_rules! chain {
388    ($first:expr $(, $rest:expr)* $(,)?) => {{
389        let mut current = $first;
390        $(current = current.next($rest);)*
391        current
392    }};
393}