use futures::StreamExt;
use std::time::Duration;
use tokio::time::sleep;
use serde::{Deserialize, Serialize};
use wesichain_core::{HasFinalOutput, HasUserInput, ReActStep, ScratchpadState, WesichainError};
use wesichain_graph::{
GraphBuilder, GraphContext, GraphEvent, GraphNode, GraphState, StateSchema, StateUpdate, END,
};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct TestState {
value: Vec<String>,
#[serde(skip)]
scratchpad: Vec<ReActStep>,
}
impl StateSchema for TestState {
type Update = Self;
fn apply(current: &Self, update: Self::Update) -> Self {
let mut new_val = current.value.clone();
new_val.extend(update.value);
TestState {
value: new_val,
scratchpad: current.scratchpad.clone(), }
}
}
impl HasUserInput for TestState {
fn user_input(&self) -> &str {
""
}
}
impl HasFinalOutput for TestState {
fn final_output(&self) -> Option<&str> {
None
}
fn set_final_output(&mut self, _: String) {}
}
impl ScratchpadState for TestState {
fn scratchpad(&self) -> &Vec<ReActStep> {
&self.scratchpad
}
fn scratchpad_mut(&mut self) -> &mut Vec<ReActStep> {
&mut self.scratchpad
}
fn iteration_count(&self) -> u32 {
0
}
fn increment_iteration(&mut self) {}
}
struct SleepNode {
name: String,
delay: u64,
}
#[async_trait::async_trait]
impl GraphNode<TestState> for SleepNode {
async fn invoke_with_context(
&self,
_: GraphState<TestState>,
_: &GraphContext,
) -> Result<StateUpdate<TestState>, WesichainError> {
sleep(Duration::from_millis(self.delay)).await;
Ok(StateUpdate::new(TestState {
value: vec![self.name.clone()],
..Default::default()
}))
}
}
#[tokio::test]
async fn test_streaming_parallel() {
let builder = GraphBuilder::<TestState>::new()
.add_node(
"A",
SleepNode {
name: "A".to_string(),
delay: 200,
},
)
.add_node(
"B",
SleepNode {
name: "B".to_string(),
delay: 200,
},
)
.set_entry("start_node")
.add_conditional_edge("start_node", |_| vec!["A".to_string(), "B".to_string()])
.add_edge("A", END)
.add_edge("B", END);
struct StartNode;
#[async_trait::async_trait]
impl GraphNode<TestState> for StartNode {
async fn invoke_with_context(
&self,
_: GraphState<TestState>,
_: &GraphContext,
) -> Result<StateUpdate<TestState>, WesichainError> {
Ok(StateUpdate::new(TestState::default()))
}
}
let builder = builder.add_node("start_node", StartNode);
let graph = builder.build();
let input = GraphState::new(TestState::default());
let mut stream = graph.stream_invoke(input);
let mut events = Vec::new();
let start = std::time::Instant::now();
while let Some(evt) = stream.next().await {
let evt = evt.expect("Stream error");
println!("{:?}", evt);
events.push(evt);
}
let duration = start.elapsed();
println!("Total duration: {:?}", duration);
assert!(duration < Duration::from_millis(350));
let active_nodes: Vec<_> = events
.iter()
.filter_map(|e| match e {
GraphEvent::NodeEnter { node, .. } => Some(node.clone()),
_ => None,
})
.collect();
assert!(active_nodes.contains(&"A".to_string()));
assert!(active_nodes.contains(&"B".to_string()));
if let GraphEvent::NodeEnter { timestamp, .. } = &events[0] {
assert!(*timestamp > 0);
}
}