rustkernel_graph/
centrality.rs

1//! Centrality measure kernels.
2//!
3//! This module provides GPU-accelerated centrality algorithms:
4//! - Degree centrality
5//! - Betweenness centrality (Brandes algorithm)
6//! - Closeness centrality (BFS-based)
7//! - Eigenvector centrality (power iteration)
8//! - PageRank (power iteration with teleport)
9//! - Katz centrality (attenuated paths)
10
11use crate::messages::{CentralityInput, CentralityOutput, CentralityParams};
12use crate::ring_messages::{
13    K2KBarrier, K2KBarrierRelease, K2KIterationSync, K2KIterationSyncResponse,
14    PageRankConvergeResponse, PageRankConvergeRing, PageRankIterateResponse, PageRankIterateRing,
15    PageRankQueryResponse, PageRankQueryRing, from_fixed_point, to_fixed_point,
16};
17use crate::types::{CentralityResult, CsrGraph, NodeScore};
18use async_trait::async_trait;
19use ringkernel_core::RingContext;
20use rustkernel_core::{
21    domain::Domain,
22    error::Result,
23    k2k::IterativeState,
24    kernel::KernelMetadata,
25    traits::{BatchKernel, GpuKernel, RingKernelHandler},
26};
27use std::collections::VecDeque;
28use std::time::Instant;
29
30// ============================================================================
31// PageRank Kernel
32// ============================================================================
33
34/// PageRank kernel state.
35#[derive(Debug, Clone, Default)]
36pub struct PageRankState {
37    /// Current scores.
38    pub scores: Vec<f64>,
39    /// Previous scores (for convergence check).
40    pub prev_scores: Vec<f64>,
41    /// Graph in CSR format.
42    pub graph: Option<CsrGraph>,
43    /// Damping factor.
44    pub damping: f32,
45    /// Current iteration.
46    pub iteration: u32,
47    /// Whether converged.
48    pub converged: bool,
49}
50
51/// PageRank centrality kernel.
52///
53/// Calculates PageRank centrality using power iteration with teleportation.
54/// This is a Ring kernel for low-latency queries after graph is loaded.
55#[derive(Debug)]
56pub struct PageRank {
57    metadata: KernelMetadata,
58    /// Internal state for Ring mode operations.
59    state: std::sync::RwLock<PageRankState>,
60}
61
62impl Clone for PageRank {
63    fn clone(&self) -> Self {
64        Self {
65            metadata: self.metadata.clone(),
66            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
67        }
68    }
69}
70
71impl PageRank {
72    /// Create a new PageRank kernel.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            metadata: KernelMetadata::ring("graph/pagerank", Domain::GraphAnalytics)
77                .with_description("PageRank centrality via power iteration")
78                .with_throughput(100_000)
79                .with_latency_us(1.0)
80                .with_gpu_native(true),
81            state: std::sync::RwLock::new(PageRankState::default()),
82        }
83    }
84
85    /// Initialize the kernel with a graph for Ring mode operations.
86    pub fn initialize(&self, graph: CsrGraph, damping: f32) {
87        let mut state = self.state.write().unwrap();
88        *state = Self::initialize_state(graph, damping);
89    }
90
91    /// Query the score for a specific node.
92    pub fn query_score(&self, node_id: u64) -> Option<f64> {
93        let state = self.state.read().unwrap();
94        state.scores.get(node_id as usize).copied()
95    }
96
97    /// Get current iteration count.
98    pub fn current_iteration(&self) -> u32 {
99        self.state.read().unwrap().iteration
100    }
101
102    /// Check if converged.
103    pub fn is_converged(&self) -> bool {
104        self.state.read().unwrap().converged
105    }
106
107    /// Perform one iteration step using internal state.
108    pub fn iterate(&self) -> f64 {
109        let mut state = self.state.write().unwrap();
110        Self::iterate_step(&mut state)
111    }
112
113    /// Perform one iteration of PageRank on the given state.
114    pub fn iterate_step(state: &mut PageRankState) -> f64 {
115        let Some(ref graph) = state.graph else {
116            return 0.0;
117        };
118
119        let n = graph.num_nodes;
120        if n == 0 {
121            return 0.0;
122        }
123
124        let d = state.damping as f64;
125        let teleport = (1.0 - d) / n as f64;
126
127        // Swap buffers
128        std::mem::swap(&mut state.scores, &mut state.prev_scores);
129
130        // Calculate new scores
131        let mut max_diff = 0.0f64;
132
133        for i in 0..n {
134            let mut rank_sum = 0.0f64;
135
136            // Sum contributions from incoming edges
137            for &neighbor in graph.neighbors(i as u64) {
138                let out_degree = graph.out_degree(neighbor) as f64;
139                if out_degree > 0.0 {
140                    rank_sum += state.prev_scores[neighbor as usize] / out_degree;
141                }
142            }
143
144            let new_score = teleport + d * rank_sum;
145            state.scores[i] = new_score;
146
147            let diff = (new_score - state.prev_scores[i]).abs();
148            if diff > max_diff {
149                max_diff = diff;
150            }
151        }
152
153        state.iteration += 1;
154        max_diff
155    }
156
157    /// Initialize state for a graph.
158    pub fn initialize_state(graph: CsrGraph, damping: f32) -> PageRankState {
159        let n = graph.num_nodes;
160        PageRankState {
161            scores: vec![1.0 / n as f64; n],
162            prev_scores: vec![0.0; n],
163            graph: Some(graph),
164            damping,
165            iteration: 0,
166            converged: false,
167        }
168    }
169
170    /// Run PageRank to convergence.
171    pub fn run_to_convergence(
172        graph: CsrGraph,
173        damping: f32,
174        max_iterations: u32,
175        threshold: f64,
176    ) -> Result<CentralityResult> {
177        let mut state = Self::initialize_state(graph, damping);
178
179        for _ in 0..max_iterations {
180            let diff = Self::iterate_step(&mut state);
181            if diff < threshold {
182                state.converged = true;
183                break;
184            }
185        }
186
187        Ok(CentralityResult {
188            scores: state
189                .scores
190                .iter()
191                .enumerate()
192                .map(|(i, &score)| NodeScore {
193                    node_id: i as u64,
194                    score,
195                })
196                .collect(),
197            iterations: Some(state.iteration),
198            converged: state.converged,
199        })
200    }
201}
202
203impl Default for PageRank {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl GpuKernel for PageRank {
210    fn metadata(&self) -> &KernelMetadata {
211        &self.metadata
212    }
213}
214
215// ============================================================================
216// PageRank RingKernelHandler Implementations
217// ============================================================================
218
219/// RingKernelHandler for PageRank queries.
220///
221/// Enables low-latency score queries for individual nodes in Ring mode.
222#[async_trait]
223impl RingKernelHandler<PageRankQueryRing, PageRankQueryResponse> for PageRank {
224    async fn handle(
225        &self,
226        _ctx: &mut RingContext,
227        msg: PageRankQueryRing,
228    ) -> Result<PageRankQueryResponse> {
229        let state = self.state.read().unwrap();
230        let score = state
231            .scores
232            .get(msg.node_id as usize)
233            .copied()
234            .unwrap_or(0.0);
235
236        Ok(PageRankQueryResponse {
237            request_id: msg.id.0,
238            node_id: msg.node_id,
239            score_fp: to_fixed_point(score),
240            iteration: state.iteration,
241            converged: state.converged,
242        })
243    }
244}
245
246/// RingKernelHandler for PageRank single iteration.
247///
248/// Performs one power iteration step in Ring mode.
249#[async_trait]
250impl RingKernelHandler<PageRankIterateRing, PageRankIterateResponse> for PageRank {
251    async fn handle(
252        &self,
253        _ctx: &mut RingContext,
254        msg: PageRankIterateRing,
255    ) -> Result<PageRankIterateResponse> {
256        // Perform one iteration on internal state
257        let max_delta = self.iterate();
258
259        // Check convergence using default threshold
260        let state = self.state.read().unwrap();
261        let converged = max_delta < 1e-6;
262
263        Ok(PageRankIterateResponse {
264            request_id: msg.id.0,
265            iteration: state.iteration,
266            max_delta_fp: to_fixed_point(max_delta),
267            converged,
268        })
269    }
270}
271
272/// RingKernelHandler for PageRank convergence.
273///
274/// Runs PageRank to convergence using K2K coordination for iterative state.
275#[async_trait]
276impl RingKernelHandler<PageRankConvergeRing, PageRankConvergeResponse> for PageRank {
277    async fn handle(
278        &self,
279        _ctx: &mut RingContext,
280        msg: PageRankConvergeRing,
281    ) -> Result<PageRankConvergeResponse> {
282        let threshold = from_fixed_point(msg.threshold_fp);
283        let max_iterations = msg.max_iterations as u64;
284
285        // Use K2K IterativeState for convergence tracking
286        let mut iterative_state = IterativeState::new(threshold, max_iterations);
287
288        // Run actual iterations on internal state
289        while iterative_state.should_continue() {
290            let max_delta = self.iterate();
291            iterative_state.update(max_delta);
292        }
293
294        // Update convergence status in internal state
295        {
296            let mut state = self.state.write().unwrap();
297            state.converged = iterative_state.summary().converged;
298        }
299
300        let summary = iterative_state.summary();
301
302        Ok(PageRankConvergeResponse {
303            request_id: msg.id.0,
304            iterations: summary.iterations as u32,
305            final_delta_fp: to_fixed_point(summary.final_delta),
306            converged: summary.converged,
307        })
308    }
309}
310
311/// RingKernelHandler for K2K iteration synchronization.
312///
313/// Used in distributed PageRank to synchronize iterations across partitions.
314/// In a single-instance setting, this validates the worker's iteration state
315/// and returns convergence status based on the reported delta.
316#[async_trait]
317impl RingKernelHandler<K2KIterationSync, K2KIterationSyncResponse> for PageRank {
318    async fn handle(
319        &self,
320        _ctx: &mut RingContext,
321        msg: K2KIterationSync,
322    ) -> Result<K2KIterationSyncResponse> {
323        let state = self.state.read().unwrap();
324
325        // For single-instance, verify iteration matches internal state
326        // In distributed setting, would aggregate deltas from all workers
327        let current_iteration = state.iteration as u64;
328        let all_synced = msg.iteration <= current_iteration;
329
330        // Use reported local delta as global delta (single worker case)
331        // In distributed setting, would compute max across all workers
332        let local_delta = from_fixed_point(msg.local_delta_fp);
333        let global_converged = local_delta < 1e-6 || state.converged;
334
335        Ok(K2KIterationSyncResponse {
336            request_id: msg.id.0,
337            iteration: msg.iteration,
338            all_synced,
339            global_delta_fp: msg.local_delta_fp,
340            global_converged,
341        })
342    }
343}
344
345/// RingKernelHandler for K2K barrier synchronization.
346///
347/// Implements barrier synchronization for distributed PageRank iterations.
348#[async_trait]
349impl RingKernelHandler<K2KBarrier, K2KBarrierRelease> for PageRank {
350    async fn handle(&self, _ctx: &mut RingContext, msg: K2KBarrier) -> Result<K2KBarrierRelease> {
351        // In a distributed setting, this would:
352        // 1. Record this worker as ready
353        // 2. Check if all workers are ready
354        // 3. Release barrier when all ready
355        let all_ready = msg.ready_count >= msg.total_workers;
356
357        Ok(K2KBarrierRelease {
358            barrier_id: msg.barrier_id,
359            all_ready,
360            next_iteration: msg.barrier_id + 1,
361        })
362    }
363}
364
365// ============================================================================
366// Degree Centrality Kernel
367// ============================================================================
368
369/// Degree centrality kernel.
370///
371/// Simple O(1) lookup of node degrees after graph is loaded.
372#[derive(Debug, Clone)]
373pub struct DegreeCentrality {
374    metadata: KernelMetadata,
375}
376
377impl DegreeCentrality {
378    /// Create a new degree centrality kernel.
379    #[must_use]
380    pub fn new() -> Self {
381        Self {
382            metadata: KernelMetadata::ring("graph/degree-centrality", Domain::GraphAnalytics)
383                .with_description("Degree centrality (O(1) lookup)")
384                .with_throughput(1_000_000)
385                .with_latency_us(0.1),
386        }
387    }
388
389    /// Calculate degree centrality for all nodes.
390    ///
391    /// Returns normalized degree centrality (degree / (n-1)).
392    pub fn compute(graph: &CsrGraph) -> CentralityResult {
393        let n = graph.num_nodes;
394        let normalizer = if n > 1 { (n - 1) as f64 } else { 1.0 };
395
396        let scores: Vec<NodeScore> = (0..n)
397            .map(|i| NodeScore {
398                node_id: i as u64,
399                score: graph.out_degree(i as u64) as f64 / normalizer,
400            })
401            .collect();
402
403        CentralityResult {
404            scores,
405            iterations: None,
406            converged: true,
407        }
408    }
409}
410
411impl Default for DegreeCentrality {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417impl GpuKernel for DegreeCentrality {
418    fn metadata(&self) -> &KernelMetadata {
419        &self.metadata
420    }
421}
422
423// ============================================================================
424// Betweenness Centrality Kernel (Brandes Algorithm)
425// ============================================================================
426
427/// Betweenness centrality kernel.
428///
429/// Uses Brandes algorithm for efficient computation in O(VE) time.
430#[derive(Debug, Clone)]
431pub struct BetweennessCentrality {
432    metadata: KernelMetadata,
433}
434
435impl BetweennessCentrality {
436    /// Create a new betweenness centrality kernel.
437    #[must_use]
438    pub fn new() -> Self {
439        Self {
440            metadata: KernelMetadata::batch("graph/betweenness-centrality", Domain::GraphAnalytics)
441                .with_description("Betweenness centrality (Brandes algorithm)")
442                .with_throughput(10_000)
443                .with_latency_us(100.0),
444        }
445    }
446
447    /// Compute betweenness centrality using Brandes algorithm.
448    ///
449    /// The algorithm runs BFS from each vertex and accumulates
450    /// dependency scores in a single backward pass.
451    pub fn compute(graph: &CsrGraph, normalized: bool) -> CentralityResult {
452        let n = graph.num_nodes;
453        let mut centrality = vec![0.0f64; n];
454
455        // Run Brandes algorithm from each source
456        for s in 0..n {
457            // BFS structures
458            let mut stack: Vec<usize> = Vec::with_capacity(n);
459            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
460            let mut sigma = vec![0.0f64; n]; // Number of shortest paths
461            let mut dist = vec![-1i64; n]; // Distance from source
462
463            sigma[s] = 1.0;
464            dist[s] = 0;
465
466            let mut queue = VecDeque::new();
467            queue.push_back(s);
468
469            // Forward BFS
470            while let Some(v) = queue.pop_front() {
471                stack.push(v);
472
473                for &w in graph.neighbors(v as u64) {
474                    let w = w as usize;
475
476                    // First time visiting w?
477                    if dist[w] < 0 {
478                        dist[w] = dist[v] + 1;
479                        queue.push_back(w);
480                    }
481
482                    // Is this a shortest path to w via v?
483                    if dist[w] == dist[v] + 1 {
484                        sigma[w] += sigma[v];
485                        predecessors[w].push(v);
486                    }
487                }
488            }
489
490            // Backward pass - accumulate dependencies
491            let mut delta = vec![0.0f64; n];
492
493            while let Some(w) = stack.pop() {
494                for &v in &predecessors[w] {
495                    let contribution = (sigma[v] / sigma[w]) * (1.0 + delta[w]);
496                    delta[v] += contribution;
497                }
498
499                if w != s {
500                    centrality[w] += delta[w];
501                }
502            }
503        }
504
505        // Normalize if requested
506        if normalized && n > 2 {
507            let scale = 1.0 / ((n - 1) * (n - 2)) as f64;
508            for c in &mut centrality {
509                *c *= scale;
510            }
511        }
512
513        CentralityResult {
514            scores: centrality
515                .into_iter()
516                .enumerate()
517                .map(|(i, score)| NodeScore {
518                    node_id: i as u64,
519                    score,
520                })
521                .collect(),
522            iterations: None,
523            converged: true,
524        }
525    }
526}
527
528impl Default for BetweennessCentrality {
529    fn default() -> Self {
530        Self::new()
531    }
532}
533
534impl GpuKernel for BetweennessCentrality {
535    fn metadata(&self) -> &KernelMetadata {
536        &self.metadata
537    }
538}
539
540// ============================================================================
541// Closeness Centrality Kernel
542// ============================================================================
543
544/// Closeness centrality kernel.
545///
546/// BFS-based closeness centrality calculation.
547/// Closeness = (n-1) / sum(shortest_path_distances)
548#[derive(Debug, Clone)]
549pub struct ClosenessCentrality {
550    metadata: KernelMetadata,
551}
552
553impl ClosenessCentrality {
554    /// Create a new closeness centrality kernel.
555    #[must_use]
556    pub fn new() -> Self {
557        Self {
558            metadata: KernelMetadata::batch("graph/closeness-centrality", Domain::GraphAnalytics)
559                .with_description("Closeness centrality (BFS-based)")
560                .with_throughput(10_000)
561                .with_latency_us(100.0),
562        }
563    }
564
565    /// Compute closeness centrality using BFS from each node.
566    ///
567    /// For disconnected graphs, uses harmonic mean variant.
568    pub fn compute(graph: &CsrGraph, harmonic: bool) -> CentralityResult {
569        let n = graph.num_nodes;
570        let mut centrality = vec![0.0f64; n];
571
572        for source in 0..n {
573            let distances = Self::bfs_distances(graph, source);
574
575            if harmonic {
576                // Harmonic centrality: sum(1/d) for all reachable nodes
577                let sum: f64 = distances
578                    .iter()
579                    .enumerate()
580                    .filter(|(i, d)| *i != source && **d > 0)
581                    .map(|(_, d)| 1.0 / *d as f64)
582                    .sum();
583                centrality[source] = sum / (n - 1) as f64;
584            } else {
585                // Classic closeness: (n-1) / sum(distances)
586                let sum: i64 = distances.iter().sum();
587                let reachable: usize = distances.iter().filter(|&&d| d > 0).count();
588
589                if sum > 0 && reachable > 0 {
590                    centrality[source] = reachable as f64 / sum as f64;
591                }
592            }
593        }
594
595        CentralityResult {
596            scores: centrality
597                .into_iter()
598                .enumerate()
599                .map(|(i, score)| NodeScore {
600                    node_id: i as u64,
601                    score,
602                })
603                .collect(),
604            iterations: None,
605            converged: true,
606        }
607    }
608
609    /// BFS to compute distances from source to all other nodes.
610    fn bfs_distances(graph: &CsrGraph, source: usize) -> Vec<i64> {
611        let n = graph.num_nodes;
612        let mut distances = vec![0i64; n];
613        let mut visited = vec![false; n];
614
615        let mut queue = VecDeque::new();
616        queue.push_back(source);
617        visited[source] = true;
618
619        while let Some(v) = queue.pop_front() {
620            for &w in graph.neighbors(v as u64) {
621                let w = w as usize;
622                if !visited[w] {
623                    visited[w] = true;
624                    distances[w] = distances[v] + 1;
625                    queue.push_back(w);
626                }
627            }
628        }
629
630        distances
631    }
632}
633
634impl Default for ClosenessCentrality {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640impl GpuKernel for ClosenessCentrality {
641    fn metadata(&self) -> &KernelMetadata {
642        &self.metadata
643    }
644}
645
646// ============================================================================
647// Eigenvector Centrality Kernel
648// ============================================================================
649
650/// Eigenvector centrality kernel.
651///
652/// Power iteration method for eigenvector centrality.
653/// A node's score is proportional to the sum of its neighbors' scores.
654#[derive(Debug, Clone)]
655pub struct EigenvectorCentrality {
656    metadata: KernelMetadata,
657}
658
659impl EigenvectorCentrality {
660    /// Create a new eigenvector centrality kernel.
661    #[must_use]
662    pub fn new() -> Self {
663        Self {
664            metadata: KernelMetadata::batch("graph/eigenvector-centrality", Domain::GraphAnalytics)
665                .with_description("Eigenvector centrality (power iteration)")
666                .with_throughput(50_000)
667                .with_latency_us(10.0),
668        }
669    }
670
671    /// Compute eigenvector centrality using power iteration.
672    pub fn compute(graph: &CsrGraph, max_iterations: u32, tolerance: f64) -> CentralityResult {
673        let n = graph.num_nodes;
674        if n == 0 {
675            return CentralityResult {
676                scores: Vec::new(),
677                iterations: Some(0),
678                converged: true,
679            };
680        }
681
682        // Initialize with uniform scores
683        let mut scores = vec![1.0 / (n as f64).sqrt(); n];
684        let mut new_scores = vec![0.0f64; n];
685        let mut converged = false;
686        let mut iterations = 0u32;
687
688        for iter in 0..max_iterations {
689            iterations = iter + 1;
690
691            // Compute new scores: x_i = sum(A_ij * x_j)
692            for i in 0..n {
693                let mut sum = 0.0f64;
694                for &j in graph.neighbors(i as u64) {
695                    sum += scores[j as usize];
696                }
697                new_scores[i] = sum;
698            }
699
700            // Normalize
701            let norm: f64 = new_scores.iter().map(|x| x * x).sum::<f64>().sqrt();
702            if norm > 0.0 {
703                for x in &mut new_scores {
704                    *x /= norm;
705                }
706            }
707
708            // Check convergence
709            let diff: f64 = scores
710                .iter()
711                .zip(new_scores.iter())
712                .map(|(a, b)| (a - b).abs())
713                .fold(0.0f64, |acc, x| acc.max(x));
714
715            std::mem::swap(&mut scores, &mut new_scores);
716
717            if diff < tolerance {
718                converged = true;
719                break;
720            }
721        }
722
723        CentralityResult {
724            scores: scores
725                .into_iter()
726                .enumerate()
727                .map(|(i, score)| NodeScore {
728                    node_id: i as u64,
729                    score,
730                })
731                .collect(),
732            iterations: Some(iterations),
733            converged,
734        }
735    }
736}
737
738impl Default for EigenvectorCentrality {
739    fn default() -> Self {
740        Self::new()
741    }
742}
743
744impl GpuKernel for EigenvectorCentrality {
745    fn metadata(&self) -> &KernelMetadata {
746        &self.metadata
747    }
748}
749
750// ============================================================================
751// Katz Centrality Kernel
752// ============================================================================
753
754/// Katz centrality kernel.
755///
756/// Measures influence through attenuated paths.
757/// Katz(i) = sum over all paths from j to i of alpha^(path_length)
758#[derive(Debug, Clone)]
759pub struct KatzCentrality {
760    metadata: KernelMetadata,
761}
762
763impl KatzCentrality {
764    /// Create a new Katz centrality kernel.
765    #[must_use]
766    pub fn new() -> Self {
767        Self {
768            metadata: KernelMetadata::batch("graph/katz-centrality", Domain::GraphAnalytics)
769                .with_description("Katz centrality (attenuated paths)")
770                .with_throughput(50_000)
771                .with_latency_us(10.0),
772        }
773    }
774
775    /// Compute Katz centrality.
776    ///
777    /// # Arguments
778    /// * `graph` - The input graph
779    /// * `alpha` - Attenuation factor (should be < 1/lambda_max)
780    /// * `beta` - Base score for each node (default 1.0)
781    /// * `max_iterations` - Maximum iterations for power iteration
782    /// * `tolerance` - Convergence threshold
783    pub fn compute(
784        graph: &CsrGraph,
785        alpha: f64,
786        beta: f64,
787        max_iterations: u32,
788        tolerance: f64,
789    ) -> CentralityResult {
790        let n = graph.num_nodes;
791        if n == 0 {
792            return CentralityResult {
793                scores: Vec::new(),
794                iterations: Some(0),
795                converged: true,
796            };
797        }
798
799        // Initialize scores
800        let mut scores = vec![0.0f64; n];
801        let mut new_scores = vec![0.0f64; n];
802        let mut converged = false;
803        let mut iterations = 0u32;
804
805        // Power iteration: x = alpha * A * x + beta
806        for iter in 0..max_iterations {
807            iterations = iter + 1;
808
809            for i in 0..n {
810                let mut sum = 0.0f64;
811                for &j in graph.neighbors(i as u64) {
812                    sum += scores[j as usize];
813                }
814                new_scores[i] = alpha * sum + beta;
815            }
816
817            // Check convergence
818            let diff: f64 = scores
819                .iter()
820                .zip(new_scores.iter())
821                .map(|(a, b)| (a - b).abs())
822                .fold(0.0f64, |acc, x| acc.max(x));
823
824            std::mem::swap(&mut scores, &mut new_scores);
825
826            if diff < tolerance {
827                converged = true;
828                break;
829            }
830        }
831
832        // Normalize by maximum score
833        let max_score = scores.iter().cloned().fold(0.0f64, f64::max);
834        if max_score > 0.0 {
835            for s in &mut scores {
836                *s /= max_score;
837            }
838        }
839
840        CentralityResult {
841            scores: scores
842                .into_iter()
843                .enumerate()
844                .map(|(i, score)| NodeScore {
845                    node_id: i as u64,
846                    score,
847                })
848                .collect(),
849            iterations: Some(iterations),
850            converged,
851        }
852    }
853}
854
855impl Default for KatzCentrality {
856    fn default() -> Self {
857        Self::new()
858    }
859}
860
861impl GpuKernel for KatzCentrality {
862    fn metadata(&self) -> &KernelMetadata {
863        &self.metadata
864    }
865}
866
867// ============================================================================
868// BatchKernel Implementations
869// ============================================================================
870
871/// Batch execution wrapper for all centrality kernels.
872///
873/// Since centrality algorithms are computationally intensive,
874/// they benefit from batch execution with CPU orchestration.
875
876#[async_trait]
877impl BatchKernel<CentralityInput, CentralityOutput> for BetweennessCentrality {
878    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
879        let start = Instant::now();
880        let normalized = input.normalize;
881        let result = Self::compute(&input.graph, normalized);
882        let compute_time_us = start.elapsed().as_micros() as u64;
883
884        Ok(CentralityOutput {
885            result,
886            compute_time_us,
887        })
888    }
889}
890
891#[async_trait]
892impl BatchKernel<CentralityInput, CentralityOutput> for ClosenessCentrality {
893    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
894        let start = Instant::now();
895        let harmonic = match input.params {
896            CentralityParams::Closeness { harmonic } => harmonic,
897            _ => false,
898        };
899        let result = Self::compute(&input.graph, harmonic);
900        let compute_time_us = start.elapsed().as_micros() as u64;
901
902        Ok(CentralityOutput {
903            result,
904            compute_time_us,
905        })
906    }
907}
908
909#[async_trait]
910impl BatchKernel<CentralityInput, CentralityOutput> for EigenvectorCentrality {
911    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
912        let start = Instant::now();
913        let max_iterations = input.max_iterations.unwrap_or(1000);
914        let tolerance = input.tolerance.unwrap_or(1e-6);
915        let result = Self::compute(&input.graph, max_iterations, tolerance);
916        let compute_time_us = start.elapsed().as_micros() as u64;
917
918        Ok(CentralityOutput {
919            result,
920            compute_time_us,
921        })
922    }
923}
924
925#[async_trait]
926impl BatchKernel<CentralityInput, CentralityOutput> for KatzCentrality {
927    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
928        let start = Instant::now();
929        let (alpha, beta) = match input.params {
930            CentralityParams::Katz { alpha, beta } => (alpha, beta),
931            _ => (0.1, 1.0),
932        };
933        let max_iterations = input.max_iterations.unwrap_or(100);
934        let tolerance = input.tolerance.unwrap_or(1e-6);
935        let result = Self::compute(&input.graph, alpha, beta, max_iterations, tolerance);
936        let compute_time_us = start.elapsed().as_micros() as u64;
937
938        Ok(CentralityOutput {
939            result,
940            compute_time_us,
941        })
942    }
943}
944
945/// PageRank can be used in both batch and ring modes.
946/// This is the batch mode implementation.
947impl PageRank {
948    /// Execute PageRank as a batch operation.
949    ///
950    /// Convenience method that runs the algorithm to convergence.
951    pub async fn compute_batch(
952        &self,
953        graph: CsrGraph,
954        damping: f32,
955        max_iterations: u32,
956        threshold: f64,
957    ) -> Result<CentralityResult> {
958        Self::run_to_convergence(graph, damping, max_iterations, threshold)
959    }
960}
961
962#[async_trait]
963impl BatchKernel<CentralityInput, CentralityOutput> for PageRank {
964    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
965        let start = Instant::now();
966        let damping = match input.params {
967            CentralityParams::PageRank { damping } => damping,
968            _ => 0.85,
969        };
970        let max_iterations = input.max_iterations.unwrap_or(100);
971        let tolerance = input.tolerance.unwrap_or(1e-6);
972        let result = Self::run_to_convergence(input.graph, damping, max_iterations, tolerance)?;
973        let compute_time_us = start.elapsed().as_micros() as u64;
974
975        Ok(CentralityOutput {
976            result,
977            compute_time_us,
978        })
979    }
980}
981
982/// Degree centrality batch implementation.
983#[async_trait]
984impl BatchKernel<CentralityInput, CentralityOutput> for DegreeCentrality {
985    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
986        let start = Instant::now();
987        let result = Self::compute(&input.graph);
988        let compute_time_us = start.elapsed().as_micros() as u64;
989
990        Ok(CentralityOutput {
991            result,
992            compute_time_us,
993        })
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000
1001    fn create_test_graph() -> CsrGraph {
1002        // Simple graph: 0 -> 1 -> 2 -> 3 -> 0 (cycle)
1003        CsrGraph::from_edges(4, &[(0, 1), (1, 2), (2, 3), (3, 0)])
1004    }
1005
1006    fn create_star_graph() -> CsrGraph {
1007        // Star graph: center node 0 connected to all others
1008        CsrGraph::from_edges(
1009            5,
1010            &[
1011                (0, 1),
1012                (0, 2),
1013                (0, 3),
1014                (0, 4),
1015                (1, 0),
1016                (2, 0),
1017                (3, 0),
1018                (4, 0),
1019            ],
1020        )
1021    }
1022
1023    #[test]
1024    fn test_pagerank_metadata() {
1025        let kernel = PageRank::new();
1026        assert_eq!(kernel.metadata().id, "graph/pagerank");
1027        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
1028    }
1029
1030    #[test]
1031    fn test_pagerank_iteration() {
1032        let graph = create_test_graph();
1033        let mut state = PageRank::initialize_state(graph, 0.85);
1034
1035        let diff = PageRank::iterate_step(&mut state);
1036        assert!(diff >= 0.0);
1037        assert_eq!(state.iteration, 1);
1038    }
1039
1040    #[test]
1041    fn test_pagerank_convergence() {
1042        let graph = create_test_graph();
1043        let result = PageRank::run_to_convergence(graph, 0.85, 100, 1e-6).unwrap();
1044
1045        assert!(result.converged);
1046        assert_eq!(result.scores.len(), 4);
1047
1048        // In a cycle, all nodes should have equal PageRank
1049        let first_score = result.scores[0].score;
1050        for score in &result.scores {
1051            assert!((score.score - first_score).abs() < 0.01);
1052        }
1053    }
1054
1055    #[test]
1056    fn test_degree_centrality() {
1057        let graph = create_star_graph();
1058        let result = DegreeCentrality::compute(&graph);
1059
1060        assert_eq!(result.scores.len(), 5);
1061
1062        // Center node (0) should have highest degree
1063        let center_score = result.scores[0].score;
1064        for score in &result.scores[1..] {
1065            assert!(center_score > score.score);
1066        }
1067    }
1068
1069    #[test]
1070    fn test_betweenness_centrality() {
1071        // Line graph: 0 - 1 - 2 - 3
1072        let graph = CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)]);
1073
1074        let result = BetweennessCentrality::compute(&graph, false);
1075
1076        assert_eq!(result.scores.len(), 4);
1077
1078        // Middle nodes (1, 2) should have highest betweenness
1079        let node_1_score = result.scores[1].score;
1080        let node_0_score = result.scores[0].score;
1081        assert!(node_1_score > node_0_score);
1082    }
1083
1084    #[test]
1085    fn test_closeness_centrality() {
1086        let graph = create_star_graph();
1087        let result = ClosenessCentrality::compute(&graph, false);
1088
1089        assert_eq!(result.scores.len(), 5);
1090
1091        // Center node should have highest closeness
1092        let center_score = result.scores[0].score;
1093        for score in &result.scores[1..] {
1094            assert!(center_score >= score.score);
1095        }
1096    }
1097
1098    #[test]
1099    fn test_eigenvector_centrality() {
1100        let graph = create_star_graph();
1101        let result = EigenvectorCentrality::compute(&graph, 1000, 1e-4);
1102
1103        // May or may not converge depending on graph structure
1104        assert_eq!(result.scores.len(), 5);
1105
1106        // Center node should have high eigenvector centrality
1107        // (may not be highest due to star graph properties)
1108        let center_score = result.scores[0].score;
1109        assert!(center_score > 0.0);
1110    }
1111
1112    #[test]
1113    fn test_katz_centrality() {
1114        let graph = create_star_graph();
1115        let result = KatzCentrality::compute(&graph, 0.1, 1.0, 100, 1e-6);
1116
1117        assert!(result.converged);
1118        assert_eq!(result.scores.len(), 5);
1119    }
1120}