use oris_runtime::graph::{function_node, MessagesState, StateGraph, END, START};
use oris_runtime::schemas::messages::Message;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mock_llm = function_node("mock_llm", |_state: &MessagesState| async move {
use std::collections::HashMap;
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message("hello world")])?,
);
Ok(update)
});
let mut graph = StateGraph::<MessagesState>::new();
graph.add_node("mock_llm", mock_llm)?;
graph.add_edge(START, "mock_llm");
graph.add_edge("mock_llm", END);
let compiled = graph.compile()?;
let initial_state = MessagesState::with_messages(vec![Message::new_human_message("hi!")]);
let final_state = compiled.invoke(initial_state).await?;
println!("Final messages:");
for message in &final_state.messages {
println!(
" {}: {}",
message.message_type.to_string(),
message.content
);
}
Ok(())
}