use langchainrust::{
StateGraph, START, END,
AgentState, StateUpdate,
};
use std::time::Duration;
#[test]
fn test_fan_out_edge_creation() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("source", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("target1", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("target2", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("target3", |state| Ok(StateUpdate::full(state.clone())));
graph.add_edge(START, "source");
graph.add_fan_out("source", vec![
"target1".to_string(),
"target2".to_string(),
"target3".to_string(),
]);
graph.add_edge("target1", END);
graph.add_edge("target2", END);
graph.add_edge("target3", END);
let result = graph.compile();
assert!(result.is_ok(), "FanOut edge should compile successfully");
}
#[test]
fn test_fan_in_edge_creation() {
let edge = langchainrust::GraphEdge::fan_in(vec![
"source1".to_string(),
"source2".to_string(),
], "merge");
assert!(edge.fan_in_sources().is_some());
let sources = edge.fan_in_sources().unwrap();
assert_eq!(sources.len(), 2);
assert!(sources.contains(&"source1".to_string()));
assert!(sources.contains(&"source2".to_string()));
assert_eq!(edge.fixed_target(), Some("merge"));
}
#[tokio::test]
async fn test_parallel_execution_flow() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("start", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("parallel1", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("parallel1".to_string()));
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("parallel2", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("parallel2".to_string()));
Ok(StateUpdate::full(new_state))
});
graph.add_edge(START, "start");
graph.add_fan_out("start", vec![
"parallel1".to_string(),
"parallel2".to_string(),
]);
graph.add_edge("parallel1", END);
graph.add_edge("parallel2", END);
let compiled = graph.compile().unwrap();
let input = AgentState::new("test".to_string());
let result = compiled.invoke(input).await.unwrap();
assert!(result.final_state.messages.len() >= 2);
}
#[tokio::test]
async fn test_parallel_execution_timing() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_async_node("slow_task", |state: &AgentState| {
let state = state.clone();
async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let mut new_state = state;
new_state.set_output("slow_task_done".to_string());
Ok(StateUpdate::full(new_state))
}
});
graph.add_edge(START, "slow_task");
graph.add_edge("slow_task", END);
let compiled = graph.compile().unwrap();
let start_time = std::time::Instant::now();
let input = AgentState::new("timing_test".to_string());
let result = compiled.invoke(input).await.unwrap();
let elapsed = start_time.elapsed();
println!("异步节点执行耗时: {}ms", elapsed.as_millis());
assert!(elapsed >= Duration::from_millis(50));
assert!(result.final_state.output.is_some());
}
#[tokio::test]
async fn test_parallel_state_merge() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("step1", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("step1_result".to_string()));
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("step2", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("step2_result".to_string()));
new_state.set_output("final_result".to_string());
Ok(StateUpdate::full(new_state))
});
graph.add_edge(START, "step1");
graph.add_edge("step1", "step2");
graph.add_edge("step2", END);
let compiled = graph.compile().unwrap();
let input = AgentState::new("merge_test".to_string());
let result = compiled.invoke(input).await.unwrap();
assert!(result.final_state.output.is_some());
let output = result.final_state.output.unwrap();
println!("最终输出: {}", output);
assert!(result.final_state.messages.len() >= 3);
}