1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10pub struct PageRank;
11
12impl PageRank {
13 fn compute_pagerank(
15 &self,
16 graph: &ArrowGraph,
17 damping_factor: f64,
18 max_iterations: usize,
19 tolerance: f64,
20 ) -> Result<HashMap<String, f64>> {
21 let node_count = graph.node_count();
22 if node_count == 0 {
23 return Ok(HashMap::new());
24 }
25
26 let initial_score = 1.0 / node_count as f64;
28 let mut current_scores: HashMap<String, f64> = HashMap::new();
29 let mut next_scores: HashMap<String, f64> = HashMap::new();
30
31 for node_id in graph.node_ids() {
33 current_scores.insert(node_id.clone(), initial_score);
34 next_scores.insert(node_id.clone(), 0.0);
35 }
36
37 let mut out_degrees: HashMap<String, usize> = HashMap::new();
39 for node_id in graph.node_ids() {
40 let degree = graph.neighbors(node_id).map(|n| n.len()).unwrap_or(0);
41 out_degrees.insert(node_id.clone(), degree);
42 }
43
44 for iteration in 0..max_iterations {
46 for score in next_scores.values_mut() {
48 *score = (1.0 - damping_factor) / node_count as f64;
49 }
50
51 for node_id in graph.node_ids() {
53 let current_score = current_scores.get(node_id).unwrap_or(&0.0);
54 let out_degree = out_degrees.get(node_id).unwrap_or(&0);
55
56 if *out_degree > 0 {
57 let contribution = current_score * damping_factor / *out_degree as f64;
58
59 if let Some(neighbors) = graph.neighbors(node_id) {
60 for neighbor in neighbors {
61 if let Some(neighbor_score) = next_scores.get_mut(neighbor) {
62 *neighbor_score += contribution;
63 }
64 }
65 }
66 } else {
67 let dangling_contribution = current_score * damping_factor / node_count as f64;
69 for score in next_scores.values_mut() {
70 *score += dangling_contribution;
71 }
72 }
73 }
74
75 let mut diff = 0.0;
77 for node_id in graph.node_ids() {
78 let old_score = current_scores.get(node_id).unwrap_or(&0.0);
79 let new_score = next_scores.get(node_id).unwrap_or(&0.0);
80 diff += (new_score - old_score).abs();
81 }
82
83 if diff < tolerance {
85 log::debug!("PageRank converged after {} iterations", iteration + 1);
86 break;
87 }
88
89 std::mem::swap(&mut current_scores, &mut next_scores);
91 }
92
93 Ok(current_scores)
94 }
95}
96
97impl GraphAlgorithm for PageRank {
98 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
99 let damping_factor: f64 = params.get("damping_factor").unwrap_or(0.85);
100 let max_iterations: usize = params.get("max_iterations").unwrap_or(100);
101 let tolerance: f64 = params.get("tolerance").unwrap_or(1e-6);
102
103 if !(0.0..=1.0).contains(&damping_factor) {
105 return Err(GraphError::invalid_parameter(
106 "damping_factor must be between 0.0 and 1.0"
107 ));
108 }
109
110 if max_iterations == 0 {
111 return Err(GraphError::invalid_parameter(
112 "max_iterations must be greater than 0"
113 ));
114 }
115
116 if tolerance <= 0.0 {
117 return Err(GraphError::invalid_parameter(
118 "tolerance must be greater than 0.0"
119 ));
120 }
121
122 let scores = self.compute_pagerank(graph, damping_factor, max_iterations, tolerance)?;
123
124 let schema = Arc::new(Schema::new(vec![
126 Field::new("node_id", DataType::Utf8, false),
127 Field::new("pagerank_score", DataType::Float64, false),
128 ]));
129
130 let mut node_ids = Vec::new();
131 let mut pagerank_scores = Vec::new();
132
133 let mut sorted_scores: Vec<(&String, &f64)> = scores.iter().collect();
135 sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
136
137 for (node_id, score) in sorted_scores {
138 node_ids.push(node_id.clone());
139 pagerank_scores.push(*score);
140 }
141
142 RecordBatch::try_new(
143 schema,
144 vec![
145 Arc::new(StringArray::from(node_ids)),
146 Arc::new(Float64Array::from(pagerank_scores)),
147 ],
148 ).map_err(GraphError::from)
149 }
150
151 fn name(&self) -> &'static str {
152 "pagerank"
153 }
154
155 fn description(&self) -> &'static str {
156 "Calculate PageRank scores using power iteration with early termination"
157 }
158}
159
160pub struct BetweennessCentrality;
161
162impl BetweennessCentrality {
163 fn compute_betweenness_centrality(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
165 let mut centrality: HashMap<String, f64> = HashMap::new();
166
167 for node_id in graph.node_ids() {
169 centrality.insert(node_id.clone(), 0.0);
170 }
171
172 for source in graph.node_ids() {
174 let mut stack = Vec::new();
175 let mut paths: HashMap<String, Vec<String>> = HashMap::new();
176 let mut num_paths: HashMap<String, f64> = HashMap::new();
177 let mut distances: HashMap<String, i32> = HashMap::new();
178 let mut delta: HashMap<String, f64> = HashMap::new();
179
180 for node_id in graph.node_ids() {
182 paths.insert(node_id.clone(), Vec::new());
183 num_paths.insert(node_id.clone(), 0.0);
184 distances.insert(node_id.clone(), -1);
185 delta.insert(node_id.clone(), 0.0);
186 }
187
188 num_paths.insert(source.clone(), 1.0);
189 distances.insert(source.clone(), 0);
190
191 let mut queue = std::collections::VecDeque::new();
193 queue.push_back(source.clone());
194
195 while let Some(current) = queue.pop_front() {
196 stack.push(current.clone());
197
198 if let Some(neighbors) = graph.neighbors(¤t) {
199 for neighbor in neighbors {
200 let current_dist = *distances.get(¤t).unwrap_or(&-1);
201 let neighbor_dist = *distances.get(neighbor).unwrap_or(&-1);
202
203 if neighbor_dist < 0 {
205 queue.push_back(neighbor.clone());
206 distances.insert(neighbor.clone(), current_dist + 1);
207 }
208
209 if neighbor_dist == current_dist + 1 {
211 let current_paths = *num_paths.get(¤t).unwrap_or(&0.0);
212 let neighbor_paths = num_paths.get_mut(neighbor).unwrap();
213 *neighbor_paths += current_paths;
214
215 paths.get_mut(neighbor).unwrap().push(current.clone());
216 }
217 }
218 }
219 }
220
221 while let Some(w) = stack.pop() {
223 if let Some(predecessors) = paths.get(&w) {
224 for predecessor in predecessors {
225 let w_delta = *delta.get(&w).unwrap_or(&0.0);
226 let w_paths = *num_paths.get(&w).unwrap_or(&0.0);
227 let pred_paths = *num_paths.get(predecessor).unwrap_or(&0.0);
228
229 if pred_paths > 0.0 {
230 let contribution = (pred_paths / w_paths) * (1.0 + w_delta);
231 *delta.get_mut(predecessor).unwrap() += contribution;
232 }
233 }
234 }
235
236 if w != *source {
237 let w_delta = *delta.get(&w).unwrap_or(&0.0);
238 *centrality.get_mut(&w).unwrap() += w_delta;
239 }
240 }
241 }
242
243 let node_count = graph.node_count() as f64;
245 if node_count > 2.0 {
246 let normalization = 2.0 / ((node_count - 1.0) * (node_count - 2.0));
247 for score in centrality.values_mut() {
248 *score *= normalization;
249 }
250 }
251
252 Ok(centrality)
253 }
254}
255
256impl GraphAlgorithm for BetweennessCentrality {
257 fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
258 let centrality = self.compute_betweenness_centrality(graph)?;
259
260 if centrality.is_empty() {
261 let schema = Arc::new(Schema::new(vec![
262 Field::new("node_id", DataType::Utf8, false),
263 Field::new("betweenness_centrality", DataType::Float64, false),
264 ]));
265
266 return RecordBatch::try_new(
267 schema,
268 vec![
269 Arc::new(StringArray::from(Vec::<String>::new())),
270 Arc::new(Float64Array::from(Vec::<f64>::new())),
271 ],
272 ).map_err(GraphError::from);
273 }
274
275 let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
277 sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
278
279 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
280 let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
281
282 let schema = Arc::new(Schema::new(vec![
283 Field::new("node_id", DataType::Utf8, false),
284 Field::new("betweenness_centrality", DataType::Float64, false),
285 ]));
286
287 RecordBatch::try_new(
288 schema,
289 vec![
290 Arc::new(StringArray::from(node_ids)),
291 Arc::new(Float64Array::from(scores)),
292 ],
293 ).map_err(GraphError::from)
294 }
295
296 fn name(&self) -> &'static str {
297 "betweenness_centrality"
298 }
299
300 fn description(&self) -> &'static str {
301 "Calculate betweenness centrality using Brandes' algorithm"
302 }
303}
304
305pub struct EigenvectorCentrality;
306
307impl EigenvectorCentrality {
308 fn compute_eigenvector_centrality(
310 &self,
311 graph: &ArrowGraph,
312 max_iterations: usize,
313 tolerance: f64,
314 ) -> Result<HashMap<String, f64>> {
315 let node_count = graph.node_count();
316 if node_count == 0 {
317 return Ok(HashMap::new());
318 }
319
320 let node_ids: Vec<String> = graph.node_ids().cloned().collect();
321 let mut centrality: HashMap<String, f64> = HashMap::new();
322 let mut new_centrality: HashMap<String, f64> = HashMap::new();
323
324 let initial_value = 1.0 / (node_count as f64).sqrt();
326 for node_id in &node_ids {
327 centrality.insert(node_id.clone(), initial_value);
328 new_centrality.insert(node_id.clone(), 0.0);
329 }
330
331 for iteration in 0..max_iterations {
333 for value in new_centrality.values_mut() {
335 *value = 0.0;
336 }
337
338 for node_id in &node_ids {
340 let current_score = *centrality.get(node_id).unwrap_or(&0.0);
341
342 if let Some(neighbors) = graph.neighbors(node_id) {
343 for neighbor in neighbors {
344 if let Some(neighbor_score) = new_centrality.get_mut(neighbor) {
345 *neighbor_score += current_score;
346 }
347 }
348 }
349 }
350
351 let norm: f64 = new_centrality.values().map(|x| x * x).sum::<f64>().sqrt();
353 if norm > 0.0 {
354 for value in new_centrality.values_mut() {
355 *value /= norm;
356 }
357 }
358
359 let mut diff = 0.0;
361 for node_id in &node_ids {
362 let old_score = *centrality.get(node_id).unwrap_or(&0.0);
363 let new_score = *new_centrality.get(node_id).unwrap_or(&0.0);
364 diff += (new_score - old_score).abs();
365 }
366
367 if diff < tolerance {
368 log::debug!("Eigenvector centrality converged after {} iterations", iteration + 1);
369 break;
370 }
371
372 std::mem::swap(&mut centrality, &mut new_centrality);
374 }
375
376 Ok(centrality)
377 }
378}
379
380impl GraphAlgorithm for EigenvectorCentrality {
381 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
382 let max_iterations: usize = params.get("max_iterations").unwrap_or(100);
383 let tolerance: f64 = params.get("tolerance").unwrap_or(1e-6);
384
385 if max_iterations == 0 {
387 return Err(GraphError::invalid_parameter(
388 "max_iterations must be greater than 0"
389 ));
390 }
391
392 if tolerance <= 0.0 {
393 return Err(GraphError::invalid_parameter(
394 "tolerance must be greater than 0.0"
395 ));
396 }
397
398 let centrality = self.compute_eigenvector_centrality(graph, max_iterations, tolerance)?;
399
400 if centrality.is_empty() {
401 let schema = Arc::new(Schema::new(vec![
402 Field::new("node_id", DataType::Utf8, false),
403 Field::new("eigenvector_centrality", DataType::Float64, false),
404 ]));
405
406 return RecordBatch::try_new(
407 schema,
408 vec![
409 Arc::new(StringArray::from(Vec::<String>::new())),
410 Arc::new(Float64Array::from(Vec::<f64>::new())),
411 ],
412 ).map_err(GraphError::from);
413 }
414
415 let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
417 sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
418
419 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
420 let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
421
422 let schema = Arc::new(Schema::new(vec![
423 Field::new("node_id", DataType::Utf8, false),
424 Field::new("eigenvector_centrality", DataType::Float64, false),
425 ]));
426
427 RecordBatch::try_new(
428 schema,
429 vec![
430 Arc::new(StringArray::from(node_ids)),
431 Arc::new(Float64Array::from(scores)),
432 ],
433 ).map_err(GraphError::from)
434 }
435
436 fn name(&self) -> &'static str {
437 "eigenvector_centrality"
438 }
439
440 fn description(&self) -> &'static str {
441 "Calculate eigenvector centrality using power iteration"
442 }
443}
444
445pub struct ClosenessCentrality;
446
447impl ClosenessCentrality {
448 fn compute_closeness_centrality(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
450 let mut centrality: HashMap<String, f64> = HashMap::new();
451 let node_count = graph.node_count();
452
453 if node_count <= 1 {
454 for node_id in graph.node_ids() {
455 centrality.insert(node_id.clone(), 0.0);
456 }
457 return Ok(centrality);
458 }
459
460 for source in graph.node_ids() {
462 let distances = self.single_source_shortest_path_lengths(graph, source)?;
463
464 let mut total_distance = 0.0;
466 let mut reachable_count = 0;
467
468 for (target, distance) in &distances {
469 if target != source && *distance >= 0.0 {
470 total_distance += distance;
471 reachable_count += 1;
472 }
473 }
474
475 let closeness = if total_distance > 0.0 && reachable_count > 0 {
477 let avg_distance = total_distance / reachable_count as f64;
478 let connectivity = reachable_count as f64 / (node_count - 1) as f64;
480 connectivity / avg_distance
481 } else {
482 0.0
483 };
484
485 centrality.insert(source.clone(), closeness);
486 }
487
488 Ok(centrality)
489 }
490
491 fn single_source_shortest_path_lengths(
493 &self,
494 graph: &ArrowGraph,
495 source: &str,
496 ) -> Result<HashMap<String, f64>> {
497 let mut distances: HashMap<String, f64> = HashMap::new();
498 let mut queue = std::collections::VecDeque::new();
499
500 for node_id in graph.node_ids() {
502 distances.insert(node_id.clone(), -1.0); }
504
505 distances.insert(source.to_string(), 0.0);
507 queue.push_back(source.to_string());
508
509 while let Some(current) = queue.pop_front() {
510 let current_distance = *distances.get(¤t).unwrap_or(&-1.0);
511
512 if let Some(neighbors) = graph.neighbors(¤t) {
513 for neighbor in neighbors {
514 let neighbor_distance = *distances.get(neighbor).unwrap_or(&-1.0);
515
516 if neighbor_distance < 0.0 {
518 distances.insert(neighbor.clone(), current_distance + 1.0);
519 queue.push_back(neighbor.clone());
520 }
521 }
522 }
523 }
524
525 Ok(distances)
526 }
527}
528
529impl GraphAlgorithm for ClosenessCentrality {
530 fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
531 let centrality = self.compute_closeness_centrality(graph)?;
532
533 if centrality.is_empty() {
534 let schema = Arc::new(Schema::new(vec![
535 Field::new("node_id", DataType::Utf8, false),
536 Field::new("closeness_centrality", DataType::Float64, false),
537 ]));
538
539 return RecordBatch::try_new(
540 schema,
541 vec![
542 Arc::new(StringArray::from(Vec::<String>::new())),
543 Arc::new(Float64Array::from(Vec::<f64>::new())),
544 ],
545 ).map_err(GraphError::from);
546 }
547
548 let mut sorted_nodes: Vec<(&String, &f64)> = centrality.iter().collect();
550 sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
551
552 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
553 let scores: Vec<f64> = sorted_nodes.iter().map(|(_, &score)| score).collect();
554
555 let schema = Arc::new(Schema::new(vec![
556 Field::new("node_id", DataType::Utf8, false),
557 Field::new("closeness_centrality", DataType::Float64, false),
558 ]));
559
560 RecordBatch::try_new(
561 schema,
562 vec![
563 Arc::new(StringArray::from(node_ids)),
564 Arc::new(Float64Array::from(scores)),
565 ],
566 ).map_err(GraphError::from)
567 }
568
569 fn name(&self) -> &'static str {
570 "closeness_centrality"
571 }
572
573 fn description(&self) -> &'static str {
574 "Calculate closeness centrality using batched distance calculations"
575 }
576}