use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use synaptic_core::SynapticError;
use synaptic_graph::{Command, Node, NodeOutput, State, StateGraph, StreamMode, END};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct CounterState {
counter: usize,
visited: Vec<String>,
}
impl State for CounterState {
fn merge(&mut self, other: Self) {
self.counter += other.counter;
self.visited.extend(other.visited);
}
}
struct IncrementNode {
name: String,
}
#[async_trait]
impl Node<CounterState> for IncrementNode {
async fn process(
&self,
mut state: CounterState,
) -> Result<NodeOutput<CounterState>, SynapticError> {
state.counter += 1;
state.visited.push(self.name.clone());
Ok(state.into())
}
}
struct CommandNode {
name: String,
command: Option<CommandKind>,
}
#[derive(Clone)]
enum CommandKind {
Goto(String),
End,
}
#[async_trait]
impl Node<CounterState> for CommandNode {
async fn process(
&self,
mut state: CounterState,
) -> Result<NodeOutput<CounterState>, SynapticError> {
state.counter += 1;
state.visited.push(self.name.clone());
match &self.command {
Some(CommandKind::Goto(target)) => {
let delta = CounterState {
counter: 1,
visited: vec![self.name.clone()],
};
Ok(NodeOutput::Command(Command::goto_with_update(
target.clone(),
delta,
)))
}
Some(CommandKind::End) => {
let delta = CounterState {
counter: 1,
visited: vec![self.name.clone()],
};
Ok(NodeOutput::Command(Command::goto_with_update(END, delta)))
}
None => Ok(state.into()),
}
}
}
#[test]
fn command_debug() {
let goto: Command<CounterState> = Command::goto("target");
let end: Command<CounterState> = Command::end();
assert!(format!("{:?}", goto).contains("goto"));
assert!(format!("{:?}", end).contains("goto")); }
#[test]
fn command_clone() {
let cmd: Command<CounterState> = Command::goto("node_a");
let cloned = cmd.clone();
let dbg = format!("{:?}", cloned);
assert!(dbg.contains("goto"));
}
#[tokio::test]
async fn goto_command_skips_node() {
let graph = StateGraph::new()
.add_node(
"a",
CommandNode {
name: "a".into(),
command: Some(CommandKind::Goto("c".to_string())), },
)
.add_node(
"b",
CommandNode {
name: "b".into(),
command: None,
},
)
.add_node(
"c",
CommandNode {
name: "c".into(),
command: None,
},
)
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", END)
.set_entry_point("a")
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.visited, vec!["a", "c"]);
assert_eq!(result.counter, 2);
}
#[tokio::test]
async fn goto_command_redirects_to_earlier_node() {
struct LoopNode {
threshold: usize,
}
#[async_trait]
impl Node<CounterState> for LoopNode {
async fn process(
&self,
mut state: CounterState,
) -> Result<NodeOutput<CounterState>, SynapticError> {
state.counter += 1;
state.visited.push("loop".to_string());
let delta = CounterState {
counter: 1,
visited: vec!["loop".to_string()],
};
if state.counter < self.threshold {
Ok(NodeOutput::Command(Command::goto_with_update("a", delta)))
} else {
Ok(NodeOutput::Command(Command::goto_with_update(END, delta)))
}
}
}
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node(
"loop",
LoopNode {
threshold: 4, },
)
.add_edge("a", "loop")
.add_edge("loop", END)
.set_entry_point("a")
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.counter, 4);
assert_eq!(result.visited, vec!["a", "loop", "a", "loop"]);
}
#[tokio::test]
async fn end_command_stops_execution() {
let graph = StateGraph::new()
.add_node(
"a",
CommandNode {
name: "a".into(),
command: Some(CommandKind::End), },
)
.add_node(
"b",
CommandNode {
name: "b".into(),
command: None,
},
)
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.visited, vec!["a"]);
assert_eq!(result.counter, 1);
}
#[tokio::test]
async fn no_command_preserves_normal_routing() {
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_node("c", IncrementNode { name: "c".into() })
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", END)
.set_entry_point("a")
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.visited, vec!["a", "b", "c"]);
assert_eq!(result.counter, 3);
}
#[tokio::test]
async fn goto_command_works_in_stream_mode() {
let graph = StateGraph::new()
.add_node(
"a",
CommandNode {
name: "a".into(),
command: Some(CommandKind::Goto("c".to_string())),
},
)
.add_node(
"b",
CommandNode {
name: "b".into(),
command: None,
},
)
.add_node(
"c",
CommandNode {
name: "c".into(),
command: None,
},
)
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", END)
.set_entry_point("a")
.compile()
.unwrap();
let events: Vec<_> = graph
.stream(CounterState::default(), StreamMode::Values)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].node, "a");
assert_eq!(events[1].node, "c");
assert_eq!(events[1].state.visited, vec!["a", "c"]);
}
#[tokio::test]
async fn end_command_works_in_stream_mode() {
let graph = StateGraph::new()
.add_node(
"a",
CommandNode {
name: "a".into(),
command: Some(CommandKind::End),
},
)
.add_node(
"b",
CommandNode {
name: "b".into(),
command: None,
},
)
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.compile()
.unwrap();
let events: Vec<_> = graph
.stream(CounterState::default(), StreamMode::Values)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].node, "a");
assert_eq!(events[0].state.counter, 1);
}