use crate::errors::Result;
use crate::state::State;
use crate::types::Send as SendType;
use async_trait::async_trait;
use std::future::Future;
#[derive(Debug, Clone)]
pub enum BranchResult {
Single(String),
Multiple(Vec<String>),
Send(Vec<SendType>),
End,
}
impl BranchResult {
pub fn single(node: impl Into<String>) -> Self {
BranchResult::Single(node.into())
}
pub fn multiple(nodes: Vec<impl Into<String>>) -> Self {
BranchResult::Multiple(nodes.into_iter().map(|n| n.into()).collect())
}
pub fn send(sends: Vec<SendType>) -> Self {
BranchResult::Send(sends)
}
pub fn end() -> Self {
BranchResult::End
}
pub fn is_end(&self) -> bool {
matches!(self, BranchResult::End)
}
pub fn node_names(&self) -> Vec<String> {
match self {
BranchResult::Single(name) => vec![name.clone()],
BranchResult::Multiple(names) => names.clone(),
BranchResult::Send(_) => vec![],
BranchResult::End => vec![],
}
}
}
#[async_trait]
pub trait Branch<S>: std::marker::Send + Sync
where
S: State,
{
async fn route(&self, state: &S) -> Result<BranchResult>;
}
#[async_trait]
impl<S, F, Fut> Branch<S> for F
where
S: State,
F: Fn(&S) -> Fut + std::marker::Send + Sync,
Fut: Future<Output = Result<BranchResult>> + std::marker::Send,
{
async fn route(&self, state: &S) -> Result<BranchResult> {
self(state).await
}
}
pub type BranchBox<S> = Box<dyn Branch<S>>;
pub fn branch_fn<S, F, Fut>(f: F) -> impl Branch<S>
where
S: State,
F: Fn(&S) -> Fut + std::marker::Send + Sync + 'static,
Fut: Future<Output = Result<BranchResult>> + std::marker::Send + 'static,
{
f
}
pub struct StaticBranch {
target: String,
}
impl StaticBranch {
pub fn new(target: impl Into<String>) -> Self {
Self {
target: target.into(),
}
}
}
#[async_trait]
impl<S: State> Branch<S> for StaticBranch {
async fn route(&self, _state: &S) -> Result<BranchResult> {
Ok(BranchResult::Single(self.target.clone()))
}
}
pub struct EndBranch;
#[async_trait]
impl<S: State> Branch<S> for EndBranch {
async fn route(&self, _state: &S) -> Result<BranchResult> {
Ok(BranchResult::End)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::DictState;
#[test]
fn test_branch_result_creation() {
let single = BranchResult::single("node1");
assert!(matches!(single, BranchResult::Single(_)));
assert_eq!(single.node_names(), vec!["node1"]);
let multiple = BranchResult::multiple(vec!["node1", "node2"]);
assert!(matches!(multiple, BranchResult::Multiple(_)));
assert_eq!(multiple.node_names(), vec!["node1", "node2"]);
let end = BranchResult::end();
assert!(end.is_end());
assert!(end.node_names().is_empty());
}
#[tokio::test]
async fn test_static_branch() {
let branch = StaticBranch::new("target");
let state = DictState::new();
let result = branch.route(&state).await.unwrap();
assert!(matches!(result, BranchResult::Single(_)));
assert_eq!(result.node_names(), vec!["target"]);
}
#[tokio::test]
async fn test_end_branch() {
let branch = EndBranch;
let state = DictState::new();
let result = branch.route(&state).await.unwrap();
assert!(result.is_end());
}
#[tokio::test]
async fn test_branch_closure() {
let branch = |_state: &DictState| async {
Ok(BranchResult::single("dynamic_node"))
};
let state = DictState::new();
let result = branch.route(&state).await.unwrap();
assert_eq!(result.node_names(), vec!["dynamic_node"]);
}
#[tokio::test]
async fn test_branch_with_send() {
let sends = vec![
SendType::new("process", serde_json::json!({"id": 1})),
SendType::new("process", serde_json::json!({"id": 2})),
];
let result = BranchResult::send(sends);
assert!(matches!(result, BranchResult::Send(_)));
}
}