1use super::{EdgeType, GraphEdge, GraphError, GraphNode, KnowledgeGraph, NodeType};
7use crate::RragResult;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::{BTreeMap, HashMap, HashSet};
11
12#[async_trait]
14pub trait GraphStorage: Send + Sync {
15 async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()>;
17
18 async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph>;
20
21 async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()>;
23
24 async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()>;
26
27 async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>>;
29
30 async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>>;
32
33 async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>>;
35
36 async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>>;
38
39 async fn get_neighbors(
41 &self,
42 node_id: &str,
43 direction: EdgeDirection,
44 ) -> RragResult<Vec<GraphNode>>;
45
46 async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()>;
48
49 async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()>;
51
52 async fn clear(&self) -> RragResult<()>;
54
55 async fn get_stats(&self) -> RragResult<StorageStats>;
57}
58
59pub struct InMemoryGraphStorage {
61 graphs: tokio::sync::RwLock<HashMap<String, KnowledgeGraph>>,
63
64 node_index: tokio::sync::RwLock<GraphIndex<GraphNode>>,
66
67 edge_index: tokio::sync::RwLock<GraphIndex<GraphEdge>>,
69
70 config: GraphStorageConfig,
72}
73
74#[derive(Debug, Clone)]
76pub struct GraphIndex<T> {
77 by_id: HashMap<String, T>,
79
80 indices: HashMap<String, BTreeMap<String, HashSet<String>>>,
82
83 text_index: HashMap<String, HashSet<String>>,
85}
86
87#[derive(Debug, Clone)]
89pub struct NodeQuery {
90 pub node_ids: Option<Vec<String>>,
92
93 pub node_types: Option<Vec<NodeType>>,
95
96 pub text_search: Option<String>,
98
99 pub attribute_filters: HashMap<String, serde_json::Value>,
101
102 pub source_document_filters: Option<Vec<String>>,
104
105 pub min_confidence: Option<f32>,
107
108 pub limit: Option<usize>,
110
111 pub offset: Option<usize>,
113}
114
115#[derive(Debug, Clone)]
117pub struct EdgeQuery {
118 pub edge_ids: Option<Vec<String>>,
120
121 pub source_node_ids: Option<Vec<String>>,
123
124 pub target_node_ids: Option<Vec<String>>,
126
127 pub edge_types: Option<Vec<EdgeType>>,
129
130 pub text_search: Option<String>,
132
133 pub attribute_filters: HashMap<String, serde_json::Value>,
135
136 pub weight_range: Option<(f32, f32)>,
138
139 pub min_confidence: Option<f32>,
141
142 pub limit: Option<usize>,
144
145 pub offset: Option<usize>,
147}
148
149#[derive(Debug, Clone, Copy)]
151pub enum EdgeDirection {
152 Outgoing,
154
155 Incoming,
157
158 Both,
160}
161
162#[derive(Debug, Clone)]
164pub struct GraphQuery {
165 pub start_nodes: Vec<String>,
167
168 pub pattern: GraphPattern,
170
171 pub max_depth: usize,
173
174 pub limit: Option<usize>,
176}
177
178#[derive(Debug, Clone)]
180pub struct GraphPattern {
181 pub node_patterns: Vec<NodePattern>,
183
184 pub edge_patterns: Vec<EdgePattern>,
186
187 pub constraints: Vec<PatternConstraint>,
189}
190
191#[derive(Debug, Clone)]
193pub struct NodePattern {
194 pub variable: String,
196
197 pub node_type: Option<NodeType>,
199
200 pub label_pattern: Option<String>,
202
203 pub attribute_constraints: HashMap<String, serde_json::Value>,
205}
206
207#[derive(Debug, Clone)]
209pub struct EdgePattern {
210 pub source_variable: String,
212
213 pub target_variable: String,
215
216 pub edge_type: Option<EdgeType>,
218
219 pub label_pattern: Option<String>,
221
222 pub attribute_constraints: HashMap<String, serde_json::Value>,
224}
225
226#[derive(Debug, Clone)]
228pub enum PatternConstraint {
229 Distance {
231 var1: String,
232 var2: String,
233 max_distance: usize,
234 },
235
236 Path {
238 start_var: String,
239 end_var: String,
240 path_type: PathType,
241 },
242
243 Count {
245 variable: String,
246 min_count: usize,
247 max_count: Option<usize>,
248 },
249}
250
251#[derive(Debug, Clone)]
253pub enum PathType {
254 Any,
256
257 Shortest,
259
260 EdgeTypes(Vec<EdgeType>),
262}
263
264#[derive(Debug, Clone)]
266pub struct GraphQueryResult {
267 pub bindings: Vec<HashMap<String, String>>,
269
270 pub execution_time_ms: u64,
272
273 pub nodes_examined: usize,
275
276 pub edges_examined: usize,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct StorageStats {
283 pub graph_count: usize,
285
286 pub total_nodes: usize,
288
289 pub total_edges: usize,
291
292 pub storage_size_bytes: usize,
294
295 pub index_size_bytes: usize,
297
298 pub node_type_distribution: HashMap<String, usize>,
300
301 pub edge_type_distribution: HashMap<String, usize>,
303
304 pub last_updated: chrono::DateTime<chrono::Utc>,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct GraphStorageConfig {
311 pub enable_text_indexing: bool,
313
314 pub enable_attribute_indexing: bool,
316
317 pub max_cache_size: usize,
319
320 pub batch_size: usize,
322
323 pub enable_compression: bool,
325}
326
327impl Default for GraphStorageConfig {
328 fn default() -> Self {
329 Self {
330 enable_text_indexing: true,
331 enable_attribute_indexing: true,
332 max_cache_size: 10_000,
333 batch_size: 1_000,
334 enable_compression: false,
335 }
336 }
337}
338
339impl<T> GraphIndex<T>
340where
341 T: Clone + Send + Sync,
342{
343 pub fn new() -> Self {
345 Self {
346 by_id: HashMap::new(),
347 indices: HashMap::new(),
348 text_index: HashMap::new(),
349 }
350 }
351
352 pub fn add_item(&mut self, id: String, item: T, indexable_fields: &HashMap<String, String>) {
354 self.by_id.insert(id.clone(), item);
356
357 for (field_name, field_value) in indexable_fields {
359 self.indices
360 .entry(field_name.clone())
361 .or_insert_with(BTreeMap::new)
362 .entry(field_value.clone())
363 .or_insert_with(HashSet::new)
364 .insert(id.clone());
365 }
366
367 for (_, field_value) in indexable_fields {
369 let tokens = Self::tokenize(field_value);
370 for token in tokens {
371 self.text_index
372 .entry(token.to_lowercase())
373 .or_insert_with(HashSet::new)
374 .insert(id.clone());
375 }
376 }
377 }
378
379 pub fn remove_item(&mut self, id: &str) {
381 self.by_id.remove(id);
382
383 for index in self.indices.values_mut() {
385 for ids in index.values_mut() {
386 ids.remove(id);
387 }
388 }
389
390 for ids in self.text_index.values_mut() {
392 ids.remove(id);
393 }
394 }
395
396 pub fn get_by_id(&self, id: &str) -> Option<&T> {
398 self.by_id.get(id)
399 }
400
401 pub fn find_by_field(&self, field_name: &str, field_value: &str) -> Vec<&T> {
403 if let Some(index) = self.indices.get(field_name) {
404 if let Some(ids) = index.get(field_value) {
405 return ids.iter().filter_map(|id| self.by_id.get(id)).collect();
406 }
407 }
408 Vec::new()
409 }
410
411 pub fn text_search(&self, query: &str) -> Vec<&T> {
413 let tokens = Self::tokenize(query);
414 let mut matching_ids = HashSet::new();
415
416 for (i, token) in tokens.iter().enumerate() {
417 if let Some(ids) = self.text_index.get(&token.to_lowercase()) {
418 if i == 0 {
419 matching_ids.extend(ids.clone());
420 } else {
421 matching_ids.retain(|id| ids.contains(id));
422 }
423 }
424 }
425
426 matching_ids
427 .iter()
428 .filter_map(|id| self.by_id.get(id))
429 .collect()
430 }
431
432 pub fn get_all(&self) -> Vec<&T> {
434 self.by_id.values().collect()
435 }
436
437 pub fn stats(&self) -> HashMap<String, usize> {
439 let mut stats = HashMap::new();
440 stats.insert("total_items".to_string(), self.by_id.len());
441 stats.insert("indices_count".to_string(), self.indices.len());
442 stats.insert("text_terms".to_string(), self.text_index.len());
443 stats
444 }
445
446 pub fn clear(&mut self) {
448 self.by_id.clear();
449 self.indices.clear();
450 self.text_index.clear();
451 }
452
453 fn tokenize(text: &str) -> Vec<String> {
455 text.split_whitespace()
456 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
457 .filter(|s| !s.is_empty())
458 .map(|s| s.to_string())
459 .collect()
460 }
461}
462
463impl Default for NodeQuery {
464 fn default() -> Self {
465 Self {
466 node_ids: None,
467 node_types: None,
468 text_search: None,
469 attribute_filters: HashMap::new(),
470 source_document_filters: None,
471 min_confidence: None,
472 limit: None,
473 offset: None,
474 }
475 }
476}
477
478impl Default for EdgeQuery {
479 fn default() -> Self {
480 Self {
481 edge_ids: None,
482 source_node_ids: None,
483 target_node_ids: None,
484 edge_types: None,
485 text_search: None,
486 attribute_filters: HashMap::new(),
487 weight_range: None,
488 min_confidence: None,
489 limit: None,
490 offset: None,
491 }
492 }
493}
494
495impl InMemoryGraphStorage {
496 pub fn new() -> Self {
498 Self::with_config(GraphStorageConfig::default())
499 }
500
501 pub fn with_config(config: GraphStorageConfig) -> Self {
503 Self {
504 graphs: tokio::sync::RwLock::new(HashMap::new()),
505 node_index: tokio::sync::RwLock::new(GraphIndex::new()),
506 edge_index: tokio::sync::RwLock::new(GraphIndex::new()),
507 config,
508 }
509 }
510
511 fn create_node_indexable_fields(node: &GraphNode) -> HashMap<String, String> {
513 let mut fields = HashMap::new();
514
515 fields.insert("label".to_string(), node.label.clone());
516 fields.insert(
517 "node_type".to_string(),
518 Self::node_type_string(&node.node_type),
519 );
520 fields.insert("confidence".to_string(), node.confidence.to_string());
521
522 for (key, value) in &node.attributes {
524 if let Some(string_value) = value.as_str() {
525 fields.insert(format!("attr_{}", key), string_value.to_string());
526 }
527 }
528
529 fields
530 }
531
532 fn create_edge_indexable_fields(edge: &GraphEdge) -> HashMap<String, String> {
534 let mut fields = HashMap::new();
535
536 fields.insert("label".to_string(), edge.label.clone());
537 fields.insert(
538 "edge_type".to_string(),
539 Self::edge_type_string(&edge.edge_type),
540 );
541 fields.insert("source_id".to_string(), edge.source_id.clone());
542 fields.insert("target_id".to_string(), edge.target_id.clone());
543 fields.insert("weight".to_string(), edge.weight.to_string());
544 fields.insert("confidence".to_string(), edge.confidence.to_string());
545
546 for (key, value) in &edge.attributes {
548 if let Some(string_value) = value.as_str() {
549 fields.insert(format!("attr_{}", key), string_value.to_string());
550 }
551 }
552
553 fields
554 }
555
556 fn node_type_string(node_type: &NodeType) -> String {
558 match node_type {
559 NodeType::Entity(entity_type) => format!("Entity({})", entity_type),
560 NodeType::Concept => "Concept".to_string(),
561 NodeType::Document => "Document".to_string(),
562 NodeType::DocumentChunk => "DocumentChunk".to_string(),
563 NodeType::Keyword => "Keyword".to_string(),
564 NodeType::Custom(custom_type) => format!("Custom({})", custom_type),
565 }
566 }
567
568 fn edge_type_string(edge_type: &EdgeType) -> String {
570 match edge_type {
571 EdgeType::Semantic(relation) => format!("Semantic({})", relation),
572 EdgeType::Hierarchical => "Hierarchical".to_string(),
573 EdgeType::Contains => "Contains".to_string(),
574 EdgeType::References => "References".to_string(),
575 EdgeType::CoOccurs => "CoOccurs".to_string(),
576 EdgeType::Similar => "Similar".to_string(),
577 EdgeType::Custom(custom_type) => format!("Custom({})", custom_type),
578 }
579 }
580
581 fn apply_node_filters(&self, nodes: Vec<&GraphNode>, query: &NodeQuery) -> Vec<GraphNode> {
583 let mut result: Vec<_> = nodes.into_iter().cloned().collect();
584
585 if let Some(node_types) = &query.node_types {
587 result.retain(|node| node_types.contains(&node.node_type));
588 }
589
590 if let Some(min_confidence) = query.min_confidence {
592 result.retain(|node| node.confidence >= min_confidence);
593 }
594
595 for (attr_key, attr_value) in &query.attribute_filters {
597 result.retain(|node| {
598 node.attributes
599 .get(attr_key)
600 .map_or(false, |v| v == attr_value)
601 });
602 }
603
604 if let Some(source_docs) = &query.source_document_filters {
606 result.retain(|node| {
607 node.source_documents
608 .iter()
609 .any(|doc| source_docs.contains(doc))
610 });
611 }
612
613 if let Some(offset) = query.offset {
615 if offset < result.len() {
616 result.drain(0..offset);
617 } else {
618 result.clear();
619 }
620 }
621
622 if let Some(limit) = query.limit {
623 result.truncate(limit);
624 }
625
626 result
627 }
628
629 fn apply_edge_filters(&self, edges: Vec<&GraphEdge>, query: &EdgeQuery) -> Vec<GraphEdge> {
631 let mut result: Vec<_> = edges.into_iter().cloned().collect();
632
633 if let Some(source_ids) = &query.source_node_ids {
635 result.retain(|edge| source_ids.contains(&edge.source_id));
636 }
637
638 if let Some(target_ids) = &query.target_node_ids {
639 result.retain(|edge| target_ids.contains(&edge.target_id));
640 }
641
642 if let Some(edge_types) = &query.edge_types {
644 result.retain(|edge| edge_types.contains(&edge.edge_type));
645 }
646
647 if let Some((min_weight, max_weight)) = query.weight_range {
649 result.retain(|edge| edge.weight >= min_weight && edge.weight <= max_weight);
650 }
651
652 if let Some(min_confidence) = query.min_confidence {
654 result.retain(|edge| edge.confidence >= min_confidence);
655 }
656
657 for (attr_key, attr_value) in &query.attribute_filters {
659 result.retain(|edge| {
660 edge.attributes
661 .get(attr_key)
662 .map_or(false, |v| v == attr_value)
663 });
664 }
665
666 if let Some(offset) = query.offset {
668 if offset < result.len() {
669 result.drain(0..offset);
670 } else {
671 result.clear();
672 }
673 }
674
675 if let Some(limit) = query.limit {
676 result.truncate(limit);
677 }
678
679 result
680 }
681}
682
683#[async_trait]
684impl GraphStorage for InMemoryGraphStorage {
685 async fn store_graph(&self, graph: &KnowledgeGraph) -> RragResult<()> {
686 let graph_id = uuid::Uuid::new_v4().to_string();
688 self.graphs.write().await.insert(graph_id, graph.clone());
689
690 self.store_nodes(&graph.nodes.values().cloned().collect::<Vec<_>>())
692 .await?;
693 self.store_edges(&graph.edges.values().cloned().collect::<Vec<_>>())
694 .await?;
695
696 Ok(())
697 }
698
699 async fn load_graph(&self, graph_id: &str) -> RragResult<KnowledgeGraph> {
700 self.graphs
701 .read()
702 .await
703 .get(graph_id)
704 .cloned()
705 .ok_or_else(|| {
706 GraphError::Storage {
707 operation: "load_graph".to_string(),
708 message: format!("Graph '{}' not found", graph_id),
709 }
710 .into()
711 })
712 }
713
714 async fn store_nodes(&self, nodes: &[GraphNode]) -> RragResult<()> {
715 let mut node_index = self.node_index.write().await;
716
717 for node in nodes {
718 let indexable_fields = Self::create_node_indexable_fields(node);
719 node_index.add_item(node.id.clone(), node.clone(), &indexable_fields);
720 }
721
722 Ok(())
723 }
724
725 async fn store_edges(&self, edges: &[GraphEdge]) -> RragResult<()> {
726 let mut edge_index = self.edge_index.write().await;
727
728 for edge in edges {
729 let indexable_fields = Self::create_edge_indexable_fields(edge);
730 edge_index.add_item(edge.id.clone(), edge.clone(), &indexable_fields);
731 }
732
733 Ok(())
734 }
735
736 async fn query_nodes(&self, query: &NodeQuery) -> RragResult<Vec<GraphNode>> {
737 let node_index = self.node_index.read().await;
738 let mut candidates = Vec::new();
739
740 if let Some(node_ids) = &query.node_ids {
741 for node_id in node_ids {
743 if let Some(node) = node_index.get_by_id(node_id) {
744 candidates.push(node);
745 }
746 }
747 } else if let Some(text_query) = &query.text_search {
748 candidates = node_index.text_search(text_query);
750 } else {
751 candidates = node_index.get_all();
753 }
754
755 Ok(self.apply_node_filters(candidates, query))
756 }
757
758 async fn query_edges(&self, query: &EdgeQuery) -> RragResult<Vec<GraphEdge>> {
759 let edge_index = self.edge_index.read().await;
760 let mut candidates = Vec::new();
761
762 if let Some(edge_ids) = &query.edge_ids {
763 for edge_id in edge_ids {
765 if let Some(edge) = edge_index.get_by_id(edge_id) {
766 candidates.push(edge);
767 }
768 }
769 } else if let Some(text_query) = &query.text_search {
770 candidates = edge_index.text_search(text_query);
772 } else {
773 candidates = edge_index.get_all();
775 }
776
777 Ok(self.apply_edge_filters(candidates, query))
778 }
779
780 async fn get_node(&self, node_id: &str) -> RragResult<Option<GraphNode>> {
781 let node_index = self.node_index.read().await;
782 Ok(node_index.get_by_id(node_id).cloned())
783 }
784
785 async fn get_edge(&self, edge_id: &str) -> RragResult<Option<GraphEdge>> {
786 let edge_index = self.edge_index.read().await;
787 Ok(edge_index.get_by_id(edge_id).cloned())
788 }
789
790 async fn get_neighbors(
791 &self,
792 node_id: &str,
793 direction: EdgeDirection,
794 ) -> RragResult<Vec<GraphNode>> {
795 let edge_index = self.edge_index.read().await;
796 let node_index = self.node_index.read().await;
797 let mut neighbor_ids = HashSet::new();
798
799 match direction {
800 EdgeDirection::Outgoing => {
801 let outgoing_edges = edge_index.find_by_field("source_id", node_id);
802 for edge in outgoing_edges {
803 neighbor_ids.insert(&edge.target_id);
804 }
805 }
806 EdgeDirection::Incoming => {
807 let incoming_edges = edge_index.find_by_field("target_id", node_id);
808 for edge in incoming_edges {
809 neighbor_ids.insert(&edge.source_id);
810 }
811 }
812 EdgeDirection::Both => {
813 let outgoing_edges = edge_index.find_by_field("source_id", node_id);
814 for edge in outgoing_edges {
815 neighbor_ids.insert(&edge.target_id);
816 }
817 let incoming_edges = edge_index.find_by_field("target_id", node_id);
818 for edge in incoming_edges {
819 neighbor_ids.insert(&edge.source_id);
820 }
821 }
822 }
823
824 let neighbors = neighbor_ids
825 .into_iter()
826 .filter_map(|id| node_index.get_by_id(id))
827 .cloned()
828 .collect();
829
830 Ok(neighbors)
831 }
832
833 async fn delete_nodes(&self, node_ids: &[String]) -> RragResult<()> {
834 let mut node_index = self.node_index.write().await;
835
836 for node_id in node_ids {
837 node_index.remove_item(node_id);
838 }
839
840 Ok(())
841 }
842
843 async fn delete_edges(&self, edge_ids: &[String]) -> RragResult<()> {
844 let mut edge_index = self.edge_index.write().await;
845
846 for edge_id in edge_ids {
847 edge_index.remove_item(edge_id);
848 }
849
850 Ok(())
851 }
852
853 async fn clear(&self) -> RragResult<()> {
854 self.graphs.write().await.clear();
855 self.node_index.write().await.clear();
856 self.edge_index.write().await.clear();
857 Ok(())
858 }
859
860 async fn get_stats(&self) -> RragResult<StorageStats> {
861 let graphs = self.graphs.read().await;
862 let node_index = self.node_index.read().await;
863 let edge_index = self.edge_index.read().await;
864
865 let graph_count = graphs.len();
866 let total_nodes = node_index.by_id.len();
867 let total_edges = edge_index.by_id.len();
868
869 let mut node_type_distribution = HashMap::new();
871 for node in node_index.by_id.values() {
872 let type_key = Self::node_type_string(&node.node_type);
873 *node_type_distribution.entry(type_key).or_insert(0) += 1;
874 }
875
876 let mut edge_type_distribution = HashMap::new();
878 for edge in edge_index.by_id.values() {
879 let type_key = Self::edge_type_string(&edge.edge_type);
880 *edge_type_distribution.entry(type_key).or_insert(0) += 1;
881 }
882
883 let storage_size_bytes = (total_nodes + total_edges) * 1000; let index_size_bytes = (node_index.indices.len() + edge_index.indices.len()) * 100;
886
887 Ok(StorageStats {
888 graph_count,
889 total_nodes,
890 total_edges,
891 storage_size_bytes,
892 index_size_bytes,
893 node_type_distribution,
894 edge_type_distribution,
895 last_updated: chrono::Utc::now(),
896 })
897 }
898}
899
900impl Default for InMemoryGraphStorage {
901 fn default() -> Self {
902 Self::new()
903 }
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909 use crate::graph_retrieval::{EdgeType, NodeType};
910
911 #[tokio::test]
912 async fn test_in_memory_graph_storage() {
913 let storage = InMemoryGraphStorage::new();
914
915 let node1 = GraphNode::new("Test Node 1", NodeType::Concept);
917 let node2 = GraphNode::new("Test Node 2", NodeType::Entity("Person".to_string()));
918
919 let node1_id = node1.id.clone();
920 let node2_id = node2.id.clone();
921
922 storage
924 .store_nodes(&[node1.clone(), node2.clone()])
925 .await
926 .unwrap();
927
928 let mut query = NodeQuery::default();
930 query.text_search = Some("Test".to_string());
931
932 let results = storage.query_nodes(&query).await.unwrap();
933 assert_eq!(results.len(), 2);
934
935 let retrieved_node = storage.get_node(&node1_id).await.unwrap();
937 assert!(retrieved_node.is_some());
938 assert_eq!(retrieved_node.unwrap().label, "Test Node 1");
939 }
940
941 #[tokio::test]
942 async fn test_edge_storage_and_queries() {
943 let storage = InMemoryGraphStorage::new();
944
945 let node1 = GraphNode::new("Node 1", NodeType::Concept);
947 let node2 = GraphNode::new("Node 2", NodeType::Concept);
948
949 let node1_id = node1.id.clone();
950 let node2_id = node2.id.clone();
951
952 storage.store_nodes(&[node1, node2]).await.unwrap();
953
954 let edge = GraphEdge::new(
956 node1_id.clone(),
957 node2_id.clone(),
958 "test_relation",
959 EdgeType::Similar,
960 );
961 let edge_id = edge.id.clone();
962
963 storage.store_edges(&[edge]).await.unwrap();
965
966 let mut query = EdgeQuery::default();
968 query.source_node_ids = Some(vec![node1_id.clone()]);
969
970 let results = storage.query_edges(&query).await.unwrap();
971 assert_eq!(results.len(), 1);
972 assert_eq!(results[0].source_id, node1_id);
973 assert_eq!(results[0].target_id, node2_id);
974
975 let neighbors = storage
977 .get_neighbors(&node1_id, EdgeDirection::Outgoing)
978 .await
979 .unwrap();
980 assert_eq!(neighbors.len(), 1);
981 assert_eq!(neighbors[0].id, node2_id);
982 }
983
984 #[tokio::test]
985 async fn test_storage_stats() {
986 let storage = InMemoryGraphStorage::new();
987
988 let nodes = vec![
990 GraphNode::new("Node 1", NodeType::Concept),
991 GraphNode::new("Node 2", NodeType::Entity("Person".to_string())),
992 GraphNode::new("Node 3", NodeType::Document),
993 ];
994
995 storage.store_nodes(&nodes).await.unwrap();
996
997 let stats = storage.get_stats().await.unwrap();
998 assert_eq!(stats.total_nodes, 3);
999 assert_eq!(stats.total_edges, 0);
1000 assert!(stats.node_type_distribution.contains_key("Concept"));
1001 assert!(stats.node_type_distribution.contains_key("Entity(Person)"));
1002 }
1003}