rexis_rag/graph_retrieval/
algorithms.rs

1//! # Graph Algorithms
2//!
3//! Implementation of graph-based retrieval algorithms including PageRank,
4//! graph traversal, and semantic path finding.
5
6use super::{GraphEdge, GraphError, KnowledgeGraph};
7use crate::RragResult;
8use serde::{Deserialize, Serialize};
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
11
12/// Graph algorithms implementation
13pub struct GraphAlgorithms;
14
15impl GraphAlgorithms {
16    /// Calculate PageRank scores for all nodes
17    pub fn pagerank(
18        graph: &KnowledgeGraph,
19        config: &PageRankConfig,
20    ) -> RragResult<HashMap<String, f32>> {
21        let mut scores = HashMap::new();
22        let node_count = graph.nodes.len();
23
24        if node_count == 0 {
25            return Ok(scores);
26        }
27
28        // Initialize scores
29        let initial_score = 1.0 / node_count as f32;
30        for node_id in graph.nodes.keys() {
31            scores.insert(node_id.clone(), initial_score);
32        }
33
34        // Calculate outbound link counts
35        let mut outbound_counts = HashMap::new();
36        for (node_id, neighbors) in &graph.adjacency_list {
37            outbound_counts.insert(node_id.clone(), neighbors.len().max(1)); // Avoid division by zero
38        }
39
40        // Iterative PageRank calculation
41        for _iteration in 0..config.max_iterations {
42            let mut new_scores = HashMap::new();
43            let mut convergence_diff = 0.0;
44
45            for node_id in graph.nodes.keys() {
46                let mut score = config.damping_factor / node_count as f32;
47
48                // Add contributions from incoming links
49                if let Some(incoming_neighbors) = graph.reverse_adjacency_list.get(node_id) {
50                    for neighbor_id in incoming_neighbors {
51                        if let (Some(neighbor_score), Some(neighbor_outbound_count)) =
52                            (scores.get(neighbor_id), outbound_counts.get(neighbor_id))
53                        {
54                            // Get edge weight if available
55                            let edge_weight = graph
56                                .edges
57                                .values()
58                                .find(|edge| {
59                                    edge.source_id == *neighbor_id && edge.target_id == *node_id
60                                })
61                                .map(|edge| edge.weight)
62                                .unwrap_or(1.0);
63
64                            score += (1.0 - config.damping_factor) * (neighbor_score * edge_weight)
65                                / (*neighbor_outbound_count as f32);
66                        }
67                    }
68                }
69
70                // Apply personalization if configured
71                if let Some(personalization) = &config.personalization {
72                    if let Some(personal_score) = personalization.get(node_id) {
73                        score = config.personalization_weight * personal_score
74                            + (1.0 - config.personalization_weight) * score;
75                    }
76                }
77
78                let old_score = scores.get(node_id).unwrap_or(&0.0);
79                convergence_diff += (score - old_score).abs();
80                new_scores.insert(node_id.clone(), score);
81            }
82
83            scores = new_scores;
84
85            // Check for convergence
86            if convergence_diff < config.convergence_threshold {
87                break;
88            }
89        }
90
91        Ok(scores)
92    }
93
94    /// Find shortest paths from a source node using Dijkstra's algorithm
95    pub fn shortest_paths(
96        graph: &KnowledgeGraph,
97        source_node_id: &str,
98        config: &TraversalConfig,
99    ) -> RragResult<HashMap<String, PathInfo>> {
100        if !graph.nodes.contains_key(source_node_id) {
101            return Err(GraphError::Algorithm {
102                algorithm: "shortest_paths".to_string(),
103                message: format!("Source node '{}' not found", source_node_id),
104            }
105            .into());
106        }
107
108        let mut distances = HashMap::new();
109        let mut previous = HashMap::new();
110        let mut heap = BinaryHeap::new();
111
112        // Initialize distances
113        for node_id in graph.nodes.keys() {
114            distances.insert(node_id.clone(), f32::INFINITY);
115        }
116        distances.insert(source_node_id.to_string(), 0.0);
117        heap.push(DijkstraState {
118            cost: 0.0,
119            node_id: source_node_id.to_string(),
120        });
121
122        while let Some(current) = heap.pop() {
123            if current.cost > *distances.get(&current.node_id).unwrap_or(&f32::INFINITY) {
124                continue;
125            }
126
127            // Check max distance limit
128            if current.cost > config.max_distance {
129                continue;
130            }
131
132            // Explore neighbors
133            if let Some(neighbors) = graph.adjacency_list.get(&current.node_id) {
134                for neighbor_id in neighbors {
135                    // Calculate edge weight/cost
136                    let edge_cost = graph
137                        .edges
138                        .values()
139                        .find(|edge| {
140                            edge.source_id == current.node_id && edge.target_id == *neighbor_id
141                        })
142                        .map(|edge| Self::calculate_edge_cost(edge, config))
143                        .unwrap_or(1.0);
144
145                    let new_cost = current.cost + edge_cost;
146                    let neighbor_distance = distances.get(neighbor_id).unwrap_or(&f32::INFINITY);
147
148                    if new_cost < *neighbor_distance {
149                        distances.insert(neighbor_id.clone(), new_cost);
150                        previous.insert(neighbor_id.clone(), current.node_id.clone());
151                        heap.push(DijkstraState {
152                            cost: new_cost,
153                            node_id: neighbor_id.clone(),
154                        });
155                    }
156                }
157            }
158        }
159
160        // Build path information
161        let mut result = HashMap::new();
162        for (node_id, distance) in distances {
163            if distance < f32::INFINITY {
164                let path = Self::reconstruct_path(&previous, source_node_id, &node_id);
165                let hop_count = path.len().saturating_sub(1);
166                result.insert(
167                    node_id,
168                    PathInfo {
169                        distance,
170                        path,
171                        hop_count,
172                    },
173                );
174            }
175        }
176
177        Ok(result)
178    }
179
180    /// Find semantic paths between two nodes
181    pub fn semantic_paths(
182        graph: &KnowledgeGraph,
183        source_node_id: &str,
184        target_node_id: &str,
185        config: &PathFindingConfig,
186    ) -> RragResult<Vec<SemanticPath>> {
187        if !graph.nodes.contains_key(source_node_id) {
188            return Err(GraphError::Algorithm {
189                algorithm: "semantic_paths".to_string(),
190                message: format!("Source node '{}' not found", source_node_id),
191            }
192            .into());
193        }
194
195        if !graph.nodes.contains_key(target_node_id) {
196            return Err(GraphError::Algorithm {
197                algorithm: "semantic_paths".to_string(),
198                message: format!("Target node '{}' not found", target_node_id),
199            }
200            .into());
201        }
202
203        let mut paths = Vec::new();
204        let mut visited = HashSet::new();
205        let mut current_path = Vec::new();
206
207        Self::dfs_semantic_paths(
208            graph,
209            source_node_id,
210            target_node_id,
211            config,
212            &mut visited,
213            &mut current_path,
214            &mut paths,
215            0.0,
216            0,
217        );
218
219        // Sort paths by semantic score (descending)
220        paths.sort_by(|a, b| {
221            b.semantic_score
222                .partial_cmp(&a.semantic_score)
223                .unwrap_or(Ordering::Equal)
224        });
225
226        // Limit number of returned paths
227        paths.truncate(config.max_paths);
228
229        Ok(paths)
230    }
231
232    /// Breadth-first search from a source node
233    pub fn bfs_search(
234        graph: &KnowledgeGraph,
235        source_node_id: &str,
236        config: &TraversalConfig,
237    ) -> RragResult<Vec<String>> {
238        if !graph.nodes.contains_key(source_node_id) {
239            return Err(GraphError::Algorithm {
240                algorithm: "bfs_search".to_string(),
241                message: format!("Source node '{}' not found", source_node_id),
242            }
243            .into());
244        }
245
246        let mut visited = HashSet::new();
247        let mut queue = VecDeque::new();
248        let mut result = Vec::new();
249
250        queue.push_back((source_node_id.to_string(), 0));
251        visited.insert(source_node_id.to_string());
252
253        while let Some((current_node_id, depth)) = queue.pop_front() {
254            result.push(current_node_id.clone());
255
256            // Check depth limit
257            if depth >= config.max_depth {
258                continue;
259            }
260
261            // Explore neighbors
262            if let Some(neighbors) = graph.adjacency_list.get(&current_node_id) {
263                for neighbor_id in neighbors {
264                    if !visited.contains(neighbor_id) {
265                        visited.insert(neighbor_id.clone());
266                        queue.push_back((neighbor_id.clone(), depth + 1));
267
268                        // Check max nodes limit
269                        if result.len() + queue.len() >= config.max_nodes {
270                            break;
271                        }
272                    }
273                }
274
275                if result.len() + queue.len() >= config.max_nodes {
276                    break;
277                }
278            }
279        }
280
281        Ok(result)
282    }
283
284    /// Depth-first search from a source node
285    pub fn dfs_search(
286        graph: &KnowledgeGraph,
287        source_node_id: &str,
288        config: &TraversalConfig,
289    ) -> RragResult<Vec<String>> {
290        if !graph.nodes.contains_key(source_node_id) {
291            return Err(GraphError::Algorithm {
292                algorithm: "dfs_search".to_string(),
293                message: format!("Source node '{}' not found", source_node_id),
294            }
295            .into());
296        }
297
298        let mut visited = HashSet::new();
299        let mut result = Vec::new();
300
301        Self::dfs_recursive(graph, source_node_id, config, &mut visited, &mut result, 0);
302
303        Ok(result)
304    }
305
306    /// Find strongly connected components using Tarjan's algorithm
307    pub fn strongly_connected_components(graph: &KnowledgeGraph) -> Vec<Vec<String>> {
308        let mut components = Vec::new();
309        let mut visited = HashMap::new();
310        let mut low_link = HashMap::new();
311        let mut stack = Vec::new();
312        let mut on_stack = HashSet::new();
313        let mut index = 0;
314
315        for node_id in graph.nodes.keys() {
316            if !visited.contains_key(node_id) {
317                Self::tarjan_scc(
318                    graph,
319                    node_id,
320                    &mut visited,
321                    &mut low_link,
322                    &mut stack,
323                    &mut on_stack,
324                    &mut components,
325                    &mut index,
326                );
327            }
328        }
329
330        components
331    }
332
333    /// Calculate betweenness centrality for all nodes
334    pub fn betweenness_centrality(graph: &KnowledgeGraph) -> HashMap<String, f32> {
335        let mut centrality = HashMap::new();
336        let nodes: Vec<_> = graph.nodes.keys().collect();
337
338        // Initialize centrality scores
339        for node_id in &nodes {
340            centrality.insert(node_id.to_string(), 0.0);
341        }
342
343        // For each node as source
344        for &source in &nodes {
345            let mut stack = Vec::new();
346            let mut predecessors: HashMap<String, Vec<String>> = HashMap::new();
347            let mut sigma: HashMap<String, f32> = HashMap::new();
348            let mut distance: HashMap<String, i32> = HashMap::new();
349            let mut delta: HashMap<String, f32> = HashMap::new();
350            let mut queue = VecDeque::new();
351
352            // Initialize
353            for node_id in &nodes {
354                predecessors.insert(node_id.to_string(), Vec::new());
355                sigma.insert(node_id.to_string(), 0.0);
356                distance.insert(node_id.to_string(), -1);
357                delta.insert(node_id.to_string(), 0.0);
358            }
359
360            sigma.insert(source.to_string(), 1.0);
361            distance.insert(source.to_string(), 0);
362            queue.push_back(source.to_string());
363
364            // BFS
365            while let Some(current) = queue.pop_front() {
366                stack.push(current.clone());
367
368                if let Some(neighbors) = graph.adjacency_list.get(&current) {
369                    for neighbor in neighbors {
370                        let neighbor_distance = *distance.get(neighbor).unwrap();
371                        let current_distance = *distance.get(&current).unwrap();
372
373                        // Found for the first time?
374                        if neighbor_distance < 0 {
375                            queue.push_back(neighbor.clone());
376                            distance.insert(neighbor.clone(), current_distance + 1);
377                        }
378
379                        // Shortest path to neighbor via current?
380                        if neighbor_distance == current_distance + 1 {
381                            let current_sigma = *sigma.get(&current).unwrap();
382                            let neighbor_sigma = sigma.get_mut(neighbor).unwrap();
383                            *neighbor_sigma += current_sigma;
384                            predecessors
385                                .get_mut(neighbor)
386                                .unwrap()
387                                .push(current.clone());
388                        }
389                    }
390                }
391            }
392
393            // Accumulation
394            while let Some(node) = stack.pop() {
395                if let Some(preds) = predecessors.get(&node) {
396                    for pred in preds {
397                        let node_sigma = *sigma.get(&node).unwrap();
398                        let pred_sigma = *sigma.get(pred).unwrap();
399                        let node_delta = *delta.get(&node).unwrap();
400
401                        if pred_sigma > 0.0 {
402                            let pred_delta = delta.get_mut(pred).unwrap();
403                            *pred_delta += (pred_sigma / node_sigma) * (1.0 + node_delta);
404                        }
405                    }
406                }
407
408                if node != *source {
409                    let node_delta = *delta.get(&node).unwrap();
410                    let node_centrality = centrality.get_mut(&node).unwrap();
411                    *node_centrality += node_delta;
412                }
413            }
414        }
415
416        // Normalize
417        let node_count = nodes.len();
418        if node_count > 2 {
419            let normalization = ((node_count - 1) * (node_count - 2)) as f32;
420            for value in centrality.values_mut() {
421                *value /= normalization;
422            }
423        }
424
425        centrality
426    }
427
428    // Helper methods
429
430    fn calculate_edge_cost(edge: &GraphEdge, config: &TraversalConfig) -> f32 {
431        match config.cost_function {
432            CostFunction::Weight => 1.0 / edge.weight.max(0.001), // Higher weight = lower cost
433            CostFunction::InverseConfidence => 1.0 / edge.confidence.max(0.001),
434            CostFunction::Uniform => 1.0,
435        }
436    }
437
438    fn reconstruct_path(
439        previous: &HashMap<String, String>,
440        source: &str,
441        target: &str,
442    ) -> Vec<String> {
443        let mut path = Vec::new();
444        let mut current = target.to_string();
445
446        while current != source {
447            path.push(current.clone());
448            if let Some(prev) = previous.get(&current) {
449                current = prev.clone();
450            } else {
451                return Vec::new(); // No path found
452            }
453        }
454
455        path.push(source.to_string());
456        path.reverse();
457        path
458    }
459
460    fn dfs_semantic_paths(
461        graph: &KnowledgeGraph,
462        current_node_id: &str,
463        target_node_id: &str,
464        config: &PathFindingConfig,
465        visited: &mut HashSet<String>,
466        current_path: &mut Vec<String>,
467        all_paths: &mut Vec<SemanticPath>,
468        current_score: f32,
469        depth: usize,
470    ) {
471        if depth > config.max_depth || all_paths.len() >= config.max_paths {
472            return;
473        }
474
475        current_path.push(current_node_id.to_string());
476        visited.insert(current_node_id.to_string());
477
478        if current_node_id == target_node_id {
479            // Found a path
480            let semantic_path = SemanticPath {
481                nodes: current_path.clone(),
482                semantic_score: current_score,
483                path_length: current_path.len() - 1,
484                edge_types: Self::extract_edge_types_from_path(graph, current_path),
485            };
486            all_paths.push(semantic_path);
487        } else {
488            // Continue exploring
489            if let Some(neighbors) = graph.adjacency_list.get(current_node_id) {
490                for neighbor_id in neighbors {
491                    if !visited.contains(neighbor_id) {
492                        // Calculate semantic score contribution
493                        let edge_score = graph
494                            .edges
495                            .values()
496                            .find(|edge| {
497                                edge.source_id == current_node_id && edge.target_id == *neighbor_id
498                            })
499                            .map(|edge| Self::calculate_semantic_score(edge, config))
500                            .unwrap_or(0.0);
501
502                        let new_score = current_score + edge_score;
503
504                        if new_score >= config.min_semantic_score {
505                            Self::dfs_semantic_paths(
506                                graph,
507                                neighbor_id,
508                                target_node_id,
509                                config,
510                                visited,
511                                current_path,
512                                all_paths,
513                                new_score,
514                                depth + 1,
515                            );
516                        }
517                    }
518                }
519            }
520        }
521
522        current_path.pop();
523        visited.remove(current_node_id);
524    }
525
526    fn extract_edge_types_from_path(graph: &KnowledgeGraph, path: &[String]) -> Vec<String> {
527        let mut edge_types = Vec::new();
528
529        for i in 0..(path.len().saturating_sub(1)) {
530            if let Some(edge) = graph
531                .edges
532                .values()
533                .find(|edge| edge.source_id == path[i] && edge.target_id == path[i + 1])
534            {
535                edge_types.push(edge.label.clone());
536            }
537        }
538
539        edge_types
540    }
541
542    fn calculate_semantic_score(edge: &GraphEdge, config: &PathFindingConfig) -> f32 {
543        let base_score = edge.confidence * edge.weight;
544
545        // Apply semantic type weighting
546        let type_weight = config
547            .semantic_weights
548            .get(&edge.edge_type)
549            .copied()
550            .unwrap_or(1.0);
551
552        base_score * type_weight
553    }
554
555    fn dfs_recursive(
556        graph: &KnowledgeGraph,
557        current_node_id: &str,
558        config: &TraversalConfig,
559        visited: &mut HashSet<String>,
560        result: &mut Vec<String>,
561        depth: usize,
562    ) {
563        if depth > config.max_depth || result.len() >= config.max_nodes {
564            return;
565        }
566
567        visited.insert(current_node_id.to_string());
568        result.push(current_node_id.to_string());
569
570        if let Some(neighbors) = graph.adjacency_list.get(current_node_id) {
571            for neighbor_id in neighbors {
572                if !visited.contains(neighbor_id) && result.len() < config.max_nodes {
573                    Self::dfs_recursive(graph, neighbor_id, config, visited, result, depth + 1);
574                }
575            }
576        }
577    }
578
579    fn tarjan_scc(
580        graph: &KnowledgeGraph,
581        node_id: &str,
582        visited: &mut HashMap<String, usize>,
583        low_link: &mut HashMap<String, usize>,
584        stack: &mut Vec<String>,
585        on_stack: &mut HashSet<String>,
586        components: &mut Vec<Vec<String>>,
587        index: &mut usize,
588    ) {
589        visited.insert(node_id.to_string(), *index);
590        low_link.insert(node_id.to_string(), *index);
591        stack.push(node_id.to_string());
592        on_stack.insert(node_id.to_string());
593        *index += 1;
594
595        if let Some(neighbors) = graph.adjacency_list.get(node_id) {
596            for neighbor_id in neighbors {
597                if !visited.contains_key(neighbor_id) {
598                    Self::tarjan_scc(
599                        graph,
600                        neighbor_id,
601                        visited,
602                        low_link,
603                        stack,
604                        on_stack,
605                        components,
606                        index,
607                    );
608
609                    let node_low = *low_link.get(node_id).unwrap();
610                    let neighbor_low = *low_link.get(neighbor_id).unwrap();
611                    low_link.insert(node_id.to_string(), node_low.min(neighbor_low));
612                } else if on_stack.contains(neighbor_id) {
613                    let node_low = *low_link.get(node_id).unwrap();
614                    let neighbor_visited = *visited.get(neighbor_id).unwrap();
615                    low_link.insert(node_id.to_string(), node_low.min(neighbor_visited));
616                }
617            }
618        }
619
620        // If node_id is a root node, pop the stack and create a component
621        if low_link[node_id] == visited[node_id] {
622            let mut component = Vec::new();
623            loop {
624                if let Some(w) = stack.pop() {
625                    on_stack.remove(&w);
626                    component.push(w.clone());
627                    if w == node_id {
628                        break;
629                    }
630                } else {
631                    break;
632                }
633            }
634            components.push(component);
635        }
636    }
637}
638
639/// PageRank algorithm configuration
640#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct PageRankConfig {
642    /// Damping factor (typically 0.85)
643    pub damping_factor: f32,
644
645    /// Maximum number of iterations
646    pub max_iterations: usize,
647
648    /// Convergence threshold
649    pub convergence_threshold: f32,
650
651    /// Personalization vector (optional)
652    pub personalization: Option<HashMap<String, f32>>,
653
654    /// Weight for personalization
655    pub personalization_weight: f32,
656}
657
658impl Default for PageRankConfig {
659    fn default() -> Self {
660        Self {
661            damping_factor: 0.85,
662            max_iterations: 100,
663            convergence_threshold: 1e-6,
664            personalization: None,
665            personalization_weight: 0.15,
666        }
667    }
668}
669
670/// Graph traversal configuration
671#[derive(Debug, Clone)]
672pub struct TraversalConfig {
673    /// Maximum traversal depth
674    pub max_depth: usize,
675
676    /// Maximum number of nodes to visit
677    pub max_nodes: usize,
678
679    /// Maximum distance for shortest path algorithms
680    pub max_distance: f32,
681
682    /// Cost function for edge traversal
683    pub cost_function: CostFunction,
684}
685
686impl Default for TraversalConfig {
687    fn default() -> Self {
688        Self {
689            max_depth: 5,
690            max_nodes: 100,
691            max_distance: f32::INFINITY,
692            cost_function: CostFunction::Weight,
693        }
694    }
695}
696
697/// Path finding configuration
698#[derive(Debug, Clone)]
699pub struct PathFindingConfig {
700    /// Maximum path depth
701    pub max_depth: usize,
702
703    /// Maximum number of paths to find
704    pub max_paths: usize,
705
706    /// Minimum semantic score threshold
707    pub min_semantic_score: f32,
708
709    /// Semantic weights for different edge types
710    pub semantic_weights: HashMap<super::EdgeType, f32>,
711}
712
713impl Default for PathFindingConfig {
714    fn default() -> Self {
715        let mut semantic_weights = HashMap::new();
716        semantic_weights.insert(super::EdgeType::Semantic("is_a".to_string()), 1.0);
717        semantic_weights.insert(super::EdgeType::Semantic("part_of".to_string()), 0.9);
718        semantic_weights.insert(super::EdgeType::Similar, 0.8);
719        semantic_weights.insert(super::EdgeType::CoOccurs, 0.5);
720
721        Self {
722            max_depth: 4,
723            max_paths: 10,
724            min_semantic_score: 0.1,
725            semantic_weights,
726        }
727    }
728}
729
730/// Edge cost functions
731#[derive(Debug, Clone)]
732pub enum CostFunction {
733    /// Use edge weight (higher weight = lower cost)
734    Weight,
735
736    /// Use inverse confidence
737    InverseConfidence,
738
739    /// Uniform cost for all edges
740    Uniform,
741}
742
743/// Path information from shortest path algorithm
744#[derive(Debug, Clone)]
745pub struct PathInfo {
746    /// Total distance/cost
747    pub distance: f32,
748
749    /// Node IDs in the path
750    pub path: Vec<String>,
751
752    /// Number of hops
753    pub hop_count: usize,
754}
755
756/// Semantic path between nodes
757#[derive(Debug, Clone)]
758pub struct SemanticPath {
759    /// Node IDs in the path
760    pub nodes: Vec<String>,
761
762    /// Semantic score of the path
763    pub semantic_score: f32,
764
765    /// Path length (number of edges)
766    pub path_length: usize,
767
768    /// Edge types in the path
769    pub edge_types: Vec<String>,
770}
771
772/// State for Dijkstra's algorithm
773#[derive(Debug, Clone)]
774struct DijkstraState {
775    cost: f32,
776    node_id: String,
777}
778
779impl Eq for DijkstraState {}
780
781impl PartialEq for DijkstraState {
782    fn eq(&self, other: &Self) -> bool {
783        self.cost == other.cost
784    }
785}
786
787impl Ord for DijkstraState {
788    fn cmp(&self, other: &Self) -> Ordering {
789        // Min-heap: reverse the ordering
790        other
791            .cost
792            .partial_cmp(&self.cost)
793            .unwrap_or(Ordering::Equal)
794    }
795}
796
797impl PartialOrd for DijkstraState {
798    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
799        Some(self.cmp(other))
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use crate::graph_retrieval::{EdgeType, GraphEdge, GraphNode, NodeType};
807
808    fn create_test_graph() -> KnowledgeGraph {
809        let mut graph = KnowledgeGraph::new();
810
811        // Add nodes
812        let node1 = GraphNode::new("Node1", NodeType::Concept);
813        let node2 = GraphNode::new("Node2", NodeType::Concept);
814        let node3 = GraphNode::new("Node3", NodeType::Concept);
815        let node4 = GraphNode::new("Node4", NodeType::Concept);
816
817        let node1_id = node1.id.clone();
818        let node2_id = node2.id.clone();
819        let node3_id = node3.id.clone();
820        let node4_id = node4.id.clone();
821
822        graph.add_node(node1).unwrap();
823        graph.add_node(node2).unwrap();
824        graph.add_node(node3).unwrap();
825        graph.add_node(node4).unwrap();
826
827        // Add edges: 1 -> 2 -> 3, 1 -> 4
828        graph
829            .add_edge(
830                GraphEdge::new(
831                    node1_id.clone(),
832                    node2_id.clone(),
833                    "edge1",
834                    EdgeType::Similar,
835                )
836                .with_weight(0.8),
837            )
838            .unwrap();
839
840        graph
841            .add_edge(
842                GraphEdge::new(
843                    node2_id.clone(),
844                    node3_id.clone(),
845                    "edge2",
846                    EdgeType::Similar,
847                )
848                .with_weight(0.6),
849            )
850            .unwrap();
851
852        graph
853            .add_edge(
854                GraphEdge::new(
855                    node1_id.clone(),
856                    node4_id.clone(),
857                    "edge3",
858                    EdgeType::Similar,
859                )
860                .with_weight(0.9),
861            )
862            .unwrap();
863
864        graph
865    }
866
867    #[test]
868    fn test_pagerank() {
869        let graph = create_test_graph();
870        let config = PageRankConfig::default();
871
872        let scores = GraphAlgorithms::pagerank(&graph, &config).unwrap();
873        assert_eq!(scores.len(), 4);
874
875        // All scores should be positive and sum to approximately 4.0 (number of nodes)
876        let total: f32 = scores.values().sum();
877        assert!((total - 4.0).abs() < 0.1);
878    }
879
880    #[test]
881    fn test_shortest_paths() {
882        let graph = create_test_graph();
883        let config = TraversalConfig::default();
884        let node_ids: Vec<_> = graph.nodes.keys().cloned().collect();
885
886        let paths = GraphAlgorithms::shortest_paths(&graph, &node_ids[0], &config).unwrap();
887
888        // Should find paths to all reachable nodes
889        assert!(!paths.is_empty());
890
891        // Path to self should have distance 0
892        assert_eq!(paths[&node_ids[0]].distance, 0.0);
893        assert_eq!(paths[&node_ids[0]].hop_count, 0);
894    }
895
896    #[test]
897    fn test_bfs_search() {
898        let graph = create_test_graph();
899        let config = TraversalConfig::default();
900        let node_ids: Vec<_> = graph.nodes.keys().cloned().collect();
901
902        let result = GraphAlgorithms::bfs_search(&graph, &node_ids[0], &config).unwrap();
903
904        // Should visit at least the source node
905        assert!(!result.is_empty());
906        assert_eq!(result[0], node_ids[0]);
907    }
908
909    #[test]
910    fn test_dfs_search() {
911        let graph = create_test_graph();
912        let config = TraversalConfig::default();
913        let node_ids: Vec<_> = graph.nodes.keys().cloned().collect();
914
915        let result = GraphAlgorithms::dfs_search(&graph, &node_ids[0], &config).unwrap();
916
917        // Should visit at least the source node
918        assert!(!result.is_empty());
919        assert_eq!(result[0], node_ids[0]);
920    }
921
922    #[test]
923    fn test_betweenness_centrality() {
924        let graph = create_test_graph();
925        let centrality = GraphAlgorithms::betweenness_centrality(&graph);
926
927        assert_eq!(centrality.len(), 4);
928
929        // All centrality scores should be non-negative
930        for score in centrality.values() {
931            assert!(*score >= 0.0);
932        }
933    }
934
935    #[test]
936    fn test_strongly_connected_components() {
937        let graph = create_test_graph();
938        let components = GraphAlgorithms::strongly_connected_components(&graph);
939
940        // Should have at least one component
941        assert!(!components.is_empty());
942
943        // Total nodes in all components should equal graph node count
944        let total_nodes: usize = components.iter().map(|c| c.len()).sum();
945        assert_eq!(total_nodes, graph.nodes.len());
946    }
947}