1use std::collections::HashMap;
2use std::sync::Arc;
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum FlowError {
8 #[error("Execution error: {0}")]
9 Execution(String),
10 #[error("Node not found")]
11 NodeNotFound,
12 #[error("Parameter error: {0}")]
13 ParameterError(String),
14}
15
16pub type Result<T> = std::result::Result<T, FlowError>;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Params {
20 inner: HashMap<String, serde_json::Value>,
21}
22
23impl Params {
24 pub fn new() -> Self {
25 Self {
26 inner: HashMap::new(),
27 }
28 }
29
30 pub fn insert<S: Into<String>, V: Into<serde_json::Value>>(&mut self, key: S, value: V) {
31 self.inner.insert(key.into(), value.into());
32 }
33
34 pub fn get<S: AsRef<str>>(&self, key: S) -> Option<&serde_json::Value> {
35 self.inner.get(key.as_ref())
36 }
37
38 pub fn merge(&mut self, other: &Params) {
39 self.inner.extend(other.inner.clone());
40 }
41
42 pub fn remove<S: AsRef<str>>(&mut self, key: S) -> Option<serde_json::Value> {
43 self.inner.remove(key.as_ref())
44 }
45}
46
47impl Default for Params {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53pub trait SharedData: Clone + Send + Sync {}
54
55type NodeFunc = Arc<dyn Fn(&mut (dyn std::any::Any + Send), &Params) -> Result<Option<String>> + Send + Sync>;
56
57#[derive(Clone)]
58pub struct Node {
59 name: String,
60 params: Params,
61 successors: HashMap<String, Node>,
62 func: NodeFunc,
63}
64
65impl Node {
66 pub fn new<F>(name: impl Into<String>, func: F) -> Self
67 where
68 F: Fn(&mut (dyn std::any::Any + Send), &Params) -> Result<Option<String>> + Send + Sync + 'static,
69 {
70 Self {
71 name: name.into(),
72 params: Params::new(),
73 successors: HashMap::new(),
74 func: Arc::new(func),
75 }
76 }
77
78 pub fn add_successor(&mut self, action: impl Into<String>, node: Node) -> &mut Self {
79 self.successors.insert(action.into(), node);
80 self
81 }
82
83 pub fn next(&mut self, node: Node) -> &mut Self {
84 self.add_successor("default", node)
85 }
86
87 pub fn set_params(&mut self, params: Params) {
88 self.params = params;
89 }
90
91 pub fn get_params(&self) -> &Params {
92 &self.params
93 }
94
95 pub fn get_successor(&self, action: &str) -> Option<&Node> {
96 self.successors.get(action)
97 }
98
99 pub fn has_successors(&self) -> bool {
100 !self.successors.is_empty()
101 }
102
103 pub fn run(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
104 if self.has_successors() {
105 eprintln!("Warning: Node won't run successors. Use Flow.");
106 }
107
108 (self.func)(shared, &self.params)?;
109 Ok(())
110 }
111
112 pub fn run_recursive(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
113 let action = (self.func)(shared, &self.params)?;
114
115 if let Some(next_node) = action
116 .as_ref()
117 .and_then(|a| self.successors.get(a))
118 .or_else(|| self.successors.get("default")) {
119 next_node.run_recursive(shared)?;
120 }
121
122 Ok(())
123 }
124}
125
126pub struct Flow {
127 start_node: Option<Node>,
128 params: Params,
129}
130
131impl Flow {
132 pub fn new() -> Self {
133 Self {
134 start_node: None,
135 params: Params::new(),
136 }
137 }
138
139 pub fn start(mut self, node: Node) -> Self {
140 self.start_node = Some(node);
141 self
142 }
143
144 pub fn set_params(&mut self, params: Params) {
145 self.params = params;
146 }
147
148 pub fn run(&self, shared: &mut (dyn std::any::Any + Send)) -> Result<()> {
149 if let Some(ref node) = self.start_node {
150 let mut node = node.clone();
151 node.set_params(self.params.clone());
152 node.run_recursive(shared)?;
153 }
154 Ok(())
155 }
156
157 pub fn run_with_params(&self, shared: &mut (dyn std::any::Any + Send), params: Params) -> Result<()> {
158 if let Some(ref node) = self.start_node {
159 let mut node = node.clone();
160 let mut merged_params = self.params.clone();
161 merged_params.merge(¶ms);
162 node.set_params(merged_params);
163 node.run_recursive(shared)?;
164 }
165 Ok(())
166 }
167}
168
169impl Default for Flow {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175pub struct BatchFlow {
176 start_node: Option<Node>,
177 params: Params,
178}
179
180impl BatchFlow {
181 pub fn new() -> Self {
182 Self {
183 start_node: None,
184 params: Params::new(),
185 }
186 }
187
188 pub fn start(mut self, node: Node) -> Self {
189 self.start_node = Some(node);
190 self
191 }
192
193 pub fn set_params(&mut self, params: Params) {
194 self.params = params;
195 }
196
197 pub fn run_batch(&self, shared: &mut (dyn std::any::Any + Send), batch_params: Vec<Params>) -> Result<Vec<()>> {
198 let mut results = Vec::with_capacity(batch_params.len());
199 for params in batch_params {
200 let flow = Flow {
201 start_node: self.start_node.clone(),
202 params: self.params.clone(),
203 };
204 flow.run_with_params(shared, params)?;
205 results.push(());
206 }
207 Ok(results)
208 }
209}
210
211impl Default for BatchFlow {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::sync::{Arc, Mutex};
221
222 #[derive(Default, Clone)]
223 struct TestShared {
224 pub counter: Arc<Mutex<i32>>,
225 }
226
227 #[test]
228 fn test_basic_node() {
229 let mut shared = TestShared::default();
230 let node = Node::new("test", |shared, _params| {
231 if let Some(shared) = shared.downcast_mut::<TestShared>() {
232 let mut counter = shared.counter.lock().unwrap();
233 *counter += 1;
234 }
235 Ok(None)
236 });
237
238 node.run(&mut shared).unwrap();
239
240 let counter = shared.counter.lock().unwrap();
241 assert_eq!(*counter, 1);
242 }
243
244 #[test]
245 fn test_flow_execution() {
246 let mut shared = TestShared::default();
247
248 let node = Node::new("test", |shared, _params| {
249 if let Some(shared) = shared.downcast_mut::<TestShared>() {
250 let mut counter = shared.counter.lock().unwrap();
251 *counter += 1;
252 }
253 Ok(None)
254 });
255
256 let flow = Flow::new().start(node);
257 flow.run(&mut shared).unwrap();
258
259 let counter = shared.counter.lock().unwrap();
260 assert_eq!(*counter, 1);
261 }
262
263 #[test]
264 fn test_chained_flow() {
265 let mut shared = TestShared::default();
266
267 let mut node1 = Node::new("node1", |shared, _params| {
268 if let Some(shared) = shared.downcast_mut::<TestShared>() {
269 let mut counter = shared.counter.lock().unwrap();
270 *counter += 1;
271 }
272 Ok(None)
273 });
274
275 let node2 = Node::new("node2", |shared, _params| {
276 if let Some(shared) = shared.downcast_mut::<TestShared>() {
277 let mut counter = shared.counter.lock().unwrap();
278 *counter += 10;
279 }
280 Ok(None)
281 });
282
283 node1.next(node2);
284 let flow = Flow::new().start(node1);
285 flow.run(&mut shared).unwrap();
286
287 let counter = shared.counter.lock().unwrap();
288 assert_eq!(*counter, 11);
289 }
290
291 #[test]
292 fn test_batch_flow() {
293 let mut shared = TestShared::default();
294
295 let node = Node::new("batch", |shared, params| {
296 if let Some(shared) = shared.downcast_mut::<TestShared>() {
297 let mut counter = shared.counter.lock().unwrap();
298 if let Some(value) = params.get("value") {
299 if let Some(num) = value.as_i64() {
300 *counter += num as i32;
301 }
302 }
303 }
304 Ok(None)
305 });
306
307 let flow = BatchFlow::new().start(node);
308
309 let batch_params = vec![
310 {
311 let mut params = Params::new();
312 params.insert("value", 1);
313 params
314 },
315 {
316 let mut params = Params::new();
317 params.insert("value", 2);
318 params
319 },
320 ];
321
322 flow.run_batch(&mut shared, batch_params
323 ).unwrap();
324
325 let counter = shared.counter.lock().unwrap();
326 assert_eq!(*counter, 3);
327 }
328
329 #[test]
330 fn test_conditional_flow() {
331 let mut shared = TestShared::default();
332
333 let mut node1 = Node::new("node1", |_shared, params| {
334 let should_continue = params.get("continue").and_then(|v| v.as_bool()).unwrap_or(false);
335 if should_continue {
336 Ok(Some("continue".to_string()))
337 } else {
338 Ok(None)
339 }
340 });
341
342 let node2 = Node::new("node2", |shared, _params| {
343 if let Some(shared) = shared.downcast_mut::<TestShared>() {
344 let mut counter = shared.counter.lock().unwrap();
345 *counter += 100;
346 }
347 Ok(None)
348 });
349
350 node1.add_successor("continue", node2);
351
352 let mut params = Params::new();
353 params.insert("continue", true);
354
355 let flow = Flow::new().start(node1);
356 flow.run_with_params(&mut shared, params
357 ).unwrap();
358
359 let counter = shared.counter.lock().unwrap();
360 assert_eq!(*counter, 100);
361 }
362}
363
364pub mod prelude {
366 pub use super::{Node, Flow, BatchFlow, Params, Result, FlowError};
367}
368
369#[macro_export]
371macro_rules! node {
372 ($name:expr, $func:expr) => {
373 Node::new($name, $func)
374 };
375}
376
377#[macro_export]
378macro_rules! flow {
379 ($($node:expr),+ $(,)?) => {{
380 let mut current = None;
381 $(current = Some($node);)*
382 Flow::new().start(current.unwrap())
383 }};
384}
385
386#[macro_export]
387macro_rules! chain {
388 ($first:expr $(, $rest:expr)* $(,)?) => {{
389 let mut current = $first;
390 $(current = current.next($rest);)*
391 current
392 }};
393}