use oris_runtime::graph::{
function_node_with_store, GraphError, InMemorySaver, InMemoryStore, MessagesState,
RunnableConfig, StateGraph, END, START,
};
use oris_runtime::schemas::messages::Message;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let call_model = function_node_with_store(
"call_model",
|state: &MessagesState, config: &RunnableConfig, store: oris_runtime::graph::StoreBox| {
let user_id = config.get_user_id().unwrap_or("default".to_string());
let last_message = state
.messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default();
async move {
use std::collections::HashMap;
let namespace = ["memories", user_id.as_str()];
let last_msg_ref = last_message.as_str();
let memories = store
.search(&namespace, Some(last_msg_ref), Some(3))
.await
.map_err(|e| GraphError::ExecutionError(format!("Store error: {}", e)))?;
let memory_context: String = memories
.iter()
.map(|item| {
if let Some(data) = item.value.get("data") {
data.as_str().unwrap_or("").to_string()
} else {
String::new()
}
})
.collect::<Vec<_>>()
.join("\n");
let last_msg_lower = last_message.to_lowercase();
if last_msg_lower.contains("remember") {
let memory_text = last_message.replace("remember", "").trim().to_string();
if !memory_text.is_empty() {
let memory_id = format!("memory-{}", chrono::Utc::now().timestamp());
store
.put(
&namespace,
&memory_id,
serde_json::json!({"data": memory_text}),
)
.await
.map_err(|e| {
GraphError::ExecutionError(format!("Store error: {}", e))
})?;
}
}
let response_text = if !memory_context.is_empty() {
format!(
"Based on your memories: {}. Response to: {}",
memory_context, last_message
)
} else {
format!("Response to: {}", last_message)
};
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message(response_text)])?,
);
Ok(update)
}
},
);
let mut graph = StateGraph::<MessagesState>::new();
graph.add_node("call_model", call_model)?;
graph.add_edge(START, "call_model");
graph.add_edge("call_model", END);
let checkpointer = std::sync::Arc::new(InMemorySaver::new());
let store = std::sync::Arc::new(InMemoryStore::new());
let compiled =
graph.compile_with_persistence(Some(checkpointer.clone()), Some(store.clone()))?;
println!("=== Thread 1: User introduces themselves ===");
let config1 = {
let mut cfg = RunnableConfig::with_thread_id("thread-1");
cfg.configurable
.insert("user_id".to_string(), serde_json::json!("user-123"));
cfg
};
let initial_state1 = MessagesState::with_messages(vec![Message::new_human_message(
"Hi! Remember: my name is Bob",
)]);
let result1 = compiled
.invoke_with_config(Some(initial_state1), &config1)
.await?;
println!("Response: {}", result1.messages.last().unwrap().content);
println!("\n=== Thread 2: Same user, different thread ===");
let config2 = {
let mut cfg = RunnableConfig::with_thread_id("thread-2");
cfg.configurable
.insert("user_id".to_string(), serde_json::json!("user-123"));
cfg
};
let initial_state2 =
MessagesState::with_messages(vec![Message::new_human_message("What is my name?")]);
let result2 = compiled
.invoke_with_config(Some(initial_state2), &config2)
.await?;
println!("Response: {}", result2.messages.last().unwrap().content);
Ok(())
}