mod llm_integration;
mod async_node;
mod visualize;
mod human_loop;
mod validation;
mod parallel;
use langchainrust::{
StateGraph, GraphBuilder, START, END,
AgentState, StateUpdate,
ThreadSafeMemoryCheckpointer, Checkpointer,
};
use std::collections::HashMap;
#[tokio::test]
async fn test_linear_graph_execution() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("node1", |state: &AgentState| {
Ok(StateUpdate::full(AgentState::new(format!("step1: {}", state.input))))
})
.add_node_fn("node2", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("step2: {}", state.input));
Ok(StateUpdate::full(s))
})
.add_edge(START, "node1")
.add_edge("node1", "node2")
.add_edge("node2", END)
.compile()
.unwrap();
let result = compiled.invoke(AgentState::new("test".to_string())).await.unwrap();
println!("=== test_linear_graph_execution ===");
println!("final_state.input: {}", result.final_state.input);
println!("final_state.output: {:?}", result.final_state.output);
println!("recursion_count: {}", result.recursion_count);
}
#[tokio::test]
async fn test_single_node_graph() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("single", |state: &AgentState| {
let mut s = state.clone();
s.set_output("done".to_string());
Ok(StateUpdate::full(s))
})
.add_edge(START, "single")
.add_edge("single", END)
.compile()
.unwrap();
let result = compiled.invoke(AgentState::new("input".to_string())).await.unwrap();
println!("=== test_single_node_graph ===");
println!("final_state.input: {}", result.final_state.input);
println!("final_state.output: {:?}", result.final_state.output);
println!("recursion_count: {}", result.recursion_count);
}
#[tokio::test]
async fn test_three_node_chain() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("first", |state| {
Ok(StateUpdate::full(AgentState::new(format!("1:{}", state.input))))
})
.add_node_fn("second", |state| {
Ok(StateUpdate::full(AgentState::new(format!("2:{}", state.input))))
})
.add_node_fn("third", |state| {
let mut s = state.clone();
s.set_output(format!("final:{}", state.input));
Ok(StateUpdate::full(s))
})
.add_edge(START, "first")
.add_edge("first", "second")
.add_edge("second", "third")
.add_edge("third", END)
.compile()
.unwrap();
let result = compiled.invoke(AgentState::new("input".to_string())).await.unwrap();
println!("=== test_three_node_chain ===");
println!("final_state.input: {}", result.final_state.input);
println!("final_state.output: {:?}", result.final_state.output);
println!("recursion_count: {}", result.recursion_count);
}
#[tokio::test]
async fn test_stream_execution_returns_events() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("a", |state| Ok(StateUpdate::full(state.clone())))
.add_node_fn("b", |state| Ok(StateUpdate::full(state.clone())))
.add_edge(START, "a")
.add_edge("a", "b")
.add_edge("b", END)
.compile()
.unwrap();
let events = compiled.stream(AgentState::new("test".to_string())).await.unwrap();
println!("=== test_stream_execution_returns_events ===");
println!("events count: {}", events.len());
for (i, event) in events.iter().enumerate() {
println!("event[{}]: {:?}", i, event);
}
}
#[tokio::test]
async fn test_checkpointer_integration() {
let checkpointer = ThreadSafeMemoryCheckpointer::<AgentState>::new();
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("process", |state| {
let mut s = state.clone();
s.set_output("processed".to_string());
Ok(StateUpdate::full(s))
})
.add_edge(START, "process")
.add_edge("process", END)
.compile()
.unwrap()
.with_checkpointer(checkpointer);
let result = compiled.invoke(AgentState::new("test".to_string())).await.unwrap();
println!("=== test_checkpointer_integration ===");
println!("final_state.input: {}", result.final_state.input);
println!("final_state.output: {:?}", result.final_state.output);
}
#[tokio::test]
async fn test_checkpointer_save_and_load() {
let checkpointer = ThreadSafeMemoryCheckpointer::<AgentState>::new();
let state1 = AgentState::new("first".to_string());
let id1 = checkpointer.save(&state1).await.unwrap();
let state2 = AgentState::new("second".to_string());
let id2 = checkpointer.save(&state2).await.unwrap();
let list = checkpointer.list().await.unwrap();
println!("=== test_checkpointer_save_and_load ===");
println!("checkpoint id1: {}", id1);
println!("checkpoint id2: {}", id2);
println!("list count: {}", list.len());
let loaded = checkpointer.load(&id1).await.unwrap();
println!("loaded state input: {}", loaded.input);
checkpointer.delete(&id2).await.unwrap();
let list_after_delete = checkpointer.list().await.unwrap();
println!("list count after delete: {}", list_after_delete.len());
}
#[test]
fn test_graph_builder_creates_valid_graph() {
let graph = GraphBuilder::<AgentState>::new()
.add_node_fn("n1", |state| Ok(StateUpdate::full(state.clone())))
.add_edge(START, "n1")
.add_edge("n1", END)
.build();
let result = graph.compile();
println!("=== test_graph_builder_creates_valid_graph ===");
println!("compile result: {:?}", result.is_ok());
}
#[test]
fn test_empty_graph_fails_compile() {
let graph: StateGraph<AgentState> = StateGraph::new();
let result = graph.compile();
println!("=== test_empty_graph_fails_compile ===");
println!("compile result (should be err): {:?}", result.is_err());
}
#[test]
fn test_invalid_entry_point_fails() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("node1", |state| Ok(StateUpdate::full(state.clone())));
graph.set_entry_point("nonexistent");
let result = graph.compile();
println!("=== test_invalid_entry_point_fails ===");
println!("compile result (should be err): {:?}", result.is_err());
}
#[test]
fn test_missing_edge_fails_validation() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("a", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("b", |state| Ok(StateUpdate::full(state.clone())));
graph.add_edge(START, "a");
let compiled = graph.compile();
println!("=== test_missing_edge_fails_validation ===");
println!("compile result (should be ok): {:?}", compiled.is_ok());
}