rustkernel_graph/
centrality.rs

1//! Centrality measure kernels.
2//!
3//! This module provides GPU-accelerated centrality algorithms:
4//! - Degree centrality
5//! - Betweenness centrality (Brandes algorithm)
6//! - Closeness centrality (BFS-based)
7//! - Eigenvector centrality (power iteration)
8//! - PageRank (power iteration with teleport)
9//! - Katz centrality (attenuated paths)
10
11use crate::messages::{CentralityInput, CentralityOutput, CentralityParams};
12use crate::ring_messages::{
13    K2KBarrier, K2KBarrierRelease, K2KIterationSync, K2KIterationSyncResponse,
14    PageRankConvergeResponse, PageRankConvergeRing, PageRankIterateResponse, PageRankIterateRing,
15    PageRankQueryResponse, PageRankQueryRing, from_fixed_point, to_fixed_point,
16};
17use crate::types::{CentralityResult, CsrGraph, NodeScore};
18use async_trait::async_trait;
19use ringkernel_core::RingContext;
20use rustkernel_core::{
21    domain::Domain,
22    error::Result,
23    k2k::IterativeState,
24    kernel::KernelMetadata,
25    traits::{BatchKernel, GpuKernel, RingKernelHandler},
26};
27use serde::{Deserialize, Serialize};
28use std::collections::VecDeque;
29use std::time::Instant;
30
31// ============================================================================
32// PageRank Kernel
33// ============================================================================
34
35/// PageRank operation type.
36#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub enum PageRankOp {
38    /// Query the current PageRank score for a node.
39    Query,
40    /// Perform one iteration of PageRank.
41    Iterate,
42    /// Reset all scores.
43    Reset,
44    /// Initialize with a graph.
45    Initialize,
46}
47
48/// PageRank request.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PageRankRequest {
51    /// Node ID to query (for Query operation).
52    pub node_id: Option<u64>,
53    /// Operation type.
54    pub operation: PageRankOp,
55    /// Graph data (for Initialize operation).
56    pub graph: Option<CsrGraph>,
57    /// Damping factor (default: 0.85).
58    pub damping: Option<f32>,
59}
60
61/// PageRank response.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct PageRankResponse {
64    /// Score for the queried node.
65    pub score: Option<f64>,
66    /// Whether the algorithm has converged.
67    pub converged: bool,
68    /// Current iteration count.
69    pub iteration: u32,
70    /// Full result (for Query after convergence).
71    pub result: Option<CentralityResult>,
72}
73
74/// PageRank kernel state.
75#[derive(Debug, Clone, Default)]
76pub struct PageRankState {
77    /// Current scores.
78    pub scores: Vec<f64>,
79    /// Previous scores (for convergence check).
80    pub prev_scores: Vec<f64>,
81    /// Graph in CSR format.
82    pub graph: Option<CsrGraph>,
83    /// Damping factor.
84    pub damping: f32,
85    /// Current iteration.
86    pub iteration: u32,
87    /// Whether converged.
88    pub converged: bool,
89}
90
91/// PageRank centrality kernel.
92///
93/// Calculates PageRank centrality using power iteration with teleportation.
94/// This is a Ring kernel for low-latency queries after graph is loaded.
95#[derive(Debug)]
96pub struct PageRank {
97    metadata: KernelMetadata,
98    /// Internal state for Ring mode operations.
99    state: std::sync::RwLock<PageRankState>,
100}
101
102impl Clone for PageRank {
103    fn clone(&self) -> Self {
104        Self {
105            metadata: self.metadata.clone(),
106            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
107        }
108    }
109}
110
111impl PageRank {
112    /// Create a new PageRank kernel.
113    #[must_use]
114    pub fn new() -> Self {
115        Self {
116            metadata: KernelMetadata::ring("graph/pagerank", Domain::GraphAnalytics)
117                .with_description("PageRank centrality via power iteration")
118                .with_throughput(100_000)
119                .with_latency_us(1.0)
120                .with_gpu_native(true),
121            state: std::sync::RwLock::new(PageRankState::default()),
122        }
123    }
124
125    /// Initialize the kernel with a graph for Ring mode operations.
126    pub fn initialize(&self, graph: CsrGraph, damping: f32) {
127        let mut state = self.state.write().unwrap();
128        *state = Self::initialize_state(graph, damping);
129    }
130
131    /// Query the score for a specific node.
132    pub fn query_score(&self, node_id: u64) -> Option<f64> {
133        let state = self.state.read().unwrap();
134        state.scores.get(node_id as usize).copied()
135    }
136
137    /// Get current iteration count.
138    pub fn current_iteration(&self) -> u32 {
139        self.state.read().unwrap().iteration
140    }
141
142    /// Check if converged.
143    pub fn is_converged(&self) -> bool {
144        self.state.read().unwrap().converged
145    }
146
147    /// Perform one iteration step using internal state.
148    pub fn iterate(&self) -> f64 {
149        let mut state = self.state.write().unwrap();
150        Self::iterate_step(&mut state)
151    }
152
153    /// Perform one iteration of PageRank on the given state.
154    pub fn iterate_step(state: &mut PageRankState) -> f64 {
155        let Some(ref graph) = state.graph else {
156            return 0.0;
157        };
158
159        let n = graph.num_nodes;
160        if n == 0 {
161            return 0.0;
162        }
163
164        let d = state.damping as f64;
165        let teleport = (1.0 - d) / n as f64;
166
167        // Swap buffers
168        std::mem::swap(&mut state.scores, &mut state.prev_scores);
169
170        // Calculate new scores
171        let mut max_diff = 0.0f64;
172
173        for i in 0..n {
174            let mut rank_sum = 0.0f64;
175
176            // Sum contributions from incoming edges
177            for &neighbor in graph.neighbors(i as u64) {
178                let out_degree = graph.out_degree(neighbor) as f64;
179                if out_degree > 0.0 {
180                    rank_sum += state.prev_scores[neighbor as usize] / out_degree;
181                }
182            }
183
184            let new_score = teleport + d * rank_sum;
185            state.scores[i] = new_score;
186
187            let diff = (new_score - state.prev_scores[i]).abs();
188            if diff > max_diff {
189                max_diff = diff;
190            }
191        }
192
193        state.iteration += 1;
194        max_diff
195    }
196
197    /// Initialize state for a graph.
198    pub fn initialize_state(graph: CsrGraph, damping: f32) -> PageRankState {
199        let n = graph.num_nodes;
200        PageRankState {
201            scores: vec![1.0 / n as f64; n],
202            prev_scores: vec![0.0; n],
203            graph: Some(graph),
204            damping,
205            iteration: 0,
206            converged: false,
207        }
208    }
209
210    /// Run PageRank to convergence.
211    pub fn run_to_convergence(
212        graph: CsrGraph,
213        damping: f32,
214        max_iterations: u32,
215        threshold: f64,
216    ) -> Result<CentralityResult> {
217        let mut state = Self::initialize_state(graph, damping);
218
219        for _ in 0..max_iterations {
220            let diff = Self::iterate_step(&mut state);
221            if diff < threshold {
222                state.converged = true;
223                break;
224            }
225        }
226
227        Ok(CentralityResult {
228            scores: state
229                .scores
230                .iter()
231                .enumerate()
232                .map(|(i, &score)| NodeScore {
233                    node_id: i as u64,
234                    score,
235                })
236                .collect(),
237            iterations: Some(state.iteration),
238            converged: state.converged,
239        })
240    }
241}
242
243impl Default for PageRank {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249impl GpuKernel for PageRank {
250    fn metadata(&self) -> &KernelMetadata {
251        &self.metadata
252    }
253}
254
255// ============================================================================
256// PageRank RingKernelHandler Implementations
257// ============================================================================
258
259/// RingKernelHandler for PageRank queries.
260///
261/// Enables low-latency score queries for individual nodes in Ring mode.
262#[async_trait]
263impl RingKernelHandler<PageRankQueryRing, PageRankQueryResponse> for PageRank {
264    async fn handle(
265        &self,
266        _ctx: &mut RingContext,
267        msg: PageRankQueryRing,
268    ) -> Result<PageRankQueryResponse> {
269        let state = self.state.read().unwrap();
270        let score = state
271            .scores
272            .get(msg.node_id as usize)
273            .copied()
274            .unwrap_or(0.0);
275
276        Ok(PageRankQueryResponse {
277            request_id: msg.id.0,
278            node_id: msg.node_id,
279            score_fp: to_fixed_point(score),
280            iteration: state.iteration,
281            converged: state.converged,
282        })
283    }
284}
285
286/// RingKernelHandler for PageRank single iteration.
287///
288/// Performs one power iteration step in Ring mode.
289#[async_trait]
290impl RingKernelHandler<PageRankIterateRing, PageRankIterateResponse> for PageRank {
291    async fn handle(
292        &self,
293        _ctx: &mut RingContext,
294        msg: PageRankIterateRing,
295    ) -> Result<PageRankIterateResponse> {
296        // Perform one iteration on internal state
297        let max_delta = self.iterate();
298
299        // Check convergence using default threshold
300        let state = self.state.read().unwrap();
301        let converged = max_delta < 1e-6;
302
303        Ok(PageRankIterateResponse {
304            request_id: msg.id.0,
305            iteration: state.iteration,
306            max_delta_fp: to_fixed_point(max_delta),
307            converged,
308        })
309    }
310}
311
312/// RingKernelHandler for PageRank convergence.
313///
314/// Runs PageRank to convergence using K2K coordination for iterative state.
315#[async_trait]
316impl RingKernelHandler<PageRankConvergeRing, PageRankConvergeResponse> for PageRank {
317    async fn handle(
318        &self,
319        _ctx: &mut RingContext,
320        msg: PageRankConvergeRing,
321    ) -> Result<PageRankConvergeResponse> {
322        let threshold = from_fixed_point(msg.threshold_fp);
323        let max_iterations = msg.max_iterations as u64;
324
325        // Use K2K IterativeState for convergence tracking
326        let mut iterative_state = IterativeState::new(threshold, max_iterations);
327
328        // Run actual iterations on internal state
329        while iterative_state.should_continue() {
330            let max_delta = self.iterate();
331            iterative_state.update(max_delta);
332        }
333
334        // Update convergence status in internal state
335        {
336            let mut state = self.state.write().unwrap();
337            state.converged = iterative_state.summary().converged;
338        }
339
340        let summary = iterative_state.summary();
341
342        Ok(PageRankConvergeResponse {
343            request_id: msg.id.0,
344            iterations: summary.iterations as u32,
345            final_delta_fp: to_fixed_point(summary.final_delta),
346            converged: summary.converged,
347        })
348    }
349}
350
351/// RingKernelHandler for K2K iteration synchronization.
352///
353/// Used in distributed PageRank to synchronize iterations across partitions.
354/// In a single-instance setting, this validates the worker's iteration state
355/// and returns convergence status based on the reported delta.
356#[async_trait]
357impl RingKernelHandler<K2KIterationSync, K2KIterationSyncResponse> for PageRank {
358    async fn handle(
359        &self,
360        _ctx: &mut RingContext,
361        msg: K2KIterationSync,
362    ) -> Result<K2KIterationSyncResponse> {
363        let state = self.state.read().unwrap();
364
365        // For single-instance, verify iteration matches internal state
366        // In distributed setting, would aggregate deltas from all workers
367        let current_iteration = state.iteration as u64;
368        let all_synced = msg.iteration <= current_iteration;
369
370        // Use reported local delta as global delta (single worker case)
371        // In distributed setting, would compute max across all workers
372        let local_delta = from_fixed_point(msg.local_delta_fp);
373        let global_converged = local_delta < 1e-6 || state.converged;
374
375        Ok(K2KIterationSyncResponse {
376            request_id: msg.id.0,
377            iteration: msg.iteration,
378            all_synced,
379            global_delta_fp: msg.local_delta_fp,
380            global_converged,
381        })
382    }
383}
384
385/// RingKernelHandler for K2K barrier synchronization.
386///
387/// Implements barrier synchronization for distributed PageRank iterations.
388#[async_trait]
389impl RingKernelHandler<K2KBarrier, K2KBarrierRelease> for PageRank {
390    async fn handle(&self, _ctx: &mut RingContext, msg: K2KBarrier) -> Result<K2KBarrierRelease> {
391        // In a distributed setting, this would:
392        // 1. Record this worker as ready
393        // 2. Check if all workers are ready
394        // 3. Release barrier when all ready
395        let all_ready = msg.ready_count >= msg.total_workers;
396
397        Ok(K2KBarrierRelease {
398            barrier_id: msg.barrier_id,
399            all_ready,
400            next_iteration: msg.barrier_id + 1,
401        })
402    }
403}
404
405// ============================================================================
406// Degree Centrality Kernel
407// ============================================================================
408
409/// Degree centrality kernel.
410///
411/// Simple O(1) lookup of node degrees after graph is loaded.
412#[derive(Debug, Clone)]
413pub struct DegreeCentrality {
414    metadata: KernelMetadata,
415}
416
417impl DegreeCentrality {
418    /// Create a new degree centrality kernel.
419    #[must_use]
420    pub fn new() -> Self {
421        Self {
422            metadata: KernelMetadata::ring("graph/degree-centrality", Domain::GraphAnalytics)
423                .with_description("Degree centrality (O(1) lookup)")
424                .with_throughput(1_000_000)
425                .with_latency_us(0.1),
426        }
427    }
428
429    /// Calculate degree centrality for all nodes.
430    ///
431    /// Returns normalized degree centrality (degree / (n-1)).
432    pub fn compute(graph: &CsrGraph) -> CentralityResult {
433        let n = graph.num_nodes;
434        let normalizer = if n > 1 { (n - 1) as f64 } else { 1.0 };
435
436        let scores: Vec<NodeScore> = (0..n)
437            .map(|i| NodeScore {
438                node_id: i as u64,
439                score: graph.out_degree(i as u64) as f64 / normalizer,
440            })
441            .collect();
442
443        CentralityResult {
444            scores,
445            iterations: None,
446            converged: true,
447        }
448    }
449}
450
451impl Default for DegreeCentrality {
452    fn default() -> Self {
453        Self::new()
454    }
455}
456
457impl GpuKernel for DegreeCentrality {
458    fn metadata(&self) -> &KernelMetadata {
459        &self.metadata
460    }
461}
462
463// ============================================================================
464// Betweenness Centrality Kernel (Brandes Algorithm)
465// ============================================================================
466
467/// Betweenness centrality kernel.
468///
469/// Uses Brandes algorithm for efficient computation in O(VE) time.
470#[derive(Debug, Clone)]
471pub struct BetweennessCentrality {
472    metadata: KernelMetadata,
473}
474
475impl BetweennessCentrality {
476    /// Create a new betweenness centrality kernel.
477    #[must_use]
478    pub fn new() -> Self {
479        Self {
480            metadata: KernelMetadata::batch("graph/betweenness-centrality", Domain::GraphAnalytics)
481                .with_description("Betweenness centrality (Brandes algorithm)")
482                .with_throughput(10_000)
483                .with_latency_us(100.0),
484        }
485    }
486
487    /// Compute betweenness centrality using Brandes algorithm.
488    ///
489    /// The algorithm runs BFS from each vertex and accumulates
490    /// dependency scores in a single backward pass.
491    pub fn compute(graph: &CsrGraph, normalized: bool) -> CentralityResult {
492        let n = graph.num_nodes;
493        let mut centrality = vec![0.0f64; n];
494
495        // Run Brandes algorithm from each source
496        for s in 0..n {
497            // BFS structures
498            let mut stack: Vec<usize> = Vec::with_capacity(n);
499            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
500            let mut sigma = vec![0.0f64; n]; // Number of shortest paths
501            let mut dist = vec![-1i64; n]; // Distance from source
502
503            sigma[s] = 1.0;
504            dist[s] = 0;
505
506            let mut queue = VecDeque::new();
507            queue.push_back(s);
508
509            // Forward BFS
510            while let Some(v) = queue.pop_front() {
511                stack.push(v);
512
513                for &w in graph.neighbors(v as u64) {
514                    let w = w as usize;
515
516                    // First time visiting w?
517                    if dist[w] < 0 {
518                        dist[w] = dist[v] + 1;
519                        queue.push_back(w);
520                    }
521
522                    // Is this a shortest path to w via v?
523                    if dist[w] == dist[v] + 1 {
524                        sigma[w] += sigma[v];
525                        predecessors[w].push(v);
526                    }
527                }
528            }
529
530            // Backward pass - accumulate dependencies
531            let mut delta = vec![0.0f64; n];
532
533            while let Some(w) = stack.pop() {
534                for &v in &predecessors[w] {
535                    let contribution = (sigma[v] / sigma[w]) * (1.0 + delta[w]);
536                    delta[v] += contribution;
537                }
538
539                if w != s {
540                    centrality[w] += delta[w];
541                }
542            }
543        }
544
545        // Normalize if requested
546        if normalized && n > 2 {
547            let scale = 1.0 / ((n - 1) * (n - 2)) as f64;
548            for c in &mut centrality {
549                *c *= scale;
550            }
551        }
552
553        CentralityResult {
554            scores: centrality
555                .into_iter()
556                .enumerate()
557                .map(|(i, score)| NodeScore {
558                    node_id: i as u64,
559                    score,
560                })
561                .collect(),
562            iterations: None,
563            converged: true,
564        }
565    }
566}
567
568impl Default for BetweennessCentrality {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574impl GpuKernel for BetweennessCentrality {
575    fn metadata(&self) -> &KernelMetadata {
576        &self.metadata
577    }
578}
579
580// ============================================================================
581// Closeness Centrality Kernel
582// ============================================================================
583
584/// Closeness centrality kernel.
585///
586/// BFS-based closeness centrality calculation.
587/// Closeness = (n-1) / sum(shortest_path_distances)
588#[derive(Debug, Clone)]
589pub struct ClosenessCentrality {
590    metadata: KernelMetadata,
591}
592
593impl ClosenessCentrality {
594    /// Create a new closeness centrality kernel.
595    #[must_use]
596    pub fn new() -> Self {
597        Self {
598            metadata: KernelMetadata::batch("graph/closeness-centrality", Domain::GraphAnalytics)
599                .with_description("Closeness centrality (BFS-based)")
600                .with_throughput(10_000)
601                .with_latency_us(100.0),
602        }
603    }
604
605    /// Compute closeness centrality using BFS from each node.
606    ///
607    /// For disconnected graphs, uses harmonic mean variant.
608    pub fn compute(graph: &CsrGraph, harmonic: bool) -> CentralityResult {
609        let n = graph.num_nodes;
610        let mut centrality = vec![0.0f64; n];
611
612        for source in 0..n {
613            let distances = Self::bfs_distances(graph, source);
614
615            if harmonic {
616                // Harmonic centrality: sum(1/d) for all reachable nodes
617                let sum: f64 = distances
618                    .iter()
619                    .enumerate()
620                    .filter(|(i, d)| *i != source && **d > 0)
621                    .map(|(_, d)| 1.0 / *d as f64)
622                    .sum();
623                centrality[source] = sum / (n - 1) as f64;
624            } else {
625                // Classic closeness: (n-1) / sum(distances)
626                let sum: i64 = distances.iter().sum();
627                let reachable: usize = distances.iter().filter(|&&d| d > 0).count();
628
629                if sum > 0 && reachable > 0 {
630                    centrality[source] = reachable as f64 / sum as f64;
631                }
632            }
633        }
634
635        CentralityResult {
636            scores: centrality
637                .into_iter()
638                .enumerate()
639                .map(|(i, score)| NodeScore {
640                    node_id: i as u64,
641                    score,
642                })
643                .collect(),
644            iterations: None,
645            converged: true,
646        }
647    }
648
649    /// BFS to compute distances from source to all other nodes.
650    fn bfs_distances(graph: &CsrGraph, source: usize) -> Vec<i64> {
651        let n = graph.num_nodes;
652        let mut distances = vec![0i64; n];
653        let mut visited = vec![false; n];
654
655        let mut queue = VecDeque::new();
656        queue.push_back(source);
657        visited[source] = true;
658
659        while let Some(v) = queue.pop_front() {
660            for &w in graph.neighbors(v as u64) {
661                let w = w as usize;
662                if !visited[w] {
663                    visited[w] = true;
664                    distances[w] = distances[v] + 1;
665                    queue.push_back(w);
666                }
667            }
668        }
669
670        distances
671    }
672}
673
674impl Default for ClosenessCentrality {
675    fn default() -> Self {
676        Self::new()
677    }
678}
679
680impl GpuKernel for ClosenessCentrality {
681    fn metadata(&self) -> &KernelMetadata {
682        &self.metadata
683    }
684}
685
686// ============================================================================
687// Eigenvector Centrality Kernel
688// ============================================================================
689
690/// Eigenvector centrality kernel.
691///
692/// Power iteration method for eigenvector centrality.
693/// A node's score is proportional to the sum of its neighbors' scores.
694#[derive(Debug, Clone)]
695pub struct EigenvectorCentrality {
696    metadata: KernelMetadata,
697}
698
699impl EigenvectorCentrality {
700    /// Create a new eigenvector centrality kernel.
701    #[must_use]
702    pub fn new() -> Self {
703        Self {
704            metadata: KernelMetadata::batch("graph/eigenvector-centrality", Domain::GraphAnalytics)
705                .with_description("Eigenvector centrality (power iteration)")
706                .with_throughput(50_000)
707                .with_latency_us(10.0),
708        }
709    }
710
711    /// Compute eigenvector centrality using power iteration.
712    pub fn compute(graph: &CsrGraph, max_iterations: u32, tolerance: f64) -> CentralityResult {
713        let n = graph.num_nodes;
714        if n == 0 {
715            return CentralityResult {
716                scores: Vec::new(),
717                iterations: Some(0),
718                converged: true,
719            };
720        }
721
722        // Initialize with uniform scores
723        let mut scores = vec![1.0 / (n as f64).sqrt(); n];
724        let mut new_scores = vec![0.0f64; n];
725        let mut converged = false;
726        let mut iterations = 0u32;
727
728        for iter in 0..max_iterations {
729            iterations = iter + 1;
730
731            // Compute new scores: x_i = sum(A_ij * x_j)
732            for i in 0..n {
733                let mut sum = 0.0f64;
734                for &j in graph.neighbors(i as u64) {
735                    sum += scores[j as usize];
736                }
737                new_scores[i] = sum;
738            }
739
740            // Normalize
741            let norm: f64 = new_scores.iter().map(|x| x * x).sum::<f64>().sqrt();
742            if norm > 0.0 {
743                for x in &mut new_scores {
744                    *x /= norm;
745                }
746            }
747
748            // Check convergence
749            let diff: f64 = scores
750                .iter()
751                .zip(new_scores.iter())
752                .map(|(a, b)| (a - b).abs())
753                .fold(0.0f64, |acc, x| acc.max(x));
754
755            std::mem::swap(&mut scores, &mut new_scores);
756
757            if diff < tolerance {
758                converged = true;
759                break;
760            }
761        }
762
763        CentralityResult {
764            scores: scores
765                .into_iter()
766                .enumerate()
767                .map(|(i, score)| NodeScore {
768                    node_id: i as u64,
769                    score,
770                })
771                .collect(),
772            iterations: Some(iterations),
773            converged,
774        }
775    }
776}
777
778impl Default for EigenvectorCentrality {
779    fn default() -> Self {
780        Self::new()
781    }
782}
783
784impl GpuKernel for EigenvectorCentrality {
785    fn metadata(&self) -> &KernelMetadata {
786        &self.metadata
787    }
788}
789
790// ============================================================================
791// Katz Centrality Kernel
792// ============================================================================
793
794/// Katz centrality kernel.
795///
796/// Measures influence through attenuated paths.
797/// Katz(i) = sum over all paths from j to i of alpha^(path_length)
798#[derive(Debug, Clone)]
799pub struct KatzCentrality {
800    metadata: KernelMetadata,
801}
802
803impl KatzCentrality {
804    /// Create a new Katz centrality kernel.
805    #[must_use]
806    pub fn new() -> Self {
807        Self {
808            metadata: KernelMetadata::batch("graph/katz-centrality", Domain::GraphAnalytics)
809                .with_description("Katz centrality (attenuated paths)")
810                .with_throughput(50_000)
811                .with_latency_us(10.0),
812        }
813    }
814
815    /// Compute Katz centrality.
816    ///
817    /// # Arguments
818    /// * `graph` - The input graph
819    /// * `alpha` - Attenuation factor (should be < 1/lambda_max)
820    /// * `beta` - Base score for each node (default 1.0)
821    /// * `max_iterations` - Maximum iterations for power iteration
822    /// * `tolerance` - Convergence threshold
823    pub fn compute(
824        graph: &CsrGraph,
825        alpha: f64,
826        beta: f64,
827        max_iterations: u32,
828        tolerance: f64,
829    ) -> CentralityResult {
830        let n = graph.num_nodes;
831        if n == 0 {
832            return CentralityResult {
833                scores: Vec::new(),
834                iterations: Some(0),
835                converged: true,
836            };
837        }
838
839        // Initialize scores
840        let mut scores = vec![0.0f64; n];
841        let mut new_scores = vec![0.0f64; n];
842        let mut converged = false;
843        let mut iterations = 0u32;
844
845        // Power iteration: x = alpha * A * x + beta
846        for iter in 0..max_iterations {
847            iterations = iter + 1;
848
849            for i in 0..n {
850                let mut sum = 0.0f64;
851                for &j in graph.neighbors(i as u64) {
852                    sum += scores[j as usize];
853                }
854                new_scores[i] = alpha * sum + beta;
855            }
856
857            // Check convergence
858            let diff: f64 = scores
859                .iter()
860                .zip(new_scores.iter())
861                .map(|(a, b)| (a - b).abs())
862                .fold(0.0f64, |acc, x| acc.max(x));
863
864            std::mem::swap(&mut scores, &mut new_scores);
865
866            if diff < tolerance {
867                converged = true;
868                break;
869            }
870        }
871
872        // Normalize by maximum score
873        let max_score = scores.iter().cloned().fold(0.0f64, f64::max);
874        if max_score > 0.0 {
875            for s in &mut scores {
876                *s /= max_score;
877            }
878        }
879
880        CentralityResult {
881            scores: scores
882                .into_iter()
883                .enumerate()
884                .map(|(i, score)| NodeScore {
885                    node_id: i as u64,
886                    score,
887                })
888                .collect(),
889            iterations: Some(iterations),
890            converged,
891        }
892    }
893}
894
895impl Default for KatzCentrality {
896    fn default() -> Self {
897        Self::new()
898    }
899}
900
901impl GpuKernel for KatzCentrality {
902    fn metadata(&self) -> &KernelMetadata {
903        &self.metadata
904    }
905}
906
907// ============================================================================
908// BatchKernel Implementations
909// ============================================================================
910
911/// Batch execution wrapper for all centrality kernels.
912///
913/// Since centrality algorithms are computationally intensive,
914/// they benefit from batch execution with CPU orchestration.
915
916#[async_trait]
917impl BatchKernel<CentralityInput, CentralityOutput> for BetweennessCentrality {
918    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
919        let start = Instant::now();
920        let normalized = input.normalize;
921        let result = Self::compute(&input.graph, normalized);
922        let compute_time_us = start.elapsed().as_micros() as u64;
923
924        Ok(CentralityOutput {
925            result,
926            compute_time_us,
927        })
928    }
929}
930
931#[async_trait]
932impl BatchKernel<CentralityInput, CentralityOutput> for ClosenessCentrality {
933    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
934        let start = Instant::now();
935        let harmonic = match input.params {
936            CentralityParams::Closeness { harmonic } => harmonic,
937            _ => false,
938        };
939        let result = Self::compute(&input.graph, harmonic);
940        let compute_time_us = start.elapsed().as_micros() as u64;
941
942        Ok(CentralityOutput {
943            result,
944            compute_time_us,
945        })
946    }
947}
948
949#[async_trait]
950impl BatchKernel<CentralityInput, CentralityOutput> for EigenvectorCentrality {
951    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
952        let start = Instant::now();
953        let max_iterations = input.max_iterations.unwrap_or(1000);
954        let tolerance = input.tolerance.unwrap_or(1e-6);
955        let result = Self::compute(&input.graph, max_iterations, tolerance);
956        let compute_time_us = start.elapsed().as_micros() as u64;
957
958        Ok(CentralityOutput {
959            result,
960            compute_time_us,
961        })
962    }
963}
964
965#[async_trait]
966impl BatchKernel<CentralityInput, CentralityOutput> for KatzCentrality {
967    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
968        let start = Instant::now();
969        let (alpha, beta) = match input.params {
970            CentralityParams::Katz { alpha, beta } => (alpha, beta),
971            _ => (0.1, 1.0),
972        };
973        let max_iterations = input.max_iterations.unwrap_or(100);
974        let tolerance = input.tolerance.unwrap_or(1e-6);
975        let result = Self::compute(&input.graph, alpha, beta, max_iterations, tolerance);
976        let compute_time_us = start.elapsed().as_micros() as u64;
977
978        Ok(CentralityOutput {
979            result,
980            compute_time_us,
981        })
982    }
983}
984
985/// PageRank can be used in both batch and ring modes.
986/// This is the batch mode implementation.
987impl PageRank {
988    /// Execute PageRank as a batch operation.
989    ///
990    /// Convenience method that runs the algorithm to convergence.
991    pub async fn compute_batch(
992        &self,
993        graph: CsrGraph,
994        damping: f32,
995        max_iterations: u32,
996        threshold: f64,
997    ) -> Result<CentralityResult> {
998        Self::run_to_convergence(graph, damping, max_iterations, threshold)
999    }
1000}
1001
1002#[async_trait]
1003impl BatchKernel<CentralityInput, CentralityOutput> for PageRank {
1004    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
1005        let start = Instant::now();
1006        let damping = match input.params {
1007            CentralityParams::PageRank { damping } => damping,
1008            _ => 0.85,
1009        };
1010        let max_iterations = input.max_iterations.unwrap_or(100);
1011        let tolerance = input.tolerance.unwrap_or(1e-6);
1012        let result = Self::run_to_convergence(input.graph, damping, max_iterations, tolerance)?;
1013        let compute_time_us = start.elapsed().as_micros() as u64;
1014
1015        Ok(CentralityOutput {
1016            result,
1017            compute_time_us,
1018        })
1019    }
1020}
1021
1022/// Degree centrality batch implementation.
1023#[async_trait]
1024impl BatchKernel<CentralityInput, CentralityOutput> for DegreeCentrality {
1025    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
1026        let start = Instant::now();
1027        let result = Self::compute(&input.graph);
1028        let compute_time_us = start.elapsed().as_micros() as u64;
1029
1030        Ok(CentralityOutput {
1031            result,
1032            compute_time_us,
1033        })
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    fn create_test_graph() -> CsrGraph {
1042        // Simple graph: 0 -> 1 -> 2 -> 3 -> 0 (cycle)
1043        CsrGraph::from_edges(4, &[(0, 1), (1, 2), (2, 3), (3, 0)])
1044    }
1045
1046    fn create_star_graph() -> CsrGraph {
1047        // Star graph: center node 0 connected to all others
1048        CsrGraph::from_edges(
1049            5,
1050            &[
1051                (0, 1),
1052                (0, 2),
1053                (0, 3),
1054                (0, 4),
1055                (1, 0),
1056                (2, 0),
1057                (3, 0),
1058                (4, 0),
1059            ],
1060        )
1061    }
1062
1063    #[test]
1064    fn test_pagerank_metadata() {
1065        let kernel = PageRank::new();
1066        assert_eq!(kernel.metadata().id, "graph/pagerank");
1067        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
1068    }
1069
1070    #[test]
1071    fn test_pagerank_iteration() {
1072        let graph = create_test_graph();
1073        let mut state = PageRank::initialize_state(graph, 0.85);
1074
1075        let diff = PageRank::iterate_step(&mut state);
1076        assert!(diff >= 0.0);
1077        assert_eq!(state.iteration, 1);
1078    }
1079
1080    #[test]
1081    fn test_pagerank_convergence() {
1082        let graph = create_test_graph();
1083        let result = PageRank::run_to_convergence(graph, 0.85, 100, 1e-6).unwrap();
1084
1085        assert!(result.converged);
1086        assert_eq!(result.scores.len(), 4);
1087
1088        // In a cycle, all nodes should have equal PageRank
1089        let first_score = result.scores[0].score;
1090        for score in &result.scores {
1091            assert!((score.score - first_score).abs() < 0.01);
1092        }
1093    }
1094
1095    #[test]
1096    fn test_degree_centrality() {
1097        let graph = create_star_graph();
1098        let result = DegreeCentrality::compute(&graph);
1099
1100        assert_eq!(result.scores.len(), 5);
1101
1102        // Center node (0) should have highest degree
1103        let center_score = result.scores[0].score;
1104        for score in &result.scores[1..] {
1105            assert!(center_score > score.score);
1106        }
1107    }
1108
1109    #[test]
1110    fn test_betweenness_centrality() {
1111        // Line graph: 0 - 1 - 2 - 3
1112        let graph = CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)]);
1113
1114        let result = BetweennessCentrality::compute(&graph, false);
1115
1116        assert_eq!(result.scores.len(), 4);
1117
1118        // Middle nodes (1, 2) should have highest betweenness
1119        let node_1_score = result.scores[1].score;
1120        let node_0_score = result.scores[0].score;
1121        assert!(node_1_score > node_0_score);
1122    }
1123
1124    #[test]
1125    fn test_closeness_centrality() {
1126        let graph = create_star_graph();
1127        let result = ClosenessCentrality::compute(&graph, false);
1128
1129        assert_eq!(result.scores.len(), 5);
1130
1131        // Center node should have highest closeness
1132        let center_score = result.scores[0].score;
1133        for score in &result.scores[1..] {
1134            assert!(center_score >= score.score);
1135        }
1136    }
1137
1138    #[test]
1139    fn test_eigenvector_centrality() {
1140        let graph = create_star_graph();
1141        let result = EigenvectorCentrality::compute(&graph, 1000, 1e-4);
1142
1143        // May or may not converge depending on graph structure
1144        assert_eq!(result.scores.len(), 5);
1145
1146        // Center node should have high eigenvector centrality
1147        // (may not be highest due to star graph properties)
1148        let center_score = result.scores[0].score;
1149        assert!(center_score > 0.0);
1150    }
1151
1152    #[test]
1153    fn test_katz_centrality() {
1154        let graph = create_star_graph();
1155        let result = KatzCentrality::compute(&graph, 0.1, 1.0, 100, 1e-6);
1156
1157        assert!(result.converged);
1158        assert_eq!(result.scores.len(), 5);
1159    }
1160}