use oris_runtime::graph::{
function_node, interrupt, Command, InMemorySaver, MessagesState, RunnableConfig, StateGraph,
StateOrCommand, END, START,
};
use oris_runtime::schemas::messages::Message;
use std::collections::HashMap;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let approval_node = function_node("approval", |_state: &MessagesState| async move {
use oris_runtime::graph::GraphError;
let approved = interrupt("Do you approve this action?")
.await
.map_err(GraphError::InterruptError)?;
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message(format!(
"Approval result: {}",
approved
))])?,
);
Ok(update)
});
let mut graph = StateGraph::<MessagesState>::new();
graph.add_node("approval", approval_node)?;
graph.add_edge(START, "approval");
graph.add_edge("approval", END);
let checkpointer = std::sync::Arc::new(InMemorySaver::new());
let compiled = graph.compile_with_persistence(Some(checkpointer.clone()), None)?;
let config = RunnableConfig::with_thread_id("thread-1");
let initial_state = MessagesState::with_messages(vec![Message::new_human_message("start")]);
let result = compiled
.invoke_with_config_interrupt(StateOrCommand::State(initial_state), &config)
.await?;
if let Some(interrupts) = result.interrupt {
println!("Interrupted: {:?}", interrupts);
if let Some(interrupt) = interrupts.first() {
println!("Interrupt value: {}", interrupt.value);
}
}
let resumed = compiled
.invoke_with_config_interrupt(
StateOrCommand::Command(Command::resume(serde_json::json!(true))),
&config,
)
.await?;
println!(
"Final state messages count: {}",
resumed.state.messages.len()
);
for message in &resumed.state.messages {
println!(
" {}: {}",
message.message_type.to_string(),
message.content
);
}
Ok(())
}