Skip to main content

graphrag_core/graph/
analytics.rs

1//! Graph Analytics
2//!
3//! Advanced graph analysis algorithms including:
4//! - Community detection (Louvain algorithm)
5//! - Centrality measures (betweenness, closeness, degree)
6//! - Path finding (shortest path, all paths)
7//! - Graph embeddings preparation
8//! - Temporal graph analysis
9
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12
13/// Community detection result
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Community {
16    /// Community ID
17    pub id: usize,
18    /// Node IDs in this community
19    pub nodes: Vec<String>,
20    /// Community modularity score
21    pub modularity: f32,
22}
23
24/// Centrality scores for a node
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CentralityScores {
27    /// Node ID
28    pub node_id: String,
29    /// Degree centrality (normalized)
30    pub degree: f32,
31    /// Betweenness centrality
32    pub betweenness: f32,
33    /// Closeness centrality
34    pub closeness: f32,
35    /// PageRank score (if available)
36    pub pagerank: Option<f32>,
37}
38
39/// Path between two nodes
40#[derive(Debug, Clone)]
41pub struct Path {
42    /// Node IDs in order
43    pub nodes: Vec<String>,
44    /// Total path weight
45    pub weight: f32,
46}
47
48/// Helper struct for DFS path search state
49struct PathSearchState<'a> {
50    path: &'a mut Vec<String>,
51    visited: &'a mut HashSet<String>,
52    all_paths: &'a mut Vec<Path>,
53    weight: f32,
54}
55
56/// Graph analytics engine
57pub struct GraphAnalytics {
58    /// Adjacency list representation
59    adjacency: HashMap<String, Vec<(String, f32)>>,
60    /// Node degrees
61    degrees: HashMap<String, usize>,
62}
63
64impl GraphAnalytics {
65    /// Create analytics engine from edges
66    ///
67    /// # Arguments
68    /// * `edges` - List of (source, target, weight) tuples
69    pub fn new(edges: Vec<(String, String, f32)>) -> Self {
70        let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
71        let mut degrees: HashMap<String, usize> = HashMap::new();
72
73        for (source, target, weight) in edges {
74            adjacency
75                .entry(source.clone())
76                .or_default()
77                .push((target.clone(), weight));
78
79            adjacency
80                .entry(target.clone())
81                .or_default()
82                .push((source.clone(), weight));
83
84            *degrees.entry(source).or_insert(0) += 1;
85            *degrees.entry(target).or_insert(0) += 1;
86        }
87
88        Self { adjacency, degrees }
89    }
90
91    /// Detect communities using Louvain algorithm
92    ///
93    /// This is a simplified implementation. Full Louvain requires iterative optimization.
94    ///
95    /// # Returns
96    /// List of detected communities
97    pub fn detect_communities(&self) -> Vec<Community> {
98        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
99        let mut communities: HashMap<String, usize> = HashMap::new();
100        let mut community_id = 0;
101
102        // Simple connected components as initial communities
103        for node in &nodes {
104            if !communities.contains_key(node) {
105                let component = self.get_connected_component(node);
106                for n in component {
107                    communities.insert(n, community_id);
108                }
109                community_id += 1;
110            }
111        }
112
113        // Group nodes by community
114        let mut community_map: HashMap<usize, Vec<String>> = HashMap::new();
115        for (node, id) in communities {
116            community_map.entry(id).or_default().push(node);
117        }
118
119        // Calculate modularity for each community
120        community_map
121            .into_iter()
122            .map(|(id, nodes)| {
123                let modularity = self.calculate_modularity(&nodes);
124                Community {
125                    id,
126                    nodes,
127                    modularity,
128                }
129            })
130            .collect()
131    }
132
133    /// Get connected component starting from a node
134    fn get_connected_component(&self, start: &str) -> Vec<String> {
135        let mut visited = HashSet::new();
136        let mut queue = VecDeque::new();
137        queue.push_back(start.to_string());
138
139        while let Some(node) = queue.pop_front() {
140            if visited.contains(&node) {
141                continue;
142            }
143            visited.insert(node.clone());
144
145            if let Some(neighbors) = self.adjacency.get(&node) {
146                for (neighbor, _) in neighbors {
147                    if !visited.contains(neighbor) {
148                        queue.push_back(neighbor.clone());
149                    }
150                }
151            }
152        }
153
154        visited.into_iter().collect()
155    }
156
157    /// Calculate modularity for a set of nodes
158    fn calculate_modularity(&self, nodes: &[String]) -> f32 {
159        let total_edges = self.adjacency.len() as f32;
160        let mut internal_edges = 0.0;
161
162        let node_set: HashSet<_> = nodes.iter().collect();
163
164        for node in nodes {
165            if let Some(neighbors) = self.adjacency.get(node) {
166                for (neighbor, _) in neighbors {
167                    if node_set.contains(&neighbor) {
168                        internal_edges += 1.0;
169                    }
170                }
171            }
172        }
173
174        // Normalize (simplified formula)
175        internal_edges / (2.0 * total_edges)
176    }
177
178    /// Calculate centrality scores for all nodes
179    ///
180    /// # Returns
181    /// Map of node ID to centrality scores
182    pub fn calculate_centrality(&self) -> HashMap<String, CentralityScores> {
183        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
184        let n = nodes.len() as f32;
185
186        let mut scores = HashMap::new();
187
188        for node in &nodes {
189            let degree = self.degree_centrality(node, n);
190            let betweenness = self.betweenness_centrality(node);
191            let closeness = self.closeness_centrality(node);
192
193            scores.insert(
194                node.clone(),
195                CentralityScores {
196                    node_id: node.clone(),
197                    degree,
198                    betweenness,
199                    closeness,
200                    pagerank: None, // Would be filled by PageRank module
201                },
202            );
203        }
204
205        scores
206    }
207
208    /// Calculate degree centrality
209    fn degree_centrality(&self, node: &str, n: f32) -> f32 {
210        let degree = *self.degrees.get(node).unwrap_or(&0) as f32;
211        if n > 1.0 {
212            degree / (n - 1.0)
213        } else {
214            0.0
215        }
216    }
217
218    /// Calculate betweenness centrality (simplified)
219    fn betweenness_centrality(&self, node: &str) -> f32 {
220        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
221        let mut betweenness = 0.0;
222
223        // For each pair of nodes, count shortest paths through this node
224        for source in &nodes {
225            if source == node {
226                continue;
227            }
228            for target in &nodes {
229                if target == node || source == target {
230                    continue;
231                }
232
233                if let Some(path) = self.shortest_path(source, target) {
234                    if path.nodes.contains(&node.to_string()) {
235                        betweenness += 1.0;
236                    }
237                }
238            }
239        }
240
241        let n = nodes.len() as f32;
242        if n > 2.0 {
243            betweenness / ((n - 1.0) * (n - 2.0) / 2.0)
244        } else {
245            0.0
246        }
247    }
248
249    /// Calculate closeness centrality
250    fn closeness_centrality(&self, node: &str) -> f32 {
251        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
252        let mut total_distance = 0.0;
253        let mut reachable = 0;
254
255        for target in &nodes {
256            if target == node {
257                continue;
258            }
259
260            if let Some(path) = self.shortest_path(node, target) {
261                total_distance += path.weight;
262                reachable += 1;
263            }
264        }
265
266        if reachable > 0 && total_distance > 0.0 {
267            (reachable as f32) / total_distance
268        } else {
269            0.0
270        }
271    }
272
273    /// Find shortest path between two nodes (Dijkstra's algorithm)
274    ///
275    /// # Arguments
276    /// * `start` - Starting node ID
277    /// * `end` - Ending node ID
278    ///
279    /// # Returns
280    /// Shortest path if exists
281    pub fn shortest_path(&self, start: &str, end: &str) -> Option<Path> {
282        let mut distances: HashMap<String, f32> = HashMap::new();
283        let mut previous: HashMap<String, String> = HashMap::new();
284        let mut unvisited: HashSet<String> = self.adjacency.keys().cloned().collect();
285
286        distances.insert(start.to_string(), 0.0);
287
288        while !unvisited.is_empty() {
289            // Find node with minimum distance
290            let current = unvisited
291                .iter()
292                .min_by(|a, b| {
293                    let dist_a = *distances.get(*a).unwrap_or(&f32::INFINITY);
294                    let dist_b = *distances.get(*b).unwrap_or(&f32::INFINITY);
295                    dist_a
296                        .partial_cmp(&dist_b)
297                        .unwrap_or(std::cmp::Ordering::Equal)
298                })?
299                .clone();
300
301            if current == end {
302                break;
303            }
304
305            unvisited.remove(&current);
306
307            let current_dist = *distances.get(&current).unwrap_or(&f32::INFINITY);
308
309            if let Some(neighbors) = self.adjacency.get(&current) {
310                for (neighbor, weight) in neighbors {
311                    if unvisited.contains(neighbor) {
312                        let alt = current_dist + weight;
313                        let neighbor_dist = *distances.get(neighbor).unwrap_or(&f32::INFINITY);
314
315                        if alt < neighbor_dist {
316                            distances.insert(neighbor.clone(), alt);
317                            previous.insert(neighbor.clone(), current.clone());
318                        }
319                    }
320                }
321            }
322        }
323
324        // Reconstruct path
325        let mut path_nodes = Vec::new();
326        let mut current = end.to_string();
327
328        while let Some(prev) = previous.get(&current) {
329            path_nodes.push(current.clone());
330            current = prev.clone();
331        }
332
333        if current == start {
334            path_nodes.push(start.to_string());
335            path_nodes.reverse();
336
337            let weight = *distances.get(end).unwrap_or(&f32::INFINITY);
338
339            Some(Path {
340                nodes: path_nodes,
341                weight,
342            })
343        } else {
344            None
345        }
346    }
347
348    /// Find all paths between two nodes (limited depth)
349    ///
350    /// # Arguments
351    /// * `start` - Starting node
352    /// * `end` - Ending node
353    /// * `max_depth` - Maximum path length
354    ///
355    /// # Returns
356    /// All paths up to max_depth
357    pub fn all_paths(&self, start: &str, end: &str, max_depth: usize) -> Vec<Path> {
358        let mut paths = Vec::new();
359        let mut current_path = Vec::new();
360        let mut visited = HashSet::new();
361
362        let mut state = PathSearchState {
363            path: &mut current_path,
364            visited: &mut visited,
365            all_paths: &mut paths,
366            weight: 0.0,
367        };
368
369        self.dfs_paths(start, end, &mut state, max_depth);
370
371        paths
372    }
373
374    /// DFS helper for all_paths
375    fn dfs_paths(&self, current: &str, end: &str, state: &mut PathSearchState, max_depth: usize) {
376        if state.path.len() >= max_depth {
377            return;
378        }
379
380        state.path.push(current.to_string());
381        state.visited.insert(current.to_string());
382
383        if current == end {
384            state.all_paths.push(Path {
385                nodes: state.path.clone(),
386                weight: state.weight,
387            });
388        } else if let Some(neighbors) = self.adjacency.get(current) {
389            for (neighbor, edge_weight) in neighbors {
390                if !state.visited.contains(neighbor) {
391                    let old_weight = state.weight;
392                    state.weight += edge_weight;
393                    self.dfs_paths(neighbor, end, state, max_depth);
394                    state.weight = old_weight;
395                }
396            }
397        } else {
398            // Current node has no neighbors in the graph
399        }
400
401        state.path.pop();
402        state.visited.remove(current);
403    }
404
405    /// Get nodes with highest degree centrality
406    ///
407    /// # Arguments
408    /// * `top_k` - Number of top nodes to return
409    ///
410    /// # Returns
411    /// List of (node_id, degree_centrality) sorted by degree
412    pub fn top_degree_nodes(&self, top_k: usize) -> Vec<(String, f32)> {
413        let n = self.adjacency.len() as f32;
414        let mut scores: Vec<_> = self
415            .adjacency
416            .keys()
417            .map(|node| {
418                let degree = self.degree_centrality(node, n);
419                (node.clone(), degree)
420            })
421            .collect();
422
423        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
424        scores.truncate(top_k);
425        scores
426    }
427
428    /// Get graph density
429    ///
430    /// # Returns
431    /// Graph density (0.0 to 1.0)
432    pub fn density(&self) -> f32 {
433        let n = self.adjacency.len() as f32;
434        let edge_count: usize = self.adjacency.values().map(|v| v.len()).sum();
435        let actual_edges = (edge_count / 2) as f32; // Undirected graph
436
437        if n > 1.0 {
438            (2.0 * actual_edges) / (n * (n - 1.0))
439        } else {
440            0.0
441        }
442    }
443
444    /// Get clustering coefficient
445    ///
446    /// # Returns
447    /// Average clustering coefficient
448    pub fn clustering_coefficient(&self) -> f32 {
449        let mut total = 0.0;
450        let mut count = 0;
451
452        for neighbors in self.adjacency.values() {
453            if neighbors.len() < 2 {
454                continue;
455            }
456
457            let neighbor_set: HashSet<_> = neighbors.iter().map(|(n, _)| n).collect();
458            let mut triangles = 0;
459
460            for (n1, _) in neighbors {
461                if let Some(n1_neighbors) = self.adjacency.get(n1) {
462                    for (n2, _) in n1_neighbors {
463                        if neighbor_set.contains(&n2) {
464                            triangles += 1;
465                        }
466                    }
467                }
468            }
469
470            let k = neighbors.len() as f32;
471            let coefficient = triangles as f32 / (k * (k - 1.0));
472            total += coefficient;
473            count += 1;
474        }
475
476        if count > 0 {
477            total / count as f32
478        } else {
479            0.0
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    fn create_test_graph() -> GraphAnalytics {
489        let edges = vec![
490            ("A".to_string(), "B".to_string(), 1.0),
491            ("A".to_string(), "C".to_string(), 1.0),
492            ("B".to_string(), "C".to_string(), 1.0),
493            ("B".to_string(), "D".to_string(), 1.0),
494            ("C".to_string(), "D".to_string(), 1.0),
495        ];
496        GraphAnalytics::new(edges)
497    }
498
499    #[test]
500    fn test_shortest_path() {
501        let graph = create_test_graph();
502        let path = graph.shortest_path("A", "D").unwrap();
503
504        assert_eq!(path.nodes.len(), 3); // A -> B -> D or A -> C -> D
505        assert_eq!(path.weight, 2.0);
506    }
507
508    #[test]
509    fn test_centrality() {
510        let graph = create_test_graph();
511        let scores = graph.calculate_centrality();
512
513        assert!(scores.contains_key("A"));
514        assert!(scores.contains_key("B"));
515        assert!(scores.contains_key("C"));
516        assert!(scores.contains_key("D"));
517
518        // B and C should have higher betweenness (they're central)
519        let b_score = &scores["B"];
520        let a_score = &scores["A"];
521        assert!(b_score.betweenness >= a_score.betweenness);
522    }
523
524    #[test]
525    fn test_community_detection() {
526        let graph = create_test_graph();
527        let communities = graph.detect_communities();
528
529        assert_eq!(communities.len(), 1); // Should be one connected component
530        assert_eq!(communities[0].nodes.len(), 4);
531    }
532
533    #[test]
534    fn test_density() {
535        let graph = create_test_graph();
536        let density = graph.density();
537
538        assert!(density > 0.0 && density <= 1.0);
539    }
540
541    #[test]
542    fn test_clustering() {
543        let graph = create_test_graph();
544        let coeff = graph.clustering_coefficient();
545
546        assert!((0.0..=1.0).contains(&coeff));
547    }
548}