rexis_rag/graph_retrieval/
storage.rs

1//! # Graph Storage and Indexing
2//!
3//! Efficient storage and indexing system for knowledge graphs with support for
4//! various storage backends and optimized query operations.
5
6use super::{EdgeType, GraphEdge, GraphError, GraphNode, KnowledgeGraph, NodeType};
7use crate::RragResult;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::{BTreeMap, HashMap, HashSet};
11
12/// Graph storage trait for different storage backends
13#[async_trait]
14pub trait GraphStorage: Send + Sync {
15    /// Store a complete knowledge graph
16    async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()>;
17
18    /// Load a complete knowledge graph
19    async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph>;
20
21    /// Store individual nodes
22    async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()>;
23
24    /// Store individual edges
25    async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()>;
26
27    /// Query nodes by criteria
28    async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>>;
29
30    /// Query edges by criteria
31    async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>>;
32
33    /// Get node by ID
34    async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>>;
35
36    /// Get edge by ID
37    async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>>;
38
39    /// Get neighbors of a node
40    async fn get_neighbors(
41        &self,
42        node_id: &str,
43        direction: EdgeDirection,
44    ) -> RragResult<Vec<GraphNode>>;
45
46    /// Delete nodes
47    async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()>;
48
49    /// Delete edges
50    async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()>;
51
52    /// Clear all data
53    async fn clear(&self) -> RragResult<()>;
54
55    /// Get storage statistics
56    async fn get_stats(&self) -> RragResult<StorageStats>;
57}
58
59/// In-memory graph storage implementation
60pub struct InMemoryGraphStorage {
61    /// Stored graphs
62    graphs: tokio::sync::RwLock<HashMap<String, KnowledgeGraph>>,
63
64    /// Global node index
65    node_index: tokio::sync::RwLock<GraphIndex<GraphNode>>,
66
67    /// Global edge index
68    edge_index: tokio::sync::RwLock<GraphIndex<GraphEdge>>,
69
70    /// Configuration
71    config: GraphStorageConfig,
72}
73
74/// Graph indexing system for fast queries
75#[derive(Debug, Clone)]
76pub struct GraphIndex<T> {
77    /// Primary index by ID
78    by_id: HashMap<String, T>,
79
80    /// Secondary indices
81    indices: HashMap<String, BTreeMap<String, HashSet<String>>>,
82
83    /// Full-text search index (simple implementation)
84    text_index: HashMap<String, HashSet<String>>,
85}
86
87/// Node query parameters
88#[derive(Debug, Clone)]
89pub struct NodeQuery {
90    /// Node IDs to match
91    pub node_ids: Option<Vec<String>>,
92
93    /// Node types to match
94    pub node_types: Option<Vec<NodeType>>,
95
96    /// Text search in labels
97    pub text_search: Option<String>,
98
99    /// Attribute filters
100    pub attribute_filters: HashMap<String, serde_json::Value>,
101
102    /// Source document filters
103    pub source_document_filters: Option<Vec<String>>,
104
105    /// Confidence threshold
106    pub min_confidence: Option<f32>,
107
108    /// Limit number of results
109    pub limit: Option<usize>,
110
111    /// Offset for pagination
112    pub offset: Option<usize>,
113}
114
115/// Edge query parameters
116#[derive(Debug, Clone)]
117pub struct EdgeQuery {
118    /// Edge IDs to match
119    pub edge_ids: Option<Vec<String>>,
120
121    /// Source node IDs
122    pub source_node_ids: Option<Vec<String>>,
123
124    /// Target node IDs
125    pub target_node_ids: Option<Vec<String>>,
126
127    /// Edge types to match
128    pub edge_types: Option<Vec<EdgeType>>,
129
130    /// Text search in labels
131    pub text_search: Option<String>,
132
133    /// Attribute filters
134    pub attribute_filters: HashMap<String, serde_json::Value>,
135
136    /// Weight range
137    pub weight_range: Option<(f32, f32)>,
138
139    /// Confidence threshold
140    pub min_confidence: Option<f32>,
141
142    /// Limit number of results
143    pub limit: Option<usize>,
144
145    /// Offset for pagination
146    pub offset: Option<usize>,
147}
148
149/// Edge direction for neighbor queries
150#[derive(Debug, Clone, Copy)]
151pub enum EdgeDirection {
152    /// Outgoing edges from the node
153    Outgoing,
154
155    /// Incoming edges to the node
156    Incoming,
157
158    /// Both directions
159    Both,
160}
161
162/// Graph query for complex graph operations
163#[derive(Debug, Clone)]
164pub struct GraphQuery {
165    /// Starting nodes for the query
166    pub start_nodes: Vec<String>,
167
168    /// Query pattern (simplified graph pattern matching)
169    pub pattern: GraphPattern,
170
171    /// Maximum traversal depth
172    pub max_depth: usize,
173
174    /// Result limit
175    pub limit: Option<usize>,
176}
177
178/// Graph pattern for pattern matching
179#[derive(Debug, Clone)]
180pub struct GraphPattern {
181    /// Node patterns
182    pub node_patterns: Vec<NodePattern>,
183
184    /// Edge patterns
185    pub edge_patterns: Vec<EdgePattern>,
186
187    /// Pattern constraints
188    pub constraints: Vec<PatternConstraint>,
189}
190
191/// Node pattern for pattern matching
192#[derive(Debug, Clone)]
193pub struct NodePattern {
194    /// Pattern variable name
195    pub variable: String,
196
197    /// Node type constraint
198    pub node_type: Option<NodeType>,
199
200    /// Label constraint
201    pub label_pattern: Option<String>,
202
203    /// Attribute constraints
204    pub attribute_constraints: HashMap<String, serde_json::Value>,
205}
206
207/// Edge pattern for pattern matching
208#[derive(Debug, Clone)]
209pub struct EdgePattern {
210    /// Source node variable
211    pub source_variable: String,
212
213    /// Target node variable
214    pub target_variable: String,
215
216    /// Edge type constraint
217    pub edge_type: Option<EdgeType>,
218
219    /// Label constraint
220    pub label_pattern: Option<String>,
221
222    /// Attribute constraints
223    pub attribute_constraints: HashMap<String, serde_json::Value>,
224}
225
226/// Pattern constraint
227#[derive(Debug, Clone)]
228pub enum PatternConstraint {
229    /// Distance constraint between two nodes
230    Distance {
231        var1: String,
232        var2: String,
233        max_distance: usize,
234    },
235
236    /// Path constraint
237    Path {
238        start_var: String,
239        end_var: String,
240        path_type: PathType,
241    },
242
243    /// Count constraint
244    Count {
245        variable: String,
246        min_count: usize,
247        max_count: Option<usize>,
248    },
249}
250
251/// Path type for path constraints
252#[derive(Debug, Clone)]
253pub enum PathType {
254    /// Any path
255    Any,
256
257    /// Shortest path
258    Shortest,
259
260    /// Path with specific edge types
261    EdgeTypes(Vec<EdgeType>),
262}
263
264/// Query result for graph queries
265#[derive(Debug, Clone)]
266pub struct GraphQueryResult {
267    /// Matched variable bindings
268    pub bindings: Vec<HashMap<String, String>>,
269
270    /// Execution time in milliseconds
271    pub execution_time_ms: u64,
272
273    /// Number of nodes examined
274    pub nodes_examined: usize,
275
276    /// Number of edges examined
277    pub edges_examined: usize,
278}
279
280/// Storage statistics
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct StorageStats {
283    /// Number of stored graphs
284    pub graph_count: usize,
285
286    /// Total number of nodes
287    pub total_nodes: usize,
288
289    /// Total number of edges
290    pub total_edges: usize,
291
292    /// Storage size in bytes (estimate)
293    pub storage_size_bytes: usize,
294
295    /// Index size in bytes (estimate)
296    pub index_size_bytes: usize,
297
298    /// Node type distribution
299    pub node_type_distribution: HashMap<String, usize>,
300
301    /// Edge type distribution
302    pub edge_type_distribution: HashMap<String, usize>,
303
304    /// Last update timestamp
305    pub last_updated: chrono::DateTime<chrono::Utc>,
306}
307
308/// Graph storage configuration
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct GraphStorageConfig {
311    /// Enable full-text indexing
312    pub enable_text_indexing: bool,
313
314    /// Enable attribute indexing
315    pub enable_attribute_indexing: bool,
316
317    /// Maximum cache size for frequently accessed items
318    pub max_cache_size: usize,
319
320    /// Batch size for bulk operations
321    pub batch_size: usize,
322
323    /// Enable compression for storage
324    pub enable_compression: bool,
325}
326
327impl Default for GraphStorageConfig {
328    fn default() -> Self {
329        Self {
330            enable_text_indexing: true,
331            enable_attribute_indexing: true,
332            max_cache_size: 10_000,
333            batch_size: 1_000,
334            enable_compression: false,
335        }
336    }
337}
338
339impl<T> GraphIndex<T>
340where
341    T: Clone + Send + Sync,
342{
343    /// Create a new graph index
344    pub fn new() -> Self {
345        Self {
346            by_id: HashMap::new(),
347            indices: HashMap::new(),
348            text_index: HashMap::new(),
349        }
350    }
351
352    /// Add an item to the index
353    pub fn add_item(&mut self, id: String, item: T, indexable_fields: &HashMap<String, String>) {
354        // Add to primary index
355        self.by_id.insert(id.clone(), item);
356
357        // Add to secondary indices
358        for (field_name, field_value) in indexable_fields {
359            self.indices
360                .entry(field_name.clone())
361                .or_insert_with(BTreeMap::new)
362                .entry(field_value.clone())
363                .or_insert_with(HashSet::new)
364                .insert(id.clone());
365        }
366
367        // Add to text index (simple tokenization)
368        for (_, field_value) in indexable_fields {
369            let tokens = Self::tokenize(field_value);
370            for token in tokens {
371                self.text_index
372                    .entry(token.to_lowercase())
373                    .or_insert_with(HashSet::new)
374                    .insert(id.clone());
375            }
376        }
377    }
378
379    /// Remove an item from the index
380    pub fn remove_item(&mut self, id: &str) {
381        self.by_id.remove(id);
382
383        // Remove from secondary indices
384        for index in self.indices.values_mut() {
385            for ids in index.values_mut() {
386                ids.remove(id);
387            }
388        }
389
390        // Remove from text index
391        for ids in self.text_index.values_mut() {
392            ids.remove(id);
393        }
394    }
395
396    /// Get item by ID
397    pub fn get_by_id(&self, id: &str) -> Option<&T> {
398        self.by_id.get(id)
399    }
400
401    /// Find items by field value
402    pub fn find_by_field(&self, field_name: &str, field_value: &str) -> Vec<&T> {
403        if let Some(index) = self.indices.get(field_name) {
404            if let Some(ids) = index.get(field_value) {
405                return ids.iter().filter_map(|id| self.by_id.get(id)).collect();
406            }
407        }
408        Vec::new()
409    }
410
411    /// Text search
412    pub fn text_search(&self, query: &str) -> Vec<&T> {
413        let tokens = Self::tokenize(query);
414        let mut matching_ids = HashSet::new();
415
416        for (i, token) in tokens.iter().enumerate() {
417            if let Some(ids) = self.text_index.get(&token.to_lowercase()) {
418                if i == 0 {
419                    matching_ids.extend(ids.clone());
420                } else {
421                    matching_ids.retain(|id| ids.contains(id));
422                }
423            }
424        }
425
426        matching_ids
427            .iter()
428            .filter_map(|id| self.by_id.get(id))
429            .collect()
430    }
431
432    /// Get all items
433    pub fn get_all(&self) -> Vec<&T> {
434        self.by_id.values().collect()
435    }
436
437    /// Get statistics
438    pub fn stats(&self) -> HashMap<String, usize> {
439        let mut stats = HashMap::new();
440        stats.insert("total_items".to_string(), self.by_id.len());
441        stats.insert("indices_count".to_string(), self.indices.len());
442        stats.insert("text_terms".to_string(), self.text_index.len());
443        stats
444    }
445
446    /// Clear all data
447    pub fn clear(&mut self) {
448        self.by_id.clear();
449        self.indices.clear();
450        self.text_index.clear();
451    }
452
453    /// Simple tokenization
454    fn tokenize(text: &str) -> Vec<String> {
455        text.split_whitespace()
456            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
457            .filter(|s| !s.is_empty())
458            .map(|s| s.to_string())
459            .collect()
460    }
461}
462
463impl Default for NodeQuery {
464    fn default() -> Self {
465        Self {
466            node_ids: None,
467            node_types: None,
468            text_search: None,
469            attribute_filters: HashMap::new(),
470            source_document_filters: None,
471            min_confidence: None,
472            limit: None,
473            offset: None,
474        }
475    }
476}
477
478impl Default for EdgeQuery {
479    fn default() -> Self {
480        Self {
481            edge_ids: None,
482            source_node_ids: None,
483            target_node_ids: None,
484            edge_types: None,
485            text_search: None,
486            attribute_filters: HashMap::new(),
487            weight_range: None,
488            min_confidence: None,
489            limit: None,
490            offset: None,
491        }
492    }
493}
494
495impl InMemoryGraphStorage {
496    /// Create a new in-memory graph storage
497    pub fn new() -> Self {
498        Self::with_config(GraphStorageConfig::default())
499    }
500
501    /// Create with custom configuration
502    pub fn with_config(config: GraphStorageConfig) -> Self {
503        Self {
504            graphs: tokio::sync::RwLock::new(HashMap::new()),
505            node_index: tokio::sync::RwLock::new(GraphIndex::new()),
506            edge_index: tokio::sync::RwLock::new(GraphIndex::new()),
507            config,
508        }
509    }
510
511    /// Create indexable fields for a node
512    fn create_node_indexable_fields(node: &GraphNode) -> HashMap<String, String> {
513        let mut fields = HashMap::new();
514
515        fields.insert("label".to_string(), node.label.clone());
516        fields.insert(
517            "node_type".to_string(),
518            Self::node_type_string(&node.node_type),
519        );
520        fields.insert("confidence".to_string(), node.confidence.to_string());
521
522        // Add attribute fields
523        for (key, value) in &node.attributes {
524            if let Some(string_value) = value.as_str() {
525                fields.insert(format!("attr_{}", key), string_value.to_string());
526            }
527        }
528
529        fields
530    }
531
532    /// Create indexable fields for an edge
533    fn create_edge_indexable_fields(edge: &GraphEdge) -> HashMap<String, String> {
534        let mut fields = HashMap::new();
535
536        fields.insert("label".to_string(), edge.label.clone());
537        fields.insert(
538            "edge_type".to_string(),
539            Self::edge_type_string(&edge.edge_type),
540        );
541        fields.insert("source_id".to_string(), edge.source_id.clone());
542        fields.insert("target_id".to_string(), edge.target_id.clone());
543        fields.insert("weight".to_string(), edge.weight.to_string());
544        fields.insert("confidence".to_string(), edge.confidence.to_string());
545
546        // Add attribute fields
547        for (key, value) in &edge.attributes {
548            if let Some(string_value) = value.as_str() {
549                fields.insert(format!("attr_{}", key), string_value.to_string());
550            }
551        }
552
553        fields
554    }
555
556    /// Convert node type to string for indexing
557    fn node_type_string(node_type: &NodeType) -> String {
558        match node_type {
559            NodeType::Entity(entity_type) => format!("Entity({})", entity_type),
560            NodeType::Concept => "Concept".to_string(),
561            NodeType::Document => "Document".to_string(),
562            NodeType::DocumentChunk => "DocumentChunk".to_string(),
563            NodeType::Keyword => "Keyword".to_string(),
564            NodeType::Custom(custom_type) => format!("Custom({})", custom_type),
565        }
566    }
567
568    /// Convert edge type to string for indexing
569    fn edge_type_string(edge_type: &EdgeType) -> String {
570        match edge_type {
571            EdgeType::Semantic(relation) => format!("Semantic({})", relation),
572            EdgeType::Hierarchical => "Hierarchical".to_string(),
573            EdgeType::Contains => "Contains".to_string(),
574            EdgeType::References => "References".to_string(),
575            EdgeType::CoOccurs => "CoOccurs".to_string(),
576            EdgeType::Similar => "Similar".to_string(),
577            EdgeType::Custom(custom_type) => format!("Custom({})", custom_type),
578        }
579    }
580
581    /// Apply filters to node query results
582    fn apply_node_filters(&self, nodes: Vec<&GraphNode>, query: &NodeQuery) -> Vec<GraphNode> {
583        let mut result: Vec<_> = nodes.into_iter().cloned().collect();
584
585        // Apply node type filter
586        if let Some(node_types) = &query.node_types {
587            result.retain(|node| node_types.contains(&node.node_type));
588        }
589
590        // Apply confidence filter
591        if let Some(min_confidence) = query.min_confidence {
592            result.retain(|node| node.confidence >= min_confidence);
593        }
594
595        // Apply attribute filters
596        for (attr_key, attr_value) in &query.attribute_filters {
597            result.retain(|node| {
598                node.attributes
599                    .get(attr_key)
600                    .map_or(false, |v| v == attr_value)
601            });
602        }
603
604        // Apply source document filter
605        if let Some(source_docs) = &query.source_document_filters {
606            result.retain(|node| {
607                node.source_documents
608                    .iter()
609                    .any(|doc| source_docs.contains(doc))
610            });
611        }
612
613        // Apply pagination
614        if let Some(offset) = query.offset {
615            if offset < result.len() {
616                result.drain(0..offset);
617            } else {
618                result.clear();
619            }
620        }
621
622        if let Some(limit) = query.limit {
623            result.truncate(limit);
624        }
625
626        result
627    }
628
629    /// Apply filters to edge query results
630    fn apply_edge_filters(&self, edges: Vec<&GraphEdge>, query: &EdgeQuery) -> Vec<GraphEdge> {
631        let mut result: Vec<_> = edges.into_iter().cloned().collect();
632
633        // Apply source/target node filters
634        if let Some(source_ids) = &query.source_node_ids {
635            result.retain(|edge| source_ids.contains(&edge.source_id));
636        }
637
638        if let Some(target_ids) = &query.target_node_ids {
639            result.retain(|edge| target_ids.contains(&edge.target_id));
640        }
641
642        // Apply edge type filter
643        if let Some(edge_types) = &query.edge_types {
644            result.retain(|edge| edge_types.contains(&edge.edge_type));
645        }
646
647        // Apply weight range filter
648        if let Some((min_weight, max_weight)) = query.weight_range {
649            result.retain(|edge| edge.weight >= min_weight && edge.weight <= max_weight);
650        }
651
652        // Apply confidence filter
653        if let Some(min_confidence) = query.min_confidence {
654            result.retain(|edge| edge.confidence >= min_confidence);
655        }
656
657        // Apply attribute filters
658        for (attr_key, attr_value) in &query.attribute_filters {
659            result.retain(|edge| {
660                edge.attributes
661                    .get(attr_key)
662                    .map_or(false, |v| v == attr_value)
663            });
664        }
665
666        // Apply pagination
667        if let Some(offset) = query.offset {
668            if offset < result.len() {
669                result.drain(0..offset);
670            } else {
671                result.clear();
672            }
673        }
674
675        if let Some(limit) = query.limit {
676            result.truncate(limit);
677        }
678
679        result
680    }
681}
682
683#[async_trait]
684impl GraphStorage for InMemoryGraphStorage {
685    async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()> {
686        // Store graph
687        let graph_id = uuid::Uuid::new_v4().to_string();
688        self.graphs.write().await.insert(graph_id, graph.clone());
689
690        // Update indices
691        self.store_nodes(&graph.nodes.values().cloned().collect::<Vec<_>>())
692            .await?;
693        self.store_edges(&graph.edges.values().cloned().collect::<Vec<_>>())
694            .await?;
695
696        Ok(())
697    }
698
699    async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph> {
700        self.graphs
701            .read()
702            .await
703            .get(graph_id)
704            .cloned()
705            .ok_or_else(|| {
706                GraphError::Storage {
707                    operation: "load_graph".to_string(),
708                    message: format!("Graph '{}' not found", graph_id),
709                }
710                .into()
711            })
712    }
713
714    async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()> {
715        let mut node_index = self.node_index.write().await;
716
717        for node in nodes {
718            let indexable_fields = Self::create_node_indexable_fields(node);
719            node_index.add_item(node.id.clone(), node.clone(), &indexable_fields);
720        }
721
722        Ok(())
723    }
724
725    async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()> {
726        let mut edge_index = self.edge_index.write().await;
727
728        for edge in edges {
729            let indexable_fields = Self::create_edge_indexable_fields(edge);
730            edge_index.add_item(edge.id.clone(), edge.clone(), &indexable_fields);
731        }
732
733        Ok(())
734    }
735
736    async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>> {
737        let node_index = self.node_index.read().await;
738        let mut candidates = Vec::new();
739
740        if let Some(node_ids) = &query.node_ids {
741            // Query by specific IDs
742            for node_id in node_ids {
743                if let Some(node) = node_index.get_by_id(node_id) {
744                    candidates.push(node);
745                }
746            }
747        } else if let Some(text_query) = &query.text_search {
748            // Text search
749            candidates = node_index.text_search(text_query);
750        } else {
751            // Get all nodes
752            candidates = node_index.get_all();
753        }
754
755        Ok(self.apply_node_filters(candidates, query))
756    }
757
758    async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>> {
759        let edge_index = self.edge_index.read().await;
760        let mut candidates = Vec::new();
761
762        if let Some(edge_ids) = &query.edge_ids {
763            // Query by specific IDs
764            for edge_id in edge_ids {
765                if let Some(edge) = edge_index.get_by_id(edge_id) {
766                    candidates.push(edge);
767                }
768            }
769        } else if let Some(text_query) = &query.text_search {
770            // Text search
771            candidates = edge_index.text_search(text_query);
772        } else {
773            // Get all edges
774            candidates = edge_index.get_all();
775        }
776
777        Ok(self.apply_edge_filters(candidates, query))
778    }
779
780    async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>> {
781        let node_index = self.node_index.read().await;
782        Ok(node_index.get_by_id(node_id).cloned())
783    }
784
785    async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>> {
786        let edge_index = self.edge_index.read().await;
787        Ok(edge_index.get_by_id(edge_id).cloned())
788    }
789
790    async fn get_neighbors(
791        &self,
792        node_id: &str,
793        direction: EdgeDirection,
794    ) -> RragResult<Vec<GraphNode>> {
795        let edge_index = self.edge_index.read().await;
796        let node_index = self.node_index.read().await;
797        let mut neighbor_ids = HashSet::new();
798
799        match direction {
800            EdgeDirection::Outgoing => {
801                let outgoing_edges = edge_index.find_by_field("source_id", node_id);
802                for edge in outgoing_edges {
803                    neighbor_ids.insert(&edge.target_id);
804                }
805            }
806            EdgeDirection::Incoming => {
807                let incoming_edges = edge_index.find_by_field("target_id", node_id);
808                for edge in incoming_edges {
809                    neighbor_ids.insert(&edge.source_id);
810                }
811            }
812            EdgeDirection::Both => {
813                let outgoing_edges = edge_index.find_by_field("source_id", node_id);
814                for edge in outgoing_edges {
815                    neighbor_ids.insert(&edge.target_id);
816                }
817                let incoming_edges = edge_index.find_by_field("target_id", node_id);
818                for edge in incoming_edges {
819                    neighbor_ids.insert(&edge.source_id);
820                }
821            }
822        }
823
824        let neighbors = neighbor_ids
825            .into_iter()
826            .filter_map(|id| node_index.get_by_id(id))
827            .cloned()
828            .collect();
829
830        Ok(neighbors)
831    }
832
833    async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()> {
834        let mut node_index = self.node_index.write().await;
835
836        for node_id in node_ids {
837            node_index.remove_item(node_id);
838        }
839
840        Ok(())
841    }
842
843    async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()> {
844        let mut edge_index = self.edge_index.write().await;
845
846        for edge_id in edge_ids {
847            edge_index.remove_item(edge_id);
848        }
849
850        Ok(())
851    }
852
853    async fn clear(&self) -> RragResult<()> {
854        self.graphs.write().await.clear();
855        self.node_index.write().await.clear();
856        self.edge_index.write().await.clear();
857        Ok(())
858    }
859
860    async fn get_stats(&self) -> RragResult<StorageStats> {
861        let graphs = self.graphs.read().await;
862        let node_index = self.node_index.read().await;
863        let edge_index = self.edge_index.read().await;
864
865        let graph_count = graphs.len();
866        let total_nodes = node_index.by_id.len();
867        let total_edges = edge_index.by_id.len();
868
869        // Calculate node type distribution
870        let mut node_type_distribution = HashMap::new();
871        for node in node_index.by_id.values() {
872            let type_key = Self::node_type_string(&node.node_type);
873            *node_type_distribution.entry(type_key).or_insert(0) += 1;
874        }
875
876        // Calculate edge type distribution
877        let mut edge_type_distribution = HashMap::new();
878        for edge in edge_index.by_id.values() {
879            let type_key = Self::edge_type_string(&edge.edge_type);
880            *edge_type_distribution.entry(type_key).or_insert(0) += 1;
881        }
882
883        // Rough size estimates
884        let storage_size_bytes = (total_nodes + total_edges) * 1000; // Rough estimate
885        let index_size_bytes = (node_index.indices.len() + edge_index.indices.len()) * 100;
886
887        Ok(StorageStats {
888            graph_count,
889            total_nodes,
890            total_edges,
891            storage_size_bytes,
892            index_size_bytes,
893            node_type_distribution,
894            edge_type_distribution,
895            last_updated: chrono::Utc::now(),
896        })
897    }
898}
899
900impl Default for InMemoryGraphStorage {
901    fn default() -> Self {
902        Self::new()
903    }
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909    use crate::graph_retrieval::{EdgeType, NodeType};
910
911    #[tokio::test]
912    async fn test_in_memory_graph_storage() {
913        let storage = InMemoryGraphStorage::new();
914
915        // Create test nodes
916        let node1 = GraphNode::new("Test Node 1", NodeType::Concept);
917        let node2 = GraphNode::new("Test Node 2", NodeType::Entity("Person".to_string()));
918
919        let node1_id = node1.id.clone();
920        let node2_id = node2.id.clone();
921
922        // Store nodes
923        storage
924            .store_nodes(&[node1.clone(), node2.clone()])
925            .await
926            .unwrap();
927
928        // Query nodes
929        let mut query = NodeQuery::default();
930        query.text_search = Some("Test".to_string());
931
932        let results = storage.query_nodes(&query).await.unwrap();
933        assert_eq!(results.len(), 2);
934
935        // Get specific node
936        let retrieved_node = storage.get_node(&node1_id).await.unwrap();
937        assert!(retrieved_node.is_some());
938        assert_eq!(retrieved_node.unwrap().label, "Test Node 1");
939    }
940
941    #[tokio::test]
942    async fn test_edge_storage_and_queries() {
943        let storage = InMemoryGraphStorage::new();
944
945        // Create test nodes
946        let node1 = GraphNode::new("Node 1", NodeType::Concept);
947        let node2 = GraphNode::new("Node 2", NodeType::Concept);
948
949        let node1_id = node1.id.clone();
950        let node2_id = node2.id.clone();
951
952        storage.store_nodes(&[node1, node2]).await.unwrap();
953
954        // Create test edge
955        let edge = GraphEdge::new(
956            node1_id.clone(),
957            node2_id.clone(),
958            "test_relation",
959            EdgeType::Similar,
960        );
961        let edge_id = edge.id.clone();
962
963        // Store edge
964        storage.store_edges(&[edge]).await.unwrap();
965
966        // Query edges
967        let mut query = EdgeQuery::default();
968        query.source_node_ids = Some(vec![node1_id.clone()]);
969
970        let results = storage.query_edges(&query).await.unwrap();
971        assert_eq!(results.len(), 1);
972        assert_eq!(results[0].source_id, node1_id);
973        assert_eq!(results[0].target_id, node2_id);
974
975        // Get neighbors
976        let neighbors = storage
977            .get_neighbors(&node1_id, EdgeDirection::Outgoing)
978            .await
979            .unwrap();
980        assert_eq!(neighbors.len(), 1);
981        assert_eq!(neighbors[0].id, node2_id);
982    }
983
984    #[tokio::test]
985    async fn test_storage_stats() {
986        let storage = InMemoryGraphStorage::new();
987
988        // Add some test data
989        let nodes = vec![
990            GraphNode::new("Node 1", NodeType::Concept),
991            GraphNode::new("Node 2", NodeType::Entity("Person".to_string())),
992            GraphNode::new("Node 3", NodeType::Document),
993        ];
994
995        storage.store_nodes(&nodes).await.unwrap();
996
997        let stats = storage.get_stats().await.unwrap();
998        assert_eq!(stats.total_nodes, 3);
999        assert_eq!(stats.total_edges, 0);
1000        assert!(stats.node_type_distribution.contains_key("Concept"));
1001        assert!(stats.node_type_distribution.contains_key("Entity(Person)"));
1002    }
1003}