Skip to main content

graphrag_core/graph/
analytics.rs

1//! Graph Analytics
2//!
3//! Advanced graph analysis algorithms including:
4//! - Community detection (Louvain algorithm)
5//! - Centrality measures (betweenness, closeness, degree)
6//! - Path finding (shortest path, all paths)
7//! - Graph embeddings preparation
8//! - Temporal graph analysis
9
10use std::collections::{HashMap, HashSet, VecDeque};
11use serde::{Deserialize, Serialize};
12
13/// Community detection result
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Community {
16    /// Community ID
17    pub id: usize,
18    /// Node IDs in this community
19    pub nodes: Vec<String>,
20    /// Community modularity score
21    pub modularity: f32,
22}
23
24/// Centrality scores for a node
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CentralityScores {
27    /// Node ID
28    pub node_id: String,
29    /// Degree centrality (normalized)
30    pub degree: f32,
31    /// Betweenness centrality
32    pub betweenness: f32,
33    /// Closeness centrality
34    pub closeness: f32,
35    /// PageRank score (if available)
36    pub pagerank: Option<f32>,
37}
38
39/// Path between two nodes
40#[derive(Debug, Clone)]
41pub struct Path {
42    /// Node IDs in order
43    pub nodes: Vec<String>,
44    /// Total path weight
45    pub weight: f32,
46}
47
48/// Helper struct for DFS path search state
49struct PathSearchState<'a> {
50    path: &'a mut Vec<String>,
51    visited: &'a mut HashSet<String>,
52    all_paths: &'a mut Vec<Path>,
53    weight: f32,
54}
55
56/// Graph analytics engine
57pub struct GraphAnalytics {
58    /// Adjacency list representation
59    adjacency: HashMap<String, Vec<(String, f32)>>,
60    /// Node degrees
61    degrees: HashMap<String, usize>,
62}
63
64impl GraphAnalytics {
65    /// Create analytics engine from edges
66    ///
67    /// # Arguments
68    /// * `edges` - List of (source, target, weight) tuples
69    pub fn new(edges: Vec<(String, String, f32)>) -> Self {
70        let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
71        let mut degrees: HashMap<String, usize> = HashMap::new();
72
73        for (source, target, weight) in edges {
74            adjacency
75                .entry(source.clone())
76                .or_default()
77                .push((target.clone(), weight));
78
79            adjacency
80                .entry(target.clone())
81                .or_default()
82                .push((source.clone(), weight));
83
84            *degrees.entry(source).or_insert(0) += 1;
85            *degrees.entry(target).or_insert(0) += 1;
86        }
87
88        Self { adjacency, degrees }
89    }
90
91    /// Detect communities using Louvain algorithm
92    ///
93    /// This is a simplified implementation. Full Louvain requires iterative optimization.
94    ///
95    /// # Returns
96    /// List of detected communities
97    pub fn detect_communities(&self) -> Vec<Community> {
98        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
99        let mut communities: HashMap<String, usize> = HashMap::new();
100        let mut community_id = 0;
101
102        // Simple connected components as initial communities
103        for node in &nodes {
104            if !communities.contains_key(node) {
105                let component = self.get_connected_component(node);
106                for n in component {
107                    communities.insert(n, community_id);
108                }
109                community_id += 1;
110            }
111        }
112
113        // Group nodes by community
114        let mut community_map: HashMap<usize, Vec<String>> = HashMap::new();
115        for (node, id) in communities {
116            community_map.entry(id).or_default().push(node);
117        }
118
119        // Calculate modularity for each community
120        community_map
121            .into_iter()
122            .map(|(id, nodes)| {
123                let modularity = self.calculate_modularity(&nodes);
124                Community {
125                    id,
126                    nodes,
127                    modularity,
128                }
129            })
130            .collect()
131    }
132
133    /// Get connected component starting from a node
134    fn get_connected_component(&self, start: &str) -> Vec<String> {
135        let mut visited = HashSet::new();
136        let mut queue = VecDeque::new();
137        queue.push_back(start.to_string());
138
139        while let Some(node) = queue.pop_front() {
140            if visited.contains(&node) {
141                continue;
142            }
143            visited.insert(node.clone());
144
145            if let Some(neighbors) = self.adjacency.get(&node) {
146                for (neighbor, _) in neighbors {
147                    if !visited.contains(neighbor) {
148                        queue.push_back(neighbor.clone());
149                    }
150                }
151            }
152        }
153
154        visited.into_iter().collect()
155    }
156
157    /// Calculate modularity for a set of nodes
158    fn calculate_modularity(&self, nodes: &[String]) -> f32 {
159        let total_edges = self.adjacency.len() as f32;
160        let mut internal_edges = 0.0;
161
162        let node_set: HashSet<_> = nodes.iter().collect();
163
164        for node in nodes {
165            if let Some(neighbors) = self.adjacency.get(node) {
166                for (neighbor, _) in neighbors {
167                    if node_set.contains(&neighbor) {
168                        internal_edges += 1.0;
169                    }
170                }
171            }
172        }
173
174        // Normalize (simplified formula)
175        internal_edges / (2.0 * total_edges)
176    }
177
178    /// Calculate centrality scores for all nodes
179    ///
180    /// # Returns
181    /// Map of node ID to centrality scores
182    pub fn calculate_centrality(&self) -> HashMap<String, CentralityScores> {
183        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
184        let n = nodes.len() as f32;
185
186        let mut scores = HashMap::new();
187
188        for node in &nodes {
189            let degree = self.degree_centrality(node, n);
190            let betweenness = self.betweenness_centrality(node);
191            let closeness = self.closeness_centrality(node);
192
193            scores.insert(
194                node.clone(),
195                CentralityScores {
196                    node_id: node.clone(),
197                    degree,
198                    betweenness,
199                    closeness,
200                    pagerank: None, // Would be filled by PageRank module
201                },
202            );
203        }
204
205        scores
206    }
207
208    /// Calculate degree centrality
209    fn degree_centrality(&self, node: &str, n: f32) -> f32 {
210        let degree = *self.degrees.get(node).unwrap_or(&0) as f32;
211        if n > 1.0 {
212            degree / (n - 1.0)
213        } else {
214            0.0
215        }
216    }
217
218    /// Calculate betweenness centrality (simplified)
219    fn betweenness_centrality(&self, node: &str) -> f32 {
220        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
221        let mut betweenness = 0.0;
222
223        // For each pair of nodes, count shortest paths through this node
224        for source in &nodes {
225            if source == node {
226                continue;
227            }
228            for target in &nodes {
229                if target == node || source == target {
230                    continue;
231                }
232
233                if let Some(path) = self.shortest_path(source, target) {
234                    if path.nodes.contains(&node.to_string()) {
235                        betweenness += 1.0;
236                    }
237                }
238            }
239        }
240
241        let n = nodes.len() as f32;
242        if n > 2.0 {
243            betweenness / ((n - 1.0) * (n - 2.0) / 2.0)
244        } else {
245            0.0
246        }
247    }
248
249    /// Calculate closeness centrality
250    fn closeness_centrality(&self, node: &str) -> f32 {
251        let nodes: Vec<String> = self.adjacency.keys().cloned().collect();
252        let mut total_distance = 0.0;
253        let mut reachable = 0;
254
255        for target in &nodes {
256            if target == node {
257                continue;
258            }
259
260            if let Some(path) = self.shortest_path(node, target) {
261                total_distance += path.weight;
262                reachable += 1;
263            }
264        }
265
266        if reachable > 0 && total_distance > 0.0 {
267            (reachable as f32) / total_distance
268        } else {
269            0.0
270        }
271    }
272
273    /// Find shortest path between two nodes (Dijkstra's algorithm)
274    ///
275    /// # Arguments
276    /// * `start` - Starting node ID
277    /// * `end` - Ending node ID
278    ///
279    /// # Returns
280    /// Shortest path if exists
281    pub fn shortest_path(&self, start: &str, end: &str) -> Option<Path> {
282        let mut distances: HashMap<String, f32> = HashMap::new();
283        let mut previous: HashMap<String, String> = HashMap::new();
284        let mut unvisited: HashSet<String> = self.adjacency.keys().cloned().collect();
285
286        distances.insert(start.to_string(), 0.0);
287
288        while !unvisited.is_empty() {
289            // Find node with minimum distance
290            let current = unvisited
291                .iter()
292                .min_by(|a, b| {
293                    let dist_a = *distances.get(*a).unwrap_or(&f32::INFINITY);
294                    let dist_b = *distances.get(*b).unwrap_or(&f32::INFINITY);
295                    dist_a.partial_cmp(&dist_b).unwrap()
296                })?
297                .clone();
298
299            if current == end {
300                break;
301            }
302
303            unvisited.remove(&current);
304
305            let current_dist = *distances.get(&current).unwrap_or(&f32::INFINITY);
306
307            if let Some(neighbors) = self.adjacency.get(&current) {
308                for (neighbor, weight) in neighbors {
309                    if unvisited.contains(neighbor) {
310                        let alt = current_dist + weight;
311                        let neighbor_dist = *distances.get(neighbor).unwrap_or(&f32::INFINITY);
312
313                        if alt < neighbor_dist {
314                            distances.insert(neighbor.clone(), alt);
315                            previous.insert(neighbor.clone(), current.clone());
316                        }
317                    }
318                }
319            }
320        }
321
322        // Reconstruct path
323        let mut path_nodes = Vec::new();
324        let mut current = end.to_string();
325
326        while let Some(prev) = previous.get(&current) {
327            path_nodes.push(current.clone());
328            current = prev.clone();
329        }
330
331        if current == start {
332            path_nodes.push(start.to_string());
333            path_nodes.reverse();
334
335            let weight = *distances.get(end).unwrap_or(&f32::INFINITY);
336
337            Some(Path {
338                nodes: path_nodes,
339                weight,
340            })
341        } else {
342            None
343        }
344    }
345
346    /// Find all paths between two nodes (limited depth)
347    ///
348    /// # Arguments
349    /// * `start` - Starting node
350    /// * `end` - Ending node
351    /// * `max_depth` - Maximum path length
352    ///
353    /// # Returns
354    /// All paths up to max_depth
355    pub fn all_paths(&self, start: &str, end: &str, max_depth: usize) -> Vec<Path> {
356        let mut paths = Vec::new();
357        let mut current_path = Vec::new();
358        let mut visited = HashSet::new();
359
360        let mut state = PathSearchState {
361            path: &mut current_path,
362            visited: &mut visited,
363            all_paths: &mut paths,
364            weight: 0.0,
365        };
366
367        self.dfs_paths(start, end, &mut state, max_depth);
368
369        paths
370    }
371
372    /// DFS helper for all_paths
373    fn dfs_paths(
374        &self,
375        current: &str,
376        end: &str,
377        state: &mut PathSearchState,
378        max_depth: usize,
379    ) {
380        if state.path.len() >= max_depth {
381            return;
382        }
383
384        state.path.push(current.to_string());
385        state.visited.insert(current.to_string());
386
387        if current == end {
388            state.all_paths.push(Path {
389                nodes: state.path.clone(),
390                weight: state.weight,
391            });
392        } else if let Some(neighbors) = self.adjacency.get(current) {
393            for (neighbor, edge_weight) in neighbors {
394                if !state.visited.contains(neighbor) {
395                    let old_weight = state.weight;
396                    state.weight += edge_weight;
397                    self.dfs_paths(neighbor, end, state, max_depth);
398                    state.weight = old_weight;
399                }
400            }
401        } else {
402            // Current node has no neighbors in the graph
403        }
404
405        state.path.pop();
406        state.visited.remove(current);
407    }
408
409    /// Get nodes with highest degree centrality
410    ///
411    /// # Arguments
412    /// * `top_k` - Number of top nodes to return
413    ///
414    /// # Returns
415    /// List of (node_id, degree_centrality) sorted by degree
416    pub fn top_degree_nodes(&self, top_k: usize) -> Vec<(String, f32)> {
417        let n = self.adjacency.len() as f32;
418        let mut scores: Vec<_> = self
419            .adjacency
420            .keys()
421            .map(|node| {
422                let degree = self.degree_centrality(node, n);
423                (node.clone(), degree)
424            })
425            .collect();
426
427        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
428        scores.truncate(top_k);
429        scores
430    }
431
432    /// Get graph density
433    ///
434    /// # Returns
435    /// Graph density (0.0 to 1.0)
436    pub fn density(&self) -> f32 {
437        let n = self.adjacency.len() as f32;
438        let edge_count: usize = self.adjacency.values().map(|v| v.len()).sum();
439        let actual_edges = (edge_count / 2) as f32; // Undirected graph
440
441        if n > 1.0 {
442            (2.0 * actual_edges) / (n * (n - 1.0))
443        } else {
444            0.0
445        }
446    }
447
448    /// Get clustering coefficient
449    ///
450    /// # Returns
451    /// Average clustering coefficient
452    pub fn clustering_coefficient(&self) -> f32 {
453        let mut total = 0.0;
454        let mut count = 0;
455
456        for neighbors in self.adjacency.values() {
457            if neighbors.len() < 2 {
458                continue;
459            }
460
461            let neighbor_set: HashSet<_> = neighbors.iter().map(|(n, _)| n).collect();
462            let mut triangles = 0;
463
464            for (n1, _) in neighbors {
465                if let Some(n1_neighbors) = self.adjacency.get(n1) {
466                    for (n2, _) in n1_neighbors {
467                        if neighbor_set.contains(&n2) {
468                            triangles += 1;
469                        }
470                    }
471                }
472            }
473
474            let k = neighbors.len() as f32;
475            let coefficient = triangles as f32 / (k * (k - 1.0));
476            total += coefficient;
477            count += 1;
478        }
479
480        if count > 0 {
481            total / count as f32
482        } else {
483            0.0
484        }
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    fn create_test_graph() -> GraphAnalytics {
493        let edges = vec![
494            ("A".to_string(), "B".to_string(), 1.0),
495            ("A".to_string(), "C".to_string(), 1.0),
496            ("B".to_string(), "C".to_string(), 1.0),
497            ("B".to_string(), "D".to_string(), 1.0),
498            ("C".to_string(), "D".to_string(), 1.0),
499        ];
500        GraphAnalytics::new(edges)
501    }
502
503    #[test]
504    fn test_shortest_path() {
505        let graph = create_test_graph();
506        let path = graph.shortest_path("A", "D").unwrap();
507
508        assert_eq!(path.nodes.len(), 3); // A -> B -> D or A -> C -> D
509        assert_eq!(path.weight, 2.0);
510    }
511
512    #[test]
513    fn test_centrality() {
514        let graph = create_test_graph();
515        let scores = graph.calculate_centrality();
516
517        assert!(scores.contains_key("A"));
518        assert!(scores.contains_key("B"));
519        assert!(scores.contains_key("C"));
520        assert!(scores.contains_key("D"));
521
522        // B and C should have higher betweenness (they're central)
523        let b_score = &scores["B"];
524        let a_score = &scores["A"];
525        assert!(b_score.betweenness >= a_score.betweenness);
526    }
527
528    #[test]
529    fn test_community_detection() {
530        let graph = create_test_graph();
531        let communities = graph.detect_communities();
532
533        assert_eq!(communities.len(), 1); // Should be one connected component
534        assert_eq!(communities[0].nodes.len(), 4);
535    }
536
537    #[test]
538    fn test_density() {
539        let graph = create_test_graph();
540        let density = graph.density();
541
542        assert!(density > 0.0 && density <= 1.0);
543    }
544
545    #[test]
546    fn test_clustering() {
547        let graph = create_test_graph();
548        let coeff = graph.clustering_coefficient();
549
550        assert!(coeff >= 0.0 && coeff <= 1.0);
551    }
552}