Skip to main content

scirs2_graph/
compressed.rs

1//! Compressed Sparse Row (CSR) graph representation for large-scale graphs
2//!
3//! This module provides a memory-efficient CSR graph format optimized for
4//! large graph analytics workloads. It offers O(1) neighbor access start,
5//! O(degree) neighbor iteration, and efficient parallel construction.
6//!
7//! # Key Features
8//!
9//! - **Memory efficient**: Contiguous arrays minimize cache misses and allocator overhead
10//! - **Fast neighbor access**: O(1) to locate neighbor list start, O(degree) iteration
11//! - **Weighted/unweighted**: Optional edge weights with zero-cost abstraction
12//! - **Directed/undirected**: Support for both graph types
13//! - **Parallel construction**: Feature-gated parallel edge list sorting and prefix sum
14//! - **Conversions**: Convert to/from adjacency list and `Graph<usize, f64>` types
15
16use crate::error::{GraphError, Result};
17
18#[cfg(feature = "parallel")]
19use scirs2_core::parallel_ops::*;
20
21/// A compressed sparse row (CSR) graph representation.
22///
23/// Stores edges in three contiguous arrays:
24/// - `row_ptr[i]..row_ptr[i+1]` gives the range of neighbors for node `i`
25/// - `col_indices[row_ptr[i]..row_ptr[i+1]]` are the neighbor node IDs
26/// - `values[row_ptr[i]..row_ptr[i+1]]` are the corresponding edge weights
27///
28/// This format is identical to the CSR sparse matrix format from numerical
29/// linear algebra, enabling direct use in sparse matrix-vector products
30/// (e.g., for PageRank).
31#[derive(Debug, Clone)]
32pub struct CsrGraph {
33    /// Number of nodes in the graph
34    num_nodes: usize,
35    /// Number of (directed) edges stored
36    num_edges: usize,
37    /// Row pointers: length = num_nodes + 1
38    row_ptr: Vec<usize>,
39    /// Column indices for each edge: length = num_edges
40    col_indices: Vec<usize>,
41    /// Edge weights: length = num_edges
42    values: Vec<f64>,
43    /// Whether this graph is directed
44    directed: bool,
45}
46
47/// Builder for constructing CSR graphs from edge lists.
48///
49/// Accumulates edges and then performs a single sort + prefix-sum
50/// to build the CSR arrays efficiently.
51#[derive(Debug, Clone)]
52pub struct CsrGraphBuilder {
53    num_nodes: usize,
54    edges: Vec<(usize, usize, f64)>,
55    directed: bool,
56}
57
58/// An adjacency list representation for graph interchange.
59#[derive(Debug, Clone)]
60pub struct AdjacencyList {
61    /// Number of nodes
62    pub num_nodes: usize,
63    /// For each node, a list of (neighbor, weight) pairs
64    pub adjacency: Vec<Vec<(usize, f64)>>,
65    /// Whether the graph is directed
66    pub directed: bool,
67}
68
69// ────────────────────────────────────────────────────────────────────────────
70// CsrGraphBuilder
71// ────────────────────────────────────────────────────────────────────────────
72
73impl CsrGraphBuilder {
74    /// Create a new CSR graph builder.
75    ///
76    /// # Arguments
77    /// * `num_nodes` - Number of nodes in the graph
78    /// * `directed` - Whether the graph is directed
79    pub fn new(num_nodes: usize, directed: bool) -> Self {
80        Self {
81            num_nodes,
82            edges: Vec::new(),
83            directed,
84        }
85    }
86
87    /// Create a builder with pre-allocated edge capacity.
88    pub fn with_capacity(num_nodes: usize, edge_capacity: usize, directed: bool) -> Self {
89        Self {
90            num_nodes,
91            edges: Vec::with_capacity(edge_capacity),
92            directed,
93        }
94    }
95
96    /// Add a single edge.
97    ///
98    /// For undirected graphs, the reverse edge is added automatically during build.
99    pub fn add_edge(&mut self, src: usize, dst: usize, weight: f64) -> Result<()> {
100        if src >= self.num_nodes {
101            return Err(GraphError::node_not_found_with_context(
102                src,
103                self.num_nodes,
104                "CsrGraphBuilder::add_edge (source)",
105            ));
106        }
107        if dst >= self.num_nodes {
108            return Err(GraphError::node_not_found_with_context(
109                dst,
110                self.num_nodes,
111                "CsrGraphBuilder::add_edge (destination)",
112            ));
113        }
114        self.edges.push((src, dst, weight));
115        Ok(())
116    }
117
118    /// Add an unweighted edge (weight = 1.0).
119    pub fn add_unweighted_edge(&mut self, src: usize, dst: usize) -> Result<()> {
120        self.add_edge(src, dst, 1.0)
121    }
122
123    /// Add edges from an iterator of `(src, dst, weight)` triples.
124    pub fn add_edges<I>(&mut self, edges: I) -> Result<()>
125    where
126        I: IntoIterator<Item = (usize, usize, f64)>,
127    {
128        for (src, dst, weight) in edges {
129            self.add_edge(src, dst, weight)?;
130        }
131        Ok(())
132    }
133
134    /// Build the CSR graph (sequential).
135    pub fn build(self) -> Result<CsrGraph> {
136        build_csr_sequential(self.num_nodes, self.edges, self.directed)
137    }
138
139    /// Build the CSR graph using parallel sorting and construction.
140    #[cfg(feature = "parallel")]
141    pub fn build_parallel(self) -> Result<CsrGraph> {
142        build_csr_parallel(self.num_nodes, self.edges, self.directed)
143    }
144}
145
146// ────────────────────────────────────────────────────────────────────────────
147// Sequential CSR construction
148// ────────────────────────────────────────────────────────────────────────────
149
150fn build_csr_sequential(
151    num_nodes: usize,
152    mut edges: Vec<(usize, usize, f64)>,
153    directed: bool,
154) -> Result<CsrGraph> {
155    // For undirected graphs, add reverse edges
156    if !directed {
157        let reverse: Vec<(usize, usize, f64)> = edges.iter().map(|&(s, d, w)| (d, s, w)).collect();
158        edges.extend(reverse);
159    }
160
161    let num_edges = edges.len();
162
163    // Validate all node indices
164    for &(src, dst, _) in &edges {
165        if src >= num_nodes {
166            return Err(GraphError::node_not_found_with_context(
167                src,
168                num_nodes,
169                "CSR construction (source)",
170            ));
171        }
172        if dst >= num_nodes {
173            return Err(GraphError::node_not_found_with_context(
174                dst,
175                num_nodes,
176                "CSR construction (destination)",
177            ));
178        }
179    }
180
181    // Count degrees (first pass)
182    let mut degree = vec![0usize; num_nodes];
183    for &(src, _, _) in &edges {
184        degree[src] += 1;
185    }
186
187    // Build row_ptr via prefix sum
188    let mut row_ptr = Vec::with_capacity(num_nodes + 1);
189    row_ptr.push(0);
190    for &deg in &degree {
191        let last = row_ptr.last().copied().unwrap_or(0);
192        row_ptr.push(last + deg);
193    }
194
195    // Place edges into CSR arrays using counting sort
196    let mut col_indices = vec![0usize; num_edges];
197    let mut values = vec![0.0f64; num_edges];
198    let mut current_pos: Vec<usize> = row_ptr[..num_nodes].to_vec();
199
200    for (src, dst, weight) in &edges {
201        let pos = current_pos[*src];
202        col_indices[pos] = *dst;
203        values[pos] = *weight;
204        current_pos[*src] += 1;
205    }
206
207    // Sort neighbors within each row for cache-friendly access and binary search
208    for node in 0..num_nodes {
209        let start = row_ptr[node];
210        let end = row_ptr[node + 1];
211        if end > start + 1 {
212            // Build (col, value) pairs, sort by col, write back
213            let mut pairs: Vec<(usize, f64)> = col_indices[start..end]
214                .iter()
215                .zip(&values[start..end])
216                .map(|(&c, &v)| (c, v))
217                .collect();
218            pairs.sort_unstable_by_key(|&(c, _)| c);
219            for (i, (c, v)) in pairs.into_iter().enumerate() {
220                col_indices[start + i] = c;
221                values[start + i] = v;
222            }
223        }
224    }
225
226    Ok(CsrGraph {
227        num_nodes,
228        num_edges,
229        row_ptr,
230        col_indices,
231        values,
232        directed,
233    })
234}
235
236// ────────────────────────────────────────────────────────────────────────────
237// Parallel CSR construction
238// ────────────────────────────────────────────────────────────────────────────
239
240#[cfg(feature = "parallel")]
241fn build_csr_parallel(
242    num_nodes: usize,
243    mut edges: Vec<(usize, usize, f64)>,
244    directed: bool,
245) -> Result<CsrGraph> {
246    // For undirected graphs, add reverse edges
247    if !directed {
248        let reverse: Vec<(usize, usize, f64)> = edges.iter().map(|&(s, d, w)| (d, s, w)).collect();
249        edges.extend(reverse);
250    }
251
252    let num_edges = edges.len();
253
254    // Validate all node indices
255    for &(src, dst, _) in &edges {
256        if src >= num_nodes {
257            return Err(GraphError::node_not_found_with_context(
258                src,
259                num_nodes,
260                "CSR parallel construction (source)",
261            ));
262        }
263        if dst >= num_nodes {
264            return Err(GraphError::node_not_found_with_context(
265                dst,
266                num_nodes,
267                "CSR parallel construction (destination)",
268            ));
269        }
270    }
271
272    // Parallel sort edges by source node
273    edges.par_sort_unstable_by_key(|&(src, _, _)| src);
274
275    // Count degrees using parallel fold
276    let degree: Vec<usize> = {
277        let mut deg = vec![0usize; num_nodes];
278        for &(src, _, _) in &edges {
279            deg[src] += 1;
280        }
281        deg
282    };
283
284    // Build row_ptr via prefix sum
285    let mut row_ptr = Vec::with_capacity(num_nodes + 1);
286    row_ptr.push(0);
287    for &deg in &degree {
288        let last = row_ptr.last().copied().unwrap_or(0);
289        row_ptr.push(last + deg);
290    }
291
292    // Since edges are sorted by source, we can directly split
293    let mut col_indices = Vec::with_capacity(num_edges);
294    let mut values = Vec::with_capacity(num_edges);
295    for &(_, dst, weight) in &edges {
296        col_indices.push(dst);
297        values.push(weight);
298    }
299
300    // Sort neighbors within each row sequentially
301    // (row-level parallelism is already handled by the parallel sort above)
302    for node in 0..num_nodes {
303        let start = row_ptr[node];
304        let end = row_ptr[node + 1];
305        if end > start + 1 {
306            let mut pairs: Vec<(usize, f64)> = col_indices[start..end]
307                .iter()
308                .zip(&values[start..end])
309                .map(|(&c, &v)| (c, v))
310                .collect();
311            pairs.sort_unstable_by_key(|&(c, _)| c);
312            for (i, (c, v)) in pairs.into_iter().enumerate() {
313                col_indices[start + i] = c;
314                values[start + i] = v;
315            }
316        }
317    }
318
319    Ok(CsrGraph {
320        num_nodes,
321        num_edges,
322        row_ptr,
323        col_indices,
324        values,
325        directed,
326    })
327}
328
329// ────────────────────────────────────────────────────────────────────────────
330// CsrGraph core API
331// ────────────────────────────────────────────────────────────────────────────
332
333impl CsrGraph {
334    /// Construct a CSR graph directly from raw arrays.
335    ///
336    /// # Safety contract (logical, not `unsafe`)
337    /// The caller must ensure:
338    /// - `row_ptr.len() == num_nodes + 1`
339    /// - All values in `col_indices` are `< num_nodes`
340    /// - `col_indices.len() == values.len() == row_ptr[num_nodes]`
341    pub fn from_raw(
342        num_nodes: usize,
343        row_ptr: Vec<usize>,
344        col_indices: Vec<usize>,
345        values: Vec<f64>,
346        directed: bool,
347    ) -> Result<Self> {
348        if row_ptr.len() != num_nodes + 1 {
349            return Err(GraphError::InvalidGraph(format!(
350                "row_ptr length {} does not match num_nodes + 1 = {}",
351                row_ptr.len(),
352                num_nodes + 1
353            )));
354        }
355        if col_indices.len() != values.len() {
356            return Err(GraphError::InvalidGraph(format!(
357                "col_indices length {} != values length {}",
358                col_indices.len(),
359                values.len()
360            )));
361        }
362        let num_edges = col_indices.len();
363        let last_ptr = row_ptr.last().copied().unwrap_or(0);
364        if last_ptr != num_edges {
365            return Err(GraphError::InvalidGraph(format!(
366                "row_ptr last element {} != num_edges {}",
367                last_ptr, num_edges
368            )));
369        }
370        // Validate column indices
371        for (i, &col) in col_indices.iter().enumerate() {
372            if col >= num_nodes {
373                return Err(GraphError::node_not_found_with_context(
374                    col,
375                    num_nodes,
376                    &format!("from_raw validation at edge index {i}"),
377                ));
378            }
379        }
380        Ok(Self {
381            num_nodes,
382            num_edges,
383            row_ptr,
384            col_indices,
385            values,
386            directed,
387        })
388    }
389
390    /// Construct from an edge list (convenience wrapper).
391    ///
392    /// For undirected graphs, reverse edges are added automatically.
393    pub fn from_edges(
394        num_nodes: usize,
395        edges: Vec<(usize, usize, f64)>,
396        directed: bool,
397    ) -> Result<Self> {
398        build_csr_sequential(num_nodes, edges, directed)
399    }
400
401    /// Construct from an edge list using parallel construction.
402    #[cfg(feature = "parallel")]
403    pub fn from_edges_parallel(
404        num_nodes: usize,
405        edges: Vec<(usize, usize, f64)>,
406        directed: bool,
407    ) -> Result<Self> {
408        build_csr_parallel(num_nodes, edges, directed)
409    }
410
411    /// Construct an unweighted CSR graph from an edge list.
412    pub fn from_unweighted_edges(
413        num_nodes: usize,
414        edges: &[(usize, usize)],
415        directed: bool,
416    ) -> Result<Self> {
417        let weighted: Vec<(usize, usize, f64)> = edges.iter().map(|&(s, d)| (s, d, 1.0)).collect();
418        build_csr_sequential(num_nodes, weighted, directed)
419    }
420
421    /// Number of nodes.
422    #[inline]
423    pub fn num_nodes(&self) -> usize {
424        self.num_nodes
425    }
426
427    /// Number of stored (directed) edges.
428    ///
429    /// For undirected graphs built from `n` input edges, this returns `2n`
430    /// because both directions are stored.
431    #[inline]
432    pub fn num_edges(&self) -> usize {
433        self.num_edges
434    }
435
436    /// Number of logical edges.
437    ///
438    /// For undirected graphs, returns `num_edges / 2` (the original count).
439    /// For directed graphs, returns `num_edges`.
440    #[inline]
441    pub fn num_logical_edges(&self) -> usize {
442        if self.directed {
443            self.num_edges
444        } else {
445            self.num_edges / 2
446        }
447    }
448
449    /// Whether the graph is directed.
450    #[inline]
451    pub fn is_directed(&self) -> bool {
452        self.directed
453    }
454
455    /// Out-degree of a node (number of outgoing edges).
456    ///
457    /// Returns 0 if the node index is out of range.
458    #[inline]
459    pub fn degree(&self, node: usize) -> usize {
460        if node >= self.num_nodes {
461            return 0;
462        }
463        self.row_ptr[node + 1] - self.row_ptr[node]
464    }
465
466    /// Iterator over neighbors of `node` as `(neighbor_id, weight)` pairs.
467    ///
468    /// Returns an empty iterator if `node` is out of range.
469    #[inline]
470    pub fn neighbors(&self, node: usize) -> NeighborIter<'_> {
471        if node >= self.num_nodes {
472            return NeighborIter {
473                col_iter: [].iter(),
474                val_iter: [].iter(),
475            };
476        }
477        let start = self.row_ptr[node];
478        let end = self.row_ptr[node + 1];
479        NeighborIter {
480            col_iter: self.col_indices[start..end].iter(),
481            val_iter: self.values[start..end].iter(),
482        }
483    }
484
485    /// Check if an edge exists from `src` to `dst`.
486    ///
487    /// Uses binary search on the sorted neighbor list. O(log(degree)).
488    pub fn has_edge(&self, src: usize, dst: usize) -> bool {
489        if src >= self.num_nodes || dst >= self.num_nodes {
490            return false;
491        }
492        let start = self.row_ptr[src];
493        let end = self.row_ptr[src + 1];
494        self.col_indices[start..end].binary_search(&dst).is_ok()
495    }
496
497    /// Get the weight of an edge from `src` to `dst`.
498    ///
499    /// Returns `None` if the edge does not exist.
500    pub fn edge_weight(&self, src: usize, dst: usize) -> Option<f64> {
501        if src >= self.num_nodes || dst >= self.num_nodes {
502            return None;
503        }
504        let start = self.row_ptr[src];
505        let end = self.row_ptr[src + 1];
506        match self.col_indices[start..end].binary_search(&dst) {
507            Ok(idx) => Some(self.values[start + idx]),
508            Err(_) => None,
509        }
510    }
511
512    /// Get the raw row pointer array (read-only).
513    #[inline]
514    pub fn row_ptr(&self) -> &[usize] {
515        &self.row_ptr
516    }
517
518    /// Get the raw column index array (read-only).
519    #[inline]
520    pub fn col_indices(&self) -> &[usize] {
521        &self.col_indices
522    }
523
524    /// Get the raw values array (read-only).
525    #[inline]
526    pub fn values(&self) -> &[f64] {
527        &self.values
528    }
529
530    /// Memory usage in bytes (approximate).
531    pub fn memory_bytes(&self) -> usize {
532        use std::mem::size_of;
533        // Struct overhead
534        size_of::<Self>()
535            // row_ptr
536            + self.row_ptr.capacity() * size_of::<usize>()
537            // col_indices
538            + self.col_indices.capacity() * size_of::<usize>()
539            // values
540            + self.values.capacity() * size_of::<f64>()
541    }
542
543    /// Sparse matrix-vector product: y = A * x.
544    ///
545    /// This is the core operation for iterative algorithms like PageRank.
546    /// `x` and the returned vector both have length `num_nodes`.
547    pub fn spmv(&self, x: &[f64]) -> Result<Vec<f64>> {
548        if x.len() != self.num_nodes {
549            return Err(GraphError::InvalidGraph(format!(
550                "spmv: vector length {} != num_nodes {}",
551                x.len(),
552                self.num_nodes
553            )));
554        }
555        let mut y = vec![0.0f64; self.num_nodes];
556        for row in 0..self.num_nodes {
557            let start = self.row_ptr[row];
558            let end = self.row_ptr[row + 1];
559            let mut sum = 0.0;
560            for idx in start..end {
561                sum += self.values[idx] * x[self.col_indices[idx]];
562            }
563            y[row] = sum;
564        }
565        Ok(y)
566    }
567
568    /// Parallel sparse matrix-vector product: y = A * x.
569    #[cfg(feature = "parallel")]
570    pub fn spmv_parallel(&self, x: &[f64]) -> Result<Vec<f64>> {
571        if x.len() != self.num_nodes {
572            return Err(GraphError::InvalidGraph(format!(
573                "spmv_parallel: vector length {} != num_nodes {}",
574                x.len(),
575                self.num_nodes
576            )));
577        }
578        let y: Vec<f64> = (0..self.num_nodes)
579            .into_par_iter()
580            .map(|row| {
581                let start = self.row_ptr[row];
582                let end = self.row_ptr[row + 1];
583                let mut sum = 0.0;
584                for idx in start..end {
585                    sum += self.values[idx] * x[self.col_indices[idx]];
586                }
587                sum
588            })
589            .collect();
590        Ok(y)
591    }
592
593    /// Transpose the graph (reverse all edge directions).
594    ///
595    /// For undirected graphs, the transpose is the same graph.
596    pub fn transpose(&self) -> Result<Self> {
597        if !self.directed {
598            return Ok(self.clone());
599        }
600        // Collect all edges in reversed form
601        let mut edges = Vec::with_capacity(self.num_edges);
602        for src in 0..self.num_nodes {
603            for (dst, weight) in self.neighbors(src) {
604                edges.push((dst, src, weight));
605            }
606        }
607        build_csr_sequential(self.num_nodes, edges, true)
608    }
609}
610
611// ────────────────────────────────────────────────────────────────────────────
612// NeighborIter
613// ────────────────────────────────────────────────────────────────────────────
614
615/// Iterator over (neighbor_id, weight) pairs for a node.
616pub struct NeighborIter<'a> {
617    col_iter: std::slice::Iter<'a, usize>,
618    val_iter: std::slice::Iter<'a, f64>,
619}
620
621impl<'a> Iterator for NeighborIter<'a> {
622    type Item = (usize, f64);
623
624    #[inline]
625    fn next(&mut self) -> Option<Self::Item> {
626        match (self.col_iter.next(), self.val_iter.next()) {
627            (Some(&col), Some(&val)) => Some((col, val)),
628            _ => None,
629        }
630    }
631
632    #[inline]
633    fn size_hint(&self) -> (usize, Option<usize>) {
634        self.col_iter.size_hint()
635    }
636}
637
638impl<'a> ExactSizeIterator for NeighborIter<'a> {}
639
640// ────────────────────────────────────────────────────────────────────────────
641// Conversions: CsrGraph <-> AdjacencyList
642// ────────────────────────────────────────────────────────────────────────────
643
644impl CsrGraph {
645    /// Convert to an adjacency list representation.
646    pub fn to_adjacency_list(&self) -> AdjacencyList {
647        let mut adjacency = Vec::with_capacity(self.num_nodes);
648        for node in 0..self.num_nodes {
649            let neighbors: Vec<(usize, f64)> = self.neighbors(node).collect();
650            adjacency.push(neighbors);
651        }
652        AdjacencyList {
653            num_nodes: self.num_nodes,
654            adjacency,
655            directed: self.directed,
656        }
657    }
658
659    /// Construct a CSR graph from an adjacency list.
660    ///
661    /// The adjacency list is consumed and edges are extracted directly.
662    /// For undirected graphs, the adjacency list should already contain
663    /// both directions (i.e., if `(u,v)` is present in `adj[u]`, then `(v,u)`
664    /// should be in `adj[v]`).
665    pub fn from_adjacency_list(adj: &AdjacencyList) -> Result<Self> {
666        let num_nodes = adj.num_nodes;
667        let mut edges = Vec::new();
668        for (src, neighbors) in adj.adjacency.iter().enumerate() {
669            for &(dst, weight) in neighbors {
670                edges.push((src, dst, weight));
671            }
672        }
673        // Since adjacency list already has both directions for undirected,
674        // we build as directed to avoid doubling
675        let num_edges = edges.len();
676
677        // Validate
678        for &(src, dst, _) in &edges {
679            if src >= num_nodes {
680                return Err(GraphError::node_not_found_with_context(
681                    src,
682                    num_nodes,
683                    "from_adjacency_list (source)",
684                ));
685            }
686            if dst >= num_nodes {
687                return Err(GraphError::node_not_found_with_context(
688                    dst,
689                    num_nodes,
690                    "from_adjacency_list (destination)",
691                ));
692            }
693        }
694
695        // Count degrees
696        let mut degree = vec![0usize; num_nodes];
697        for &(src, _, _) in &edges {
698            degree[src] += 1;
699        }
700
701        // Build row_ptr
702        let mut row_ptr = Vec::with_capacity(num_nodes + 1);
703        row_ptr.push(0);
704        for &deg in &degree {
705            let last = row_ptr.last().copied().unwrap_or(0);
706            row_ptr.push(last + deg);
707        }
708
709        // Fill arrays
710        let mut col_indices = vec![0usize; num_edges];
711        let mut values = vec![0.0f64; num_edges];
712        let mut current_pos: Vec<usize> = row_ptr[..num_nodes].to_vec();
713
714        for &(src, dst, weight) in &edges {
715            let pos = current_pos[src];
716            col_indices[pos] = dst;
717            values[pos] = weight;
718            current_pos[src] += 1;
719        }
720
721        // Sort neighbors within each row
722        for node in 0..num_nodes {
723            let start = row_ptr[node];
724            let end = row_ptr[node + 1];
725            if end > start + 1 {
726                let mut pairs: Vec<(usize, f64)> = col_indices[start..end]
727                    .iter()
728                    .zip(&values[start..end])
729                    .map(|(&c, &v)| (c, v))
730                    .collect();
731                pairs.sort_unstable_by_key(|&(c, _)| c);
732                for (i, (c, v)) in pairs.into_iter().enumerate() {
733                    col_indices[start + i] = c;
734                    values[start + i] = v;
735                }
736            }
737        }
738
739        Ok(Self {
740            num_nodes,
741            num_edges,
742            row_ptr,
743            col_indices,
744            values,
745            directed: adj.directed,
746        })
747    }
748}
749
750// ────────────────────────────────────────────────────────────────────────────
751// Conversions: CsrGraph <-> Graph<usize, f64>
752// ────────────────────────────────────────────────────────────────────────────
753
754impl CsrGraph {
755    /// Convert to a `Graph<usize, f64>` (undirected adjacency-list graph).
756    ///
757    /// For directed CSR graphs, the resulting `Graph` will be undirected
758    /// (edges are treated as undirected).
759    pub fn to_graph(&self) -> crate::Graph<usize, f64> {
760        let mut graph = crate::Graph::new();
761        for i in 0..self.num_nodes {
762            graph.add_node(i);
763        }
764        // For undirected CSR, each edge is stored twice. Only add once.
765        for src in 0..self.num_nodes {
766            for (dst, weight) in self.neighbors(src) {
767                if self.directed || src <= dst {
768                    // Ignore errors from duplicate edges
769                    let _ = graph.add_edge(src, dst, weight);
770                }
771            }
772        }
773        graph
774    }
775
776    /// Construct a CSR graph from a `Graph<usize, f64>`.
777    pub fn from_graph(graph: &crate::Graph<usize, f64>) -> Result<Self> {
778        let num_nodes = graph.node_count();
779        let edges: Vec<(usize, usize, f64)> = graph
780            .edges()
781            .into_iter()
782            .map(|e| (e.source, e.target, e.weight))
783            .collect();
784        // Graph<usize, f64> is undirected
785        build_csr_sequential(num_nodes, edges, false)
786    }
787
788    /// Construct a CSR graph from a `DiGraph<usize, f64>`.
789    pub fn from_digraph(graph: &crate::DiGraph<usize, f64>) -> Result<Self> {
790        let num_nodes = graph.node_count();
791        let edges: Vec<(usize, usize, f64)> = graph
792            .edges()
793            .into_iter()
794            .map(|e| (e.source, e.target, e.weight))
795            .collect();
796        build_csr_sequential(num_nodes, edges, true)
797    }
798}
799
800// ────────────────────────────────────────────────────────────────────────────
801// Display
802// ────────────────────────────────────────────────────────────────────────────
803
804impl std::fmt::Display for CsrGraph {
805    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
806        write!(
807            f,
808            "CsrGraph(nodes={}, edges={}, directed={}, mem={}KB)",
809            self.num_nodes,
810            self.num_logical_edges(),
811            self.directed,
812            self.memory_bytes() / 1024
813        )
814    }
815}
816
817// ────────────────────────────────────────────────────────────────────────────
818// Tests
819// ────────────────────────────────────────────────────────────────────────────
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824
825    #[test]
826    fn test_csr_from_edges_directed() {
827        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 0.5)];
828        let g = CsrGraph::from_edges(4, edges, true).expect("build failed");
829
830        assert_eq!(g.num_nodes(), 4);
831        assert_eq!(g.num_edges(), 4);
832        assert_eq!(g.num_logical_edges(), 4);
833        assert!(g.is_directed());
834
835        // Node 0 neighbors: 1, 2
836        assert_eq!(g.degree(0), 2);
837        let n0: Vec<(usize, f64)> = g.neighbors(0).collect();
838        assert_eq!(n0, vec![(1, 1.0), (2, 2.0)]);
839
840        // Node 2 neighbors: 0
841        assert_eq!(g.degree(2), 1);
842
843        // Node 3 has no outgoing edges
844        assert_eq!(g.degree(3), 0);
845
846        // Edge checks
847        assert!(g.has_edge(0, 1));
848        assert!(g.has_edge(2, 0));
849        assert!(!g.has_edge(1, 0)); // directed
850        assert!(!g.has_edge(3, 0));
851
852        // Weight lookup
853        assert_eq!(g.edge_weight(0, 2), Some(2.0));
854        assert_eq!(g.edge_weight(1, 0), None);
855    }
856
857    #[test]
858    fn test_csr_from_edges_undirected() {
859        let edges = vec![(0, 1, 1.0), (1, 2, 2.0), (2, 3, 3.0)];
860        let g = CsrGraph::from_edges(4, edges, false).expect("build failed");
861
862        assert_eq!(g.num_nodes(), 4);
863        assert_eq!(g.num_edges(), 6); // 3 edges * 2 directions
864        assert_eq!(g.num_logical_edges(), 3);
865        assert!(!g.is_directed());
866
867        // Both directions present
868        assert!(g.has_edge(0, 1));
869        assert!(g.has_edge(1, 0));
870        assert!(g.has_edge(2, 3));
871        assert!(g.has_edge(3, 2));
872        assert!(!g.has_edge(0, 3));
873    }
874
875    #[test]
876    fn test_csr_builder() {
877        let mut builder = CsrGraphBuilder::with_capacity(5, 4, true);
878        builder.add_edge(0, 1, 1.0).expect("add edge failed");
879        builder.add_edge(0, 2, 2.0).expect("add edge failed");
880        builder.add_unweighted_edge(3, 4).expect("add edge failed");
881        builder.add_unweighted_edge(4, 0).expect("add edge failed");
882
883        let g = builder.build().expect("build failed");
884        assert_eq!(g.num_nodes(), 5);
885        assert_eq!(g.degree(0), 2);
886        assert_eq!(g.degree(3), 1);
887        assert!(g.has_edge(4, 0));
888    }
889
890    #[test]
891    fn test_csr_builder_validation() {
892        let mut builder = CsrGraphBuilder::new(3, true);
893        assert!(builder.add_edge(0, 1, 1.0).is_ok());
894        assert!(builder.add_edge(5, 1, 1.0).is_err()); // src out of range
895        assert!(builder.add_edge(0, 5, 1.0).is_err()); // dst out of range
896    }
897
898    #[test]
899    fn test_csr_from_raw() {
900        let row_ptr = vec![0, 2, 3, 4];
901        let col_indices = vec![1, 2, 0, 1];
902        let values = vec![1.0, 2.0, 3.0, 4.0];
903        let g = CsrGraph::from_raw(3, row_ptr, col_indices, values, true).expect("from_raw");
904        assert_eq!(g.num_nodes(), 3);
905        assert_eq!(g.num_edges(), 4);
906        assert_eq!(g.degree(0), 2);
907    }
908
909    #[test]
910    fn test_csr_from_raw_validation() {
911        // Wrong row_ptr length
912        let r = CsrGraph::from_raw(3, vec![0, 1], vec![0], vec![1.0], true);
913        assert!(r.is_err());
914
915        // Mismatched col/val lengths
916        let r = CsrGraph::from_raw(2, vec![0, 1, 2], vec![1, 0], vec![1.0], true);
917        assert!(r.is_err());
918
919        // Column index out of range
920        let r = CsrGraph::from_raw(2, vec![0, 1, 2], vec![1, 5], vec![1.0, 2.0], true);
921        assert!(r.is_err());
922    }
923
924    #[test]
925    fn test_csr_spmv() {
926        // Simple directed: 0->1 (w=2), 1->0 (w=3)
927        let g =
928            CsrGraph::from_edges(2, vec![(0, 1, 2.0), (1, 0, 3.0)], true).expect("build failed");
929        let x = vec![1.0, 2.0];
930        let y = g.spmv(&x).expect("spmv failed");
931        // y[0] = 2.0 * 2.0 = 4.0
932        // y[1] = 3.0 * 1.0 = 3.0
933        assert!((y[0] - 4.0).abs() < 1e-10);
934        assert!((y[1] - 3.0).abs() < 1e-10);
935    }
936
937    #[test]
938    fn test_csr_spmv_wrong_length() {
939        let g = CsrGraph::from_edges(3, vec![(0, 1, 1.0)], true).expect("build failed");
940        let r = g.spmv(&[1.0, 2.0]);
941        assert!(r.is_err());
942    }
943
944    #[test]
945    fn test_csr_transpose() {
946        let g = CsrGraph::from_edges(3, vec![(0, 1, 1.0), (0, 2, 2.0), (2, 1, 3.0)], true)
947            .expect("build");
948        let gt = g.transpose().expect("transpose");
949
950        assert_eq!(gt.num_nodes(), 3);
951        assert!(gt.has_edge(1, 0));
952        assert!(gt.has_edge(2, 0));
953        assert!(gt.has_edge(1, 2));
954        assert!(!gt.has_edge(0, 1)); // reversed
955    }
956
957    #[test]
958    fn test_csr_transpose_undirected() {
959        let g = CsrGraph::from_edges(3, vec![(0, 1, 1.0)], false).expect("build");
960        let gt = g.transpose().expect("transpose");
961        // For undirected, transpose is identity
962        assert_eq!(gt.num_edges(), g.num_edges());
963        assert!(gt.has_edge(0, 1));
964        assert!(gt.has_edge(1, 0));
965    }
966
967    #[test]
968    fn test_csr_adjacency_list_roundtrip() {
969        let edges = vec![(0, 1, 1.5), (1, 2, 2.5), (2, 0, 3.5)];
970        let g = CsrGraph::from_edges(3, edges, true).expect("build");
971
972        let adj = g.to_adjacency_list();
973        assert_eq!(adj.num_nodes, 3);
974        assert_eq!(adj.adjacency[0].len(), 1); // 0->1
975        assert_eq!(adj.adjacency[1].len(), 1); // 1->2
976        assert_eq!(adj.adjacency[2].len(), 1); // 2->0
977
978        let g2 = CsrGraph::from_adjacency_list(&adj).expect("from adj");
979        assert_eq!(g2.num_nodes(), 3);
980        assert_eq!(g2.num_edges(), 3);
981        assert!(g2.has_edge(0, 1));
982        assert!(g2.has_edge(1, 2));
983        assert!(g2.has_edge(2, 0));
984        assert_eq!(g2.edge_weight(0, 1), Some(1.5));
985    }
986
987    #[test]
988    fn test_csr_graph_conversion() {
989        let mut graph: crate::Graph<usize, f64> = crate::Graph::new();
990        for i in 0..5 {
991            graph.add_node(i);
992        }
993        graph.add_edge(0, 1, 1.0).expect("add edge");
994        graph.add_edge(1, 2, 2.0).expect("add edge");
995        graph.add_edge(2, 3, 3.0).expect("add edge");
996        graph.add_edge(3, 4, 4.0).expect("add edge");
997
998        let csr = CsrGraph::from_graph(&graph).expect("from_graph");
999        assert_eq!(csr.num_nodes(), 5);
1000        assert_eq!(csr.num_logical_edges(), 4);
1001        assert!(!csr.is_directed());
1002        assert!(csr.has_edge(0, 1));
1003        assert!(csr.has_edge(1, 0)); // undirected
1004
1005        // Convert back
1006        let graph2 = csr.to_graph();
1007        assert_eq!(graph2.node_count(), 5);
1008        assert_eq!(graph2.edge_count(), 4);
1009    }
1010
1011    #[test]
1012    fn test_csr_empty_graph() {
1013        let g = CsrGraph::from_edges(5, vec![], true).expect("build");
1014        assert_eq!(g.num_nodes(), 5);
1015        assert_eq!(g.num_edges(), 0);
1016        assert_eq!(g.degree(0), 0);
1017        assert!(!g.has_edge(0, 1));
1018        let neighbors: Vec<_> = g.neighbors(0).collect();
1019        assert!(neighbors.is_empty());
1020    }
1021
1022    #[test]
1023    fn test_csr_single_node() {
1024        let g = CsrGraph::from_edges(1, vec![], true).expect("build");
1025        assert_eq!(g.num_nodes(), 1);
1026        assert_eq!(g.degree(0), 0);
1027    }
1028
1029    #[test]
1030    fn test_csr_self_loop() {
1031        let g = CsrGraph::from_edges(2, vec![(0, 0, 1.0), (0, 1, 2.0)], true).expect("build");
1032        assert_eq!(g.degree(0), 2);
1033        assert!(g.has_edge(0, 0));
1034        assert!(g.has_edge(0, 1));
1035    }
1036
1037    #[test]
1038    fn test_csr_memory_bytes() {
1039        let g = CsrGraph::from_edges(100, vec![(0, 1, 1.0)], true).expect("build");
1040        let mem = g.memory_bytes();
1041        assert!(mem > 0);
1042        // Should be at least the row_ptr size
1043        assert!(mem >= 101 * std::mem::size_of::<usize>());
1044    }
1045
1046    #[test]
1047    fn test_csr_display() {
1048        let g = CsrGraph::from_edges(10, vec![(0, 1, 1.0), (2, 3, 1.0)], false).expect("build");
1049        let s = format!("{g}");
1050        assert!(s.contains("CsrGraph"));
1051        assert!(s.contains("nodes=10"));
1052    }
1053
1054    #[test]
1055    fn test_csr_neighbor_iter_exact_size() {
1056        let g = CsrGraph::from_edges(4, vec![(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0)], true)
1057            .expect("build");
1058        let iter = g.neighbors(0);
1059        assert_eq!(iter.len(), 3);
1060    }
1061
1062    #[test]
1063    fn test_csr_out_of_range_node() {
1064        let g = CsrGraph::from_edges(3, vec![(0, 1, 1.0)], true).expect("build");
1065        // Out-of-range should return empty / 0
1066        assert_eq!(g.degree(100), 0);
1067        let n: Vec<_> = g.neighbors(100).collect();
1068        assert!(n.is_empty());
1069        assert!(!g.has_edge(100, 0));
1070        assert_eq!(g.edge_weight(100, 0), None);
1071    }
1072
1073    #[cfg(feature = "parallel")]
1074    #[test]
1075    fn test_csr_parallel_build() {
1076        let edges: Vec<(usize, usize, f64)> = (0..100).map(|i| (i, (i + 1) % 100, 1.0)).collect();
1077        let g = CsrGraph::from_edges_parallel(100, edges, false).expect("parallel build");
1078        assert_eq!(g.num_nodes(), 100);
1079        assert_eq!(g.num_logical_edges(), 100);
1080        for i in 0..100 {
1081            assert!(g.has_edge(i, (i + 1) % 100));
1082            assert!(g.has_edge((i + 1) % 100, i));
1083        }
1084    }
1085
1086    #[cfg(feature = "parallel")]
1087    #[test]
1088    fn test_csr_builder_parallel() {
1089        let mut builder = CsrGraphBuilder::with_capacity(10, 20, true);
1090        for i in 0..9 {
1091            builder.add_edge(i, i + 1, (i + 1) as f64).expect("add");
1092        }
1093        let g = builder.build_parallel().expect("build parallel");
1094        assert_eq!(g.num_nodes(), 10);
1095        assert_eq!(g.num_edges(), 9);
1096    }
1097
1098    #[cfg(feature = "parallel")]
1099    #[test]
1100    fn test_csr_spmv_parallel() {
1101        let g = CsrGraph::from_edges(2, vec![(0, 1, 2.0), (1, 0, 3.0)], true).expect("build");
1102        let x = vec![1.0, 2.0];
1103        let y = g.spmv_parallel(&x).expect("spmv");
1104        assert!((y[0] - 4.0).abs() < 1e-10);
1105        assert!((y[1] - 3.0).abs() < 1e-10);
1106    }
1107
1108    #[test]
1109    fn test_csr_unweighted_edges() {
1110        let edges = [(0, 1), (1, 2), (2, 3)];
1111        let g = CsrGraph::from_unweighted_edges(4, &edges, false).expect("build");
1112        assert_eq!(g.num_nodes(), 4);
1113        assert_eq!(g.num_logical_edges(), 3);
1114        assert_eq!(g.edge_weight(0, 1), Some(1.0));
1115    }
1116
1117    #[test]
1118    fn test_csr_dense_graph() {
1119        // Complete graph K5
1120        let mut edges = Vec::new();
1121        for i in 0..5 {
1122            for j in 0..5 {
1123                if i != j {
1124                    edges.push((i, j, 1.0));
1125                }
1126            }
1127        }
1128        let g = CsrGraph::from_edges(5, edges, true).expect("build");
1129        assert_eq!(g.num_nodes(), 5);
1130        assert_eq!(g.num_edges(), 20); // 5*4 directed edges
1131        for i in 0..5 {
1132            assert_eq!(g.degree(i), 4);
1133        }
1134    }
1135}