use langchainrust::{
StateGraph, START, END,
AgentState, StateUpdate, MessageEntry,
FunctionRouter,
};
use std::collections::HashMap;
#[tokio::test]
async fn example_smart_qa_pipeline() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("receive_question", |state: &AgentState| {
let mut new_state = state.clone();
new_state.add_message(MessageEntry::ai(format!("收到问题: {}", state.input)));
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("analyze_complexity", |state: &AgentState| {
let is_complex = state.input.len() > 20;
let mut new_state = state.clone();
new_state.add_message(MessageEntry::ai(
format!("分析完成 - 问题类型: {}", if is_complex { "复杂" } else { "简单" })
));
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("simple_answer", |state: &AgentState| {
let mut new_state = state.clone();
let answer = format!("快速回答: {}", state.input);
new_state.set_output(answer);
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("complex_answer", |state: &AgentState| {
let mut new_state = state.clone();
new_state.add_message(MessageEntry::ai("正在深度分析...".to_string()));
let answer = format!("详细分析: {}", state.input);
new_state.set_output(answer);
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("finalize", |state: &AgentState| {
let mut new_state = state.clone();
new_state.add_message(MessageEntry::ai(format!("处理完成,共 {} 条消息", state.messages.len())));
Ok(StateUpdate::full(new_state))
});
graph.add_edge(START, "receive_question");
graph.add_edge("receive_question", "analyze_complexity");
let targets = HashMap::from([
("simple".to_string(), "simple_answer".to_string()),
("complex".to_string(), "complex_answer".to_string()),
]);
graph.add_conditional_edges("analyze_complexity", "complexity_router", targets, None);
graph.add_edge("simple_answer", "finalize");
graph.add_edge("complex_answer", "finalize");
graph.add_edge("finalize", END);
let router = FunctionRouter::new(|state: &AgentState| {
if state.input.len() <= 20 { "simple" } else { "complex" }.to_string()
});
graph.set_conditional_router("complexity_router", router);
let compiled = graph.compile().unwrap();
let simple_result = compiled.invoke(AgentState::new("什么是 Rust?".to_string())).await.unwrap();
assert_eq!(simple_result.recursion_count, 4);
let output = simple_result.final_state.output.as_ref().unwrap();
assert!(output.contains("快速回答"));
println!("=== 简单问题 ===");
println!("输出: {}", output);
println!("消息数: {}", simple_result.final_state.messages.len());
let complex_result = compiled.invoke(AgentState::new("请详细解释所有权系统和生命周期概念?".to_string())).await.unwrap();
let output2 = complex_result.final_state.output.as_ref().unwrap();
assert!(output2.contains("详细分析"));
println!("=== 复杂问题 ===");
println!("输出: {}", output2);
}
#[tokio::test]
async fn example_multi_branch_routing() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("classify_input", |state: &AgentState| {
Ok(StateUpdate::full(state.clone()))
});
graph.add_node_fn("math_handler", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("数学解答: {}", state.input));
Ok(StateUpdate::full(s))
});
graph.add_node_fn("text_handler", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("文本解答: {}", state.input));
Ok(StateUpdate::full(s))
});
graph.add_node_fn("code_handler", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("代码解答: {}", state.input));
Ok(StateUpdate::full(s))
});
graph.add_node_fn("default_handler", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("通用解答: {}", state.input));
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "classify_input");
let targets = HashMap::from([
("math".to_string(), "math_handler".to_string()),
("text".to_string(), "text_handler".to_string()),
("code".to_string(), "code_handler".to_string()),
]);
graph.add_conditional_edges("classify_input", "type_router", targets, Some("default_handler".to_string()));
graph.add_edge("math_handler", END);
graph.add_edge("text_handler", END);
graph.add_edge("code_handler", END);
graph.add_edge("default_handler", END);
let router = FunctionRouter::new(|state: &AgentState| {
let input = state.input.to_lowercase();
if input.contains("计算") || input.contains("数学") { "math" }
else if input.contains("翻译") || input.contains("写作") { "text" }
else if input.contains("代码") || input.contains("编程") { "code" }
else { "unknown" }
}.to_string());
graph.set_conditional_router("type_router", router);
let compiled = graph.compile().unwrap();
let math_result = compiled.invoke(AgentState::new("帮我计算 1+1".to_string())).await.unwrap();
assert!(math_result.final_state.output.as_ref().unwrap().contains("数学解答"));
let text_result = compiled.invoke(AgentState::new("请翻译这段文字".to_string())).await.unwrap();
assert!(text_result.final_state.output.as_ref().unwrap().contains("文本解答"));
let code_result = compiled.invoke(AgentState::new("写一段排序代码".to_string())).await.unwrap();
assert!(code_result.final_state.output.as_ref().unwrap().contains("代码解答"));
let default_result = compiled.invoke(AgentState::new("今天天气怎么样".to_string())).await.unwrap();
assert!(default_result.final_state.output.as_ref().unwrap().contains("通用解答"));
println!("=== 三分支路由 ===");
println!("数学 -> {}", math_result.final_state.output.as_ref().unwrap());
println!("文本 -> {}", text_result.final_state.output.as_ref().unwrap());
println!("代码 -> {}", code_result.final_state.output.as_ref().unwrap());
println!("未知 -> {}", default_result.final_state.output.as_ref().unwrap());
}
#[tokio::test]
async fn example_loop_with_limit() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("increment_counter", |state: &AgentState| {
let mut new_state = state.clone();
let current: i32 = state.input.parse().unwrap_or(0);
new_state.input = (current + 1).to_string();
Ok(StateUpdate::full(new_state))
});
graph.add_node_fn("check_limit", |state: &AgentState| {
Ok(StateUpdate::full(state.clone()))
});
graph.add_node_fn("finish", |state: &AgentState| {
let mut s = state.clone();
s.set_output(format!("最终计数: {}", state.input));
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "increment_counter");
graph.add_edge("increment_counter", "check_limit");
let targets = HashMap::from([
("continue".to_string(), "increment_counter".to_string()),
("done".to_string(), "finish".to_string()),
]);
graph.add_conditional_edges("check_limit", "limit_router", targets, None);
graph.add_edge("finish", END);
let router = FunctionRouter::new(|state: &AgentState| {
let current: i32 = state.input.parse().unwrap_or(0);
if current >= 5 { "done" } else { "continue" }.to_string()
});
graph.set_conditional_router("limit_router", router);
let compiled = graph.compile().unwrap().with_recursion_limit(20);
let result = compiled.invoke(AgentState::new("0".to_string())).await.unwrap();
assert!(result.final_state.output.as_ref().unwrap().contains("最终计数: 5"));
println!("=== 循环处理 ===");
println!("最终输出: {}", result.final_state.output.as_ref().unwrap());
println!("执行节点数: {}", result.recursion_count);
}