1use async_trait::async_trait;
2use uuid::Uuid;
3
4use crate::action::ActionType;
5use crate::error::FloxideError;
6
7pub type NodeId = String;
9
10#[derive(Debug, Clone)]
12pub enum NodeOutcome<Output, Action> {
13 Success(Output),
15 Skipped,
17 RouteToAction(Action),
19}
20
21#[async_trait]
23pub trait Node<Context, Action>: Send + Sync
24where
25 Context: Send + Sync + 'static,
26 Action: ActionType + Send + Sync + 'static,
27 Self::Output: Send + Sync + 'static,
28{
29 type Output;
31
32 fn id(&self) -> NodeId;
34
35 async fn process(
37 &self,
38 ctx: &mut Context,
39 ) -> Result<NodeOutcome<Self::Output, Action>, FloxideError>;
40}
41
42pub mod closure {
44 use std::fmt::Debug;
45 use std::future::Future;
46 use std::marker::PhantomData;
47
48 use super::*;
49
50 pub fn node<Closure, Context, Action, Output, Fut>(
52 closure: Closure,
53 ) -> ClosureNode<Closure, Context, Action, Output>
54 where
55 Context: Clone + Send + Sync + 'static,
56 Action: ActionType + Send + Sync + 'static,
57 Output: Send + Sync + 'static,
58 Closure: Fn(Context) -> Fut + Send + Sync + 'static,
59 Fut: Future<Output = Result<(Context, NodeOutcome<Output, Action>), FloxideError>>
60 + Send
61 + 'static,
62 {
63 ClosureNode {
64 id: Uuid::new_v4().to_string(),
65 closure,
66 _phantom: PhantomData,
67 }
68 }
69
70 #[derive(Clone)]
72 pub struct ClosureNode<Closure, Context, Action, Output> {
73 id: NodeId,
74 closure: Closure,
75 _phantom: PhantomData<(Context, Action, Output)>,
76 }
77
78 impl<Closure, Context, Action, Output> Debug for ClosureNode<Closure, Context, Action, Output> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("ClosureNode").field("id", &self.id).finish()
81 }
82 }
83
84 #[async_trait]
85 impl<Closure, Context, Action, Output, Fut> Node<Context, Action>
86 for ClosureNode<Closure, Context, Action, Output>
87 where
88 Context: Clone + Send + Sync + 'static,
89 Action: ActionType + Send + Sync + 'static,
90 Output: Send + Sync + 'static,
91 Closure: Fn(Context) -> Fut + Send + Sync + 'static,
92 Fut: Future<Output = Result<(Context, NodeOutcome<Output, Action>), FloxideError>>
93 + Send
94 + 'static,
95 {
96 type Output = Output;
97
98 fn id(&self) -> NodeId {
99 self.id.clone()
100 }
101
102 async fn process(
103 &self,
104 ctx: &mut Context,
105 ) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
106 let ctx_clone = ctx.clone();
108
109 let (updated_ctx, outcome) = (self.closure)(ctx_clone).await?;
111
112 *ctx = updated_ctx;
114
115 Ok(outcome)
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::action::DefaultAction;
125
126 #[derive(Debug, Clone)]
127 struct TestContext {
128 value: i32,
129 }
130
131 #[tokio::test]
132 async fn test_create_node_from_closure() {
133 let test_node = closure::node(|mut ctx: TestContext| async move {
134 ctx.value += 1;
135 let value = ctx.value; Ok((ctx, NodeOutcome::<i32, DefaultAction>::Success(value)))
137 });
138
139 let mut context = TestContext { value: 5 };
140 let result = test_node.process(&mut context).await.unwrap();
141
142 match result {
143 NodeOutcome::Success(value) => {
144 assert_eq!(value, 6);
145 assert_eq!(context.value, 6);
146 }
147 _ => panic!("Expected Success outcome"),
148 }
149 }
150
151 #[tokio::test]
152 async fn test_skip_node() {
153 let skip_node = closure::node(|ctx: TestContext| async move {
154 Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
155 });
156
157 let mut context = TestContext { value: 5 };
158 let result = skip_node.process(&mut context).await.unwrap();
159
160 match result {
161 NodeOutcome::Skipped => {}
162 _ => panic!("Expected Skipped outcome"),
163 }
164
165 assert_eq!(context.value, 5);
167 }
168
169 #[tokio::test]
170 async fn test_route_to_action() {
171 let route_node = closure::node(|ctx: TestContext| async move {
172 Ok((
173 ctx,
174 NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Custom(
175 "alternate_path".into(),
176 )),
177 ))
178 });
179
180 let mut context = TestContext { value: 5 };
181 let result = route_node.process(&mut context).await.unwrap();
182
183 match result {
184 NodeOutcome::RouteToAction(action) => {
185 assert_eq!(action.name(), "alternate_path");
186 }
187 _ => panic!("Expected RouteToAction outcome"),
188 }
189 }
190}