use enact_core::graph::{EdgeTarget, NodeState, StateGraph};
#[tokio::test]
async fn test_simple_two_node_graph() {
let graph = StateGraph::new()
.add_node("uppercase", |state: NodeState| async move {
let input = state.as_str().unwrap_or_default();
Ok(NodeState::from_string(&input.to_uppercase()))
})
.add_node("add_exclaim", |state: NodeState| async move {
let input = state.as_str().unwrap_or_default();
Ok(NodeState::from_string(&format!("{}!", input)))
})
.set_entry_point("uppercase")
.add_edge("uppercase", "add_exclaim")
.add_edge_to_end("add_exclaim")
.compile()
.expect("Failed to compile graph");
let result = graph.run("hello world").await.expect("Failed to run graph");
assert_eq!(result.as_str(), Some("HELLO WORLD!"));
}
#[tokio::test]
async fn test_conditional_routing() {
let graph = StateGraph::new()
.add_node("classify", |state: NodeState| async move {
let input = state.as_str().unwrap_or_default();
if input.contains("error") {
Ok(NodeState::from_string("error"))
} else {
Ok(NodeState::from_string("success"))
}
})
.add_node("handle_error", |_state: NodeState| async move {
Ok(NodeState::from_string("Error handled"))
})
.add_node("handle_success", |_state: NodeState| async move {
Ok(NodeState::from_string("Success!"))
})
.set_entry_point("classify")
.add_conditional_edge("classify", |output: &str| match output {
"error" => EdgeTarget::node("handle_error"),
_ => EdgeTarget::node("handle_success"),
})
.add_edge_to_end("handle_error")
.add_edge_to_end("handle_success")
.compile()
.expect("Failed to compile graph");
let result = graph.run("hello").await.expect("Failed to run");
assert_eq!(result.as_str(), Some("Success!"));
let result = graph.run("this has an error").await.expect("Failed to run");
assert_eq!(result.as_str(), Some("Error handled"));
}
#[tokio::test]
async fn test_graph_validation() {
let result = StateGraph::new().compile();
assert!(result.is_err());
let graph = StateGraph::new()
.add_node("a", |state: NodeState| async move { Ok(state) })
.compile();
assert!(graph.is_ok());
}
#[tokio::test]
async fn test_cycle_detection() {
let result = StateGraph::new()
.add_node("a", |state: NodeState| async move { Ok(state) })
.set_entry_point("a")
.add_edge("a", "a")
.compile();
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(err.contains("cycle"), "Expected cycle error, got: {}", err);
}
#[tokio::test]
async fn test_two_node_cycle_detection() {
let result = StateGraph::new()
.add_node("a", |state: NodeState| async move { Ok(state) })
.add_node("b", |state: NodeState| async move { Ok(state) })
.set_entry_point("a")
.add_edge("a", "b")
.add_edge("b", "a")
.compile();
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(err.contains("cycle"), "Expected cycle error, got: {}", err);
}
#[tokio::test]
async fn test_valid_dag_compiles() {
let result = StateGraph::new()
.add_node("a", |state: NodeState| async move { Ok(state) })
.add_node("b", |state: NodeState| async move { Ok(state) })
.add_node("c", |state: NodeState| async move { Ok(state) })
.add_node("d", |state: NodeState| async move { Ok(state) })
.set_entry_point("a")
.add_edge("a", "b")
.add_edge("a", "c")
.add_edge("b", "d")
.add_edge("c", "d")
.add_edge_to_end("d")
.compile();
assert!(
result.is_ok(),
"Valid DAG should compile: {:?}",
result.err()
);
}