rustkernel_graph/
paths.rs

1//! Shortest path kernels.
2//!
3//! This module provides shortest path algorithms:
4//! - Single-source shortest path (SSSP) via BFS/Delta-Stepping
5//! - All-pairs shortest path (APSP)
6//! - K-shortest paths (Yen's algorithm)
7
8use crate::types::CsrGraph;
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::cmp::Ordering;
11use std::collections::{BinaryHeap, HashMap, VecDeque};
12
13// ============================================================================
14// Shortest Path Results
15// ============================================================================
16
17/// Result of single-source shortest path calculation.
18#[derive(Debug, Clone)]
19pub struct ShortestPathResult {
20    /// Node index.
21    pub node_index: usize,
22    /// Shortest distance from source (f64::INFINITY if unreachable).
23    pub distance: f64,
24    /// Predecessor node index on shortest path (-1 if no path).
25    pub predecessor: i64,
26    /// Whether node is reachable from source.
27    pub is_reachable: bool,
28    /// Number of hops (for unweighted graphs).
29    pub hop_count: u32,
30}
31
32/// A single path result.
33#[derive(Debug, Clone)]
34pub struct PathResult {
35    /// Source node index.
36    pub source: usize,
37    /// Target node index.
38    pub target: usize,
39    /// Total path length (sum of edge weights).
40    pub path_length: f64,
41    /// Number of hops (edges) in path.
42    pub hop_count: usize,
43    /// Ordered list of node indices along the path.
44    pub node_path: Vec<usize>,
45}
46
47/// All-pairs shortest path result.
48#[derive(Debug, Clone)]
49pub struct AllPairsResult {
50    /// Number of nodes.
51    pub node_count: usize,
52    /// Distance matrix in row-major order.
53    /// distances[i * node_count + j] = shortest distance from node i to node j.
54    pub distances: Vec<f64>,
55    /// Predecessor matrix for path reconstruction.
56    pub predecessors: Vec<i64>,
57}
58
59impl AllPairsResult {
60    /// Get distance from source to target.
61    pub fn distance(&self, source: usize, target: usize) -> f64 {
62        self.distances[source * self.node_count + target]
63    }
64
65    /// Reconstruct path from source to target.
66    pub fn reconstruct_path(&self, source: usize, target: usize) -> Option<Vec<usize>> {
67        if !self.distance(source, target).is_finite() {
68            return None;
69        }
70
71        let mut path = Vec::new();
72        let mut current = target;
73
74        while current != source {
75            path.push(current);
76            let pred = self.predecessors[source * self.node_count + current];
77            if pred < 0 {
78                return None;
79            }
80            current = pred as usize;
81        }
82
83        path.push(source);
84        path.reverse();
85        Some(path)
86    }
87}
88
89// ============================================================================
90// Shortest Path Kernel
91// ============================================================================
92
93/// Shortest path kernel using BFS (unweighted) or Delta-Stepping (weighted).
94#[derive(Debug, Clone)]
95pub struct ShortestPath {
96    metadata: KernelMetadata,
97}
98
99impl Default for ShortestPath {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl ShortestPath {
106    /// Create a new shortest path kernel.
107    #[must_use]
108    pub fn new() -> Self {
109        Self {
110            metadata: KernelMetadata::batch("graph/shortest-path", Domain::GraphAnalytics)
111                .with_description("Shortest path via BFS/Delta-Stepping")
112                .with_throughput(50_000)
113                .with_latency_us(80.0),
114        }
115    }
116
117    /// Compute single-source shortest paths using BFS (for unweighted graphs).
118    ///
119    /// # Arguments
120    /// * `graph` - Input graph (CSR format)
121    /// * `source` - Source node index
122    pub fn compute_sssp_bfs(graph: &CsrGraph, source: usize) -> Vec<ShortestPathResult> {
123        let n = graph.num_nodes;
124        let mut distances = vec![f64::INFINITY; n];
125        let mut predecessors = vec![-1i64; n];
126        let mut hop_counts = vec![0u32; n];
127
128        distances[source] = 0.0;
129
130        let mut queue = VecDeque::new();
131        queue.push_back(source);
132
133        while let Some(v) = queue.pop_front() {
134            let current_dist = distances[v];
135
136            for &w in graph.neighbors(v as u64) {
137                let w = w as usize;
138                if distances[w].is_infinite() {
139                    distances[w] = current_dist + 1.0;
140                    predecessors[w] = v as i64;
141                    hop_counts[w] = hop_counts[v] + 1;
142                    queue.push_back(w);
143                }
144            }
145        }
146
147        (0..n)
148            .map(|i| ShortestPathResult {
149                node_index: i,
150                distance: distances[i],
151                predecessor: predecessors[i],
152                is_reachable: distances[i].is_finite(),
153                hop_count: hop_counts[i],
154            })
155            .collect()
156    }
157
158    /// Compute single-source shortest paths using Dijkstra (for weighted graphs).
159    ///
160    /// # Arguments
161    /// * `graph` - Input graph (CSR format)
162    /// * `source` - Source node index
163    /// * `weights` - Edge weights (parallel to graph edges)
164    pub fn compute_sssp_dijkstra(
165        graph: &CsrGraph,
166        source: usize,
167        weights: &[f64],
168    ) -> Vec<ShortestPathResult> {
169        let n = graph.num_nodes;
170        let mut distances = vec![f64::INFINITY; n];
171        let mut predecessors = vec![-1i64; n];
172        let mut hop_counts = vec![0u32; n];
173
174        distances[source] = 0.0;
175
176        // Priority queue: (negative distance, node) - negated for min-heap behavior
177        let mut heap = BinaryHeap::new();
178        heap.push(HeapNode {
179            dist: 0.0,
180            node: source,
181        });
182
183        while let Some(HeapNode { dist, node: v }) = heap.pop() {
184            if dist > distances[v] {
185                continue; // Already processed with shorter distance
186            }
187
188            let neighbors = graph.neighbors(v as u64);
189            let edge_start = if v == 0 {
190                0
191            } else {
192                graph.row_offsets[v] as usize
193            };
194
195            for (i, &w) in neighbors.iter().enumerate() {
196                let w = w as usize;
197                let weight = weights.get(edge_start + i).copied().unwrap_or(1.0);
198                let new_dist = distances[v] + weight;
199
200                if new_dist < distances[w] {
201                    distances[w] = new_dist;
202                    predecessors[w] = v as i64;
203                    hop_counts[w] = hop_counts[v] + 1;
204                    heap.push(HeapNode {
205                        dist: new_dist,
206                        node: w,
207                    });
208                }
209            }
210        }
211
212        (0..n)
213            .map(|i| ShortestPathResult {
214                node_index: i,
215                distance: distances[i],
216                predecessor: predecessors[i],
217                is_reachable: distances[i].is_finite(),
218                hop_count: hop_counts[i],
219            })
220            .collect()
221    }
222
223    /// Compute all-pairs shortest paths.
224    pub fn compute_apsp(graph: &CsrGraph) -> AllPairsResult {
225        let n = graph.num_nodes;
226        let mut distances = vec![f64::INFINITY; n * n];
227        let mut predecessors = vec![-1i64; n * n];
228
229        // Run SSSP from each node
230        for source in 0..n {
231            let sssp = Self::compute_sssp_bfs(graph, source);
232
233            for result in sssp {
234                let idx = source * n + result.node_index;
235                distances[idx] = result.distance;
236                predecessors[idx] = result.predecessor;
237            }
238        }
239
240        AllPairsResult {
241            node_count: n,
242            distances,
243            predecessors,
244        }
245    }
246
247    /// Reconstruct path from source to target.
248    pub fn reconstruct_path(
249        sssp: &[ShortestPathResult],
250        source: usize,
251        target: usize,
252    ) -> Option<Vec<usize>> {
253        if !sssp[target].is_reachable {
254            return None;
255        }
256
257        let mut path = Vec::new();
258        let mut current = target;
259
260        while current != source {
261            path.push(current);
262            let pred = sssp[current].predecessor;
263            if pred < 0 {
264                return None;
265            }
266            current = pred as usize;
267        }
268
269        path.push(source);
270        path.reverse();
271        Some(path)
272    }
273
274    /// Compute shortest path between two nodes.
275    pub fn compute_path(graph: &CsrGraph, source: usize, target: usize) -> Option<PathResult> {
276        let sssp = Self::compute_sssp_bfs(graph, source);
277
278        if !sssp[target].is_reachable {
279            return None;
280        }
281
282        let node_path = Self::reconstruct_path(&sssp, source, target)?;
283
284        Some(PathResult {
285            source,
286            target,
287            path_length: sssp[target].distance,
288            hop_count: node_path.len() - 1,
289            node_path,
290        })
291    }
292
293    /// Find k shortest paths using Yen's algorithm.
294    pub fn compute_k_shortest(
295        graph: &CsrGraph,
296        source: usize,
297        target: usize,
298        k: usize,
299    ) -> Vec<PathResult> {
300        let mut result_paths = Vec::new();
301
302        // First, find the shortest path
303        if let Some(first_path) = Self::compute_path(graph, source, target) {
304            result_paths.push(first_path);
305        } else {
306            return result_paths;
307        }
308
309        // Candidate paths
310        let mut candidates: Vec<PathResult> = Vec::new();
311
312        for _i in 1..k {
313            let prev_path = &result_paths[result_paths.len() - 1];
314
315            // For each deviation point on the previous path
316            for j in 0..(prev_path.node_path.len() - 1) {
317                let spur_node = prev_path.node_path[j];
318                let root_path: Vec<usize> = prev_path.node_path[..=j].to_vec();
319
320                // Create modified graph (remove edges used by previous paths at this deviation)
321                // For simplicity, we'll use a less efficient but correct approach
322                let edges_to_avoid = Self::collect_edges_to_avoid(&result_paths, &root_path);
323
324                // Find path in modified graph
325                if let Some(spur_path) =
326                    Self::compute_path_avoiding(graph, spur_node, target, &edges_to_avoid)
327                {
328                    let mut total_path = root_path.clone();
329                    total_path.extend(spur_path.node_path.into_iter().skip(1));
330
331                    let path_length = (total_path.len() - 1) as f64;
332                    let candidate = PathResult {
333                        source,
334                        target,
335                        path_length,
336                        hop_count: total_path.len() - 1,
337                        node_path: total_path,
338                    };
339
340                    // Add if not already in candidates or results
341                    if !Self::path_exists(&candidates, &candidate.node_path)
342                        && !Self::path_exists_in_results(&result_paths, &candidate.node_path)
343                    {
344                        candidates.push(candidate);
345                    }
346                }
347            }
348
349            if candidates.is_empty() {
350                break;
351            }
352
353            // Sort candidates by path length and take the best one
354            candidates.sort_by(|a, b| {
355                a.path_length
356                    .partial_cmp(&b.path_length)
357                    .unwrap_or(Ordering::Equal)
358            });
359
360            result_paths.push(candidates.remove(0));
361        }
362
363        result_paths
364    }
365
366    /// Compute path avoiding certain edges.
367    fn compute_path_avoiding(
368        graph: &CsrGraph,
369        source: usize,
370        target: usize,
371        avoid_edges: &[(usize, usize)],
372    ) -> Option<PathResult> {
373        let n = graph.num_nodes;
374        let mut distances = vec![f64::INFINITY; n];
375        let mut predecessors = vec![-1i64; n];
376
377        distances[source] = 0.0;
378
379        let mut queue = VecDeque::new();
380        queue.push_back(source);
381
382        while let Some(v) = queue.pop_front() {
383            if v == target {
384                break;
385            }
386
387            let current_dist = distances[v];
388
389            for &w in graph.neighbors(v as u64) {
390                let w = w as usize;
391
392                // Skip avoided edges
393                if avoid_edges.contains(&(v, w)) {
394                    continue;
395                }
396
397                if distances[w].is_infinite() {
398                    distances[w] = current_dist + 1.0;
399                    predecessors[w] = v as i64;
400                    queue.push_back(w);
401                }
402            }
403        }
404
405        if distances[target].is_infinite() {
406            return None;
407        }
408
409        // Reconstruct path
410        let mut path = Vec::new();
411        let mut current = target;
412
413        while current != source {
414            path.push(current);
415            let pred = predecessors[current];
416            if pred < 0 {
417                return None;
418            }
419            current = pred as usize;
420        }
421
422        path.push(source);
423        path.reverse();
424
425        Some(PathResult {
426            source,
427            target,
428            path_length: distances[target],
429            hop_count: path.len() - 1,
430            node_path: path,
431        })
432    }
433
434    fn collect_edges_to_avoid(
435        result_paths: &[PathResult],
436        root_path: &[usize],
437    ) -> Vec<(usize, usize)> {
438        let mut edges = Vec::new();
439
440        for path in result_paths {
441            // Check if this path shares the root
442            if path.node_path.len() >= root_path.len()
443                && path.node_path[..root_path.len()] == *root_path
444            {
445                // Add the edge right after root_path
446                if path.node_path.len() > root_path.len() {
447                    let from = root_path[root_path.len() - 1];
448                    let to = path.node_path[root_path.len()];
449                    edges.push((from, to));
450                }
451            }
452        }
453
454        edges
455    }
456
457    fn path_exists(candidates: &[PathResult], path: &[usize]) -> bool {
458        candidates.iter().any(|c| c.node_path == path)
459    }
460
461    fn path_exists_in_results(results: &[PathResult], path: &[usize]) -> bool {
462        results.iter().any(|r| r.node_path == path)
463    }
464
465    /// Compute eccentricity for each node (max distance to any other node).
466    pub fn compute_eccentricity(graph: &CsrGraph) -> Vec<f64> {
467        let n = graph.num_nodes;
468        let mut eccentricities = vec![0.0; n];
469
470        for source in 0..n {
471            let sssp = Self::compute_sssp_bfs(graph, source);
472            let max_dist = sssp
473                .iter()
474                .filter(|r| r.is_reachable)
475                .map(|r| r.distance)
476                .fold(0.0, f64::max);
477            eccentricities[source] = max_dist;
478        }
479
480        eccentricities
481    }
482
483    /// Compute graph diameter (max eccentricity).
484    pub fn compute_diameter(graph: &CsrGraph) -> f64 {
485        Self::compute_eccentricity(graph)
486            .into_iter()
487            .fold(0.0, f64::max)
488    }
489
490    /// Compute graph radius (min eccentricity).
491    pub fn compute_radius(graph: &CsrGraph) -> f64 {
492        Self::compute_eccentricity(graph)
493            .into_iter()
494            .filter(|&e| e > 0.0)
495            .fold(f64::INFINITY, f64::min)
496    }
497}
498
499impl GpuKernel for ShortestPath {
500    fn metadata(&self) -> &KernelMetadata {
501        &self.metadata
502    }
503}
504
505/// Helper struct for Dijkstra's priority queue.
506#[derive(Clone, PartialEq)]
507struct HeapNode {
508    dist: f64,
509    node: usize,
510}
511
512impl Eq for HeapNode {}
513
514impl Ord for HeapNode {
515    fn cmp(&self, other: &Self) -> Ordering {
516        // Reverse ordering for min-heap
517        other
518            .dist
519            .partial_cmp(&self.dist)
520            .unwrap_or(Ordering::Equal)
521    }
522}
523
524impl PartialOrd for HeapNode {
525    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
526        Some(self.cmp(other))
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    fn create_line_graph() -> CsrGraph {
535        // Line: 0 - 1 - 2 - 3
536        CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
537    }
538
539    fn create_complete_graph() -> CsrGraph {
540        // Complete graph K4
541        CsrGraph::from_edges(
542            4,
543            &[
544                (0, 1),
545                (0, 2),
546                (0, 3),
547                (1, 0),
548                (1, 2),
549                (1, 3),
550                (2, 0),
551                (2, 1),
552                (2, 3),
553                (3, 0),
554                (3, 1),
555                (3, 2),
556            ],
557        )
558    }
559
560    fn create_disconnected_graph() -> CsrGraph {
561        // Two disconnected pairs: 0-1 and 2-3
562        CsrGraph::from_edges(4, &[(0, 1), (1, 0), (2, 3), (3, 2)])
563    }
564
565    #[test]
566    fn test_shortest_path_metadata() {
567        let kernel = ShortestPath::new();
568        assert_eq!(kernel.metadata().id, "graph/shortest-path");
569        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
570    }
571
572    #[test]
573    fn test_sssp_bfs_line() {
574        let graph = create_line_graph();
575        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
576
577        assert_eq!(sssp[0].distance, 0.0);
578        assert_eq!(sssp[1].distance, 1.0);
579        assert_eq!(sssp[2].distance, 2.0);
580        assert_eq!(sssp[3].distance, 3.0);
581    }
582
583    #[test]
584    fn test_sssp_bfs_complete() {
585        let graph = create_complete_graph();
586        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
587
588        // In complete graph, all nodes are distance 1 from any other
589        assert_eq!(sssp[0].distance, 0.0);
590        assert_eq!(sssp[1].distance, 1.0);
591        assert_eq!(sssp[2].distance, 1.0);
592        assert_eq!(sssp[3].distance, 1.0);
593    }
594
595    #[test]
596    fn test_sssp_disconnected() {
597        let graph = create_disconnected_graph();
598        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
599
600        assert!(sssp[0].is_reachable);
601        assert!(sssp[1].is_reachable);
602        assert!(!sssp[2].is_reachable);
603        assert!(!sssp[3].is_reachable);
604    }
605
606    #[test]
607    fn test_reconstruct_path() {
608        let graph = create_line_graph();
609        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
610
611        let path = ShortestPath::reconstruct_path(&sssp, 0, 3);
612        assert!(path.is_some());
613        let path = path.unwrap();
614        assert_eq!(path, vec![0, 1, 2, 3]);
615    }
616
617    #[test]
618    fn test_compute_path() {
619        let graph = create_line_graph();
620        let path = ShortestPath::compute_path(&graph, 0, 3);
621
622        assert!(path.is_some());
623        let path = path.unwrap();
624        assert_eq!(path.hop_count, 3);
625        assert_eq!(path.node_path, vec![0, 1, 2, 3]);
626    }
627
628    #[test]
629    fn test_apsp() {
630        let graph = create_line_graph();
631        let apsp = ShortestPath::compute_apsp(&graph);
632
633        assert_eq!(apsp.distance(0, 3), 3.0);
634        assert_eq!(apsp.distance(1, 2), 1.0);
635        assert_eq!(apsp.distance(0, 0), 0.0);
636    }
637
638    #[test]
639    fn test_diameter() {
640        let graph = create_line_graph();
641        let diameter = ShortestPath::compute_diameter(&graph);
642
643        assert_eq!(diameter, 3.0);
644    }
645
646    #[test]
647    fn test_k_shortest() {
648        let graph = create_complete_graph();
649        let paths = ShortestPath::compute_k_shortest(&graph, 0, 3, 3);
650
651        assert!(paths.len() >= 1);
652        assert_eq!(paths[0].hop_count, 1); // Direct edge
653    }
654}