use async_trait::async_trait;
use floxide_core::{error::FloxideError, ActionType, DefaultAction, Node, NodeId, NodeOutcome};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{info, warn};
use uuid::Uuid;
#[async_trait]
pub trait EventDrivenNode<Event, Context, Action>: Send + Sync
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
{
async fn wait_for_event(&mut self) -> Result<Event, FloxideError>;
async fn process_event(&self, event: Event, ctx: &mut Context) -> Result<Action, FloxideError>;
fn id(&self) -> NodeId;
}
pub struct ChannelEventSource<Event> {
receiver: mpsc::Receiver<Event>,
id: NodeId,
}
impl<Event> ChannelEventSource<Event>
where
Event: Send + 'static,
{
pub fn new(capacity: usize) -> (Self, mpsc::Sender<Event>) {
let (sender, receiver) = mpsc::channel(capacity);
(
Self {
receiver,
id: Uuid::new_v4().to_string(),
},
sender,
)
}
pub fn with_id(capacity: usize, id: impl Into<String>) -> (Self, mpsc::Sender<Event>) {
let (sender, receiver) = mpsc::channel(capacity);
(
Self {
receiver,
id: id.into(),
},
sender,
)
}
}
#[async_trait]
impl<Event, Context, Action> EventDrivenNode<Event, Context, Action> for ChannelEventSource<Event>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
{
async fn wait_for_event(&mut self) -> Result<Event, FloxideError> {
match self.receiver.recv().await {
Some(event) => Ok(event),
None => Err(FloxideError::Other("Event channel closed".to_string())),
}
}
async fn process_event(
&self,
_event: Event,
_ctx: &mut Context,
) -> Result<Action, FloxideError> {
Ok(Action::default())
}
fn id(&self) -> NodeId {
self.id.clone()
}
}
pub struct EventProcessor<Event, Context, Action, F>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
{
source: Arc<tokio::sync::Mutex<ChannelEventSource<Event>>>,
processor: F,
_phantom: PhantomData<(Context, Action)>,
}
impl<Event, Context, Action, F> EventProcessor<Event, Context, Action, F>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
{
pub fn new(capacity: usize, processor: F) -> (Self, mpsc::Sender<Event>) {
let (source, sender) = ChannelEventSource::new(capacity);
(
Self {
source: Arc::new(tokio::sync::Mutex::new(source)),
processor,
_phantom: PhantomData,
},
sender,
)
}
pub fn with_id(
capacity: usize,
id: impl Into<String>,
processor: F,
) -> (Self, mpsc::Sender<Event>) {
let (source, sender) = ChannelEventSource::with_id(capacity, id);
(
Self {
source: Arc::new(tokio::sync::Mutex::new(source)),
processor,
_phantom: PhantomData,
},
sender,
)
}
}
#[async_trait]
impl<Event, Context, Action, F> EventDrivenNode<Event, Context, Action>
for EventProcessor<Event, Context, Action, F>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
F: Fn(Event, &mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
{
async fn wait_for_event(&mut self) -> Result<Event, FloxideError> {
let mut source = self.source.lock().await;
<ChannelEventSource<Event> as EventDrivenNode<Event, Context, Action>>::wait_for_event(
&mut *source,
)
.await
}
async fn process_event(&self, event: Event, ctx: &mut Context) -> Result<Action, FloxideError> {
(self.processor)(event, ctx)
}
fn id(&self) -> NodeId {
self.source
.try_lock()
.map(|source| {
<ChannelEventSource<Event> as EventDrivenNode<Event, Context, Action>>::id(&*source)
})
.unwrap_or_else(|_| "locked".to_string())
}
}
type EventNodeRef<E, C, A> = Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>;
pub struct EventDrivenWorkflow<Event, Context, Action>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
{
nodes: HashMap<NodeId, EventNodeRef<Event, Context, Action>>,
routes: HashMap<(NodeId, Action), NodeId>,
initial_node: NodeId,
termination_action: Action,
}
impl<Event, Context, Action> EventDrivenWorkflow<Event, Context, Action>
where
Event: Send + 'static,
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default,
{
pub fn new(
initial_node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<Event, Context, Action>>>,
termination_action: Action,
) -> Self {
let id = {
initial_node
.try_lock()
.map(|n| n.id())
.unwrap_or_else(|_| "locked".to_string())
};
let mut nodes = HashMap::new();
nodes.insert(id.clone(), initial_node);
Self {
nodes,
routes: HashMap::new(),
initial_node: id,
termination_action,
}
}
pub fn add_node(
&mut self,
node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<Event, Context, Action>>>,
) {
let id = {
node.try_lock()
.map(|n| n.id())
.unwrap_or_else(|_| "locked".to_string())
};
self.nodes.insert(id, node);
}
pub fn set_route(&mut self, from_id: &NodeId, action: Action, to_id: &NodeId) {
self.routes.insert((from_id.clone(), action), to_id.clone());
}
pub fn set_route_with_validation(
&mut self,
from_id: &NodeId,
action: Action,
to_id: &NodeId,
) -> Result<(), FloxideError> {
if !self.nodes.contains_key(to_id) {
return Err(FloxideError::Other(format!(
"Destination node '{}' not found in workflow",
to_id
)));
}
self.routes.insert((from_id.clone(), action), to_id.clone());
Ok(())
}
pub async fn execute(&self, ctx: &mut Context) -> Result<(), FloxideError> {
let mut current_node_id = self.initial_node.clone();
loop {
let node = self
.nodes
.get(¤t_node_id)
.ok_or_else(|| FloxideError::node_not_found(current_node_id.clone()))?;
let event = {
let mut node_guard = node.lock().await;
match node_guard.wait_for_event().await {
Ok(event) => event,
Err(e) => {
if e.to_string().contains("not an event source") {
warn!(
"Node '{}' is not an event source, routing to initial node",
current_node_id
);
current_node_id = self.initial_node.clone();
let source_node =
self.nodes.get(¤t_node_id).ok_or_else(|| {
FloxideError::Other(
"Initial node not found in workflow".to_string(),
)
})?;
let mut source_guard = source_node.lock().await;
source_guard.wait_for_event().await?
} else {
return Err(e);
}
}
}
};
let action = {
let node_guard = node.lock().await;
node_guard.process_event(event, ctx).await?
};
if action == self.termination_action {
info!("Event-driven workflow terminated with termination action");
return Ok(());
}
current_node_id = self
.routes
.get(&(current_node_id, action.clone()))
.ok_or_else(|| {
FloxideError::WorkflowDefinitionError(format!(
"No route defined for action: {}",
action.name()
))
})?
.clone();
}
}
pub async fn execute_with_timeout(
&self,
ctx: &mut Context,
timeout: Duration,
) -> Result<(), FloxideError> {
match tokio::time::timeout(timeout, self.execute(ctx)).await {
Ok(result) => result,
Err(_) => Err(FloxideError::timeout(
"Event-driven workflow execution timed out",
)),
}
}
}
pub struct EventDrivenNodeAdapter<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
id: NodeId,
timeout: Duration,
timeout_action: A,
}
impl<E, C, A> EventDrivenNodeAdapter<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
pub fn new(
node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
timeout: Duration,
timeout_action: A,
) -> Self {
let id = {
node.try_lock()
.map(|n| n.id())
.unwrap_or_else(|_| "locked".to_string())
};
Self {
node,
id,
timeout,
timeout_action,
}
}
pub fn with_id(
node: Arc<tokio::sync::Mutex<dyn EventDrivenNode<E, C, A>>>,
id: impl Into<String>,
timeout: Duration,
timeout_action: A,
) -> Self {
Self {
node,
id: id.into(),
timeout,
timeout_action,
}
}
}
#[async_trait]
impl<E, C, A> Node<C, A> for EventDrivenNodeAdapter<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
type Output = ();
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(&self, ctx: &mut C) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
let wait_for_event_future = async {
let mut node_guard = self.node.lock().await;
node_guard.wait_for_event().await
};
match tokio::time::timeout(self.timeout, wait_for_event_future).await {
Ok(Ok(event)) => {
let action = {
let node_guard = self.node.lock().await;
node_guard.process_event(event, ctx).await?
};
Ok(NodeOutcome::RouteToAction(action))
}
Ok(Err(e)) => Err(e),
Err(_) => {
Ok(NodeOutcome::RouteToAction(self.timeout_action.clone()))
}
}
}
}
pub struct NestedEventDrivenWorkflow<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
workflow: Arc<EventDrivenWorkflow<E, C, A>>,
id: NodeId,
timeout: Option<Duration>,
complete_action: A,
timeout_action: A,
}
impl<E, C, A> NestedEventDrivenWorkflow<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
pub fn new(
workflow: Arc<EventDrivenWorkflow<E, C, A>>,
complete_action: A,
timeout_action: A,
) -> Self {
Self {
workflow,
id: Uuid::new_v4().to_string(),
timeout: None,
complete_action,
timeout_action,
}
}
pub fn with_timeout(
workflow: Arc<EventDrivenWorkflow<E, C, A>>,
timeout: Duration,
complete_action: A,
timeout_action: A,
) -> Self {
Self {
workflow,
id: Uuid::new_v4().to_string(),
timeout: Some(timeout),
complete_action,
timeout_action,
}
}
pub fn with_id(
workflow: Arc<EventDrivenWorkflow<E, C, A>>,
id: impl Into<String>,
complete_action: A,
timeout_action: A,
) -> Self {
Self {
workflow,
id: id.into(),
timeout: None,
complete_action,
timeout_action,
}
}
}
#[async_trait]
impl<E, C, A> Node<C, A> for NestedEventDrivenWorkflow<E, C, A>
where
E: Send + 'static,
C: Send + Sync + 'static,
A: ActionType + Send + Sync + 'static + Default,
{
type Output = ();
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(&self, ctx: &mut C) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
match self.timeout {
Some(timeout) => {
match tokio::time::timeout(timeout, self.workflow.execute(ctx)).await {
Ok(Ok(())) => {
Ok(NodeOutcome::RouteToAction(self.complete_action.clone()))
}
Ok(Err(e)) => {
Err(e)
}
Err(_) => {
Ok(NodeOutcome::RouteToAction(self.timeout_action.clone()))
}
}
}
None => {
self.workflow.execute(ctx).await?;
Ok(NodeOutcome::RouteToAction(self.complete_action.clone()))
}
}
}
}
pub trait EventActionExt: ActionType {
fn terminate() -> Self;
fn timeout() -> Self;
}
impl EventActionExt for DefaultAction {
fn terminate() -> Self {
DefaultAction::Custom("terminate".into())
}
fn timeout() -> Self {
DefaultAction::Custom("timeout".into())
}
}