1use async_trait::async_trait;
9use floxide_core::{error::FloxideError, ActionType, DefaultAction, Node, NodeId, NodeOutcome};
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tracing::{info, warn};
16use uuid::Uuid;
17
18#[async_trait]
20pub trait EventDrivenNode<Event, Context, Action>: Send + Sync
21where
22 Event: Send + 'static,
23 Context: Send + Sync + 'static,
24 Action: ActionType + Send + Sync + 'static + Default,
25{
26 async fn wait_for_event(&mut self) -> Result<Event, FloxideError>;
28
29 async fn process_event(&self, event: Event, ctx: &mut Context) -> Result<Action, FloxideError>;
31
32 fn id(&self) -> NodeId;
34}
35
36pub struct ChannelEventSource<Event> {
38 receiver: mpsc::Receiver<Event>,
39 id: NodeId,
40}
41
42impl<Event> ChannelEventSource<Event>
43where
44 Event: Send + 'static,
45{
46 pub fn new(capacity: usize) -> (Self, mpsc::Sender<Event>) {
48 let (sender, receiver) = mpsc::channel(capacity);
49 (
50 Self {
51 receiver,
52 id: Uuid::new_v4().to_string(),
53 },
54 sender,
55 )
56 }
57
58 pub fn with_id(capacity: usize, id: impl Into<String>) -> (Self, mpsc::Sender<Event>) {
60 let (sender, receiver) = mpsc::channel(capacity);
61 (
62 Self {
63 receiver,
64 id: id.into(),
65 },
66 sender,
67 )
68 }
69}
70
71#[async_trait]
72impl<Event, Context, Action> EventDrivenNode<Event, Context, Action> for ChannelEventSource<Event>
73where
74 Event: Send + 'static,
75 Context: Send + Sync + 'static,
76 Action: ActionType + Send + Sync + 'static + Default,
77{
78 async fn wait_for_event(&mut self) -> Result<Event, FloxideError> {
79 match self.receiver.recv().await {
80 Some(event) => Ok(event),
81 None => Err(FloxideError::Other("Event channel closed".to_string())),
82 }
83 }
84
85 async fn process_event(
86 &self,
87 _event: Event,
88 _ctx: &mut Context,
89 ) -> Result<Action, FloxideError> {
90 Ok(Action::default())
93 }
94
95 fn id(&self) -> NodeId {
96 self.id.clone()
97 }
98}
99
100pub struct EventProcessor<Event, Context, Action, F>
102where
103 Event: Send + 'static,
104 Context: Send + Sync + 'static,
105 Action: ActionType + Send + Sync + 'static + Default,
106 F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
107{
108 source: Arc<tokio::sync::Mutex<ChannelEventSource<Event>>>,
109 processor: F,
110 _phantom: PhantomData<(Context, Action)>,
111}
112
113impl<Event, Context, Action, F> EventProcessor<Event, Context, Action, F>
114where
115 Event: Send + 'static,
116 Context: Send + Sync + 'static,
117 Action: ActionType + Send + Sync + 'static + Default,
118 F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
119{
120 pub fn new(capacity: usize, processor: F) -> (Self, mpsc::Sender<Event>) {
122 let (source, sender) = ChannelEventSource::new(capacity);
123 (
124 Self {
125 source: Arc::new(tokio::sync::Mutex::new(source)),
126 processor,
127 _phantom: PhantomData,
128 },
129 sender,
130 )
131 }
132
133 pub fn with_id(
135 capacity: usize,
136 id: impl Into<String>,
137 processor: F,
138 ) -> (Self, mpsc::Sender<Event>) {
139 let (source, sender) = ChannelEventSource::with_id(capacity, id);
140 (
141 Self {
142 source: Arc::new(tokio::sync::Mutex::new(source)),
143 processor,
144 _phantom: PhantomData,
145 },
146 sender,
147 )
148 }
149}
150
151#[async_trait]
152impl<Event, Context, Action, F> EventDrivenNode<Event, Context, Action>
153 for EventProcessor<Event, Context, Action, F>
154where
155 Event: Send + 'static,
156 Context: Send + Sync + 'static,
157 Action: ActionType + Send + Sync + 'static + Default,
158 F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
159{
160 async fn wait_for_event(&mut self) -> Result<Event, FloxideError> {
161 let mut source = self.source.lock().await;
162 <ChannelEventSource<Event> as EventDrivenNode<Event, Context, Action>>::wait_for_event(
163 &mut *source,
164 )
165 .await
166 }
167
168 async fn process_event(&self, event: Event, ctx: &mut Context) -> Result<Action, FloxideError> {
169 (self.processor)(event, ctx)
170 }
171
172 fn id(&self) -> NodeId {
173 self.source
174 .try_lock()
175 .map(|source| {
176 <ChannelEventSource<Event> as EventDrivenNode<Event, Context, Action>>::id(&*source)
177 })
178 .unwrap_or_else(|_| "locked".to_string())
179 }
180}
181
182type EventNodeRef<E, C, A> = Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>;
184
185pub struct EventDrivenWorkflow<Event, Context, Action>
187where
188 Event: Send + 'static,
189 Context: Send + Sync + 'static,
190 Action: ActionType + Send + Sync + 'static + Default,
191{
192 nodes: HashMap<NodeId, EventNodeRef<Event, Context, Action>>,
193 routes: HashMap<(NodeId, Action), NodeId>,
194 initial_node: NodeId,
195 termination_action: Action,
196}
197
198impl<Event, Context, Action> EventDrivenWorkflow<Event, Context, Action>
199where
200 Event: Send + 'static,
201 Context: Send + Sync + 'static,
202 Action: ActionType + Send + Sync + 'static + Default,
203{
204 pub fn new(
206 initial_node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<Event, Context, Action>>>,
207 termination_action: Action,
208 ) -> Self {
209 let id = {
210 initial_node
211 .try_lock()
212 .map(|n| n.id())
213 .unwrap_or_else(|_| "locked".to_string())
214 };
215
216 let mut nodes = HashMap::new();
217 nodes.insert(id.clone(), initial_node);
218
219 Self {
220 nodes,
221 routes: HashMap::new(),
222 initial_node: id,
223 termination_action,
224 }
225 }
226
227 pub fn add_node(
229 &mut self,
230 node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<Event, Context, Action>>>,
231 ) {
232 let id = {
233 node.try_lock()
234 .map(|n| n.id())
235 .unwrap_or_else(|_| "locked".to_string())
236 };
237 self.nodes.insert(id, node);
238 }
239
240 pub fn set_route(&mut self, from_id: &NodeId, action: Action, to_id: &NodeId) {
242 self.routes.insert((from_id.clone(), action), to_id.clone());
244 }
245
246 pub fn set_route_with_validation(
261 &mut self,
262 from_id: &NodeId,
263 action: Action,
264 to_id: &NodeId,
265 ) -> Result<(), FloxideError> {
266 if !self.nodes.contains_key(to_id) {
268 return Err(FloxideError::Other(format!(
269 "Destination node '{}' not found in workflow",
270 to_id
271 )));
272 }
273
274 self.routes.insert((from_id.clone(), action), to_id.clone());
285
286 Ok(())
287 }
288
289 pub async fn execute(&self, ctx: &mut Context) -> Result<(), FloxideError> {
291 let mut current_node_id = self.initial_node.clone();
292
293 loop {
294 let node = self
296 .nodes
297 .get(¤t_node_id)
298 .ok_or_else(|| FloxideError::node_not_found(current_node_id.clone()))?;
299
300 let event = {
302 let mut node_guard = node.lock().await;
303 match node_guard.wait_for_event().await {
304 Ok(event) => event,
305 Err(e) => {
306 if e.to_string().contains("not an event source") {
308 warn!(
311 "Node '{}' is not an event source, routing to initial node",
312 current_node_id
313 );
314
315 current_node_id = self.initial_node.clone();
317 let source_node =
318 self.nodes.get(¤t_node_id).ok_or_else(|| {
319 FloxideError::Other(
320 "Initial node not found in workflow".to_string(),
321 )
322 })?;
323
324 let mut source_guard = source_node.lock().await;
325 source_guard.wait_for_event().await?
326 } else {
327 return Err(e);
329 }
330 }
331 }
332 };
333
334 let action = {
335 let node_guard = node.lock().await;
336 node_guard.process_event(event, ctx).await?
337 };
338
339 if action == self.termination_action {
341 info!("Event-driven workflow terminated with termination action");
342 return Ok(());
343 }
344
345 current_node_id = self
347 .routes
348 .get(&(current_node_id, action.clone()))
349 .ok_or_else(|| {
350 FloxideError::WorkflowDefinitionError(format!(
351 "No route defined for action: {}",
352 action.name()
353 ))
354 })?
355 .clone();
356 }
357 }
358
359 pub async fn execute_with_timeout(
361 &self,
362 ctx: &mut Context,
363 timeout: Duration,
364 ) -> Result<(), FloxideError> {
365 match tokio::time::timeout(timeout, self.execute(ctx)).await {
366 Ok(result) => result,
367 Err(_) => Err(FloxideError::timeout(
368 "Event-driven workflow execution timed out",
369 )),
370 }
371 }
372}
373
374pub struct EventDrivenNodeAdapter<E, C, A>
376where
377 E: Send + 'static,
378 C: Send + Sync + 'static,
379 A: ActionType + Send + Sync + 'static + Default,
380{
381 node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
382 id: NodeId,
383 timeout: Duration,
384 timeout_action: A,
385}
386
387impl<E, C, A> EventDrivenNodeAdapter<E, C, A>
388where
389 E: Send + 'static,
390 C: Send + Sync + 'static,
391 A: ActionType + Send + Sync + 'static + Default,
392{
393 pub fn new(
395 node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
396 timeout: Duration,
397 timeout_action: A,
398 ) -> Self {
399 let id = {
400 node.try_lock()
401 .map(|n| n.id())
402 .unwrap_or_else(|_| "locked".to_string())
403 };
404
405 Self {
406 node,
407 id,
408 timeout,
409 timeout_action,
410 }
411 }
412
413 pub fn with_id(
415 node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
416 id: impl Into<String>,
417 timeout: Duration,
418 timeout_action: A,
419 ) -> Self {
420 Self {
421 node,
422 id: id.into(),
423 timeout,
424 timeout_action,
425 }
426 }
427}
428
429#[async_trait]
430impl<E, C, A> Node<C, A> for EventDrivenNodeAdapter<E, C, A>
431where
432 E: Send + 'static,
433 C: Send + Sync + 'static,
434 A: ActionType + Send + Sync + 'static + Default,
435{
436 type Output = ();
437
438 fn id(&self) -> NodeId {
439 self.id.clone()
440 }
441
442 async fn process(&self, ctx: &mut C) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
443 let wait_for_event_future = async {
445 let mut node_guard = self.node.lock().await;
446 node_guard.wait_for_event().await
447 };
448
449 match tokio::time::timeout(self.timeout, wait_for_event_future).await {
450 Ok(Ok(event)) => {
451 let action = {
452 let node_guard = self.node.lock().await;
453 node_guard.process_event(event, ctx).await?
454 };
455 Ok(NodeOutcome::RouteToAction(action))
456 }
457 Ok(Err(e)) => Err(e),
458 Err(_) => {
459 Ok(NodeOutcome::RouteToAction(self.timeout_action.clone()))
461 }
462 }
463 }
464}
465
466pub struct NestedEventDrivenWorkflow<E, C, A>
468where
469 E: Send + 'static,
470 C: Send + Sync + 'static,
471 A: ActionType + Send + Sync + 'static + Default,
472{
473 workflow: Arc<EventDrivenWorkflow<E, C, A>>,
474 id: NodeId,
475 timeout: Option<Duration>,
476 complete_action: A,
477 timeout_action: A,
478}
479
480impl<E, C, A> NestedEventDrivenWorkflow<E, C, A>
481where
482 E: Send + 'static,
483 C: Send + Sync + 'static,
484 A: ActionType + Send + Sync + 'static + Default,
485{
486 pub fn new(
488 workflow: Arc<EventDrivenWorkflow<E, C, A>>,
489 complete_action: A,
490 timeout_action: A,
491 ) -> Self {
492 Self {
493 workflow,
494 id: Uuid::new_v4().to_string(),
495 timeout: None,
496 complete_action,
497 timeout_action,
498 }
499 }
500
501 pub fn with_timeout(
503 workflow: Arc<EventDrivenWorkflow<E, C, A>>,
504 timeout: Duration,
505 complete_action: A,
506 timeout_action: A,
507 ) -> Self {
508 Self {
509 workflow,
510 id: Uuid::new_v4().to_string(),
511 timeout: Some(timeout),
512 complete_action,
513 timeout_action,
514 }
515 }
516
517 pub fn with_id(
519 workflow: Arc<EventDrivenWorkflow<E, C, A>>,
520 id: impl Into<String>,
521 complete_action: A,
522 timeout_action: A,
523 ) -> Self {
524 Self {
525 workflow,
526 id: id.into(),
527 timeout: None,
528 complete_action,
529 timeout_action,
530 }
531 }
532}
533
534#[async_trait]
535impl<E, C, A> Node<C, A> for NestedEventDrivenWorkflow<E, C, A>
536where
537 E: Send + 'static,
538 C: Send + Sync + 'static,
539 A: ActionType + Send + Sync + 'static + Default,
540{
541 type Output = ();
542
543 fn id(&self) -> NodeId {
544 self.id.clone()
545 }
546
547 async fn process(&self, ctx: &mut C) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
548 match self.timeout {
549 Some(timeout) => {
550 match tokio::time::timeout(timeout, self.workflow.execute(ctx)).await {
551 Ok(Ok(())) => {
552 Ok(NodeOutcome::RouteToAction(self.complete_action.clone()))
554 }
555 Ok(Err(e)) => {
556 Err(e)
558 }
559 Err(_) => {
560 Ok(NodeOutcome::RouteToAction(self.timeout_action.clone()))
562 }
563 }
564 }
565 None => {
566 self.workflow.execute(ctx).await?;
567 Ok(NodeOutcome::RouteToAction(self.complete_action.clone()))
568 }
569 }
570 }
571}
572
573pub trait EventActionExt: ActionType {
575 fn terminate() -> Self;
577
578 fn timeout() -> Self;
580}
581
582impl EventActionExt for DefaultAction {
583 fn terminate() -> Self {
584 DefaultAction::Custom("terminate".into())
585 }
586
587 fn timeout() -> Self {
588 DefaultAction::Custom("timeout".into())
589 }
590}