Skip to main content

rag/
graph.rs

1use crate::errors::{RagError, Result};
2use dashmap::DashMap;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::fs;
6use std::path::Path;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct GraphNode {
11    pub id: String,
12    pub label: String,
13    pub name: String,
14    pub properties: HashMap<String, String>,
15}
16
17impl GraphNode {
18    pub fn new(name: String, label: String) -> Self {
19        Self {
20            id: Uuid::new_v4().to_string(),
21            label,
22            name,
23            properties: HashMap::new(),
24        }
25    }
26
27    pub fn with_id(mut self, id: String) -> Self {
28        self.id = id;
29        self
30    }
31
32    pub fn with_property(mut self, key: String, value: String) -> Self {
33        self.properties.insert(key, value);
34        self
35    }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct GraphEdge {
40    pub id: String,
41    pub source: String,
42    pub target: String,
43    pub relation: String,
44    pub weight: f32,
45    pub properties: HashMap<String, String>,
46}
47
48impl GraphEdge {
49    pub fn new(source: String, target: String, relation: String) -> Self {
50        Self {
51            id: Uuid::new_v4().to_string(),
52            source,
53            target,
54            relation,
55            weight: 1.0,
56            properties: HashMap::new(),
57        }
58    }
59
60    pub fn with_weight(mut self, weight: f32) -> Self {
61        self.weight = weight;
62        self
63    }
64
65    pub fn with_property(mut self, key: String, value: String) -> Self {
66        self.properties.insert(key, value);
67        self
68    }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GraphPath {
73    pub node_ids: Vec<String>,
74    pub edge_ids: Vec<String>,
75    pub total_weight: f32,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Community {
80    pub id: usize,
81    pub node_ids: Vec<String>,
82    pub size: usize,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GraphPersisted {
87    pub nodes: Vec<GraphNode>,
88    pub edges: Vec<GraphEdge>,
89}
90
91pub struct GraphStore {
92    nodes: DashMap<String, GraphNode>,
93    edges: DashMap<String, GraphEdge>,
94    out_edges: DashMap<String, HashSet<String>>,
95    in_edges: DashMap<String, HashSet<String>>,
96    name_index: DashMap<String, String>,
97}
98
99impl Default for GraphStore {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl GraphStore {
106    pub fn new() -> Self {
107        Self {
108            nodes: DashMap::new(),
109            edges: DashMap::new(),
110            out_edges: DashMap::new(),
111            in_edges: DashMap::new(),
112            name_index: DashMap::new(),
113        }
114    }
115
116    pub fn add_node(&self, node: GraphNode) -> Result<()> {
117        let id = node.id.clone();
118        let name = node.name.to_lowercase();
119        self.name_index.insert(name, id.clone());
120        self.nodes.insert(id, node);
121        Ok(())
122    }
123
124    pub fn get_node(&self, id: &str) -> Option<GraphNode> {
125        self.nodes.get(id).map(|n| n.value().clone())
126    }
127
128    pub fn get_node_by_name(&self, name: &str) -> Option<GraphNode> {
129        let key = name.to_lowercase();
130        self.name_index
131            .get(&key)
132            .and_then(|id| self.get_node(id.value()))
133    }
134
135    pub fn update_node(&self, id: &str, node: GraphNode) -> Result<bool> {
136        if self.nodes.contains_key(id) {
137            self.nodes.insert(id.to_string(), node);
138            Ok(true)
139        } else {
140            Err(RagError::GraphError(format!("Node not found: {}", id)))
141        }
142    }
143
144    pub fn remove_node(&self, id: &str) -> Result<bool> {
145        if let Some((_, node)) = self.nodes.remove(id) {
146            let name = node.name.to_lowercase();
147            self.name_index.remove(&name);
148
149            let edge_ids_to_remove: Vec<String> = self
150                .out_edges
151                .get(id)
152                .map(|s| s.value().iter().cloned().collect())
153                .unwrap_or_default();
154
155            let in_edge_ids: Vec<String> = self
156                .in_edges
157                .get(id)
158                .map(|s| s.value().iter().cloned().collect())
159                .unwrap_or_default();
160
161            for eid in edge_ids_to_remove.iter().chain(in_edge_ids.iter()) {
162                self.remove_edge_direct(eid);
163            }
164
165            self.out_edges.remove(id);
166            self.in_edges.remove(id);
167
168            for eid in &edge_ids_to_remove {
169                if let Some(edge) = self.edges.get(eid) {
170                    let target = edge.target.clone();
171                    drop(edge);
172                    if let Some(mut set) = self.in_edges.get_mut(&target) {
173                        set.remove(eid);
174                    }
175                }
176            }
177
178            for eid in &in_edge_ids {
179                if let Some(edge) = self.edges.get(eid) {
180                    let source = edge.source.clone();
181                    drop(edge);
182                    if let Some(mut set) = self.out_edges.get_mut(&source) {
183                        set.remove(eid);
184                    }
185                }
186            }
187
188            Ok(true)
189        } else {
190            Ok(false)
191        }
192    }
193
194    pub fn add_edge(&self, edge: GraphEdge) -> Result<()> {
195        let source = edge.source.clone();
196        let target = edge.target.clone();
197        let id = edge.id.clone();
198
199        if !self.nodes.contains_key(&source) {
200            return Err(RagError::GraphError(format!(
201                "Source node not found: {}",
202                source
203            )));
204        }
205        if !self.nodes.contains_key(&target) {
206            return Err(RagError::GraphError(format!(
207                "Target node not found: {}",
208                target
209            )));
210        }
211
212        self.edges.insert(id.clone(), edge);
213
214        self.out_edges
215            .entry(source)
216            .or_insert_with(HashSet::new)
217            .insert(id.clone());
218
219        self.in_edges
220            .entry(target)
221            .or_insert_with(HashSet::new)
222            .insert(id);
223
224        Ok(())
225    }
226
227    pub fn get_edge(&self, id: &str) -> Option<GraphEdge> {
228        self.edges.get(id).map(|e| e.value().clone())
229    }
230
231    pub fn remove_edge(&self, id: &str) -> bool {
232        self.remove_edge_direct(id)
233    }
234
235    fn remove_edge_direct(&self, id: &str) -> bool {
236        if let Some((_, edge)) = self.edges.remove(id) {
237            if let Some(mut set) = self.out_edges.get_mut(&edge.source) {
238                set.remove(id);
239            }
240            if let Some(mut set) = self.in_edges.get_mut(&edge.target) {
241                set.remove(id);
242            }
243            true
244        } else {
245            false
246        }
247    }
248
249    pub fn upsert_edge(&self, edge: GraphEdge) -> Result<()> {
250        let source = edge.source.clone();
251        let target = edge.target.clone();
252        let relation = edge.relation.clone();
253
254        let existing = self.find_edge(&source, &target, &relation);
255        if let Some(existing) = existing {
256            self.remove_edge(&existing.id);
257        }
258
259        self.add_edge(edge)
260    }
261
262    pub fn find_edge(&self, source: &str, target: &str, relation: &str) -> Option<GraphEdge> {
263        self.edges
264            .iter()
265            .find(|e| {
266                e.value().source == source
267                    && e.value().target == target
268                    && e.value().relation == relation
269            })
270            .map(|e| e.value().clone())
271    }
272
273    pub fn neighbors(&self, node_id: &str) -> Vec<GraphNode> {
274        let mut result = Vec::new();
275        let mut seen = HashSet::new();
276
277        if let Some(edge_ids) = self.out_edges.get(node_id) {
278            for eid in edge_ids.value().iter() {
279                if let Some(edge) = self.edges.get(eid) {
280                    if seen.insert(edge.target.clone()) {
281                        if let Some(node) = self.nodes.get(&edge.target) {
282                            result.push(node.value().clone());
283                        }
284                    }
285                }
286            }
287        }
288
289        if let Some(edge_ids) = self.in_edges.get(node_id) {
290            for eid in edge_ids.value().iter() {
291                if let Some(edge) = self.edges.get(eid) {
292                    if seen.insert(edge.source.clone()) {
293                        if let Some(node) = self.nodes.get(&edge.source) {
294                            result.push(node.value().clone());
295                        }
296                    }
297                }
298            }
299        }
300
301        result
302    }
303
304    pub fn out_neighbors(&self, node_id: &str) -> Vec<GraphNode> {
305        let mut result = Vec::new();
306        let mut seen = HashSet::new();
307
308        if let Some(edge_ids) = self.out_edges.get(node_id) {
309            for eid in edge_ids.value().iter() {
310                if let Some(edge) = self.edges.get(eid) {
311                    if seen.insert(edge.target.clone()) {
312                        if let Some(node) = self.nodes.get(&edge.target) {
313                            result.push(node.value().clone());
314                        }
315                    }
316                }
317            }
318        }
319
320        result
321    }
322
323    pub fn in_neighbors(&self, node_id: &str) -> Vec<GraphNode> {
324        let mut result = Vec::new();
325        let mut seen = HashSet::new();
326
327        if let Some(edge_ids) = self.in_edges.get(node_id) {
328            for eid in edge_ids.value().iter() {
329                if let Some(edge) = self.edges.get(eid) {
330                    if seen.insert(edge.source.clone()) {
331                        if let Some(node) = self.nodes.get(&edge.source) {
332                            result.push(node.value().clone());
333                        }
334                    }
335                }
336            }
337        }
338
339        result
340    }
341
342    pub fn degree(&self, node_id: &str) -> usize {
343        let out = self
344            .out_edges
345            .get(node_id)
346            .map(|s| s.value().len())
347            .unwrap_or(0);
348        let in_deg = self
349            .in_edges
350            .get(node_id)
351            .map(|s| s.value().len())
352            .unwrap_or(0);
353        out + in_deg
354    }
355
356    pub fn edges_between(&self, source: &str, target: &str) -> Vec<GraphEdge> {
357        let mut result = Vec::new();
358
359        if let Some(edge_ids) = self.out_edges.get(source) {
360            for eid in edge_ids.value().iter() {
361                if let Some(edge) = self.edges.get(eid) {
362                    if edge.target == target {
363                        result.push(edge.value().clone());
364                    }
365                }
366            }
367        }
368
369        result
370    }
371
372    pub fn nodes_by_label(&self, label: &str) -> Vec<GraphNode> {
373        self.nodes
374            .iter()
375            .filter(|n| n.value().label == label)
376            .map(|n| n.value().clone())
377            .collect()
378    }
379
380    pub fn edges_by_relation(&self, relation: &str) -> Vec<GraphEdge> {
381        self.edges
382            .iter()
383            .filter(|e| e.value().relation == relation)
384            .map(|e| e.value().clone())
385            .collect()
386    }
387
388    pub fn bfs(&self, start_id: &str, max_depth: usize) -> Vec<GraphNode> {
389        let mut visited = HashSet::new();
390        let mut queue = VecDeque::new();
391        let mut result = Vec::new();
392
393        visited.insert(start_id.to_string());
394        queue.push_back((start_id.to_string(), 0usize));
395
396        while let Some((node_id, depth)) = queue.pop_front() {
397            if depth > 0 {
398                if let Some(node) = self.nodes.get(&node_id) {
399                    result.push(node.value().clone());
400                }
401            }
402
403            if depth < max_depth {
404                for neighbor in self.neighbors(&node_id) {
405                    if visited.insert(neighbor.id.clone()) {
406                        queue.push_back((neighbor.id.clone(), depth + 1));
407                    }
408                }
409            }
410        }
411
412        result
413    }
414
415    pub fn k_hop(&self, start_id: &str, k: usize) -> Vec<Vec<GraphNode>> {
416        let mut levels = Vec::new();
417        let mut visited = HashSet::new();
418        visited.insert(start_id.to_string());
419        let mut current_level = vec![start_id.to_string()];
420
421        for _ in 0..k {
422            let mut next_level_nodes = Vec::new();
423            let mut next_level_ids = Vec::new();
424
425            for nid in &current_level {
426                for neighbor in self.neighbors(nid) {
427                    if visited.insert(neighbor.id.clone()) {
428                        next_level_nodes.push(neighbor.clone());
429                        next_level_ids.push(neighbor.id.clone());
430                    }
431                }
432            }
433
434            levels.push(next_level_nodes);
435            current_level = next_level_ids;
436
437            if current_level.is_empty() {
438                break;
439            }
440        }
441
442        levels
443    }
444
445    pub fn shortest_path(&self, source: &str, target: &str) -> Option<GraphPath> {
446        if source == target {
447            return Some(GraphPath {
448                node_ids: vec![source.to_string()],
449                edge_ids: vec![],
450                total_weight: 0.0,
451            });
452        }
453
454        let mut visited = HashMap::new();
455        let mut queue = VecDeque::new();
456        queue.push_back(source.to_string());
457        visited.insert(
458            source.to_string(),
459            (None::<String>, None::<String>, 0.0f32),
460        );
461
462        while let Some(current) = queue.pop_front() {
463            if current == target {
464                let mut path_nodes = Vec::new();
465                let mut path_edges = Vec::new();
466                let mut total_weight = 0.0f32;
467                let mut node = Some(target.to_string());
468
469                while let Some(n) = node {
470                    if let Some((prev_node, edge_id, weight)) = visited.get(&n) {
471                        if let Some(eid) = edge_id {
472                            path_edges.push(eid.clone());
473                        }
474                        total_weight += weight;
475                        path_nodes.push(n.clone());
476                        node = prev_node.clone();
477                    } else {
478                        path_nodes.push(n.clone());
479                        break;
480                    }
481                }
482
483                path_nodes.reverse();
484                path_edges.reverse();
485
486                return Some(GraphPath {
487                    node_ids: path_nodes,
488                    edge_ids: path_edges,
489                    total_weight,
490                });
491            }
492
493            if let Some(edge_ids) = self.out_edges.get(&current) {
494                for eid in edge_ids.value().iter() {
495                    if let Some(edge) = self.edges.get(eid) {
496                        if !visited.contains_key(&edge.target) {
497                            visited.insert(
498                                edge.target.clone(),
499                                (
500                                    Some(current.clone()),
501                                    Some(eid.clone()),
502                                    edge.weight,
503                                ),
504                            );
505                            queue.push_back(edge.target.clone());
506                        }
507                    }
508                }
509            }
510
511            if let Some(edge_ids) = self.in_edges.get(&current) {
512                for eid in edge_ids.value().iter() {
513                    if let Some(edge) = self.edges.get(eid) {
514                        if !visited.contains_key(&edge.source) {
515                            visited.insert(
516                                edge.source.clone(),
517                                (
518                                    Some(current.clone()),
519                                    Some(eid.clone()),
520                                    edge.weight,
521                                ),
522                            );
523                            queue.push_back(edge.source.clone());
524                        }
525                    }
526                }
527            }
528        }
529
530        None
531    }
532
533    pub fn detect_communities(&self) -> Vec<Community> {
534        let node_ids: Vec<String> = self.nodes.iter().map(|n| n.key().clone()).collect();
535
536        if node_ids.is_empty() {
537            return Vec::new();
538        }
539
540        let mut labels: HashMap<String, usize> = HashMap::new();
541        for (i, id) in node_ids.iter().enumerate() {
542            labels.insert(id.clone(), i);
543        }
544
545        let max_iterations = 20;
546        for _ in 0..max_iterations {
547            let mut changed = false;
548
549            for node_id in &node_ids {
550                let neighbor_ids: Vec<String> = self
551                    .neighbors(node_id)
552                    .into_iter()
553                    .map(|n| n.id)
554                    .collect();
555
556                if neighbor_ids.is_empty() {
557                    continue;
558                }
559
560                let mut label_counts: HashMap<usize, usize> = HashMap::new();
561                for nid in &neighbor_ids {
562                    if let Some(label) = labels.get(nid) {
563                        *label_counts.entry(*label).or_insert(0) += 1;
564                    }
565                }
566
567                if let Some(best_label) = label_counts
568                    .into_iter()
569                    .max_by_key(|(_, count)| *count)
570                    .map(|(label, _)| label)
571                {
572                    if labels.get(node_id) != Some(&best_label) {
573                        labels.insert(node_id.clone(), best_label);
574                        changed = true;
575                    }
576                }
577            }
578
579            if !changed {
580                break;
581            }
582        }
583
584        let mut community_map: HashMap<usize, Vec<String>> = HashMap::new();
585        for (node_id, label) in &labels {
586            community_map
587                .entry(*label)
588                .or_insert_with(Vec::new)
589                .push(node_id.clone());
590        }
591
592        community_map
593            .into_iter()
594            .map(|(id, node_ids)| Community {
595                id,
596                size: node_ids.len(),
597                node_ids,
598            })
599            .collect()
600    }
601
602    pub fn node_count(&self) -> usize {
603        self.nodes.len()
604    }
605
606    pub fn edge_count(&self) -> usize {
607        self.edges.len()
608    }
609
610    pub fn density(&self) -> f64 {
611        let n = self.nodes.len() as f64;
612        if n <= 1.0 {
613            return 0.0;
614        }
615        let max_edges = n * (n - 1.0);
616        self.edges.len() as f64 / max_edges
617    }
618
619    pub fn all_nodes(&self) -> Vec<GraphNode> {
620        self.nodes.iter().map(|n| n.value().clone()).collect()
621    }
622
623    pub fn all_edges(&self) -> Vec<GraphEdge> {
624        self.edges.iter().map(|e| e.value().clone()).collect()
625    }
626
627    pub fn clear(&self) {
628        self.nodes.clear();
629        self.edges.clear();
630        self.out_edges.clear();
631        self.in_edges.clear();
632        self.name_index.clear();
633    }
634
635    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
636        let data = GraphPersisted {
637            nodes: self.all_nodes(),
638            edges: self.all_edges(),
639        };
640        let json = serde_json::to_string_pretty(&data)?;
641        fs::write(path, json)?;
642        Ok(())
643    }
644
645    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
646        let content = fs::read_to_string(path)?;
647        let data: GraphPersisted = serde_json::from_str(&content)?;
648        Self::from_persisted(data)
649    }
650
651    /// Restore a graph from a [`GraphPersisted`] payload (for example after JSON decode).
652    pub fn from_persisted(data: GraphPersisted) -> Result<Self> {
653        let store = Self::new();
654
655        for node in data.nodes {
656            let id = node.id.clone();
657            let name = node.name.to_lowercase();
658            store.name_index.insert(name, id.clone());
659            store.nodes.insert(id, node);
660        }
661
662        for edge in data.edges {
663            let id = edge.id.clone();
664            let source = edge.source.clone();
665            let target = edge.target.clone();
666            store
667                .out_edges
668                .entry(source)
669                .or_insert_with(HashSet::new)
670                .insert(id.clone());
671            store
672                .in_edges
673                .entry(target)
674                .or_insert_with(HashSet::new)
675                .insert(id.clone());
676            store.edges.insert(id, edge);
677        }
678
679        Ok(store)
680    }
681
682    pub fn subgraph(&self, node_ids: &[String]) -> Self {
683        let sub = Self::new();
684        let node_set: HashSet<&String> = node_ids.iter().collect();
685
686        for nid in node_ids {
687            if let Some(node) = self.get_node(nid) {
688                let _ = sub.add_node(node);
689            }
690        }
691
692        for edge in self.all_edges() {
693            if node_set.contains(&edge.source) && node_set.contains(&edge.target) {
694                let _ = sub.add_edge(edge);
695            }
696        }
697
698        sub
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705
706    #[test]
707    fn test_add_and_get_node() {
708        let store = GraphStore::new();
709        let node = GraphNode::new("Alice".to_string(), "person".to_string());
710        let id = node.id.clone();
711        store.add_node(node).unwrap();
712
713        let retrieved = store.get_node(&id).unwrap();
714        assert_eq!(retrieved.name, "Alice");
715        assert_eq!(retrieved.label, "person");
716    }
717
718    #[test]
719    fn test_get_node_by_name() {
720        let store = GraphStore::new();
721        let node = GraphNode::new("New York".to_string(), "location".to_string());
722        store.add_node(node).unwrap();
723
724        let retrieved = store.get_node_by_name("New York").unwrap();
725        assert_eq!(retrieved.label, "location");
726
727        let retrieved_lower = store.get_node_by_name("new york").unwrap();
728        assert_eq!(retrieved_lower.name, "New York");
729    }
730
731    #[test]
732    fn test_remove_node() {
733        let store = GraphStore::new();
734        let node = GraphNode::new("Alice".to_string(), "person".to_string());
735        let id = node.id.clone();
736        store.add_node(node).unwrap();
737
738        assert!(store.remove_node(&id).unwrap());
739        assert!(store.get_node(&id).is_none());
740        assert!(store.get_node_by_name("Alice").is_none());
741    }
742
743    #[test]
744    fn test_add_and_get_edge() {
745        let store = GraphStore::new();
746        let a = GraphNode::new("A".to_string(), "entity".to_string());
747        let b = GraphNode::new("B".to_string(), "entity".to_string());
748        let a_id = a.id.clone();
749        let b_id = b.id.clone();
750        store.add_node(a).unwrap();
751        store.add_node(b).unwrap();
752
753        let edge = GraphEdge::new(a_id.clone(), b_id.clone(), "connects".to_string());
754        let e_id = edge.id.clone();
755        store.add_edge(edge).unwrap();
756
757        let retrieved = store.get_edge(&e_id).unwrap();
758        assert_eq!(retrieved.source, a_id);
759        assert_eq!(retrieved.target, b_id);
760        assert_eq!(retrieved.relation, "connects");
761    }
762
763    #[test]
764    fn test_add_edge_missing_node() {
765        let store = GraphStore::new();
766        let edge = GraphEdge::new("nonexistent".to_string(), "also".to_string(), "x".to_string());
767        assert!(store.add_edge(edge).is_err());
768    }
769
770    #[test]
771    fn test_remove_edge() {
772        let store = GraphStore::new();
773        let a = GraphNode::new("A".to_string(), "e".to_string());
774        let b = GraphNode::new("B".to_string(), "e".to_string());
775        let a_id = a.id.clone();
776        let b_id = b.id.clone();
777        store.add_node(a).unwrap();
778        store.add_node(b).unwrap();
779
780        let edge = GraphEdge::new(a_id, b_id, "rel".to_string());
781        let e_id = edge.id.clone();
782        store.add_edge(edge).unwrap();
783
784        assert!(store.remove_edge(&e_id));
785        assert!(store.get_edge(&e_id).is_none());
786    }
787
788    #[test]
789    fn test_neighbors() {
790        let store = GraphStore::new();
791        let a = GraphNode::new("A".to_string(), "e".to_string());
792        let b = GraphNode::new("B".to_string(), "e".to_string());
793        let c = GraphNode::new("C".to_string(), "e".to_string());
794        let a_id = a.id.clone();
795        let b_id = b.id.clone();
796        let c_id = c.id.clone();
797        store.add_node(a).unwrap();
798        store.add_node(b).unwrap();
799        store.add_node(c).unwrap();
800
801        store
802            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "knows".to_string()))
803            .unwrap();
804        store
805            .add_edge(GraphEdge::new(c_id.clone(), a_id.clone(), "knows".to_string()))
806            .unwrap();
807
808        let neighbors = store.neighbors(&a_id);
809        assert_eq!(neighbors.len(), 2);
810        let names: Vec<&str> = neighbors.iter().map(|n| n.name.as_str()).collect();
811        assert!(names.contains(&"B"));
812        assert!(names.contains(&"C"));
813    }
814
815    #[test]
816    fn test_out_neighbors() {
817        let store = GraphStore::new();
818        let a = GraphNode::new("A".to_string(), "e".to_string());
819        let b = GraphNode::new("B".to_string(), "e".to_string());
820        let a_id = a.id.clone();
821        let b_id = b.id.clone();
822        store.add_node(a).unwrap();
823        store.add_node(b).unwrap();
824
825        store
826            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "follows".to_string()))
827            .unwrap();
828
829        let out = store.out_neighbors(&a_id);
830        assert_eq!(out.len(), 1);
831        assert_eq!(out[0].name, "B");
832
833        let out_b = store.out_neighbors(&b_id);
834        assert!(out_b.is_empty());
835    }
836
837    #[test]
838    fn test_degree() {
839        let store = GraphStore::new();
840        let a = GraphNode::new("A".to_string(), "e".to_string());
841        let b = GraphNode::new("B".to_string(), "e".to_string());
842        let c = GraphNode::new("C".to_string(), "e".to_string());
843        let a_id = a.id.clone();
844        let b_id = b.id.clone();
845        let c_id = c.id.clone();
846        store.add_node(a).unwrap();
847        store.add_node(b).unwrap();
848        store.add_node(c).unwrap();
849
850        store
851            .add_edge(GraphEdge::new(a_id.clone(), b_id, "knows".to_string()))
852            .unwrap();
853        store
854            .add_edge(GraphEdge::new(c_id, a_id.clone(), "knows".to_string()))
855            .unwrap();
856
857        assert_eq!(store.degree(&a_id), 2);
858    }
859
860    #[test]
861    fn test_bfs() {
862        let store = GraphStore::new();
863        let a = GraphNode::new("A".to_string(), "e".to_string());
864        let b = GraphNode::new("B".to_string(), "e".to_string());
865        let c = GraphNode::new("C".to_string(), "e".to_string());
866        let d = GraphNode::new("D".to_string(), "e".to_string());
867        let a_id = a.id.clone();
868        let b_id = b.id.clone();
869        let c_id = c.id.clone();
870        let d_id = d.id.clone();
871        store.add_node(a).unwrap();
872        store.add_node(b).unwrap();
873        store.add_node(c).unwrap();
874        store.add_node(d).unwrap();
875
876        store
877            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "e".to_string()))
878            .unwrap();
879        store
880            .add_edge(GraphEdge::new(b_id.clone(), c_id.clone(), "e".to_string()))
881            .unwrap();
882        store
883            .add_edge(GraphEdge::new(a_id.clone(), d_id, "e".to_string()))
884            .unwrap();
885
886        let reachable = store.bfs(&a_id, 2);
887        assert_eq!(reachable.len(), 3);
888        let names: Vec<&str> = reachable.iter().map(|n| n.name.as_str()).collect();
889        assert!(names.contains(&"B"));
890        assert!(names.contains(&"C"));
891        assert!(names.contains(&"D"));
892    }
893
894    #[test]
895    fn test_k_hop() {
896        let store = GraphStore::new();
897        let a = GraphNode::new("A".to_string(), "e".to_string());
898        let b = GraphNode::new("B".to_string(), "e".to_string());
899        let c = GraphNode::new("C".to_string(), "e".to_string());
900        let a_id = a.id.clone();
901        let b_id = b.id.clone();
902        let c_id = c.id.clone();
903        store.add_node(a).unwrap();
904        store.add_node(b).unwrap();
905        store.add_node(c).unwrap();
906
907        store
908            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "e".to_string()))
909            .unwrap();
910        store
911            .add_edge(GraphEdge::new(b_id, c_id, "e".to_string()))
912            .unwrap();
913
914        let levels = store.k_hop(&a_id, 2);
915        assert_eq!(levels.len(), 2);
916        assert_eq!(levels[0].len(), 1);
917        assert_eq!(levels[0][0].name, "B");
918        assert_eq!(levels[1].len(), 1);
919        assert_eq!(levels[1][0].name, "C");
920    }
921
922    #[test]
923    fn test_shortest_path() {
924        let store = GraphStore::new();
925        let a = GraphNode::new("A".to_string(), "e".to_string());
926        let b = GraphNode::new("B".to_string(), "e".to_string());
927        let c = GraphNode::new("C".to_string(), "e".to_string());
928        let a_id = a.id.clone();
929        let b_id = b.id.clone();
930        let c_id = c.id.clone();
931        store.add_node(a).unwrap();
932        store.add_node(b).unwrap();
933        store.add_node(c).unwrap();
934
935        store
936            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "e".to_string()))
937            .unwrap();
938        store
939            .add_edge(GraphEdge::new(b_id.clone(), c_id.clone(), "e".to_string()))
940            .unwrap();
941
942        let path = store.shortest_path(&a_id, &c_id).unwrap();
943        assert_eq!(path.node_ids.len(), 3);
944        assert_eq!(path.node_ids[0], a_id);
945        assert_eq!(path.node_ids[2], c_id);
946    }
947
948    #[test]
949    fn test_shortest_path_not_found() {
950        let store = GraphStore::new();
951        let a = GraphNode::new("A".to_string(), "e".to_string());
952        let b = GraphNode::new("B".to_string(), "e".to_string());
953        let a_id = a.id.clone();
954        let b_id = b.id.clone();
955        store.add_node(a).unwrap();
956        store.add_node(b).unwrap();
957
958        assert!(store.shortest_path(&a_id, &b_id).is_none());
959    }
960
961    #[test]
962    fn test_shortest_path_same_node() {
963        let store = GraphStore::new();
964        let a = GraphNode::new("A".to_string(), "e".to_string());
965        let a_id = a.id.clone();
966        store.add_node(a).unwrap();
967
968        let path = store.shortest_path(&a_id, &a_id).unwrap();
969        assert_eq!(path.node_ids.len(), 1);
970        assert_eq!(path.total_weight, 0.0);
971    }
972
973    #[test]
974    fn test_detect_communities() {
975        let store = GraphStore::new();
976
977        let a = GraphNode::new("A".to_string(), "e".to_string());
978        let b = GraphNode::new("B".to_string(), "e".to_string());
979        let c = GraphNode::new("C".to_string(), "e".to_string());
980        let d = GraphNode::new("D".to_string(), "e".to_string());
981        let a_id = a.id.clone();
982        let b_id = b.id.clone();
983        let c_id = c.id.clone();
984        let d_id = d.id.clone();
985        store.add_node(a).unwrap();
986        store.add_node(b).unwrap();
987        store.add_node(c).unwrap();
988        store.add_node(d).unwrap();
989
990        store
991            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "e".to_string()))
992            .unwrap();
993        store
994            .add_edge(GraphEdge::new(b_id.clone(), a_id, "e".to_string()))
995            .unwrap();
996        store
997            .add_edge(GraphEdge::new(c_id.clone(), d_id.clone(), "e".to_string()))
998            .unwrap();
999        store
1000            .add_edge(GraphEdge::new(d_id, c_id, "e".to_string()))
1001            .unwrap();
1002
1003        let communities = store.detect_communities();
1004        assert_eq!(communities.len(), 2);
1005
1006        let sizes: Vec<usize> = communities.iter().map(|c| c.size).collect();
1007        assert!(sizes.contains(&2));
1008        assert!(sizes.contains(&2));
1009    }
1010
1011    #[test]
1012    fn test_nodes_by_label() {
1013        let store = GraphStore::new();
1014        store
1015            .add_node(GraphNode::new("Alice".to_string(), "person".to_string()))
1016            .unwrap();
1017        store
1018            .add_node(GraphNode::new("Bob".to_string(), "person".to_string()))
1019            .unwrap();
1020        store
1021            .add_node(GraphNode::new("Paris".to_string(), "location".to_string()))
1022            .unwrap();
1023
1024        let people = store.nodes_by_label("person");
1025        assert_eq!(people.len(), 2);
1026
1027        let locations = store.nodes_by_label("location");
1028        assert_eq!(locations.len(), 1);
1029    }
1030
1031    #[test]
1032    fn test_edges_by_relation() {
1033        let store = GraphStore::new();
1034        let a = GraphNode::new("A".to_string(), "e".to_string());
1035        let b = GraphNode::new("B".to_string(), "e".to_string());
1036        let a_id = a.id.clone();
1037        let b_id = b.id.clone();
1038        store.add_node(a).unwrap();
1039        store.add_node(b).unwrap();
1040
1041        store
1042            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "friend".to_string()))
1043            .unwrap();
1044        store
1045            .add_edge(GraphEdge::new(b_id, a_id, "colleague".to_string()))
1046            .unwrap();
1047
1048        let friends = store.edges_by_relation("friend");
1049        assert_eq!(friends.len(), 1);
1050        let colleagues = store.edges_by_relation("colleague");
1051        assert_eq!(colleagues.len(), 1);
1052    }
1053
1054    #[test]
1055    fn test_density() {
1056        let store = GraphStore::new();
1057        assert_eq!(store.density(), 0.0);
1058
1059        let a = GraphNode::new("A".to_string(), "e".to_string());
1060        let b = GraphNode::new("B".to_string(), "e".to_string());
1061        let a_id = a.id.clone();
1062        let b_id = b.id.clone();
1063        store.add_node(a).unwrap();
1064        store.add_node(b).unwrap();
1065
1066        store
1067            .add_edge(GraphEdge::new(a_id, b_id, "e".to_string()))
1068            .unwrap();
1069
1070        let density = store.density();
1071        assert!(density > 0.0 && density <= 1.0);
1072    }
1073
1074    #[test]
1075    fn test_clear() {
1076        let store = GraphStore::new();
1077        store
1078            .add_node(GraphNode::new("A".to_string(), "e".to_string()))
1079            .unwrap();
1080        store.clear();
1081        assert_eq!(store.node_count(), 0);
1082        assert_eq!(store.edge_count(), 0);
1083    }
1084
1085    #[test]
1086    fn test_subgraph() {
1087        let store = GraphStore::new();
1088        let a = GraphNode::new("A".to_string(), "e".to_string());
1089        let b = GraphNode::new("B".to_string(), "e".to_string());
1090        let c = GraphNode::new("C".to_string(), "e".to_string());
1091        let a_id = a.id.clone();
1092        let b_id = b.id.clone();
1093        let c_id = c.id.clone();
1094        store.add_node(a).unwrap();
1095        store.add_node(b).unwrap();
1096        store.add_node(c).unwrap();
1097
1098        store
1099            .add_edge(GraphEdge::new(a_id.clone(), b_id.clone(), "e".to_string()))
1100            .unwrap();
1101        store
1102            .add_edge(GraphEdge::new(b_id.clone(), c_id, "e".to_string()))
1103            .unwrap();
1104
1105        let sub = store.subgraph(&[a_id.clone(), b_id.clone()]);
1106        assert_eq!(sub.node_count(), 2);
1107        assert_eq!(sub.edge_count(), 1);
1108    }
1109
1110    #[test]
1111    fn test_save_load() {
1112        let dir = tempfile::tempdir().unwrap();
1113        let path = dir.path().join("graph.json");
1114
1115        let store = GraphStore::new();
1116        let a = GraphNode::new("Alice".to_string(), "person".to_string());
1117        let b = GraphNode::new("Bob".to_string(), "person".to_string());
1118        let a_id = a.id.clone();
1119        let b_id = b.id.clone();
1120        store.add_node(a).unwrap();
1121        store.add_node(b).unwrap();
1122        store
1123            .add_edge(GraphEdge::new(a_id.clone(), b_id, "knows".to_string()))
1124            .unwrap();
1125
1126        store.save_to_file(&path).unwrap();
1127        let loaded = GraphStore::load_from_file(&path).unwrap();
1128
1129        assert_eq!(loaded.node_count(), 2);
1130        assert_eq!(loaded.edge_count(), 1);
1131        assert!(loaded.get_node_by_name("Alice").is_some());
1132        assert!(loaded.get_node_by_name("Bob").is_some());
1133    }
1134
1135    #[test]
1136    fn test_upsert_edge() {
1137        let store = GraphStore::new();
1138        let a = GraphNode::new("A".to_string(), "e".to_string());
1139        let b = GraphNode::new("B".to_string(), "e".to_string());
1140        let a_id = a.id.clone();
1141        let b_id = b.id.clone();
1142        store.add_node(a).unwrap();
1143        store.add_node(b).unwrap();
1144
1145        let edge1 = GraphEdge::new(a_id.clone(), b_id.clone(), "rel".to_string()).with_weight(1.0);
1146        store.upsert_edge(edge1).unwrap();
1147        assert_eq!(store.edge_count(), 1);
1148
1149        let edge2 = GraphEdge::new(a_id.clone(), b_id.clone(), "rel".to_string()).with_weight(2.0);
1150        store.upsert_edge(edge2).unwrap();
1151        assert_eq!(store.edge_count(), 1);
1152
1153        let edges = store.edges_between(&a_id, &b_id);
1154        assert_eq!(edges.len(), 1);
1155        assert!((edges[0].weight - 2.0).abs() < 0.01);
1156    }
1157
1158    #[test]
1159    fn test_node_with_property() {
1160        let node = GraphNode::new("test".to_string(), "type".to_string())
1161            .with_property("key".to_string(), "value".to_string());
1162        assert_eq!(node.properties.get("key"), Some(&"value".to_string()));
1163    }
1164
1165    #[test]
1166    fn test_edge_with_weight() {
1167        let edge = GraphEdge::new("a".to_string(), "b".to_string(), "rel".to_string()).with_weight(3.5);
1168        assert!((edge.weight - 3.5).abs() < 0.01);
1169    }
1170}