Skip to main content

neco_sparse/
sparse.rs

1use std::ops::AddAssign;
2
3/// CSR (Compressed Sparse Row) sparse matrix.
4#[derive(Debug, Clone)]
5pub struct CsrMat<T> {
6    nrows: usize,
7    ncols: usize,
8    row_offsets: Vec<usize>,
9    col_indices: Vec<usize>,
10    values: Vec<T>,
11}
12
13/// Row view of a CSR matrix.
14#[derive(Debug)]
15pub struct CsrRow<'a, T> {
16    col_indices: &'a [usize],
17    values: &'a [T],
18}
19
20impl<'a, T> CsrRow<'a, T> {
21    pub fn col_indices(&self) -> &'a [usize] {
22        self.col_indices
23    }
24
25    pub fn values(&self) -> &'a [T] {
26        self.values
27    }
28
29    pub fn nnz(&self) -> usize {
30        self.col_indices.len()
31    }
32}
33
34impl<T> CsrMat<T> {
35    pub fn try_from_csr_data(
36        nrows: usize,
37        ncols: usize,
38        row_offsets: Vec<usize>,
39        col_indices: Vec<usize>,
40        values: Vec<T>,
41    ) -> Result<Self, String> {
42        if row_offsets.len() != nrows + 1 {
43            return Err(format!(
44                "row_offsets length {} does not match nrows+1={}",
45                row_offsets.len(),
46                nrows + 1
47            ));
48        }
49        if col_indices.len() != values.len() {
50            return Err(format!(
51                "col_indices length {} does not match values length {}",
52                col_indices.len(),
53                values.len()
54            ));
55        }
56        let nnz = *row_offsets.last().unwrap_or(&0);
57        if col_indices.len() != nnz {
58            return Err(format!(
59                "col_indices length {} does not match last row_offset {}",
60                col_indices.len(),
61                nnz
62            ));
63        }
64        for window in row_offsets.windows(2) {
65            if window[0] > window[1] {
66                return Err("row_offsets is not monotonically non-decreasing".into());
67            }
68        }
69        for row in 0..nrows {
70            let start = row_offsets[row];
71            let end = row_offsets[row + 1];
72            let row_cols = &col_indices[start..end];
73            if row_cols.windows(2).any(|window| window[0] >= window[1]) {
74                return Err(format!(
75                    "col_indices in row {} are not strictly increasing",
76                    row
77                ));
78            }
79        }
80        for &col in &col_indices {
81            if col >= ncols {
82                return Err(format!(
83                    "col_index {} out of range for ncols={}",
84                    col, ncols
85                ));
86            }
87        }
88        Ok(Self {
89            nrows,
90            ncols,
91            row_offsets,
92            col_indices,
93            values,
94        })
95    }
96
97    pub fn nrows(&self) -> usize {
98        self.nrows
99    }
100
101    pub fn ncols(&self) -> usize {
102        self.ncols
103    }
104
105    pub fn nnz(&self) -> usize {
106        self.values.len()
107    }
108
109    pub fn row_offsets(&self) -> &[usize] {
110        &self.row_offsets
111    }
112
113    pub fn col_indices(&self) -> &[usize] {
114        &self.col_indices
115    }
116
117    pub fn values(&self) -> &[T] {
118        &self.values
119    }
120
121    pub fn values_mut(&mut self) -> &mut [T] {
122        &mut self.values
123    }
124
125    pub fn row(&self, i: usize) -> CsrRow<'_, T> {
126        let start = self.row_offsets[i];
127        let end = self.row_offsets[i + 1];
128        CsrRow {
129            col_indices: &self.col_indices[start..end],
130            values: &self.values[start..end],
131        }
132    }
133
134    pub fn row_iter(&self) -> impl Iterator<Item = CsrRow<'_, T>> {
135        (0..self.nrows).map(move |i| self.row(i))
136    }
137}
138
139impl<T: PartialEq + Copy> CsrMat<T> {
140    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
141        if row >= self.nrows || col >= self.ncols {
142            return None;
143        }
144        let start = self.row_offsets[row];
145        let end = self.row_offsets[row + 1];
146        let slice = &self.col_indices[start..end];
147        match slice.binary_search(&col) {
148            Ok(pos) => Some(&self.values[start + pos]),
149            Err(_) => None,
150        }
151    }
152}
153
154impl<T: Copy> CsrMat<T> {
155    pub fn triplet_iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
156        (0..self.nrows).flat_map(move |row| {
157            let start = self.row_offsets[row];
158            let end = self.row_offsets[row + 1];
159            (start..end).map(move |idx| (row, self.col_indices[idx], &self.values[idx]))
160        })
161    }
162}
163
164impl CsrMat<f64> {
165    pub fn identity(n: usize) -> Self {
166        let row_offsets: Vec<usize> = (0..=n).collect();
167        let col_indices: Vec<usize> = (0..n).collect();
168        let values = vec![1.0; n];
169        Self {
170            nrows: n,
171            ncols: n,
172            row_offsets,
173            col_indices,
174            values,
175        }
176    }
177
178    pub fn zeros(nrows: usize, ncols: usize) -> Self {
179        let row_offsets = vec![0; nrows + 1];
180        Self {
181            nrows,
182            ncols,
183            row_offsets,
184            col_indices: Vec::new(),
185            values: Vec::new(),
186        }
187    }
188
189    pub fn linear_combination(&self, alpha: f64, other: &Self, beta: f64) -> Result<Self, String> {
190        if self.nrows != other.nrows || self.ncols != other.ncols {
191            return Err(format!(
192                "matrix shape mismatch: lhs={}x{}, rhs={}x{}",
193                self.nrows, self.ncols, other.nrows, other.ncols
194            ));
195        }
196        if self.row_offsets != other.row_offsets || self.col_indices != other.col_indices {
197            return Err("linear_combination requires identical CSR sparsity patterns".into());
198        }
199
200        let values = self
201            .values
202            .iter()
203            .zip(other.values.iter())
204            .map(|(&lhs, &rhs)| alpha * lhs + beta * rhs)
205            .collect();
206
207        Self::try_from_csr_data(
208            self.nrows,
209            self.ncols,
210            self.row_offsets.clone(),
211            self.col_indices.clone(),
212            values,
213        )
214    }
215
216    pub fn diagonal(&self) -> Result<Vec<f64>, String> {
217        let ndiag = self.nrows.min(self.ncols);
218        let mut diagonal = Vec::with_capacity(ndiag);
219        for i in 0..ndiag {
220            let row = self.row(i);
221            let diag_pos = row
222                .col_indices()
223                .binary_search(&i)
224                .map_err(|_| format!("missing diagonal entry at row {i}"))?;
225            diagonal.push(row.values()[diag_pos]);
226        }
227        Ok(diagonal)
228    }
229
230    pub fn submatrix(&self, rows: &[usize], cols: &[usize]) -> Result<Self, String> {
231        let mut col_positions = vec![usize::MAX; self.ncols];
232        for (local_col, &global_col) in cols.iter().enumerate() {
233            if global_col >= self.ncols {
234                return Err(format!(
235                    "column index {} out of range for ncols={}",
236                    global_col, self.ncols
237                ));
238            }
239            if col_positions[global_col] != usize::MAX {
240                return Err(format!(
241                    "duplicate column index {global_col} in submatrix request"
242                ));
243            }
244            col_positions[global_col] = local_col;
245        }
246
247        let mut row_offsets = Vec::with_capacity(rows.len() + 1);
248        let mut col_indices = Vec::new();
249        let mut values = Vec::new();
250        row_offsets.push(0);
251
252        let mut seen_rows = vec![false; self.nrows];
253        for &global_row in rows {
254            if global_row >= self.nrows {
255                return Err(format!(
256                    "row index {} out of range for nrows={}",
257                    global_row, self.nrows
258                ));
259            }
260            if seen_rows[global_row] {
261                return Err(format!(
262                    "duplicate row index {global_row} in submatrix request"
263                ));
264            }
265            seen_rows[global_row] = true;
266
267            let row = self.row(global_row);
268            let mut entries: Vec<(usize, f64)> = row
269                .col_indices()
270                .iter()
271                .zip(row.values().iter())
272                .filter_map(|(&global_col, &value)| {
273                    let local_col = col_positions[global_col];
274                    (local_col != usize::MAX).then_some((local_col, value))
275                })
276                .collect();
277            entries.sort_unstable_by_key(|(local_col, _)| *local_col);
278
279            for (local_col, value) in entries {
280                col_indices.push(local_col);
281                values.push(value);
282            }
283            row_offsets.push(col_indices.len());
284        }
285
286        Self::try_from_csr_data(rows.len(), cols.len(), row_offsets, col_indices, values)
287    }
288}
289
290/// COO (Coordinate) sparse matrix.
291#[derive(Debug, Clone)]
292pub struct CooMat<T> {
293    nrows: usize,
294    ncols: usize,
295    rows: Vec<usize>,
296    cols: Vec<usize>,
297    vals: Vec<T>,
298}
299
300impl<T> CooMat<T> {
301    pub fn new(nrows: usize, ncols: usize) -> Self {
302        Self {
303            nrows,
304            ncols,
305            rows: Vec::new(),
306            cols: Vec::new(),
307            vals: Vec::new(),
308        }
309    }
310
311    pub fn push(&mut self, row: usize, col: usize, val: T) {
312        self.rows.push(row);
313        self.cols.push(col);
314        self.vals.push(val);
315    }
316}
317
318impl<T: Copy + Default + AddAssign + PartialEq> From<&CooMat<T>> for CsrMat<T> {
319    fn from(coo: &CooMat<T>) -> Self {
320        let nrows = coo.nrows;
321        let ncols = coo.ncols;
322        let nnz_raw = coo.rows.len();
323
324        if nnz_raw == 0 {
325            return Self {
326                nrows,
327                ncols,
328                row_offsets: vec![0; nrows + 1],
329                col_indices: Vec::new(),
330                values: Vec::new(),
331            };
332        }
333
334        // Sort indices by (row, col) without copying triplet data
335        let mut order: Vec<usize> = (0..nnz_raw).collect();
336        order.sort_unstable_by(|&a, &b| {
337            coo.rows[a]
338                .cmp(&coo.rows[b])
339                .then_with(|| coo.cols[a].cmp(&coo.cols[b]))
340        });
341
342        // Merge duplicates and build CSR arrays in a single pass
343        let mut row_offsets = Vec::with_capacity(nrows + 1);
344        let mut col_indices = Vec::with_capacity(nnz_raw);
345        let mut values = Vec::with_capacity(nnz_raw);
346
347        row_offsets.push(0);
348
349        // Fill leading empty rows
350        let first_row = coo.rows[order[0]];
351        if first_row > 0 {
352            row_offsets.extend(std::iter::repeat_n(0, first_row));
353        }
354
355        let mut prev_row = first_row;
356        let mut prev_col = coo.cols[order[0]];
357        let mut acc = T::default();
358        acc += coo.vals[order[0]];
359
360        for &idx in &order[1..] {
361            let r = coo.rows[idx];
362            let c = coo.cols[idx];
363
364            if r == prev_row && c == prev_col {
365                acc += coo.vals[idx];
366            } else {
367                // Flush previous entry
368                col_indices.push(prev_col);
369                values.push(acc);
370
371                // Fill empty rows between prev_row and r
372                for _ in prev_row..r {
373                    row_offsets.push(col_indices.len());
374                }
375
376                prev_row = r;
377                prev_col = c;
378                acc = T::default();
379                acc += coo.vals[idx];
380            }
381        }
382
383        // Flush last entry
384        col_indices.push(prev_col);
385        values.push(acc);
386
387        // Fill remaining rows (including closing the last occupied row)
388        while row_offsets.len() <= nrows {
389            row_offsets.push(col_indices.len());
390        }
391
392        debug_assert_eq!(row_offsets.len(), nrows + 1);
393
394        Self {
395            nrows,
396            ncols,
397            row_offsets,
398            col_indices,
399            values,
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    fn csr_from_triplets(
409        nrows: usize,
410        ncols: usize,
411        triplets: &[(usize, usize, f64)],
412    ) -> CsrMat<f64> {
413        let mut coo = CooMat::new(nrows, ncols);
414        for &(row, col, value) in triplets {
415            coo.push(row, col, value);
416        }
417        CsrMat::from(&coo)
418    }
419
420    #[test]
421    fn csr_basic() {
422        let mat =
423            CsrMat::try_from_csr_data(2, 2, vec![0, 2, 3], vec![0, 1, 1], vec![1.0, 2.0, 3.0])
424                .expect("CSR construction");
425        assert_eq!(mat.nrows(), 2);
426        assert_eq!(mat.ncols(), 2);
427        assert_eq!(mat.nnz(), 3);
428        assert_eq!(mat.row(0).col_indices(), &[0, 1]);
429        assert_eq!(mat.row(0).values(), &[1.0, 2.0]);
430        assert_eq!(mat.get(0, 0), Some(&1.0));
431        assert_eq!(mat.get(0, 1), Some(&2.0));
432        assert_eq!(mat.get(1, 1), Some(&3.0));
433        assert_eq!(mat.get(1, 0), None);
434    }
435
436    #[test]
437    fn csr_rejects_unsorted_row_columns() {
438        let err = CsrMat::try_from_csr_data(1, 3, vec![0, 2], vec![2, 1], vec![1.0, 2.0])
439            .expect_err("unsorted row must be rejected");
440        assert_eq!(err, "col_indices in row 0 are not strictly increasing");
441
442        let mat = CsrMat::try_from_csr_data(1, 3, vec![0, 2], vec![0, 2], vec![1.0, 2.0])
443            .expect("sorted CSR construction");
444        assert_eq!(mat.get(0, 0), Some(&1.0));
445        assert_eq!(mat.get(0, 2), Some(&2.0));
446    }
447
448    #[test]
449    fn coo_to_csr_accumulates_duplicates() {
450        let mut coo = CooMat::new(2, 2);
451        coo.push(0, 1, 1.5);
452        coo.push(0, 1, 2.5);
453        coo.push(1, 0, 4.0);
454        let csr = CsrMat::from(&coo);
455        assert_eq!(csr.nnz(), 2);
456        assert_eq!(csr.get(0, 1), Some(&4.0));
457        assert_eq!(csr.get(1, 0), Some(&4.0));
458    }
459
460    #[test]
461    fn coo_to_csr_empty() {
462        let coo: CooMat<f64> = CooMat::new(3, 3);
463        let csr = CsrMat::from(&coo);
464        assert_eq!(csr.nrows(), 3);
465        assert_eq!(csr.ncols(), 3);
466        assert_eq!(csr.nnz(), 0);
467        for i in 0..3 {
468            assert_eq!(csr.row(i).nnz(), 0);
469        }
470    }
471
472    #[test]
473    fn coo_to_csr_single_element() {
474        let mut coo = CooMat::new(5, 5);
475        coo.push(2, 3, 7.0);
476        let csr = CsrMat::from(&coo);
477        assert_eq!(csr.nnz(), 1);
478        assert_eq!(csr.get(2, 3), Some(&7.0));
479        assert_eq!(csr.get(0, 0), None);
480    }
481
482    #[test]
483    fn coo_to_csr_reverse_column_order() {
484        let mut coo = CooMat::new(1, 4);
485        coo.push(0, 3, 4.0);
486        coo.push(0, 1, 2.0);
487        coo.push(0, 0, 1.0);
488        coo.push(0, 2, 3.0);
489        let csr = CsrMat::from(&coo);
490        assert_eq!(csr.row(0).col_indices(), &[0, 1, 2, 3]);
491        assert_eq!(csr.row(0).values(), &[1.0, 2.0, 3.0, 4.0]);
492    }
493
494    #[test]
495    fn coo_to_csr_multiple_duplicates() {
496        let mut coo = CooMat::new(2, 2);
497        coo.push(0, 0, 1.0);
498        coo.push(0, 0, 2.0);
499        coo.push(0, 0, 3.0);
500        coo.push(0, 0, 4.0);
501        coo.push(1, 1, 5.0);
502        let csr = CsrMat::from(&coo);
503        assert_eq!(csr.nnz(), 2);
504        assert_eq!(csr.get(0, 0), Some(&10.0));
505        assert_eq!(csr.get(1, 1), Some(&5.0));
506    }
507
508    #[test]
509    fn coo_to_csr_sorted_columns_per_row() {
510        let mut coo = CooMat::new(3, 5);
511        coo.push(0, 4, 1.0);
512        coo.push(0, 0, 2.0);
513        coo.push(1, 3, 3.0);
514        coo.push(1, 1, 4.0);
515        coo.push(2, 2, 5.0);
516        let csr = CsrMat::from(&coo);
517        assert_eq!(csr.row(0).col_indices(), &[0, 4]);
518        assert_eq!(csr.row(1).col_indices(), &[1, 3]);
519        assert_eq!(csr.row(2).col_indices(), &[2]);
520    }
521
522    #[test]
523    fn coo_to_csr_sparse_rows_with_interior_gaps() {
524        let mut coo = CooMat::new(5, 3);
525        coo.push(0, 1, 1.0);
526        coo.push(4, 2, 2.0);
527        let csr = CsrMat::from(&coo);
528        assert_eq!(csr.nnz(), 2);
529        assert_eq!(csr.get(0, 1), Some(&1.0));
530        assert_eq!(csr.get(4, 2), Some(&2.0));
531        for i in 1..4 {
532            assert_eq!(csr.row(i).nnz(), 0);
533        }
534    }
535
536    #[test]
537    fn coo_to_csr_integer_type() {
538        let mut coo: CooMat<i32> = CooMat::new(2, 2);
539        coo.push(0, 0, 10);
540        coo.push(0, 0, 20);
541        coo.push(1, 1, 30);
542        let csr = CsrMat::from(&coo);
543        assert_eq!(csr.get(0, 0), Some(&30));
544        assert_eq!(csr.get(1, 1), Some(&30));
545    }
546
547    #[test]
548    fn linear_combination_matches_shifted_matrix_pattern() {
549        let k = csr_from_triplets(2, 2, &[(0, 0, 4.0), (0, 1, 1.0), (1, 1, 3.0)]);
550        let m = csr_from_triplets(2, 2, &[(0, 0, 1.0), (0, 1, 0.5), (1, 1, 2.0)]);
551
552        let shifted = k.linear_combination(1.0, &m, -2.0).unwrap();
553        assert_eq!(shifted.row(0).col_indices(), &[0, 1]);
554        assert_eq!(shifted.row(0).values(), &[2.0, 0.0]);
555        assert_eq!(shifted.row(1).values(), &[-1.0]);
556    }
557
558    #[test]
559    fn linear_combination_rejects_pattern_mismatch() {
560        let lhs = csr_from_triplets(2, 2, &[(0, 0, 1.0), (1, 1, 2.0)]);
561        let rhs = csr_from_triplets(2, 2, &[(0, 0, 1.0), (0, 1, 2.0), (1, 1, 3.0)]);
562
563        let err = lhs.linear_combination(1.0, &rhs, -1.0).unwrap_err();
564        assert!(err.contains("identical CSR sparsity patterns"), "err={err}");
565    }
566
567    #[test]
568    fn diagonal_extracts_all_present_entries() {
569        let mat = csr_from_triplets(3, 3, &[(0, 0, 2.0), (0, 2, 9.0), (1, 1, 3.0), (2, 2, 5.0)]);
570        assert_eq!(mat.diagonal().unwrap(), vec![2.0, 3.0, 5.0]);
571    }
572
573    #[test]
574    fn diagonal_rejects_missing_entry() {
575        let mat = csr_from_triplets(2, 2, &[(0, 1, 1.0), (1, 1, 2.0)]);
576        let err = mat.diagonal().unwrap_err();
577        assert!(err.contains("missing diagonal entry"), "err={err}");
578    }
579
580    #[test]
581    fn submatrix_preserves_requested_order_with_sorted_local_columns() {
582        let mat = csr_from_triplets(
583            3,
584            4,
585            &[
586                (0, 0, 1.0),
587                (0, 1, 2.0),
588                (0, 3, 3.0),
589                (1, 0, 4.0),
590                (1, 2, 5.0),
591                (2, 1, 6.0),
592                (2, 3, 7.0),
593            ],
594        );
595
596        let sub = mat.submatrix(&[2, 0], &[3, 1]).unwrap();
597        assert_eq!(sub.nrows(), 2);
598        assert_eq!(sub.ncols(), 2);
599        assert_eq!(sub.row(0).col_indices(), &[0, 1]);
600        assert_eq!(sub.row(0).values(), &[7.0, 6.0]);
601        assert_eq!(sub.row(1).col_indices(), &[0, 1]);
602        assert_eq!(sub.row(1).values(), &[3.0, 2.0]);
603    }
604
605    #[test]
606    fn submatrix_rejects_duplicate_indices() {
607        let mat = csr_from_triplets(2, 3, &[(0, 0, 1.0), (0, 1, 2.0), (1, 2, 3.0)]);
608
609        let row_err = mat.submatrix(&[0, 0], &[1]).unwrap_err();
610        assert!(row_err.contains("duplicate row index"), "err={row_err}");
611
612        let col_err = mat.submatrix(&[0], &[1, 1]).unwrap_err();
613        assert!(col_err.contains("duplicate column index"), "err={col_err}");
614    }
615}