Skip to main content

cyanea_omics/
sparse.rs

1//! Coordinate-format (COO) sparse matrix.
2//!
3//! [`SparseMatrix`] stores non-zero entries as `(row, col, value)` triplets.
4//! This format is efficient for incremental construction and iteration, and is
5//! the natural import format for single-cell count matrices and other sparse
6//! omics data.
7
8use cyanea_core::{CyaneaError, Result, Summarizable};
9
10/// A sparse matrix in COO (coordinate) format.
11#[derive(Debug, Clone)]
12pub struct SparseMatrix {
13    rows: Vec<usize>,
14    cols: Vec<usize>,
15    values: Vec<f64>,
16    n_rows: usize,
17    n_cols: usize,
18}
19
20impl SparseMatrix {
21    /// Create an empty sparse matrix with the given dimensions.
22    pub fn new(n_rows: usize, n_cols: usize) -> Self {
23        Self {
24            rows: Vec::new(),
25            cols: Vec::new(),
26            values: Vec::new(),
27            n_rows,
28            n_cols,
29        }
30    }
31
32    /// Create a sparse matrix from triplet vectors.
33    ///
34    /// All three vectors must have the same length, and all indices must be
35    /// within bounds.
36    pub fn from_triplets(
37        rows: Vec<usize>,
38        cols: Vec<usize>,
39        values: Vec<f64>,
40        n_rows: usize,
41        n_cols: usize,
42    ) -> Result<Self> {
43        if rows.len() != cols.len() || cols.len() != values.len() {
44            return Err(CyaneaError::InvalidInput(
45                "rows, cols, and values must have the same length".into(),
46            ));
47        }
48        for (i, (&r, &c)) in rows.iter().zip(cols.iter()).enumerate() {
49            if r >= n_rows || c >= n_cols {
50                return Err(CyaneaError::InvalidInput(format!(
51                    "triplet {i} index ({r}, {c}) out of bounds for ({n_rows}, {n_cols})"
52                )));
53            }
54        }
55        Ok(Self {
56            rows,
57            cols,
58            values,
59            n_rows,
60            n_cols,
61        })
62    }
63
64    /// Insert a single entry. Returns an error if indices are out of bounds.
65    pub fn insert(&mut self, row: usize, col: usize, value: f64) -> Result<()> {
66        if row >= self.n_rows || col >= self.n_cols {
67            return Err(CyaneaError::InvalidInput(format!(
68                "index ({row}, {col}) out of bounds for ({}, {})",
69                self.n_rows, self.n_cols
70            )));
71        }
72        self.rows.push(row);
73        self.cols.push(col);
74        self.values.push(value);
75        Ok(())
76    }
77
78    /// Get the value at `(row, col)`. Returns 0.0 if no entry is stored.
79    ///
80    /// This is an O(nnz) scan.
81    pub fn get(&self, row: usize, col: usize) -> f64 {
82        for i in 0..self.values.len() {
83            if self.rows[i] == row && self.cols[i] == col {
84                return self.values[i];
85            }
86        }
87        0.0
88    }
89
90    /// Number of stored (non-zero) entries.
91    pub fn nnz(&self) -> usize {
92        self.values.len()
93    }
94
95    /// Fraction of entries that are stored: `nnz / (n_rows * n_cols)`.
96    pub fn density(&self) -> f64 {
97        let total = self.n_rows as f64 * self.n_cols as f64;
98        if total == 0.0 {
99            return 0.0;
100        }
101        self.values.len() as f64 / total
102    }
103
104    /// (n_rows, n_cols).
105    pub fn shape(&self) -> (usize, usize) {
106        (self.n_rows, self.n_cols)
107    }
108
109    /// Convert to a dense 2D vector.
110    pub fn to_dense(&self) -> Vec<Vec<f64>> {
111        let mut dense = vec![vec![0.0; self.n_cols]; self.n_rows];
112        for i in 0..self.values.len() {
113            dense[self.rows[i]][self.cols[i]] = self.values[i];
114        }
115        dense
116    }
117
118    /// Create a sparse matrix from dense data, storing only values where `|value| > threshold`.
119    pub fn from_dense(data: &[Vec<f64>], threshold: f64) -> Self {
120        let n_rows = data.len();
121        let n_cols = data.first().map_or(0, |r| r.len());
122        let mut rows = Vec::new();
123        let mut cols = Vec::new();
124        let mut values = Vec::new();
125
126        for (r, row) in data.iter().enumerate() {
127            for (c, &val) in row.iter().enumerate() {
128                if val.abs() > threshold {
129                    rows.push(r);
130                    cols.push(c);
131                    values.push(val);
132                }
133            }
134        }
135
136        Self {
137            rows,
138            cols,
139            values,
140            n_rows,
141            n_cols,
142        }
143    }
144
145    /// Number of stored entries in a given row.
146    pub fn row_nnz(&self, row: usize) -> usize {
147        self.rows.iter().filter(|&&r| r == row).count()
148    }
149
150    /// Number of stored entries in a given column.
151    pub fn col_nnz(&self, col: usize) -> usize {
152        self.cols.iter().filter(|&&c| c == col).count()
153    }
154
155    /// Convert COO to CSR format.
156    ///
157    /// Returns `(data, indices, indptr)` where:
158    /// - `data` contains the non-zero values
159    /// - `indices` contains the column index for each value
160    /// - `indptr[i]..indptr[i+1]` gives the range of entries for row `i`
161    pub fn to_csr(&self) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
162        // Build sorted order by (row, col)
163        let nnz = self.values.len();
164        let mut order: Vec<usize> = (0..nnz).collect();
165        order.sort_by_key(|&i| (self.rows[i], self.cols[i]));
166
167        let mut data = Vec::with_capacity(nnz);
168        let mut indices = Vec::with_capacity(nnz);
169        let mut indptr = vec![0usize; self.n_rows + 1];
170
171        for &i in &order {
172            data.push(self.values[i]);
173            indices.push(self.cols[i]);
174            indptr[self.rows[i] + 1] += 1;
175        }
176
177        // Cumulative sum to build indptr
178        for i in 1..=self.n_rows {
179            indptr[i] += indptr[i - 1];
180        }
181
182        (data, indices, indptr)
183    }
184
185    /// Create a sparse matrix from CSR format.
186    ///
187    /// - `data` — non-zero values
188    /// - `indices` — column index for each value
189    /// - `indptr` — row pointer array (length `n_rows + 1`)
190    pub fn from_csr(
191        data: Vec<f64>,
192        indices: Vec<usize>,
193        indptr: Vec<usize>,
194        n_rows: usize,
195        n_cols: usize,
196    ) -> Result<Self> {
197        if data.len() != indices.len() {
198            return Err(CyaneaError::InvalidInput(
199                "CSR data and indices must have the same length".into(),
200            ));
201        }
202        if indptr.len() != n_rows + 1 {
203            return Err(CyaneaError::InvalidInput(format!(
204                "CSR indptr length ({}) must be n_rows + 1 ({})",
205                indptr.len(),
206                n_rows + 1
207            )));
208        }
209
210        let nnz = data.len();
211        let mut rows = Vec::with_capacity(nnz);
212        let mut cols = Vec::with_capacity(nnz);
213
214        for row in 0..n_rows {
215            let start = indptr[row];
216            let end = indptr[row + 1];
217            for idx in start..end {
218                if idx >= nnz {
219                    return Err(CyaneaError::InvalidInput(format!(
220                        "CSR indptr references index {idx} but nnz is {nnz}"
221                    )));
222                }
223                if indices[idx] >= n_cols {
224                    return Err(CyaneaError::InvalidInput(format!(
225                        "CSR column index {} out of bounds for n_cols={}",
226                        indices[idx], n_cols
227                    )));
228                }
229                rows.push(row);
230                cols.push(indices[idx]);
231            }
232        }
233
234        Ok(Self {
235            rows,
236            cols,
237            values: data,
238            n_rows,
239            n_cols,
240        })
241    }
242
243    /// Iterate over stored triplets `(row, col, value)`.
244    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, f64)> + '_ {
245        self.rows
246            .iter()
247            .zip(self.cols.iter())
248            .zip(self.values.iter())
249            .map(|((&r, &c), &v)| (r, c, v))
250    }
251
252    /// Sum of values in each column.
253    pub fn column_sums(&self) -> Vec<f64> {
254        let mut sums = vec![0.0; self.n_cols];
255        for i in 0..self.values.len() {
256            sums[self.cols[i]] += self.values[i];
257        }
258        sums
259    }
260
261    /// Mean value of each column (dividing by n_rows, treating missing as zero).
262    pub fn column_means(&self) -> Vec<f64> {
263        if self.n_rows == 0 {
264            return vec![0.0; self.n_cols];
265        }
266        let sums = self.column_sums();
267        let n = self.n_rows as f64;
268        sums.into_iter().map(|s| s / n).collect()
269    }
270
271    /// Sum of values in each row.
272    pub fn row_sums(&self) -> Vec<f64> {
273        let mut sums = vec![0.0; self.n_rows];
274        for i in 0..self.values.len() {
275            sums[self.rows[i]] += self.values[i];
276        }
277        sums
278    }
279
280    /// Multiply each row's values by the corresponding factor.
281    ///
282    /// `factors` must have length `n_rows`.
283    pub fn scale_rows(&mut self, factors: &[f64]) {
284        for i in 0..self.values.len() {
285            self.values[i] *= factors[self.rows[i]];
286        }
287    }
288
289    /// Apply a function to every stored value in place.
290    pub fn map_values(&mut self, f: impl Fn(f64) -> f64) {
291        for v in &mut self.values {
292            *v = f(*v);
293        }
294    }
295
296    /// Number of rows.
297    pub fn n_rows(&self) -> usize {
298        self.n_rows
299    }
300
301    /// Number of columns.
302    pub fn n_cols(&self) -> usize {
303        self.n_cols
304    }
305}
306
307impl Summarizable for SparseMatrix {
308    fn summary(&self) -> String {
309        format!(
310            "SparseMatrix: {}\u{00d7}{}, {} nonzeros ({:.2}% density)",
311            self.n_rows,
312            self.n_cols,
313            self.nnz(),
314            self.density() * 100.0
315        )
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_new_empty() {
325        let m = SparseMatrix::new(10, 20);
326        assert_eq!(m.shape(), (10, 20));
327        assert_eq!(m.nnz(), 0);
328        assert_eq!(m.density(), 0.0);
329    }
330
331    #[test]
332    fn test_from_triplets() {
333        let m = SparseMatrix::from_triplets(
334            vec![0, 1, 2],
335            vec![0, 1, 2],
336            vec![1.0, 2.0, 3.0],
337            3,
338            3,
339        )
340        .unwrap();
341        assert_eq!(m.nnz(), 3);
342        assert_eq!(m.get(0, 0), 1.0);
343        assert_eq!(m.get(1, 1), 2.0);
344        assert_eq!(m.get(0, 1), 0.0);
345    }
346
347    #[test]
348    fn test_from_triplets_bounds_check() {
349        let result = SparseMatrix::from_triplets(
350            vec![5],
351            vec![0],
352            vec![1.0],
353            3,
354            3,
355        );
356        assert!(result.is_err());
357    }
358
359    #[test]
360    fn test_from_triplets_length_mismatch() {
361        let result = SparseMatrix::from_triplets(
362            vec![0, 1],
363            vec![0],
364            vec![1.0],
365            3,
366            3,
367        );
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn test_insert() {
373        let mut m = SparseMatrix::new(3, 3);
374        m.insert(0, 0, 5.0).unwrap();
375        assert_eq!(m.get(0, 0), 5.0);
376        assert_eq!(m.nnz(), 1);
377
378        assert!(m.insert(10, 0, 1.0).is_err());
379    }
380
381    #[test]
382    fn test_density() {
383        let m = SparseMatrix::from_triplets(
384            vec![0, 1],
385            vec![0, 1],
386            vec![1.0, 2.0],
387            10,
388            10,
389        )
390        .unwrap();
391        assert!((m.density() - 0.02).abs() < 1e-10);
392    }
393
394    #[test]
395    fn test_to_dense() {
396        let m = SparseMatrix::from_triplets(
397            vec![0, 1],
398            vec![1, 0],
399            vec![3.0, 7.0],
400            2,
401            2,
402        )
403        .unwrap();
404        let dense = m.to_dense();
405        assert_eq!(dense, vec![vec![0.0, 3.0], vec![7.0, 0.0]]);
406    }
407
408    #[test]
409    fn test_from_dense() {
410        let data = vec![vec![0.0, 3.0], vec![7.0, 0.0]];
411        let m = SparseMatrix::from_dense(&data, 0.0);
412        assert_eq!(m.nnz(), 2);
413        assert_eq!(m.get(0, 1), 3.0);
414        assert_eq!(m.get(1, 0), 7.0);
415    }
416
417    #[test]
418    fn test_from_dense_with_threshold() {
419        let data = vec![vec![0.1, 3.0], vec![7.0, 0.05]];
420        let m = SparseMatrix::from_dense(&data, 0.5);
421        assert_eq!(m.nnz(), 2); // only 3.0 and 7.0
422    }
423
424    #[test]
425    fn test_row_col_nnz() {
426        let m = SparseMatrix::from_triplets(
427            vec![0, 0, 1],
428            vec![0, 1, 0],
429            vec![1.0, 2.0, 3.0],
430            2,
431            2,
432        )
433        .unwrap();
434        assert_eq!(m.row_nnz(0), 2);
435        assert_eq!(m.row_nnz(1), 1);
436        assert_eq!(m.col_nnz(0), 2);
437        assert_eq!(m.col_nnz(1), 1);
438    }
439
440    #[test]
441    fn test_iter() {
442        let m = SparseMatrix::from_triplets(
443            vec![0, 1],
444            vec![0, 1],
445            vec![1.0, 2.0],
446            2,
447            2,
448        )
449        .unwrap();
450        let triplets: Vec<_> = m.iter().collect();
451        assert_eq!(triplets, vec![(0, 0, 1.0), (1, 1, 2.0)]);
452    }
453
454    #[test]
455    fn test_summary() {
456        let m = SparseMatrix::from_triplets(
457            vec![0],
458            vec![0],
459            vec![1.0],
460            100,
461            50,
462        )
463        .unwrap();
464        assert_eq!(
465            m.summary(),
466            "SparseMatrix: 100\u{00d7}50, 1 nonzeros (0.02% density)"
467        );
468    }
469
470    #[test]
471    fn test_zero_dimension_density() {
472        let m = SparseMatrix::new(0, 0);
473        assert_eq!(m.density(), 0.0);
474    }
475
476    #[test]
477    fn test_csr_roundtrip() {
478        let m = SparseMatrix::from_triplets(
479            vec![0, 0, 1, 2, 2],
480            vec![0, 2, 1, 0, 2],
481            vec![1.0, 2.0, 3.0, 4.0, 5.0],
482            3,
483            3,
484        )
485        .unwrap();
486
487        let (data, indices, indptr) = m.to_csr();
488        let m2 = SparseMatrix::from_csr(data, indices, indptr, 3, 3).unwrap();
489
490        assert_eq!(m2.shape(), (3, 3));
491        assert_eq!(m2.nnz(), 5);
492        assert_eq!(m2.get(0, 0), 1.0);
493        assert_eq!(m2.get(0, 2), 2.0);
494        assert_eq!(m2.get(1, 1), 3.0);
495        assert_eq!(m2.get(2, 0), 4.0);
496        assert_eq!(m2.get(2, 2), 5.0);
497        assert_eq!(m2.get(1, 0), 0.0);
498    }
499
500    #[test]
501    fn test_csr_empty() {
502        let m = SparseMatrix::new(3, 4);
503        let (data, indices, indptr) = m.to_csr();
504        assert!(data.is_empty());
505        assert!(indices.is_empty());
506        assert_eq!(indptr, vec![0, 0, 0, 0]);
507
508        let m2 = SparseMatrix::from_csr(data, indices, indptr, 3, 4).unwrap();
509        assert_eq!(m2.nnz(), 0);
510        assert_eq!(m2.shape(), (3, 4));
511    }
512
513    #[test]
514    fn test_column_sums() {
515        let m = SparseMatrix::from_triplets(
516            vec![0, 0, 1, 1],
517            vec![0, 1, 0, 2],
518            vec![1.0, 2.0, 3.0, 4.0],
519            2,
520            3,
521        )
522        .unwrap();
523        assert_eq!(m.column_sums(), vec![4.0, 2.0, 4.0]);
524    }
525
526    #[test]
527    fn test_column_sums_empty() {
528        let m = SparseMatrix::new(3, 4);
529        assert_eq!(m.column_sums(), vec![0.0, 0.0, 0.0, 0.0]);
530    }
531
532    #[test]
533    fn test_column_means() {
534        let m = SparseMatrix::from_triplets(
535            vec![0, 1],
536            vec![0, 0],
537            vec![4.0, 6.0],
538            2,
539            2,
540        )
541        .unwrap();
542        let means = m.column_means();
543        assert!((means[0] - 5.0).abs() < 1e-10);
544        assert!((means[1] - 0.0).abs() < 1e-10);
545    }
546
547    #[test]
548    fn test_column_means_zero_rows() {
549        let m = SparseMatrix::new(0, 3);
550        assert_eq!(m.column_means(), vec![0.0, 0.0, 0.0]);
551    }
552
553    #[test]
554    fn test_row_sums() {
555        let m = SparseMatrix::from_triplets(
556            vec![0, 0, 1, 2],
557            vec![0, 1, 0, 2],
558            vec![1.0, 2.0, 3.0, 4.0],
559            3,
560            3,
561        )
562        .unwrap();
563        assert_eq!(m.row_sums(), vec![3.0, 3.0, 4.0]);
564    }
565
566    #[test]
567    fn test_scale_rows() {
568        let mut m = SparseMatrix::from_triplets(
569            vec![0, 0, 1, 1],
570            vec![0, 1, 0, 1],
571            vec![2.0, 4.0, 6.0, 8.0],
572            2,
573            2,
574        )
575        .unwrap();
576        m.scale_rows(&[0.5, 2.0]);
577        assert!((m.get(0, 0) - 1.0).abs() < 1e-10);
578        assert!((m.get(0, 1) - 2.0).abs() < 1e-10);
579        assert!((m.get(1, 0) - 12.0).abs() < 1e-10);
580        assert!((m.get(1, 1) - 16.0).abs() < 1e-10);
581    }
582
583    #[test]
584    fn test_map_values() {
585        let mut m = SparseMatrix::from_triplets(
586            vec![0, 1],
587            vec![0, 1],
588            vec![4.0, 9.0],
589            2,
590            2,
591        )
592        .unwrap();
593        m.map_values(|v| v.sqrt());
594        assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
595        assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
596    }
597
598    #[test]
599    fn test_n_rows_n_cols() {
600        let m = SparseMatrix::new(5, 8);
601        assert_eq!(m.n_rows(), 5);
602        assert_eq!(m.n_cols(), 8);
603    }
604
605    #[test]
606    fn test_csr_single_row() {
607        let m = SparseMatrix::from_triplets(
608            vec![0, 0, 0],
609            vec![0, 2, 4],
610            vec![1.0, 2.0, 3.0],
611            1,
612            5,
613        )
614        .unwrap();
615
616        let (data, indices, indptr) = m.to_csr();
617        assert_eq!(data, vec![1.0, 2.0, 3.0]);
618        assert_eq!(indices, vec![0, 2, 4]);
619        assert_eq!(indptr, vec![0, 3]);
620
621        let m2 = SparseMatrix::from_csr(data, indices, indptr, 1, 5).unwrap();
622        assert_eq!(m2.nnz(), 3);
623        assert_eq!(m2.get(0, 0), 1.0);
624        assert_eq!(m2.get(0, 2), 2.0);
625        assert_eq!(m2.get(0, 4), 3.0);
626    }
627}