use langchainrust::{
StateGraph, START, END,
AgentState, StateUpdate,
FunctionRouter,
};
use std::collections::HashMap;
#[tokio::test]
async fn test_conditional_routing_short_path() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("entry", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("short", |state| {
let mut s = state.clone();
s.set_output("short".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("long", |state| {
let mut s = state.clone();
s.set_output("long".to_string());
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "entry");
let targets = HashMap::from([
("short".to_string(), "short".to_string()),
("long".to_string(), "long".to_string()),
]);
graph.add_conditional_edges("entry", "router", targets, None);
graph.add_edge("short", END);
graph.add_edge("long", END);
let router = FunctionRouter::new(|state: &AgentState| {
if state.input.len() < 10 { "short" } else { "long" }.to_string()
});
graph.set_conditional_router("router", router);
let compiled = graph.compile().unwrap();
let short_result = compiled.invoke(AgentState::new("hi".to_string())).await.unwrap();
assert_eq!(short_result.final_state.output, Some("short".to_string()));
}
#[tokio::test]
async fn test_conditional_routing_long_path() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("entry", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("short", |state| {
let mut s = state.clone();
s.set_output("short".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("long", |state| {
let mut s = state.clone();
s.set_output("long".to_string());
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "entry");
let targets = HashMap::from([
("short".to_string(), "short".to_string()),
("long".to_string(), "long".to_string()),
]);
graph.add_conditional_edges("entry", "router", targets, None);
graph.add_edge("short", END);
graph.add_edge("long", END);
let router = FunctionRouter::new(|state: &AgentState| {
if state.input.len() < 10 { "short" } else { "long" }.to_string()
});
graph.set_conditional_router("router", router);
let compiled = graph.compile().unwrap();
let long_result = compiled.invoke(AgentState::new("this is a very long input".to_string())).await.unwrap();
assert_eq!(long_result.final_state.output, Some("long".to_string()));
}
#[tokio::test]
async fn test_conditional_routing_with_default() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("entry", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("path_a", |state| {
let mut s = state.clone();
s.set_output("a".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("path_b", |state| {
let mut s = state.clone();
s.set_output("b".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("default_path", |state| {
let mut s = state.clone();
s.set_output("default".to_string());
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "entry");
let targets = HashMap::from([
("a".to_string(), "path_a".to_string()),
("b".to_string(), "path_b".to_string()),
]);
graph.add_conditional_edges("entry", "router", targets, Some("default_path".to_string()));
graph.add_edge("path_a", END);
graph.add_edge("path_b", END);
graph.add_edge("default_path", END);
let router = FunctionRouter::new(|state: &AgentState| {
let input = state.input.as_str();
if input.starts_with("a") { "a" }
else if input.starts_with("b") { "b" }
else { "unknown" } }.to_string());
graph.set_conditional_router("router", router);
let compiled = graph.compile().unwrap();
let default_result = compiled.invoke(AgentState::new("xyz".to_string())).await.unwrap();
assert_eq!(default_result.final_state.output, Some("default".to_string()));
}
#[tokio::test]
async fn test_multiple_conditional_branches() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("decide", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("alpha", |state| {
let mut s = state.clone();
s.set_output("alpha_result".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("beta", |state| {
let mut s = state.clone();
s.set_output("beta_result".to_string());
Ok(StateUpdate::full(s))
});
graph.add_node_fn("gamma", |state| {
let mut s = state.clone();
s.set_output("gamma_result".to_string());
Ok(StateUpdate::full(s))
});
graph.add_edge(START, "decide");
let targets = HashMap::from([
("alpha".to_string(), "alpha".to_string()),
("beta".to_string(), "beta".to_string()),
("gamma".to_string(), "gamma".to_string()),
]);
graph.add_conditional_edges("decide", "branch_router", targets, None);
graph.add_edge("alpha", END);
graph.add_edge("beta", END);
graph.add_edge("gamma", END);
let router = FunctionRouter::new(|state: &AgentState| {
state.input.clone()
});
graph.set_conditional_router("branch_router", router);
let compiled = graph.compile().unwrap();
let alpha_result = compiled.invoke(AgentState::new("alpha".to_string())).await.unwrap();
assert_eq!(alpha_result.final_state.output, Some("alpha_result".to_string()));
let beta_result = compiled.invoke(AgentState::new("beta".to_string())).await.unwrap();
assert_eq!(beta_result.final_state.output, Some("beta_result".to_string()));
let gamma_result = compiled.invoke(AgentState::new("gamma".to_string())).await.unwrap();
assert_eq!(gamma_result.final_state.output, Some("gamma_result".to_string()));
}
#[test]
fn test_conditional_edge_validation_requires_router() {
let mut graph: StateGraph<AgentState> = StateGraph::new();
graph.add_node_fn("entry", |state| Ok(StateUpdate::full(state.clone())));
graph.add_node_fn("target", |state| Ok(StateUpdate::full(state.clone())));
graph.add_edge(START, "entry");
let targets = HashMap::from([
("route".to_string(), "target".to_string()),
]);
graph.add_conditional_edges("entry", "missing_router", targets, None);
graph.add_edge("target", END);
let result = graph.compile();
assert!(result.is_err());
}