use oris_runtime::graph::{
execute_task_with_cache, function_node, FunctionTask, GraphError, InMemorySaver, MessagesState,
RunnableConfig, StateGraph, Task, TaskCache, END, START,
};
use oris_runtime::schemas::messages::Message;
use serde_json::Value;
use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_task: Arc<dyn Task> = Arc::new(FunctionTask::new("api_call", |input: Value| {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.unwrap_or("https://example.com");
Ok(serde_json::json!({
"url": url,
"response": format!("Response from {}", url),
"timestamp": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}))
})
}));
let api_task_clone = api_task.clone();
let api_node = function_node("api_node", move |_state: &MessagesState| {
let api_task = api_task_clone.clone();
async move {
use std::collections::HashMap;
let cache = TaskCache::new();
let task_input = serde_json::json!({
"url": "https://api.example.com/data"
});
let task_result = execute_task_with_cache(api_task.as_ref(), task_input, Some(&cache))
.await
.map_err(|e| GraphError::ExecutionError(e.to_string()))?;
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message(format!(
"API response: {}",
task_result
.get("response")
.unwrap_or(&serde_json::json!("No response"))
))])?,
);
Ok(update)
}
});
let mut graph = StateGraph::<MessagesState>::new();
graph.add_node("api_node", api_node)?;
graph.add_edge(START, "api_node");
graph.add_edge("api_node", END);
let checkpointer = std::sync::Arc::new(InMemorySaver::new());
let compiled = graph.compile_with_persistence(Some(checkpointer.clone()), None)?;
let config = RunnableConfig::with_thread_id("thread-task-1");
let initial_state =
MessagesState::with_messages(vec![Message::new_human_message("Fetch data")]);
let final_state = compiled
.invoke_with_config(Some(initial_state), &config)
.await?;
println!("Final messages:");
for message in &final_state.messages {
println!(
" {}: {}",
message.message_type.to_string(),
message.content
);
}
println!("\nNote: If resuming from checkpoint, task results are cached");
println!("and the task will not be re-executed.");
Ok(())
}