use async_trait::async_trait;
use uuid::Uuid;
use crate::action::ActionType;
use crate::error::FlowrsError;
pub type NodeId = String;
#[derive(Debug, Clone)]
pub enum NodeOutcome<Output, Action> {
Success(Output),
Skipped,
RouteToAction(Action),
}
#[async_trait]
pub trait Node<Context, Action>: Send + Sync
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static,
Self::Output: Send + Sync + 'static,
{
type Output;
fn id(&self) -> NodeId;
async fn process(
&self,
ctx: &mut Context,
) -> Result<NodeOutcome<Self::Output, Action>, FlowrsError>;
}
pub mod closure {
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use super::*;
pub fn node<Closure, Context, Action, Output, Fut>(
closure: Closure,
) -> ClosureNode<Closure, Context, Action, Output>
where
Context: Clone + Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static,
Output: Send + Sync + 'static,
Closure: Fn(Context) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Context, NodeOutcome<Output, Action>), FlowrsError>>
+ Send
+ 'static,
{
ClosureNode {
id: Uuid::new_v4().to_string(),
closure,
_phantom: PhantomData,
}
}
#[derive(Clone)]
pub struct ClosureNode<Closure, Context, Action, Output> {
id: NodeId,
closure: Closure,
_phantom: PhantomData<(Context, Action, Output)>,
}
impl<Closure, Context, Action, Output> Debug for ClosureNode<Closure, Context, Action, Output> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClosureNode").field("id", &self.id).finish()
}
}
#[async_trait]
impl<Closure, Context, Action, Output, Fut> Node<Context, Action>
for ClosureNode<Closure, Context, Action, Output>
where
Context: Clone + Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static,
Output: Send + Sync + 'static,
Closure: Fn(Context) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Context, NodeOutcome<Output, Action>), FlowrsError>>
+ Send
+ 'static,
{
type Output = Output;
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(
&self,
ctx: &mut Context,
) -> Result<NodeOutcome<Self::Output, Action>, FlowrsError> {
let ctx_clone = ctx.clone();
let (updated_ctx, outcome) = (self.closure)(ctx_clone).await?;
*ctx = updated_ctx;
Ok(outcome)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::action::DefaultAction;
#[derive(Debug, Clone)]
struct TestContext {
value: i32,
}
#[tokio::test]
async fn test_create_node_from_closure() {
let test_node = closure::node(|mut ctx: TestContext| async move {
ctx.value += 1;
let value = ctx.value; Ok((ctx, NodeOutcome::<i32, DefaultAction>::Success(value)))
});
let mut context = TestContext { value: 5 };
let result = test_node.process(&mut context).await.unwrap();
match result {
NodeOutcome::Success(value) => {
assert_eq!(value, 6);
assert_eq!(context.value, 6);
}
_ => panic!("Expected Success outcome"),
}
}
#[tokio::test]
async fn test_skip_node() {
let skip_node = closure::node(|ctx: TestContext| async move {
Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
});
let mut context = TestContext { value: 5 };
let result = skip_node.process(&mut context).await.unwrap();
match result {
NodeOutcome::Skipped => {}
_ => panic!("Expected Skipped outcome"),
}
assert_eq!(context.value, 5);
}
#[tokio::test]
async fn test_route_to_action() {
let route_node = closure::node(|ctx: TestContext| async move {
Ok((
ctx,
NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Custom(
"alternate_path".into(),
)),
))
});
let mut context = TestContext { value: 5 };
let result = route_node.process(&mut context).await.unwrap();
match result {
NodeOutcome::RouteToAction(action) => {
assert_eq!(action.name(), "alternate_path");
}
_ => panic!("Expected RouteToAction outcome"),
}
}
}