use async_trait::async_trait;
use oxidizedgraph::prelude::*;
use std::sync::Arc;
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct RAGState {
pub query: String,
pub context: Vec<String>,
pub response: Option<String>,
pub messages: Vec<Message>,
}
impl State for RAGState {
fn schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" },
"context": { "type": "array", "items": { "type": "string" } },
"response": { "type": "string" },
"messages": { "type": "array", "channel": "append" }
}
})
}
}
struct RAGRetrievalNode {
}
#[async_trait]
impl NodeExecutor for RAGRetrievalNode {
fn id(&self) -> &str {
"rag_retrieval"
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
let query = {
let guard = state.read().map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.get_context::<String>("query").unwrap_or_default()
};
let retrieved_context = vec![
format!("Context 1: Information relevant to '{}'", query),
format!("Context 2: Additional details about '{}'", query),
];
{
let mut guard = state.write().map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.set_context("retrieved_context", retrieved_context);
}
Ok(NodeOutput::cont())
}
fn description(&self) -> Option<&str> {
Some("Retrieves relevant context from knowledge graph using GraphRAG")
}
}
struct RAGGenerationNode;
#[async_trait]
impl NodeExecutor for RAGGenerationNode {
fn id(&self) -> &str {
"rag_generation"
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
let (query, context) = {
let guard = state.read().map_err(|e| NodeError::execution_failed(e.to_string()))?;
let query = guard.get_context::<String>("query").unwrap_or_default();
let context: Vec<String> = guard.get_context("retrieved_context").unwrap_or_default();
(query, context)
};
let response = format!(
"Based on the following context:\n{}\n\nAnswer to '{}': This is a generated response.",
context.join("\n"),
query
);
{
let mut guard = state.write().map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.set_context("response", response.clone());
guard.messages.push(Message::assistant(response));
}
Ok(NodeOutput::finish())
}
fn description(&self) -> Option<&str> {
Some("Generates response using LLM with retrieved context")
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let graph = GraphBuilder::new()
.name("rag_workflow")
.description("RAG-augmented agent workflow")
.add_node(RAGRetrievalNode {})
.add_node(RAGGenerationNode)
.set_entry_point("rag_retrieval")
.add_edge("rag_retrieval", "rag_generation")
.add_edge_to_end("rag_generation")
.compile()?;
println!("Graph structure:\n{}", graph.to_mermaid());
let checkpointer = Arc::new(MemoryCheckpointer::new());
let runner = CheckpointingRunner::new(graph, checkpointer.clone())
.checkpoint_every_node();
let mut initial_state = AgentState::new();
initial_state.set_context("query", "What is oxidizedgraph?".to_string());
let result = runner.invoke("rag-thread-1", initial_state).await?;
match result {
RunResult::Completed(state) => {
println!("\n=== Workflow Completed ===");
if let Some(response) = state.get_context::<String>("response") {
println!("Response: {}", response);
}
}
RunResult::Interrupted { checkpoint, reason } => {
println!("\n=== Workflow Interrupted ===");
println!("Reason: {}", reason);
println!("Checkpoint ID: {}", checkpoint.id);
}
}
let history = checkpointer.history("rag-thread-1", 10).await?;
println!("\n=== Checkpoint History ===");
for cp in history {
println!(" - {} at node '{}'", cp.id, cp.node_id);
}
Ok(())
}