use langchainrust::{
StateGraph, GraphBuilder, START, END,
AgentState, StateUpdate,
};
#[tokio::test(flavor = "multi_thread")]
async fn test_async_node_basic() {
let compiled = GraphBuilder::<AgentState>::new()
.add_async_node("async_process", |state: &AgentState| {
let state = state.clone();
async move {
let mut new_state = state;
new_state.add_message(langchainrust::MessageEntry::ai("Async processed".to_string()));
Ok(StateUpdate::full(new_state))
}
})
.add_edge(START, "async_process")
.add_edge("async_process", END)
.compile()
.unwrap();
let input = AgentState::new("Hello async".to_string());
let result = compiled.invoke(input).await.unwrap();
assert_eq!(result.recursion_count, 1);
assert!(result.final_state.messages.len() > 1);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_multiple_async_nodes() {
let compiled = GraphBuilder::<AgentState>::new()
.add_async_node("step1", |state: &AgentState| {
let state = state.clone();
async move {
let mut new_state = state;
new_state.add_message(langchainrust::MessageEntry::ai("Step 1".to_string()));
Ok(StateUpdate::full(new_state))
}
})
.add_async_node("step2", |state: &AgentState| {
let state = state.clone();
async move {
let mut new_state = state;
new_state.add_message(langchainrust::MessageEntry::ai("Step 2".to_string()));
Ok(StateUpdate::full(new_state))
}
})
.add_async_node("step3", |state: &AgentState| {
let state = state.clone();
async move {
let mut new_state = state;
new_state.set_output("Done".to_string());
Ok(StateUpdate::full(new_state))
}
})
.add_edge(START, "step1")
.add_edge("step1", "step2")
.add_edge("step2", "step3")
.add_edge("step3", END)
.compile()
.unwrap();
let input = AgentState::new("Test".to_string());
let result = compiled.invoke(input).await.unwrap();
assert_eq!(result.recursion_count, 3);
assert_eq!(result.final_state.output, Some("Done".to_string()));
assert_eq!(result.final_state.messages.len(), 4);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_mixed_sync_async_nodes() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("sync_step", |state: &AgentState| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("Sync".to_string()));
Ok(StateUpdate::full(new_state))
})
.add_async_node("async_step", |state: &AgentState| {
let state = state.clone();
async move {
let mut new_state = state;
new_state.add_message(langchainrust::MessageEntry::ai("Async".to_string()));
Ok(StateUpdate::full(new_state))
}
})
.add_edge(START, "sync_step")
.add_edge("sync_step", "async_step")
.add_edge("async_step", END)
.compile()
.unwrap();
let input = AgentState::new("Mixed".to_string());
let result = compiled.invoke(input).await.unwrap();
assert_eq!(result.recursion_count, 2);
assert_eq!(result.final_state.messages.len(), 3);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_async_node_with_delay() {
let compiled = GraphBuilder::<AgentState>::new()
.add_async_node("delayed", |state: &AgentState| {
let state = state.clone();
async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let mut new_state = state;
new_state.set_output("Delayed result".to_string());
Ok(StateUpdate::full(new_state))
}
})
.add_edge(START, "delayed")
.add_edge("delayed", END)
.compile()
.unwrap();
let input = AgentState::new("Test delay".to_string());
let result = compiled.invoke(input).await.unwrap();
assert_eq!(result.final_state.output, Some("Delayed result".to_string()));
}