use oris_runtime::graph::{
function_node, InMemorySaver, MessagesState, RunnableConfig, StateGraph, END, START,
};
use oris_runtime::schemas::messages::Message;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mock_llm = function_node("mock_llm", |_state: &MessagesState| async move {
use std::collections::HashMap;
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message("hello world")])?,
);
Ok(update)
});
let mut graph = StateGraph::<MessagesState>::new();
graph.add_node("mock_llm", mock_llm)?;
graph.add_edge(START, "mock_llm");
graph.add_edge("mock_llm", END);
let checkpointer = std::sync::Arc::new(InMemorySaver::new());
let compiled = graph.compile_with_persistence(Some(checkpointer.clone()), None)?;
let config = RunnableConfig::with_thread_id("thread-1");
let initial_state = MessagesState::with_messages(vec![Message::new_human_message("hi!")]);
let final_state = compiled
.invoke_with_config(Some(initial_state), &config)
.await?;
println!("Final messages:");
for message in &final_state.messages {
println!(
" {}: {}",
message.message_type.to_string(),
message.content
);
}
let snapshot = compiled.get_state(&config).await?;
println!("\nLatest checkpoint:");
println!(" Thread ID: {}", snapshot.thread_id());
println!(" Checkpoint ID: {:?}", snapshot.checkpoint_id());
println!(" Next nodes: {:?}", snapshot.next);
println!(" Created at: {}", snapshot.created_at);
let history = compiled.get_state_history(&config).await?;
println!("\nState history ({} checkpoints):", history.len());
for (i, snapshot) in history.iter().enumerate() {
println!(
" {}: checkpoint_id={:?}, step={:?}",
i + 1,
snapshot.checkpoint_id(),
snapshot.metadata.get("step")
);
}
Ok(())
}