use langchainrust::{
StateGraph, GraphBuilder, START, END,
AgentState, StateUpdate, GraphExecution,
};
#[tokio::test]
async fn test_interrupt_before() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("step1", |state| Ok(StateUpdate::full(state.clone())))
.add_node_fn("step2", |state| Ok(StateUpdate::full(state.clone())))
.add_node_fn("step3", |state| Ok(StateUpdate::full(state.clone())))
.add_edge(START, "step1")
.add_edge("step1", "step2")
.add_edge("step2", "step3")
.add_edge("step3", END)
.compile()
.unwrap()
.with_interrupt_before(vec!["step2".to_string()]);
let input = AgentState::new("test".to_string());
let result = compiled.invoke(input).await;
assert!(result.is_err()); let err = result.unwrap_err();
assert!(matches!(err, langchainrust::GraphError::ExecutionInterrupted(_)));
let interrupted_at = err.to_string();
assert!(interrupted_at.contains("step2")); }
#[tokio::test]
async fn test_interrupt_after() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("process", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("processed".to_string()));
Ok(StateUpdate::full(new_state))
})
.add_node_fn("finalize", |state| {
let mut new_state = state.clone();
new_state.set_output("done".to_string());
Ok(StateUpdate::full(new_state))
})
.add_edge(START, "process")
.add_edge("process", "finalize")
.add_edge("finalize", END)
.compile()
.unwrap()
.with_interrupt_after(vec!["process".to_string()]);
let input = AgentState::new("test".to_string());
let result = compiled.invoke(input).await;
assert!(result.is_err()); let err = result.unwrap_err();
assert!(matches!(err, langchainrust::GraphError::ExecutionInterrupted(_)));
let interrupted_at = err.to_string();
assert!(interrupted_at.contains("after_process")); }
#[tokio::test]
async fn test_resume_from_interrupt() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("step1", |state| Ok(StateUpdate::full(state.clone())))
.add_node_fn("step2", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("step2_done".to_string()));
Ok(StateUpdate::full(new_state))
})
.add_node_fn("step3", |state| {
let mut new_state = state.clone();
new_state.set_output("completed".to_string());
Ok(StateUpdate::full(new_state))
})
.add_edge(START, "step1")
.add_edge("step1", "step2")
.add_edge("step2", "step3")
.add_edge("step3", END)
.compile()
.unwrap()
.with_interrupt_after(vec!["step1".to_string()]);
let input = AgentState::new("test".to_string());
let result = compiled.invoke(input).await;
assert!(result.is_err());
let mut execution = GraphExecution::new(
AgentState::new("test".to_string()), "step2".to_string(), "after_step1".to_string(), );
execution.recursion_count = 1;
let resumed_result = compiled.resume(execution).await.unwrap();
assert!(resumed_result.final_state.output.is_some()); assert_eq!(resumed_result.final_state.output.unwrap(), "completed");
assert_eq!(resumed_result.recursion_count, 2); }
#[tokio::test]
async fn test_multiple_interrupts() {
let compiled = GraphBuilder::<AgentState>::new()
.add_node_fn("check1", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("check1_pass".to_string()));
Ok(StateUpdate::full(new_state))
})
.add_node_fn("check2", |state| {
let mut new_state = state.clone();
new_state.add_message(langchainrust::MessageEntry::ai("check2_pass".to_string()));
Ok(StateUpdate::full(new_state))
})
.add_node_fn("final", |state| {
let mut new_state = state.clone();
new_state.set_output("all_checks_passed".to_string());
Ok(StateUpdate::full(new_state))
})
.add_edge(START, "check1")
.add_edge("check1", "check2")
.add_edge("check2", "final")
.add_edge("final", END)
.compile()
.unwrap()
.with_interrupt_before(vec!["check1".to_string(), "check2".to_string()]);
let input = AgentState::new("test".to_string());
let result1 = compiled.invoke(input).await;
assert!(result1.is_err()); assert!(result1.unwrap_err().to_string().contains("check1"));
let execution1 = GraphExecution::new(
AgentState::new("test".to_string()),
"check1".to_string(), "check1".to_string(),
);
let result2 = compiled.resume(execution1).await;
assert!(result2.is_err()); assert!(result2.unwrap_err().to_string().contains("check2"));
let mut execution2 = GraphExecution::new(
AgentState::new("test".to_string()),
"check2".to_string(), "check2".to_string(),
);
execution2.recursion_count = 1; execution2.state.add_message(langchainrust::MessageEntry::ai("check1_pass".to_string()));
let result3 = compiled.resume(execution2).await.unwrap();
assert!(result3.final_state.output.is_some()); }