use super::{EdgeType, GraphEdge, GraphError, GraphNode, KnowledgeGraph, NodeType};
use crate::RragResult;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
#[async_trait]
pub trait GraphStorage: Send + Sync {
async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()>;
async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph>;
async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()>;
async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()>;
async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>>;
async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>>;
async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>>;
async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>>;
async fn get_neighbors(
&self,
node_id: &str,
direction: EdgeDirection,
) -> RragResult<Vec<GraphNode>>;
async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()>;
async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()>;
async fn clear(&self) -> RragResult<()>;
async fn get_stats(&self) -> RragResult<StorageStats>;
}
pub struct InMemoryGraphStorage {
graphs: tokio::sync::RwLock<HashMap<String, KnowledgeGraph>>,
node_index: tokio::sync::RwLock<GraphIndex<GraphNode>>,
edge_index: tokio::sync::RwLock<GraphIndex<GraphEdge>>,
config: GraphStorageConfig,
}
#[derive(Debug, Clone)]
pub struct GraphIndex<T> {
by_id: HashMap<String, T>,
indices: HashMap<String, BTreeMap<String, HashSet<String>>>,
text_index: HashMap<String, HashSet<String>>,
}
#[derive(Debug, Clone)]
pub struct NodeQuery {
pub node_ids: Option<Vec<String>>,
pub node_types: Option<Vec<NodeType>>,
pub text_search: Option<String>,
pub attribute_filters: HashMap<String, serde_json::Value>,
pub source_document_filters: Option<Vec<String>>,
pub min_confidence: Option<f32>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct EdgeQuery {
pub edge_ids: Option<Vec<String>>,
pub source_node_ids: Option<Vec<String>>,
pub target_node_ids: Option<Vec<String>>,
pub edge_types: Option<Vec<EdgeType>>,
pub text_search: Option<String>,
pub attribute_filters: HashMap<String, serde_json::Value>,
pub weight_range: Option<(f32, f32)>,
pub min_confidence: Option<f32>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub enum EdgeDirection {
Outgoing,
Incoming,
Both,
}
#[derive(Debug, Clone)]
pub struct GraphQuery {
pub start_nodes: Vec<String>,
pub pattern: GraphPattern,
pub max_depth: usize,
pub limit: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct GraphPattern {
pub node_patterns: Vec<NodePattern>,
pub edge_patterns: Vec<EdgePattern>,
pub constraints: Vec<PatternConstraint>,
}
#[derive(Debug, Clone)]
pub struct NodePattern {
pub variable: String,
pub node_type: Option<NodeType>,
pub label_pattern: Option<String>,
pub attribute_constraints: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct EdgePattern {
pub source_variable: String,
pub target_variable: String,
pub edge_type: Option<EdgeType>,
pub label_pattern: Option<String>,
pub attribute_constraints: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub enum PatternConstraint {
Distance {
var1: String,
var2: String,
max_distance: usize,
},
Path {
start_var: String,
end_var: String,
path_type: PathType,
},
Count {
variable: String,
min_count: usize,
max_count: Option<usize>,
},
}
#[derive(Debug, Clone)]
pub enum PathType {
Any,
Shortest,
EdgeTypes(Vec<EdgeType>),
}
#[derive(Debug, Clone)]
pub struct GraphQueryResult {
pub bindings: Vec<HashMap<String, String>>,
pub execution_time_ms: u64,
pub nodes_examined: usize,
pub edges_examined: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub graph_count: usize,
pub total_nodes: usize,
pub total_edges: usize,
pub storage_size_bytes: usize,
pub index_size_bytes: usize,
pub node_type_distribution: HashMap<String, usize>,
pub edge_type_distribution: HashMap<String, usize>,
pub last_updated: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphStorageConfig {
pub enable_text_indexing: bool,
pub enable_attribute_indexing: bool,
pub max_cache_size: usize,
pub batch_size: usize,
pub enable_compression: bool,
}
impl Default for GraphStorageConfig {
fn default() -> Self {
Self {
enable_text_indexing: true,
enable_attribute_indexing: true,
max_cache_size: 10_000,
batch_size: 1_000,
enable_compression: false,
}
}
}
impl<T> GraphIndex<T>
where
T: Clone + Send + Sync,
{
pub fn new() -> Self {
Self {
by_id: HashMap::new(),
indices: HashMap::new(),
text_index: HashMap::new(),
}
}
pub fn add_item(&mut self, id: String, item: T, indexable_fields: &HashMap<String, String>) {
self.by_id.insert(id.clone(), item);
for (field_name, field_value) in indexable_fields {
self.indices
.entry(field_name.clone())
.or_insert_with(BTreeMap::new)
.entry(field_value.clone())
.or_insert_with(HashSet::new)
.insert(id.clone());
}
for (_, field_value) in indexable_fields {
let tokens = Self::tokenize(field_value);
for token in tokens {
self.text_index
.entry(token.to_lowercase())
.or_insert_with(HashSet::new)
.insert(id.clone());
}
}
}
pub fn remove_item(&mut self, id: &str) {
self.by_id.remove(id);
for index in self.indices.values_mut() {
for ids in index.values_mut() {
ids.remove(id);
}
}
for ids in self.text_index.values_mut() {
ids.remove(id);
}
}
pub fn get_by_id(&self, id: &str) -> Option<&T> {
self.by_id.get(id)
}
pub fn find_by_field(&self, field_name: &str, field_value: &str) -> Vec<&T> {
if let Some(index) = self.indices.get(field_name) {
if let Some(ids) = index.get(field_value) {
return ids.iter().filter_map(|id| self.by_id.get(id)).collect();
}
}
Vec::new()
}
pub fn text_search(&self, query: &str) -> Vec<&T> {
let tokens = Self::tokenize(query);
let mut matching_ids = HashSet::new();
for (i, token) in tokens.iter().enumerate() {
if let Some(ids) = self.text_index.get(&token.to_lowercase()) {
if i == 0 {
matching_ids.extend(ids.clone());
} else {
matching_ids.retain(|id| ids.contains(id));
}
}
}
matching_ids
.iter()
.filter_map(|id| self.by_id.get(id))
.collect()
}
pub fn get_all(&self) -> Vec<&T> {
self.by_id.values().collect()
}
pub fn stats(&self) -> HashMap<String, usize> {
let mut stats = HashMap::new();
stats.insert("total_items".to_string(), self.by_id.len());
stats.insert("indices_count".to_string(), self.indices.len());
stats.insert("text_terms".to_string(), self.text_index.len());
stats
}
pub fn clear(&mut self) {
self.by_id.clear();
self.indices.clear();
self.text_index.clear();
}
fn tokenize(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
}
}
impl Default for NodeQuery {
fn default() -> Self {
Self {
node_ids: None,
node_types: None,
text_search: None,
attribute_filters: HashMap::new(),
source_document_filters: None,
min_confidence: None,
limit: None,
offset: None,
}
}
}
impl Default for EdgeQuery {
fn default() -> Self {
Self {
edge_ids: None,
source_node_ids: None,
target_node_ids: None,
edge_types: None,
text_search: None,
attribute_filters: HashMap::new(),
weight_range: None,
min_confidence: None,
limit: None,
offset: None,
}
}
}
impl InMemoryGraphStorage {
pub fn new() -> Self {
Self::with_config(GraphStorageConfig::default())
}
pub fn with_config(config: GraphStorageConfig) -> Self {
Self {
graphs: tokio::sync::RwLock::new(HashMap::new()),
node_index: tokio::sync::RwLock::new(GraphIndex::new()),
edge_index: tokio::sync::RwLock::new(GraphIndex::new()),
config,
}
}
fn create_node_indexable_fields(node: &GraphNode) -> HashMap<String, String> {
let mut fields = HashMap::new();
fields.insert("label".to_string(), node.label.clone());
fields.insert(
"node_type".to_string(),
Self::node_type_string(&node.node_type),
);
fields.insert("confidence".to_string(), node.confidence.to_string());
for (key, value) in &node.attributes {
if let Some(string_value) = value.as_str() {
fields.insert(format!("attr_{}", key), string_value.to_string());
}
}
fields
}
fn create_edge_indexable_fields(edge: &GraphEdge) -> HashMap<String, String> {
let mut fields = HashMap::new();
fields.insert("label".to_string(), edge.label.clone());
fields.insert(
"edge_type".to_string(),
Self::edge_type_string(&edge.edge_type),
);
fields.insert("source_id".to_string(), edge.source_id.clone());
fields.insert("target_id".to_string(), edge.target_id.clone());
fields.insert("weight".to_string(), edge.weight.to_string());
fields.insert("confidence".to_string(), edge.confidence.to_string());
for (key, value) in &edge.attributes {
if let Some(string_value) = value.as_str() {
fields.insert(format!("attr_{}", key), string_value.to_string());
}
}
fields
}
fn node_type_string(node_type: &NodeType) -> String {
match node_type {
NodeType::Entity(entity_type) => format!("Entity({})", entity_type),
NodeType::Concept => "Concept".to_string(),
NodeType::Document => "Document".to_string(),
NodeType::DocumentChunk => "DocumentChunk".to_string(),
NodeType::Keyword => "Keyword".to_string(),
NodeType::Custom(custom_type) => format!("Custom({})", custom_type),
}
}
fn edge_type_string(edge_type: &EdgeType) -> String {
match edge_type {
EdgeType::Semantic(relation) => format!("Semantic({})", relation),
EdgeType::Hierarchical => "Hierarchical".to_string(),
EdgeType::Contains => "Contains".to_string(),
EdgeType::References => "References".to_string(),
EdgeType::CoOccurs => "CoOccurs".to_string(),
EdgeType::Similar => "Similar".to_string(),
EdgeType::Custom(custom_type) => format!("Custom({})", custom_type),
}
}
fn apply_node_filters(&self, nodes: Vec<&GraphNode>, query: &NodeQuery) -> Vec<GraphNode> {
let mut result: Vec<_> = nodes.into_iter().cloned().collect();
if let Some(node_types) = &query.node_types {
result.retain(|node| node_types.contains(&node.node_type));
}
if let Some(min_confidence) = query.min_confidence {
result.retain(|node| node.confidence >= min_confidence);
}
for (attr_key, attr_value) in &query.attribute_filters {
result.retain(|node| {
node.attributes
.get(attr_key)
.map_or(false, |v| v == attr_value)
});
}
if let Some(source_docs) = &query.source_document_filters {
result.retain(|node| {
node.source_documents
.iter()
.any(|doc| source_docs.contains(doc))
});
}
if let Some(offset) = query.offset {
if offset < result.len() {
result.drain(0..offset);
} else {
result.clear();
}
}
if let Some(limit) = query.limit {
result.truncate(limit);
}
result
}
fn apply_edge_filters(&self, edges: Vec<&GraphEdge>, query: &EdgeQuery) -> Vec<GraphEdge> {
let mut result: Vec<_> = edges.into_iter().cloned().collect();
if let Some(source_ids) = &query.source_node_ids {
result.retain(|edge| source_ids.contains(&edge.source_id));
}
if let Some(target_ids) = &query.target_node_ids {
result.retain(|edge| target_ids.contains(&edge.target_id));
}
if let Some(edge_types) = &query.edge_types {
result.retain(|edge| edge_types.contains(&edge.edge_type));
}
if let Some((min_weight, max_weight)) = query.weight_range {
result.retain(|edge| edge.weight >= min_weight && edge.weight <= max_weight);
}
if let Some(min_confidence) = query.min_confidence {
result.retain(|edge| edge.confidence >= min_confidence);
}
for (attr_key, attr_value) in &query.attribute_filters {
result.retain(|edge| {
edge.attributes
.get(attr_key)
.map_or(false, |v| v == attr_value)
});
}
if let Some(offset) = query.offset {
if offset < result.len() {
result.drain(0..offset);
} else {
result.clear();
}
}
if let Some(limit) = query.limit {
result.truncate(limit);
}
result
}
}
#[async_trait]
impl GraphStorage for InMemoryGraphStorage {
async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()> {
let graph_id = uuid::Uuid::new_v4().to_string();
self.graphs.write().await.insert(graph_id, graph.clone());
self.store_nodes(&graph.nodes.values().cloned().collect::<Vec<_>>())
.await?;
self.store_edges(&graph.edges.values().cloned().collect::<Vec<_>>())
.await?;
Ok(())
}
async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph> {
self.graphs
.read()
.await
.get(graph_id)
.cloned()
.ok_or_else(|| {
GraphError::Storage {
operation: "load_graph".to_string(),
message: format!("Graph '{}' not found", graph_id),
}
.into()
})
}
async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()> {
let mut node_index = self.node_index.write().await;
for node in nodes {
let indexable_fields = Self::create_node_indexable_fields(node);
node_index.add_item(node.id.clone(), node.clone(), &indexable_fields);
}
Ok(())
}
async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()> {
let mut edge_index = self.edge_index.write().await;
for edge in edges {
let indexable_fields = Self::create_edge_indexable_fields(edge);
edge_index.add_item(edge.id.clone(), edge.clone(), &indexable_fields);
}
Ok(())
}
async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>> {
let node_index = self.node_index.read().await;
let mut candidates = Vec::new();
if let Some(node_ids) = &query.node_ids {
for node_id in node_ids {
if let Some(node) = node_index.get_by_id(node_id) {
candidates.push(node);
}
}
} else if let Some(text_query) = &query.text_search {
candidates = node_index.text_search(text_query);
} else {
candidates = node_index.get_all();
}
Ok(self.apply_node_filters(candidates, query))
}
async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>> {
let edge_index = self.edge_index.read().await;
let mut candidates = Vec::new();
if let Some(edge_ids) = &query.edge_ids {
for edge_id in edge_ids {
if let Some(edge) = edge_index.get_by_id(edge_id) {
candidates.push(edge);
}
}
} else if let Some(text_query) = &query.text_search {
candidates = edge_index.text_search(text_query);
} else {
candidates = edge_index.get_all();
}
Ok(self.apply_edge_filters(candidates, query))
}
async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>> {
let node_index = self.node_index.read().await;
Ok(node_index.get_by_id(node_id).cloned())
}
async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>> {
let edge_index = self.edge_index.read().await;
Ok(edge_index.get_by_id(edge_id).cloned())
}
async fn get_neighbors(
&self,
node_id: &str,
direction: EdgeDirection,
) -> RragResult<Vec<GraphNode>> {
let edge_index = self.edge_index.read().await;
let node_index = self.node_index.read().await;
let mut neighbor_ids = HashSet::new();
match direction {
EdgeDirection::Outgoing => {
let outgoing_edges = edge_index.find_by_field("source_id", node_id);
for edge in outgoing_edges {
neighbor_ids.insert(&edge.target_id);
}
}
EdgeDirection::Incoming => {
let incoming_edges = edge_index.find_by_field("target_id", node_id);
for edge in incoming_edges {
neighbor_ids.insert(&edge.source_id);
}
}
EdgeDirection::Both => {
let outgoing_edges = edge_index.find_by_field("source_id", node_id);
for edge in outgoing_edges {
neighbor_ids.insert(&edge.target_id);
}
let incoming_edges = edge_index.find_by_field("target_id", node_id);
for edge in incoming_edges {
neighbor_ids.insert(&edge.source_id);
}
}
}
let neighbors = neighbor_ids
.into_iter()
.filter_map(|id| node_index.get_by_id(id))
.cloned()
.collect();
Ok(neighbors)
}
async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()> {
let mut node_index = self.node_index.write().await;
for node_id in node_ids {
node_index.remove_item(node_id);
}
Ok(())
}
async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()> {
let mut edge_index = self.edge_index.write().await;
for edge_id in edge_ids {
edge_index.remove_item(edge_id);
}
Ok(())
}
async fn clear(&self) -> RragResult<()> {
self.graphs.write().await.clear();
self.node_index.write().await.clear();
self.edge_index.write().await.clear();
Ok(())
}
async fn get_stats(&self) -> RragResult<StorageStats> {
let graphs = self.graphs.read().await;
let node_index = self.node_index.read().await;
let edge_index = self.edge_index.read().await;
let graph_count = graphs.len();
let total_nodes = node_index.by_id.len();
let total_edges = edge_index.by_id.len();
let mut node_type_distribution = HashMap::new();
for node in node_index.by_id.values() {
let type_key = Self::node_type_string(&node.node_type);
*node_type_distribution.entry(type_key).or_insert(0) += 1;
}
let mut edge_type_distribution = HashMap::new();
for edge in edge_index.by_id.values() {
let type_key = Self::edge_type_string(&edge.edge_type);
*edge_type_distribution.entry(type_key).or_insert(0) += 1;
}
let storage_size_bytes = (total_nodes + total_edges) * 1000; let index_size_bytes = (node_index.indices.len() + edge_index.indices.len()) * 100;
Ok(StorageStats {
graph_count,
total_nodes,
total_edges,
storage_size_bytes,
index_size_bytes,
node_type_distribution,
edge_type_distribution,
last_updated: chrono::Utc::now(),
})
}
}
impl Default for InMemoryGraphStorage {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph_retrieval::{EdgeType, NodeType};
#[tokio::test]
async fn test_in_memory_graph_storage() {
let storage = InMemoryGraphStorage::new();
let node1 = GraphNode::new("Test Node 1", NodeType::Concept);
let node2 = GraphNode::new("Test Node 2", NodeType::Entity("Person".to_string()));
let node1_id = node1.id.clone();
let node2_id = node2.id.clone();
storage
.store_nodes(&[node1.clone(), node2.clone()])
.await
.unwrap();
let mut query = NodeQuery::default();
query.text_search = Some("Test".to_string());
let results = storage.query_nodes(&query).await.unwrap();
assert_eq!(results.len(), 2);
let retrieved_node = storage.get_node(&node1_id).await.unwrap();
assert!(retrieved_node.is_some());
assert_eq!(retrieved_node.unwrap().label, "Test Node 1");
}
#[tokio::test]
async fn test_edge_storage_and_queries() {
let storage = InMemoryGraphStorage::new();
let node1 = GraphNode::new("Node 1", NodeType::Concept);
let node2 = GraphNode::new("Node 2", NodeType::Concept);
let node1_id = node1.id.clone();
let node2_id = node2.id.clone();
storage.store_nodes(&[node1, node2]).await.unwrap();
let edge = GraphEdge::new(
node1_id.clone(),
node2_id.clone(),
"test_relation",
EdgeType::Similar,
);
let edge_id = edge.id.clone();
storage.store_edges(&[edge]).await.unwrap();
let mut query = EdgeQuery::default();
query.source_node_ids = Some(vec![node1_id.clone()]);
let results = storage.query_edges(&query).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].source_id, node1_id);
assert_eq!(results[0].target_id, node2_id);
let neighbors = storage
.get_neighbors(&node1_id, EdgeDirection::Outgoing)
.await
.unwrap();
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0].id, node2_id);
}
#[tokio::test]
async fn test_storage_stats() {
let storage = InMemoryGraphStorage::new();
let nodes = vec![
GraphNode::new("Node 1", NodeType::Concept),
GraphNode::new("Node 2", NodeType::Entity("Person".to_string())),
GraphNode::new("Node 3", NodeType::Document),
];
storage.store_nodes(&nodes).await.unwrap();
let stats = storage.get_stats().await.unwrap();
assert_eq!(stats.total_nodes, 3);
assert_eq!(stats.total_edges, 0);
assert!(stats.node_type_distribution.contains_key("Concept"));
assert!(stats.node_type_distribution.contains_key("Entity(Person)"));
}
}