arrow_graph/algorithms/
centrality.rs

1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10pub struct PageRank;
11
12impl PageRank {
13    /// PageRank algorithm with power iteration and early termination
14    fn compute_pagerank(
15        &self,
16        graph: &ArrowGraph,
17        damping_factor: f64,
18        max_iterations: usize,
19        tolerance: f64,
20    ) -> Result<HashMap<String, f64>> {
21        let node_count = graph.node_count();
22        if node_count == 0 {
23            return Ok(HashMap::new());
24        }
25        
26        // Initialize PageRank scores
27        let initial_score = 1.0 / node_count as f64;
28        let mut current_scores: HashMap<String, f64> = HashMap::new();
29        let mut next_scores: HashMap<String, f64> = HashMap::new();
30        
31        // Initialize all nodes with equal probability
32        for node_id in graph.node_ids() {
33            current_scores.insert(node_id.clone(), initial_score);
34            next_scores.insert(node_id.clone(), 0.0);
35        }
36        
37        // Calculate out-degrees for each node
38        let mut out_degrees: HashMap<String, usize> = HashMap::new();
39        for node_id in graph.node_ids() {
40            let degree = graph.neighbors(node_id).map(|n| n.len()).unwrap_or(0);
41            out_degrees.insert(node_id.clone(), degree);
42        }
43        
44        // Power iteration
45        for iteration in 0..max_iterations {
46            // Reset next scores
47            for score in next_scores.values_mut() {
48                *score = (1.0 - damping_factor) / node_count as f64;
49            }
50            
51            // Distribute PageRank scores
52            for node_id in graph.node_ids() {
53                let current_score = current_scores.get(node_id).unwrap_or(&0.0);
54                let out_degree = out_degrees.get(node_id).unwrap_or(&0);
55                
56                if *out_degree > 0 {
57                    let contribution = current_score * damping_factor / *out_degree as f64;
58                    
59                    if let Some(neighbors) = graph.neighbors(node_id) {
60                        for neighbor in neighbors {
61                            if let Some(neighbor_score) = next_scores.get_mut(neighbor) {
62                                *neighbor_score += contribution;
63                            }
64                        }
65                    }
66                } else {
67                    // Handle dangling nodes - distribute equally to all nodes
68                    let dangling_contribution = current_score * damping_factor / node_count as f64;
69                    for score in next_scores.values_mut() {
70                        *score += dangling_contribution;
71                    }
72                }
73            }
74            
75            // Check for convergence
76            let mut diff = 0.0;
77            for node_id in graph.node_ids() {
78                let old_score = current_scores.get(node_id).unwrap_or(&0.0);
79                let new_score = next_scores.get(node_id).unwrap_or(&0.0);
80                diff += (new_score - old_score).abs();
81            }
82            
83            // Early termination if converged
84            if diff < tolerance {
85                log::debug!("PageRank converged after {} iterations", iteration + 1);
86                break;
87            }
88            
89            // Swap scores for next iteration
90            std::mem::swap(&mut current_scores, &mut next_scores);
91        }
92        
93        Ok(current_scores)
94    }
95}
96
97impl GraphAlgorithm for PageRank {
98    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
99        let damping_factor: f64 = params.get("damping_factor").unwrap_or(0.85);
100        let max_iterations: usize = params.get("max_iterations").unwrap_or(100);
101        let tolerance: f64 = params.get("tolerance").unwrap_or(1e-6);
102        
103        // Validate parameters
104        if !(0.0..=1.0).contains(&damping_factor) {
105            return Err(GraphError::invalid_parameter(
106                "damping_factor must be between 0.0 and 1.0"
107            ));
108        }
109        
110        if max_iterations == 0 {
111            return Err(GraphError::invalid_parameter(
112                "max_iterations must be greater than 0"
113            ));
114        }
115        
116        if tolerance <= 0.0 {
117            return Err(GraphError::invalid_parameter(
118                "tolerance must be greater than 0.0"
119            ));
120        }
121        
122        let scores = self.compute_pagerank(graph, damping_factor, max_iterations, tolerance)?;
123        
124        // Convert to Arrow RecordBatch
125        let schema = Arc::new(Schema::new(vec![
126            Field::new("node_id", DataType::Utf8, false),
127            Field::new("pagerank_score", DataType::Float64, false),
128        ]));
129        
130        let mut node_ids = Vec::new();
131        let mut pagerank_scores = Vec::new();
132        
133        // Sort by PageRank score (descending) for consistent output
134        let mut sorted_scores: Vec<(&String, &f64)> = scores.iter().collect();
135        sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
136        
137        for (node_id, score) in sorted_scores {
138            node_ids.push(node_id.clone());
139            pagerank_scores.push(*score);
140        }
141        
142        RecordBatch::try_new(
143            schema,
144            vec![
145                Arc::new(StringArray::from(node_ids)),
146                Arc::new(Float64Array::from(pagerank_scores)),
147            ],
148        ).map_err(GraphError::from)
149    }
150    
151    fn name(&self) -> &'static str {
152        "pagerank"
153    }
154    
155    fn description(&self) -> &'static str {
156        "Calculate PageRank scores using power iteration with early termination"
157    }
158}
159
160pub struct BetweennessCentrality;
161
162impl BetweennessCentrality {
163    /// Calculate betweenness centrality using Brandes' algorithm
164    fn compute_betweenness_centrality(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
165        let mut centrality: HashMap<String, f64> = HashMap::new();
166        
167        // Initialize centrality scores
168        for node_id in graph.node_ids() {
169            centrality.insert(node_id.clone(), 0.0);
170        }
171        
172        // For each node as source
173        for source in graph.node_ids() {
174            let mut stack = Vec::new();
175            let mut paths: HashMap<String, Vec<String>> = HashMap::new();
176            let mut num_paths: HashMap<String, f64> = HashMap::new();
177            let mut distances: HashMap<String, i32> = HashMap::new();
178            let mut delta: HashMap<String, f64> = HashMap::new();
179            
180            // Initialize
181            for node_id in graph.node_ids() {
182                paths.insert(node_id.clone(), Vec::new());
183                num_paths.insert(node_id.clone(), 0.0);
184                distances.insert(node_id.clone(), -1);
185                delta.insert(node_id.clone(), 0.0);
186            }
187            
188            num_paths.insert(source.clone(), 1.0);
189            distances.insert(source.clone(), 0);
190            
191            // BFS
192            let mut queue = std::collections::VecDeque::new();
193            queue.push_back(source.clone());
194            
195            while let Some(current) = queue.pop_front() {
196                stack.push(current.clone());
197                
198                if let Some(neighbors) = graph.neighbors(&current) {
199                    for neighbor in neighbors {
200                        let current_dist = *distances.get(&current).unwrap_or(&-1);
201                        let neighbor_dist = *distances.get(neighbor).unwrap_or(&-1);
202                        
203                        // First time we reach this neighbor
204                        if neighbor_dist < 0 {
205                            queue.push_back(neighbor.clone());
206                            distances.insert(neighbor.clone(), current_dist + 1);
207                        }
208                        
209                        // Shortest path to neighbor via current
210                        if neighbor_dist == current_dist + 1 {
211                            let current_paths = *num_paths.get(&current).unwrap_or(&0.0);
212                            let neighbor_paths = num_paths.get_mut(neighbor).unwrap();
213                            *neighbor_paths += current_paths;
214                            
215                            paths.get_mut(neighbor).unwrap().push(current.clone());
216                        }
217                    }
218                }
219            }
220            
221            // Accumulation
222            while let Some(w) = stack.pop() {
223                if let Some(predecessors) = paths.get(&w) {
224                    for predecessor in predecessors {
225                        let w_delta = *delta.get(&w).unwrap_or(&0.0);
226                        let w_paths = *num_paths.get(&w).unwrap_or(&0.0);
227                        let pred_paths = *num_paths.get(predecessor).unwrap_or(&0.0);
228                        
229                        if pred_paths > 0.0 {
230                            let contribution = (pred_paths / w_paths) * (1.0 + w_delta);
231                            *delta.get_mut(predecessor).unwrap() += contribution;
232                        }
233                    }
234                }
235                
236                if w != *source {
237                    let w_delta = *delta.get(&w).unwrap_or(&0.0);
238                    *centrality.get_mut(&w).unwrap() += w_delta;
239                }
240            }
241        }
242        
243        // Normalize for undirected graphs
244        let node_count = graph.node_count() as f64;
245        if node_count > 2.0 {
246            let normalization = 2.0 / ((node_count - 1.0) * (node_count - 2.0));
247            for score in centrality.values_mut() {
248                *score *= normalization;
249            }
250        }
251        
252        Ok(centrality)
253    }
254}
255
256impl GraphAlgorithm for BetweennessCentrality {
257    fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
258        let centrality = self.compute_betweenness_centrality(graph)?;
259        
260        if centrality.is_empty() {
261            let schema = Arc::new(Schema::new(vec![
262                Field::new("node_id", DataType::Utf8, false),
263                Field::new("betweenness_centrality", DataType::Float64, false),
264            ]));
265            
266            return RecordBatch::try_new(
267                schema,
268                vec![
269                    Arc::new(StringArray::from(Vec::<String>::new())),
270                    Arc::new(Float64Array::from(Vec::<f64>::new())),
271                ],
272            ).map_err(GraphError::from);
273        }
274        
275        // Sort by centrality score (descending)
276        let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
277        sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
278        
279        let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
280        let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
281        
282        let schema = Arc::new(Schema::new(vec![
283            Field::new("node_id", DataType::Utf8, false),
284            Field::new("betweenness_centrality", DataType::Float64, false),
285        ]));
286        
287        RecordBatch::try_new(
288            schema,
289            vec![
290                Arc::new(StringArray::from(node_ids)),
291                Arc::new(Float64Array::from(scores)),
292            ],
293        ).map_err(GraphError::from)
294    }
295    
296    fn name(&self) -> &'static str {
297        "betweenness_centrality"
298    }
299    
300    fn description(&self) -> &'static str {
301        "Calculate betweenness centrality using Brandes' algorithm"
302    }
303}
304
305pub struct EigenvectorCentrality;
306
307impl EigenvectorCentrality {
308    /// Calculate eigenvector centrality using power iteration
309    fn compute_eigenvector_centrality(
310        &self,
311        graph: &ArrowGraph,
312        max_iterations: usize,
313        tolerance: f64,
314    ) -> Result<HashMap<String, f64>> {
315        let node_count = graph.node_count();
316        if node_count == 0 {
317            return Ok(HashMap::new());
318        }
319        
320        let node_ids: Vec<String> = graph.node_ids().cloned().collect();
321        let mut centrality: HashMap<String, f64> = HashMap::new();
322        let mut new_centrality: HashMap<String, f64> = HashMap::new();
323        
324        // Initialize with equal values
325        let initial_value = 1.0 / (node_count as f64).sqrt();
326        for node_id in &node_ids {
327            centrality.insert(node_id.clone(), initial_value);
328            new_centrality.insert(node_id.clone(), 0.0);
329        }
330        
331        // Power iteration
332        for iteration in 0..max_iterations {
333            // Reset new centrality values
334            for value in new_centrality.values_mut() {
335                *value = 0.0;
336            }
337            
338            // Compute new centrality values
339            for node_id in &node_ids {
340                let current_score = *centrality.get(node_id).unwrap_or(&0.0);
341                
342                if let Some(neighbors) = graph.neighbors(node_id) {
343                    for neighbor in neighbors {
344                        if let Some(neighbor_score) = new_centrality.get_mut(neighbor) {
345                            *neighbor_score += current_score;
346                        }
347                    }
348                }
349            }
350            
351            // Normalize to prevent overflow
352            let norm: f64 = new_centrality.values().map(|x| x * x).sum::<f64>().sqrt();
353            if norm > 0.0 {
354                for value in new_centrality.values_mut() {
355                    *value /= norm;
356                }
357            }
358            
359            // Check for convergence
360            let mut diff = 0.0;
361            for node_id in &node_ids {
362                let old_score = *centrality.get(node_id).unwrap_or(&0.0);
363                let new_score = *new_centrality.get(node_id).unwrap_or(&0.0);
364                diff += (new_score - old_score).abs();
365            }
366            
367            if diff < tolerance {
368                log::debug!("Eigenvector centrality converged after {} iterations", iteration + 1);
369                break;
370            }
371            
372            // Swap for next iteration
373            std::mem::swap(&mut centrality, &mut new_centrality);
374        }
375        
376        Ok(centrality)
377    }
378}
379
380impl GraphAlgorithm for EigenvectorCentrality {
381    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
382        let max_iterations: usize = params.get("max_iterations").unwrap_or(100);
383        let tolerance: f64 = params.get("tolerance").unwrap_or(1e-6);
384        
385        // Validate parameters
386        if max_iterations == 0 {
387            return Err(GraphError::invalid_parameter(
388                "max_iterations must be greater than 0"
389            ));
390        }
391        
392        if tolerance <= 0.0 {
393            return Err(GraphError::invalid_parameter(
394                "tolerance must be greater than 0.0"
395            ));
396        }
397        
398        let centrality = self.compute_eigenvector_centrality(graph, max_iterations, tolerance)?;
399        
400        if centrality.is_empty() {
401            let schema = Arc::new(Schema::new(vec![
402                Field::new("node_id", DataType::Utf8, false),
403                Field::new("eigenvector_centrality", DataType::Float64, false),
404            ]));
405            
406            return RecordBatch::try_new(
407                schema,
408                vec![
409                    Arc::new(StringArray::from(Vec::<String>::new())),
410                    Arc::new(Float64Array::from(Vec::<f64>::new())),
411                ],
412            ).map_err(GraphError::from);
413        }
414        
415        // Sort by centrality score (descending)
416        let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
417        sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
418        
419        let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
420        let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
421        
422        let schema = Arc::new(Schema::new(vec![
423            Field::new("node_id", DataType::Utf8, false),
424            Field::new("eigenvector_centrality", DataType::Float64, false),
425        ]));
426        
427        RecordBatch::try_new(
428            schema,
429            vec![
430                Arc::new(StringArray::from(node_ids)),
431                Arc::new(Float64Array::from(scores)),
432            ],
433        ).map_err(GraphError::from)
434    }
435    
436    fn name(&self) -> &'static str {
437        "eigenvector_centrality"
438    }
439    
440    fn description(&self) -> &'static str {
441        "Calculate eigenvector centrality using power iteration"
442    }
443}
444
445pub struct ClosenessCentrality;
446
447impl ClosenessCentrality {
448    /// Calculate closeness centrality with batched distance calculations
449    fn compute_closeness_centrality(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
450        let mut centrality: HashMap<String, f64> = HashMap::new();
451        let node_count = graph.node_count();
452        
453        if node_count <= 1 {
454            for node_id in graph.node_ids() {
455                centrality.insert(node_id.clone(), 0.0);
456            }
457            return Ok(centrality);
458        }
459        
460        // For each node, calculate shortest paths to all other nodes
461        for source in graph.node_ids() {
462            let distances = self.single_source_shortest_path_lengths(graph, source)?;
463            
464            // Calculate sum of distances and count reachable nodes
465            let mut total_distance = 0.0;
466            let mut reachable_count = 0;
467            
468            for (target, distance) in &distances {
469                if target != source && *distance >= 0.0 {
470                    total_distance += distance;
471                    reachable_count += 1;
472                }
473            }
474            
475            // Calculate closeness centrality
476            let closeness = if total_distance > 0.0 && reachable_count > 0 {
477                let avg_distance = total_distance / reachable_count as f64;
478                // Normalize by the fraction of nodes that are reachable
479                let connectivity = reachable_count as f64 / (node_count - 1) as f64;
480                connectivity / avg_distance
481            } else {
482                0.0
483            };
484            
485            centrality.insert(source.clone(), closeness);
486        }
487        
488        Ok(centrality)
489    }
490    
491    /// Single-source shortest path lengths using BFS
492    fn single_source_shortest_path_lengths(
493        &self,
494        graph: &ArrowGraph,
495        source: &str,
496    ) -> Result<HashMap<String, f64>> {
497        let mut distances: HashMap<String, f64> = HashMap::new();
498        let mut queue = std::collections::VecDeque::new();
499        
500        // Initialize distances
501        for node_id in graph.node_ids() {
502            distances.insert(node_id.clone(), -1.0); // -1 means unreachable
503        }
504        
505        // Start BFS from source
506        distances.insert(source.to_string(), 0.0);
507        queue.push_back(source.to_string());
508        
509        while let Some(current) = queue.pop_front() {
510            let current_distance = *distances.get(&current).unwrap_or(&-1.0);
511            
512            if let Some(neighbors) = graph.neighbors(&current) {
513                for neighbor in neighbors {
514                    let neighbor_distance = *distances.get(neighbor).unwrap_or(&-1.0);
515                    
516                    // If neighbor not visited yet
517                    if neighbor_distance < 0.0 {
518                        distances.insert(neighbor.clone(), current_distance + 1.0);
519                        queue.push_back(neighbor.clone());
520                    }
521                }
522            }
523        }
524        
525        Ok(distances)
526    }
527}
528
529impl GraphAlgorithm for ClosenessCentrality {
530    fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
531        let centrality = self.compute_closeness_centrality(graph)?;
532        
533        if centrality.is_empty() {
534            let schema = Arc::new(Schema::new(vec![
535                Field::new("node_id", DataType::Utf8, false),
536                Field::new("closeness_centrality", DataType::Float64, false),
537            ]));
538            
539            return RecordBatch::try_new(
540                schema,
541                vec![
542                    Arc::new(StringArray::from(Vec::<String>::new())),
543                    Arc::new(Float64Array::from(Vec::<f64>::new())),
544                ],
545            ).map_err(GraphError::from);
546        }
547        
548        // Sort by centrality score (descending)
549        let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
550        sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
551        
552        let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
553        let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
554        
555        let schema = Arc::new(Schema::new(vec![
556            Field::new("node_id", DataType::Utf8, false),
557            Field::new("closeness_centrality", DataType::Float64, false),
558        ]));
559        
560        RecordBatch::try_new(
561            schema,
562            vec![
563                Arc::new(StringArray::from(node_ids)),
564                Arc::new(Float64Array::from(scores)),
565            ],
566        ).map_err(GraphError::from)
567    }
568    
569    fn name(&self) -> &'static str {
570        "closeness_centrality"
571    }
572    
573    fn description(&self) -> &'static str {
574        "Calculate closeness centrality using batched distance calculations"
575    }
576}