Skip to main content

trueno_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format.
2//!
3//! CSR is the primary format for sparse matrix arithmetic (SpMV, SpMM).
4//! Stores row offsets, column indices, and values in three contiguous arrays.
5//!
6//! # Contract: sparse-spmv-v1.yaml
7//!
8//! Format invariants validated at construction:
9//! - `offsets[0] == 0 && offsets[rows] == nnz`
10//! - `offsets` monotonically non-decreasing
11//! - All column indices in `[0, cols)`
12
13use crate::coo::CooMatrix;
14use crate::error::SparseError;
15use crate::validate::validate_csr_invariants;
16
17/// Compressed Sparse Row matrix.
18///
19/// Three-array representation: `offsets` (len = rows+1), `col_indices` (len = nnz),
20/// `values` (len = nnz). Row `i` has nonzeros at positions `offsets[i]..offsets[i+1]`.
21///
22/// All invariants are validated at construction time (provable contract).
23#[derive(Debug, Clone)]
24pub struct CsrMatrix<T> {
25    rows: usize,
26    cols: usize,
27    offsets: Vec<u32>,
28    col_indices: Vec<u32>,
29    values: Vec<T>,
30}
31
32impl<T: Clone + Default> CsrMatrix<T> {
33    /// Create a CSR matrix with full validation.
34    ///
35    /// # Contract: sparse-spmv-v1.yaml / format_validation
36    ///
37    /// Validates all CSR invariants before construction.
38    ///
39    /// # Errors
40    ///
41    /// Returns error if any CSR invariant is violated.
42    pub fn new(
43        rows: usize,
44        cols: usize,
45        offsets: Vec<u32>,
46        col_indices: Vec<u32>,
47        values: Vec<T>,
48    ) -> Result<Self, SparseError> {
49        validate_csr_invariants(rows, cols, &offsets, &col_indices, values.len())?;
50        Ok(Self {
51            rows,
52            cols,
53            offsets,
54            col_indices,
55            values,
56        })
57    }
58
59    /// Convert from COO format to CSR.
60    ///
61    /// Sorts triplets by row, then by column within each row.
62    /// Duplicate entries are summed (standard convention).
63    #[must_use]
64    pub fn from_coo(coo: &CooMatrix<T>) -> Self
65    where
66        T: std::ops::AddAssign + Copy,
67    {
68        let rows = coo.rows;
69        let cols = coo.cols;
70        let nnz = coo.nnz();
71
72        if nnz == 0 {
73            return Self {
74                rows,
75                cols,
76                offsets: vec![0; rows + 1],
77                col_indices: Vec::new(),
78                values: Vec::new(),
79            };
80        }
81
82        // Count nonzeros per row
83        let mut row_counts = vec![0u32; rows];
84        for &r in &coo.row_indices {
85            row_counts[r as usize] += 1;
86        }
87
88        // Build offsets via prefix sum
89        let mut offsets = vec![0u32; rows + 1];
90        for i in 0..rows {
91            offsets[i + 1] = offsets[i] + row_counts[i];
92        }
93
94        // Fill col_indices and values (sort by row)
95        let mut col_indices = vec![0u32; nnz];
96        let mut values = vec![T::default(); nnz];
97        let mut write_pos = offsets.clone();
98
99        for idx in 0..nnz {
100            let r = coo.row_indices[idx] as usize;
101            let pos = write_pos[r] as usize;
102            col_indices[pos] = coo.col_indices[idx];
103            values[pos] = coo.values[idx];
104            write_pos[r] += 1;
105        }
106
107        // Sort columns within each row
108        for i in 0..rows {
109            let start = offsets[i] as usize;
110            let end = offsets[i + 1] as usize;
111            if end - start > 1 {
112                // Simple insertion sort (rows are typically short)
113                for j in (start + 1)..end {
114                    let mut k = j;
115                    while k > start && col_indices[k - 1] > col_indices[k] {
116                        col_indices.swap(k - 1, k);
117                        values.swap(k - 1, k);
118                        k -= 1;
119                    }
120                }
121            }
122        }
123
124        Self {
125            rows,
126            cols,
127            offsets,
128            col_indices,
129            values,
130        }
131    }
132
133    /// Create an identity matrix of size n.
134    #[must_use]
135    pub fn identity(n: usize) -> Self
136    where
137        T: From<f32>,
138    {
139        let offsets: Vec<u32> = (0..=n).map(|i| i as u32).collect();
140        let col_indices: Vec<u32> = (0..n).map(|i| i as u32).collect();
141        let values: Vec<T> = (0..n).map(|_| T::from(1.0)).collect();
142        Self {
143            rows: n,
144            cols: n,
145            offsets,
146            col_indices,
147            values,
148        }
149    }
150
151    /// Number of rows.
152    #[must_use]
153    pub fn rows(&self) -> usize {
154        self.rows
155    }
156
157    /// Number of columns.
158    #[must_use]
159    pub fn cols(&self) -> usize {
160        self.cols
161    }
162
163    /// Number of stored nonzero entries.
164    #[must_use]
165    pub fn nnz(&self) -> usize {
166        self.values.len()
167    }
168
169    /// Row offsets array (len = rows + 1).
170    #[must_use]
171    pub fn offsets(&self) -> &[u32] {
172        &self.offsets
173    }
174
175    /// Column indices array (len = nnz).
176    #[must_use]
177    pub fn col_indices(&self) -> &[u32] {
178        &self.col_indices
179    }
180
181    /// Values array (len = nnz).
182    #[must_use]
183    pub fn values(&self) -> &[T] {
184        &self.values
185    }
186
187    /// Average number of nonzeros per row.
188    #[must_use]
189    #[allow(clippy::cast_precision_loss)]
190    pub fn avg_nnz_per_row(&self) -> f64 {
191        if self.rows == 0 {
192            0.0
193        } else {
194            self.nnz() as f64 / self.rows as f64
195        }
196    }
197
198    /// Variance of row lengths (key metric for algorithm selection).
199    ///
200    /// High variance → merge-based SpMV; low variance → row-split SpMV.
201    #[must_use]
202    #[allow(clippy::cast_precision_loss)]
203    pub fn row_length_variance(&self) -> f64 {
204        if self.rows == 0 {
205            return 0.0;
206        }
207        let mean = self.avg_nnz_per_row();
208        let sum_sq: f64 = (0..self.rows)
209            .map(|i| {
210                let len = f64::from(self.offsets[i + 1] - self.offsets[i]);
211                (len - mean) * (len - mean)
212            })
213            .sum();
214        sum_sq / self.rows as f64
215    }
216
217    /// Convert to dense matrix (row-major).
218    #[must_use]
219    pub fn to_dense(&self) -> Vec<T>
220    where
221        T: Copy + std::ops::AddAssign,
222    {
223        let mut dense = vec![T::default(); self.rows * self.cols];
224        for i in 0..self.rows {
225            let start = self.offsets[i] as usize;
226            let end = self.offsets[i + 1] as usize;
227            for idx in start..end {
228                let j = self.col_indices[idx] as usize;
229                dense[i * self.cols + j] += self.values[idx];
230            }
231        }
232        dense
233    }
234}