1use std::fmt::Debug;
2use std::future::Future;
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::action::ActionType;
10use crate::error::FloxideError;
11use crate::node::{Node, NodeId, NodeOutcome};
12
13#[async_trait]
15pub trait LifecycleNode<Context, Action>: Send + Sync
16where
17 Context: Send + Sync + 'static,
18 Action: ActionType + Send + Sync + 'static,
19 Self::PrepOutput: Clone + Send + Sync + 'static,
20 Self::ExecOutput: Clone + Send + Sync + 'static,
21{
22 type PrepOutput;
24
25 type ExecOutput;
27
28 fn id(&self) -> NodeId;
30
31 async fn prep(&self, ctx: &mut Context) -> Result<Self::PrepOutput, FloxideError>;
33
34 async fn exec(&self, prep_result: Self::PrepOutput) -> Result<Self::ExecOutput, FloxideError>;
36
37 async fn post(
39 &self,
40 prep_result: Self::PrepOutput,
41 exec_result: Self::ExecOutput,
42 ctx: &mut Context,
43 ) -> Result<Action, FloxideError>;
44}
45
46pub struct LifecycleNodeAdapter<LN, Context, Action>
48where
49 LN: LifecycleNode<Context, Action>,
50 Context: Send + Sync + 'static,
51 Action: ActionType + Send + Sync + 'static,
52{
53 inner: LN,
54 _phantom: PhantomData<(Context, Action)>,
55}
56
57impl<LN, Context, Action> LifecycleNodeAdapter<LN, Context, Action>
58where
59 LN: LifecycleNode<Context, Action>,
60 Context: Send + Sync + 'static,
61 Action: ActionType + Send + Sync + 'static,
62{
63 pub fn new(inner: LN) -> Self {
65 Self {
66 inner,
67 _phantom: PhantomData,
68 }
69 }
70}
71
72impl<LN, Context, Action> Debug for LifecycleNodeAdapter<LN, Context, Action>
73where
74 LN: LifecycleNode<Context, Action> + Debug,
75 Context: Send + Sync + 'static,
76 Action: ActionType + Send + Sync + 'static,
77{
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("LifecycleNodeAdapter")
80 .field("inner", &self.inner)
81 .finish()
82 }
83}
84
85#[async_trait]
86impl<LN, Context, Action> Node<Context, Action> for LifecycleNodeAdapter<LN, Context, Action>
87where
88 LN: LifecycleNode<Context, Action> + Send + Sync + 'static,
89 Context: Send + Sync + 'static,
90 Action: ActionType + Send + Sync + 'static,
91 LN::ExecOutput: Send + Sync + 'static,
92{
93 type Output = LN::ExecOutput;
94
95 fn id(&self) -> NodeId {
96 self.inner.id()
97 }
98
99 async fn process(
100 &self,
101 ctx: &mut Context,
102 ) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
103 debug!(node_id = %self.id(), "Starting prep phase");
105 let prep_result = self.inner.prep(ctx).await?;
106
107 debug!(node_id = %self.id(), "Starting exec phase");
108 let exec_result = self.inner.exec(prep_result.clone()).await?;
109
110 debug!(node_id = %self.id(), "Starting post phase");
111 let next_action = self
112 .inner
113 .post(prep_result, exec_result.clone(), ctx)
114 .await?;
115
116 Ok(NodeOutcome::RouteToAction(next_action))
118 }
119}
120
121pub fn lifecycle_node<
123 PrepFn,
124 ExecFn,
125 PostFn,
126 Context,
127 Action,
128 PrepOut,
129 ExecOut,
130 PrepFut,
131 ExecFut,
132 PostFut,
133>(
134 id: Option<String>,
135 prep_fn: PrepFn,
136 exec_fn: ExecFn,
137 post_fn: PostFn,
138) -> impl Node<Context, Action, Output = ExecOut>
139where
140 Context: Send + Sync + 'static,
141 Action: ActionType + Send + Sync + 'static,
142 PrepOut: Send + Sync + Clone + 'static,
143 ExecOut: Send + Sync + Clone + 'static,
144 PrepFn: Fn(&mut Context) -> PrepFut + Send + Sync + 'static,
145 ExecFn: Fn(PrepOut) -> ExecFut + Send + Sync + 'static,
146 PostFn: Fn(PrepOut, ExecOut, &mut Context) -> PostFut + Send + Sync + 'static,
147 PrepFut: Future<Output = Result<PrepOut, FloxideError>> + Send + 'static,
148 ExecFut: Future<Output = Result<ExecOut, FloxideError>> + Send + 'static,
149 PostFut: Future<Output = Result<Action, FloxideError>> + Send + 'static,
150{
151 struct ClosureLifecycleNode<P, E, Po, Ctx, Act, PO, EO> {
152 id: NodeId,
153 prep_fn: P,
154 exec_fn: E,
155 post_fn: Po,
156 _phantom: PhantomData<(Ctx, Act, PO, EO)>,
157 }
158
159 impl<P, E, Po, Ctx, Act, PO, EO> Debug for ClosureLifecycleNode<P, E, Po, Ctx, Act, PO, EO> {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct("ClosureLifecycleNode")
162 .field("id", &self.id)
163 .finish()
164 }
165 }
166
167 #[async_trait]
168 impl<P, E, Po, Ctx, Act, PO, EO, PF, EF, PoF> LifecycleNode<Ctx, Act>
169 for ClosureLifecycleNode<P, E, Po, Ctx, Act, PO, EO>
170 where
171 Ctx: Send + Sync + 'static,
172 Act: ActionType + Send + Sync + 'static,
173 PO: Send + Sync + Clone + 'static,
174 EO: Send + Sync + Clone + 'static,
175 P: Fn(&mut Ctx) -> PF + Send + Sync + 'static,
176 E: Fn(PO) -> EF + Send + Sync + 'static,
177 Po: Fn(PO, EO, &mut Ctx) -> PoF + Send + Sync + 'static,
178 PF: Future<Output = Result<PO, FloxideError>> + Send + 'static,
179 EF: Future<Output = Result<EO, FloxideError>> + Send + 'static,
180 PoF: Future<Output = Result<Act, FloxideError>> + Send + 'static,
181 {
182 type PrepOutput = PO;
183 type ExecOutput = EO;
184
185 fn id(&self) -> NodeId {
186 self.id.clone()
187 }
188
189 async fn prep(&self, ctx: &mut Ctx) -> Result<Self::PrepOutput, FloxideError> {
190 (self.prep_fn)(ctx).await
191 }
192
193 async fn exec(
194 &self,
195 prep_result: Self::PrepOutput,
196 ) -> Result<Self::ExecOutput, FloxideError> {
197 (self.exec_fn)(prep_result).await
198 }
199
200 async fn post(
201 &self,
202 prep_result: Self::PrepOutput,
203 exec_result: Self::ExecOutput,
204 ctx: &mut Ctx,
205 ) -> Result<Act, FloxideError> {
206 (self.post_fn)(prep_result, exec_result, ctx).await
207 }
208 }
209
210 let node_id = id.unwrap_or_else(|| Uuid::new_v4().to_string());
211
212 let lifecycle_node = ClosureLifecycleNode {
213 id: node_id,
214 prep_fn,
215 exec_fn,
216 post_fn,
217 _phantom: PhantomData,
218 };
219
220 LifecycleNodeAdapter::new(lifecycle_node)
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::action::DefaultAction;
227
228 #[derive(Debug, Clone)]
229 struct TestContext {
230 value: i32,
231 path: Vec<String>,
232 }
233
234 struct TestLifecycleNode {
236 id: NodeId,
237 }
238
239 #[async_trait]
240 impl LifecycleNode<TestContext, DefaultAction> for TestLifecycleNode {
241 type PrepOutput = i32;
242 type ExecOutput = i32;
243
244 fn id(&self) -> NodeId {
245 self.id.clone()
246 }
247
248 async fn prep(&self, ctx: &mut TestContext) -> Result<Self::PrepOutput, FloxideError> {
249 ctx.path.push("prep".to_string());
250 Ok(ctx.value)
251 }
252
253 async fn exec(
254 &self,
255 prep_result: Self::PrepOutput,
256 ) -> Result<Self::ExecOutput, FloxideError> {
257 Ok(prep_result * 2)
258 }
259
260 async fn post(
261 &self,
262 _prep_result: Self::PrepOutput,
263 exec_result: Self::ExecOutput,
264 ctx: &mut TestContext,
265 ) -> Result<DefaultAction, FloxideError> {
266 ctx.path.push("post".to_string());
267 ctx.value = exec_result;
268 Ok(DefaultAction::Next)
269 }
270 }
271
272 struct ErrorLifecycleNode {
274 id: NodeId,
275 }
276
277 #[async_trait]
278 impl LifecycleNode<TestContext, DefaultAction> for ErrorLifecycleNode {
279 type PrepOutput = i32;
280 type ExecOutput = i32;
281
282 fn id(&self) -> NodeId {
283 self.id.clone()
284 }
285
286 async fn prep(&self, _ctx: &mut TestContext) -> Result<Self::PrepOutput, FloxideError> {
287 Ok(42)
288 }
289
290 async fn exec(
291 &self,
292 _prep_result: Self::PrepOutput,
293 ) -> Result<Self::ExecOutput, FloxideError> {
294 Err(FloxideError::node_execution("test", "Simulated error"))
295 }
296
297 async fn post(
298 &self,
299 _prep_result: Self::PrepOutput,
300 _exec_result: Self::ExecOutput,
301 _ctx: &mut TestContext,
302 ) -> Result<DefaultAction, FloxideError> {
303 Ok(DefaultAction::Next)
305 }
306 }
307
308 #[tokio::test]
309 async fn test_lifecycle_node() {
310 let lifecycle_node = TestLifecycleNode {
311 id: "test-node".to_string(),
312 };
313 let node = LifecycleNodeAdapter::new(lifecycle_node);
314
315 let mut ctx = TestContext {
316 value: 5,
317 path: Vec::new(),
318 };
319
320 let result = node.process(&mut ctx).await.unwrap();
321
322 match result {
323 NodeOutcome::RouteToAction(action) => {
324 assert_eq!(action, DefaultAction::Next);
325 }
326 _ => panic!("Expected RouteToAction outcome"),
327 }
328
329 assert_eq!(ctx.value, 10); assert_eq!(ctx.path, vec!["prep", "post"]);
331 }
332
333 #[tokio::test]
334 async fn test_lifecycle_node_with_error() {
335 let lifecycle_node = ErrorLifecycleNode {
336 id: "error-node".to_string(),
337 };
338 let node = LifecycleNodeAdapter::new(lifecycle_node);
339
340 let mut ctx = TestContext {
341 value: 5,
342 path: Vec::new(),
343 };
344
345 let result = node.process(&mut ctx).await;
346 assert!(result.is_err());
347 }
348}