Skip to main content

scry_learn/
sparse.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Sparse matrix types: CSR (Compressed Sparse Row) and CSC (Compressed Sparse Column).
3//!
4//! Designed for NLP/recommender workloads with 50K+ features and >99% zeros.
5//! Provides efficient row-oriented (CSR) and column-oriented (CSC) access,
6//! plus conversion between formats.
7
8use crate::error::{Result, ScryLearnError};
9use std::ops;
10
11// ---------------------------------------------------------------------------
12// SparseRow / SparseCol views
13// ---------------------------------------------------------------------------
14
15/// View into a single row of a [`CsrMatrix`].
16#[derive(Clone, Debug)]
17pub struct SparseRow<'a> {
18    indices: &'a [usize],
19    data: &'a [f64],
20}
21
22impl<'a> SparseRow<'a> {
23    /// Iterate over `(col_idx, value)` pairs in this row.
24    pub fn iter(&self) -> impl Iterator<Item = (usize, f64)> + 'a {
25        self.indices.iter().copied().zip(self.data.iter().copied())
26    }
27
28    /// Number of non-zero entries in this row.
29    pub fn nnz(&self) -> usize {
30        self.indices.len()
31    }
32
33    /// Column indices of non-zero entries (sorted).
34    pub fn indices(&self) -> &[usize] {
35        self.indices
36    }
37
38    /// Values of non-zero entries (parallel to `indices()`).
39    pub fn values(&self) -> &[f64] {
40        self.data
41    }
42
43    /// Sparse dot product with a dense vector.
44    pub fn dot(&self, other: &[f64]) -> f64 {
45        self.indices
46            .iter()
47            .zip(self.data.iter())
48            .map(|(&j, &v)| v * other[j])
49            .sum()
50    }
51}
52
53/// View into a single column of a [`CscMatrix`].
54#[derive(Clone, Debug)]
55pub struct SparseCol<'a> {
56    indices: &'a [usize],
57    data: &'a [f64],
58}
59
60impl<'a> SparseCol<'a> {
61    /// Iterate over `(row_idx, value)` pairs in this column.
62    pub fn iter(&self) -> impl Iterator<Item = (usize, f64)> + 'a {
63        self.indices.iter().copied().zip(self.data.iter().copied())
64    }
65
66    /// Number of non-zero entries in this column.
67    pub fn nnz(&self) -> usize {
68        self.indices.len()
69    }
70
71    /// Sparse dot product with a dense vector.
72    pub fn dot(&self, other: &[f64]) -> f64 {
73        self.indices
74            .iter()
75            .zip(self.data.iter())
76            .map(|(&i, &v)| v * other[i])
77            .sum()
78    }
79}
80
81// ---------------------------------------------------------------------------
82// CsrMatrix
83// ---------------------------------------------------------------------------
84
85/// Compressed Sparse Row matrix.
86///
87/// Efficient for row iteration (KNN predict, tree predict).
88/// Standard CSR layout: `indptr[i]..indptr[i+1]` gives the range
89/// into `indices` and `data` for row `i`.
90#[derive(Clone, Debug)]
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92#[non_exhaustive]
93pub struct CsrMatrix {
94    /// Row pointers: length `n_rows + 1`.
95    indptr: Vec<usize>,
96    /// Column indices for each non-zero element.
97    indices: Vec<usize>,
98    /// Non-zero values.
99    data: Vec<f64>,
100    n_rows: usize,
101    n_cols: usize,
102}
103
104impl CsrMatrix {
105    /// Build a CSR matrix from COO (triplet) format.
106    ///
107    /// Duplicate entries at the same `(row, col)` are summed.
108    pub fn from_triplets(
109        rows: &[usize],
110        cols: &[usize],
111        vals: &[f64],
112        n_rows: usize,
113        n_cols: usize,
114    ) -> Result<Self> {
115        let nnz = rows.len();
116        if cols.len() != nnz || vals.len() != nnz {
117            return Err(ScryLearnError::InvalidParameter(format!(
118                "triplet arrays must have equal length (rows={}, cols={}, vals={})",
119                nnz,
120                cols.len(),
121                vals.len()
122            )));
123        }
124
125        // Validate indices.
126        for i in 0..nnz {
127            if rows[i] >= n_rows || cols[i] >= n_cols {
128                return Err(ScryLearnError::InvalidParameter(format!(
129                    "triplet index ({}, {}) out of bounds for {}x{} matrix",
130                    rows[i], cols[i], n_rows, n_cols
131                )));
132            }
133        }
134
135        // Count entries per row.
136        let mut row_counts = vec![0usize; n_rows];
137        for &r in rows {
138            row_counts[r] += 1;
139        }
140
141        // Build indptr.
142        let mut indptr = vec![0usize; n_rows + 1];
143        for i in 0..n_rows {
144            indptr[i + 1] = indptr[i] + row_counts[i];
145        }
146
147        // Scatter triplets into CSR arrays.
148        let total = indptr[n_rows];
149        let mut csr_indices = vec![0usize; total];
150        let mut csr_data = vec![0.0f64; total];
151        let mut offsets = indptr[..n_rows].to_vec();
152
153        for k in 0..nnz {
154            let r = rows[k];
155            let pos = offsets[r];
156            csr_indices[pos] = cols[k];
157            csr_data[pos] = vals[k];
158            offsets[r] += 1;
159        }
160
161        // Sort each row by column index and merge duplicates.
162        let mut final_indices = Vec::with_capacity(total);
163        let mut final_data = Vec::with_capacity(total);
164        let mut new_indptr = vec![0usize; n_rows + 1];
165
166        for i in 0..n_rows {
167            let start = indptr[i];
168            let end = indptr[i + 1];
169
170            // Sort by column index.
171            let mut pairs: Vec<(usize, f64)> = csr_indices[start..end]
172                .iter()
173                .copied()
174                .zip(csr_data[start..end].iter().copied())
175                .collect();
176            pairs.sort_by_key(|&(c, _)| c);
177
178            // Merge duplicates by summing (only within this row).
179            let row_start = final_indices.len();
180            for &(col, val) in &pairs {
181                // The guard `final_indices.len() > row_start` ensures non-empty.
182                if final_indices.len() > row_start && final_indices[final_indices.len() - 1] == col
183                {
184                    let last = final_data.len() - 1;
185                    final_data[last] += val;
186                    continue;
187                }
188                final_indices.push(col);
189                final_data.push(val);
190            }
191            new_indptr[i + 1] = final_indices.len();
192        }
193
194        Ok(Self {
195            indptr: new_indptr,
196            indices: final_indices,
197            data: final_data,
198            n_rows,
199            n_cols,
200        })
201    }
202
203    /// Convert a dense row-major matrix to CSR (zeros are skipped).
204    pub fn from_dense(rows: &[Vec<f64>]) -> Self {
205        let n_rows = rows.len();
206        let n_cols = if n_rows > 0 { rows[0].len() } else { 0 };
207
208        let mut indptr = vec![0usize; n_rows + 1];
209        let mut indices = Vec::new();
210        let mut data = Vec::new();
211
212        for (i, row) in rows.iter().enumerate() {
213            for (j, &val) in row.iter().enumerate() {
214                if val != 0.0 {
215                    indices.push(j);
216                    data.push(val);
217                }
218            }
219            indptr[i + 1] = indices.len();
220        }
221
222        Self {
223            indptr,
224            indices,
225            data,
226            n_rows,
227            n_cols,
228        }
229    }
230
231    /// Number of rows.
232    #[inline]
233    pub fn n_rows(&self) -> usize {
234        self.n_rows
235    }
236
237    /// Number of columns.
238    #[inline]
239    pub fn n_cols(&self) -> usize {
240        self.n_cols
241    }
242
243    /// Number of stored non-zero entries.
244    #[inline]
245    pub fn nnz(&self) -> usize {
246        self.data.len()
247    }
248
249    /// Fraction of non-zero entries: `nnz / (n_rows * n_cols)`.
250    ///
251    /// Returns 0.0 for an empty (0×0) matrix.
252    #[inline]
253    pub fn density(&self) -> f64 {
254        let total = self.n_rows * self.n_cols;
255        if total == 0 {
256            return 0.0;
257        }
258        self.nnz() as f64 / total as f64
259    }
260
261    /// View of row `i` as sparse `(col, value)` pairs.
262    pub fn row(&self, i: usize) -> SparseRow<'_> {
263        let start = self.indptr[i];
264        let end = self.indptr[i + 1];
265        SparseRow {
266            indices: &self.indices[start..end],
267            data: &self.data[start..end],
268        }
269    }
270
271    /// Retrieve a single element. Returns `0.0` if the entry is not stored.
272    pub fn get(&self, row: usize, col: usize) -> f64 {
273        let start = self.indptr[row];
274        let end = self.indptr[row + 1];
275        self.indices[start..end]
276            .binary_search(&col)
277            .map_or(0.0, |pos| self.data[start + pos])
278    }
279
280    /// Convert to CSC format in O(nnz).
281    pub fn to_csc(&self) -> CscMatrix {
282        let nnz = self.nnz();
283
284        // Count entries per column.
285        let mut col_counts = vec![0usize; self.n_cols];
286        for &c in &self.indices {
287            col_counts[c] += 1;
288        }
289
290        let mut indptr = vec![0usize; self.n_cols + 1];
291        for j in 0..self.n_cols {
292            indptr[j + 1] = indptr[j] + col_counts[j];
293        }
294
295        let mut csc_indices = vec![0usize; nnz];
296        let mut csc_data = vec![0.0f64; nnz];
297        let mut offsets = indptr[..self.n_cols].to_vec();
298
299        for i in 0..self.n_rows {
300            let start = self.indptr[i];
301            let end = self.indptr[i + 1];
302            for k in start..end {
303                let col = self.indices[k];
304                let pos = offsets[col];
305                csc_indices[pos] = i;
306                csc_data[pos] = self.data[k];
307                offsets[col] += 1;
308            }
309        }
310
311        CscMatrix {
312            indptr,
313            indices: csc_indices,
314            data: csc_data,
315            n_rows: self.n_rows,
316            n_cols: self.n_cols,
317        }
318    }
319
320    /// Convert to dense row-major format.
321    pub fn to_dense(&self) -> Vec<Vec<f64>> {
322        let mut dense = vec![vec![0.0; self.n_cols]; self.n_rows];
323        for (i, row) in dense.iter_mut().enumerate() {
324            let start = self.indptr[i];
325            let end = self.indptr[i + 1];
326            for k in start..end {
327                row[self.indices[k]] = self.data[k];
328            }
329        }
330        dense
331    }
332
333    /// Sparse matrix-vector multiply: `y = A * x`.
334    pub fn dot_vec(&self, x: &[f64]) -> Vec<f64> {
335        let mut y = vec![0.0; self.n_rows];
336        for (yi, i) in y.iter_mut().zip(0..self.n_rows) {
337            *yi = self.row(i).dot(x);
338        }
339        y
340    }
341}
342
343// ---------------------------------------------------------------------------
344// CscMatrix
345// ---------------------------------------------------------------------------
346
347/// Compressed Sparse Column matrix.
348///
349/// Efficient for column iteration (tree fit, linear algebra).
350/// `indptr[j]..indptr[j+1]` gives the range for column `j`.
351#[derive(Clone, Debug)]
352#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
353#[non_exhaustive]
354pub struct CscMatrix {
355    /// Column pointers: length `n_cols + 1`.
356    indptr: Vec<usize>,
357    /// Row indices for each non-zero element.
358    indices: Vec<usize>,
359    /// Non-zero values.
360    data: Vec<f64>,
361    n_rows: usize,
362    n_cols: usize,
363}
364
365impl CscMatrix {
366    /// Build a CSC matrix from COO (triplet) format.
367    ///
368    /// Duplicate entries at the same `(row, col)` are summed.
369    pub fn from_triplets(
370        rows: &[usize],
371        cols: &[usize],
372        vals: &[f64],
373        n_rows: usize,
374        n_cols: usize,
375    ) -> Result<Self> {
376        // Build as CSR then transpose — reuses all the validation/dedup logic.
377        let csr = CsrMatrix::from_triplets(rows, cols, vals, n_rows, n_cols)?;
378        Ok(csr.to_csc())
379    }
380
381    /// Convert a column-major dense matrix to CSC (zeros are skipped).
382    ///
383    /// `cols[j][i]` = value at row `i`, column `j`.
384    pub fn from_dense(cols: &[Vec<f64>]) -> Self {
385        let n_cols = cols.len();
386        let n_rows = if n_cols > 0 { cols[0].len() } else { 0 };
387
388        let mut indptr = vec![0usize; n_cols + 1];
389        let mut indices = Vec::new();
390        let mut data = Vec::new();
391
392        for (j, col) in cols.iter().enumerate() {
393            for (i, &val) in col.iter().enumerate() {
394                if val != 0.0 {
395                    indices.push(i);
396                    data.push(val);
397                }
398            }
399            indptr[j + 1] = indices.len();
400        }
401
402        Self {
403            indptr,
404            indices,
405            data,
406            n_rows,
407            n_cols,
408        }
409    }
410
411    /// Number of rows.
412    #[inline]
413    pub fn n_rows(&self) -> usize {
414        self.n_rows
415    }
416
417    /// Number of columns.
418    #[inline]
419    pub fn n_cols(&self) -> usize {
420        self.n_cols
421    }
422
423    /// Number of stored non-zero entries.
424    #[inline]
425    pub fn nnz(&self) -> usize {
426        self.data.len()
427    }
428
429    /// Fraction of non-zero entries.
430    #[inline]
431    pub fn density(&self) -> f64 {
432        let total = self.n_rows * self.n_cols;
433        if total == 0 {
434            return 0.0;
435        }
436        self.nnz() as f64 / total as f64
437    }
438
439    /// View of column `j` as sparse `(row, value)` pairs.
440    pub fn col(&self, j: usize) -> SparseCol<'_> {
441        let start = self.indptr[j];
442        let end = self.indptr[j + 1];
443        SparseCol {
444            indices: &self.indices[start..end],
445            data: &self.data[start..end],
446        }
447    }
448
449    /// Retrieve a single element. Returns `0.0` if not stored.
450    pub fn get(&self, row: usize, col: usize) -> f64 {
451        let start = self.indptr[col];
452        let end = self.indptr[col + 1];
453        self.indices[start..end]
454            .binary_search(&row)
455            .map_or(0.0, |pos| self.data[start + pos])
456    }
457
458    /// Convert to CSR format in O(nnz).
459    pub fn to_csr(&self) -> CsrMatrix {
460        let nnz = self.nnz();
461
462        // Count entries per row.
463        let mut row_counts = vec![0usize; self.n_rows];
464        for &r in &self.indices {
465            row_counts[r] += 1;
466        }
467
468        let mut indptr = vec![0usize; self.n_rows + 1];
469        for i in 0..self.n_rows {
470            indptr[i + 1] = indptr[i] + row_counts[i];
471        }
472
473        let mut csr_indices = vec![0usize; nnz];
474        let mut csr_data = vec![0.0f64; nnz];
475        let mut offsets = indptr[..self.n_rows].to_vec();
476
477        for j in 0..self.n_cols {
478            let start = self.indptr[j];
479            let end = self.indptr[j + 1];
480            for k in start..end {
481                let row = self.indices[k];
482                let pos = offsets[row];
483                csr_indices[pos] = j;
484                csr_data[pos] = self.data[k];
485                offsets[row] += 1;
486            }
487        }
488
489        CsrMatrix {
490            indptr,
491            indices: csr_indices,
492            data: csr_data,
493            n_rows: self.n_rows,
494            n_cols: self.n_cols,
495        }
496    }
497
498    /// Convert to dense row-major format.
499    pub fn to_dense(&self) -> Vec<Vec<f64>> {
500        // Build via CSR for clippy-friendly iteration.
501        self.to_csr().to_dense()
502    }
503
504    /// Sparse matrix-vector multiply: `y = A * x`.
505    pub fn dot_vec(&self, x: &[f64]) -> Vec<f64> {
506        let mut y = vec![0.0; self.n_rows];
507        for (j, &xj) in x.iter().enumerate() {
508            let start = self.indptr[j];
509            let end = self.indptr[j + 1];
510            for k in start..end {
511                y[self.indices[k]] += self.data[k] * xj;
512            }
513        }
514        y
515    }
516}
517
518// ---------------------------------------------------------------------------
519// Arithmetic: CsrMatrix + CsrMatrix, CsrMatrix * f64
520// ---------------------------------------------------------------------------
521
522impl ops::Add for &CsrMatrix {
523    type Output = CsrMatrix;
524
525    /// Element-wise addition of two CSR matrices with the same shape.
526    ///
527    /// # Panics
528    ///
529    /// Panics if the matrices have different shapes.
530    fn add(self, rhs: &CsrMatrix) -> CsrMatrix {
531        assert_eq!(
532            (self.n_rows, self.n_cols),
533            (rhs.n_rows, rhs.n_cols),
534            "CsrMatrix addition requires same shape"
535        );
536
537        let mut indptr = vec![0usize; self.n_rows + 1];
538        let mut indices = Vec::new();
539        let mut data = Vec::new();
540
541        for i in 0..self.n_rows {
542            let a_start = self.indptr[i];
543            let a_end = self.indptr[i + 1];
544            let b_start = rhs.indptr[i];
545            let b_end = rhs.indptr[i + 1];
546
547            let mut a = a_start;
548            let mut b = b_start;
549
550            // Merge two sorted column-index streams.
551            while a < a_end && b < b_end {
552                match self.indices[a].cmp(&rhs.indices[b]) {
553                    std::cmp::Ordering::Less => {
554                        indices.push(self.indices[a]);
555                        data.push(self.data[a]);
556                        a += 1;
557                    }
558                    std::cmp::Ordering::Greater => {
559                        indices.push(rhs.indices[b]);
560                        data.push(rhs.data[b]);
561                        b += 1;
562                    }
563                    std::cmp::Ordering::Equal => {
564                        let sum = self.data[a] + rhs.data[b];
565                        if sum != 0.0 {
566                            indices.push(self.indices[a]);
567                            data.push(sum);
568                        }
569                        a += 1;
570                        b += 1;
571                    }
572                }
573            }
574            while a < a_end {
575                indices.push(self.indices[a]);
576                data.push(self.data[a]);
577                a += 1;
578            }
579            while b < b_end {
580                indices.push(rhs.indices[b]);
581                data.push(rhs.data[b]);
582                b += 1;
583            }
584
585            indptr[i + 1] = indices.len();
586        }
587
588        CsrMatrix {
589            indptr,
590            indices,
591            data,
592            n_rows: self.n_rows,
593            n_cols: self.n_cols,
594        }
595    }
596}
597
598impl ops::Mul<f64> for &CsrMatrix {
599    type Output = CsrMatrix;
600
601    /// Scalar multiplication.
602    fn mul(self, rhs: f64) -> CsrMatrix {
603        CsrMatrix {
604            indptr: self.indptr.clone(),
605            indices: self.indices.clone(),
606            data: self.data.iter().map(|&v| v * rhs).collect(),
607            n_rows: self.n_rows,
608            n_cols: self.n_cols,
609        }
610    }
611}
612
613// ---------------------------------------------------------------------------
614// Tests
615// ---------------------------------------------------------------------------
616
617#[cfg(test)]
618#[allow(clippy::float_cmp)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_from_triplets_basic() {
624        // 3x3 matrix:
625        // [1 0 2]
626        // [0 3 0]
627        // [4 0 5]
628        let rows = vec![0, 0, 1, 2, 2];
629        let cols = vec![0, 2, 1, 0, 2];
630        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
631
632        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
633        assert_eq!(csr.n_rows(), 3);
634        assert_eq!(csr.n_cols(), 3);
635        assert_eq!(csr.nnz(), 5);
636        assert_eq!(csr.get(0, 0), 1.0);
637        assert_eq!(csr.get(0, 2), 2.0);
638        assert_eq!(csr.get(1, 1), 3.0);
639        assert_eq!(csr.get(2, 0), 4.0);
640        assert_eq!(csr.get(2, 2), 5.0);
641        assert_eq!(csr.get(0, 1), 0.0);
642        assert_eq!(csr.get(1, 0), 0.0);
643    }
644
645    #[test]
646    fn test_duplicate_entries_summed() {
647        let rows = vec![0, 0, 0];
648        let cols = vec![1, 1, 1];
649        let vals = vec![1.0, 2.0, 3.0];
650
651        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
652        assert_eq!(csr.nnz(), 1);
653        assert_eq!(csr.get(0, 1), 6.0);
654    }
655
656    #[test]
657    fn test_csr_csc_roundtrip() {
658        let rows = vec![0, 0, 1, 2, 2];
659        let cols = vec![0, 2, 1, 0, 2];
660        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
661
662        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
663        let csc = csr.to_csc();
664        let csr2 = csc.to_csr();
665
666        assert_eq!(csr.to_dense(), csr2.to_dense());
667    }
668
669    #[test]
670    fn test_dense_roundtrip() {
671        let dense = vec![
672            vec![1.0, 0.0, 2.0],
673            vec![0.0, 3.0, 0.0],
674            vec![4.0, 0.0, 5.0],
675        ];
676
677        let csr = CsrMatrix::from_dense(&dense);
678        assert_eq!(csr.to_dense(), dense);
679    }
680
681    #[test]
682    fn test_get_existing_and_missing() {
683        let csr = CsrMatrix::from_dense(&[vec![0.0, 7.0], vec![8.0, 0.0]]);
684        assert_eq!(csr.get(0, 1), 7.0);
685        assert_eq!(csr.get(1, 0), 8.0);
686        assert_eq!(csr.get(0, 0), 0.0);
687        assert_eq!(csr.get(1, 1), 0.0);
688    }
689
690    #[test]
691    fn test_dot_vec_csr() {
692        // [1 2] * [3] = [1*3+2*4] = [11]
693        // [0 3]   [4]   [0*3+3*4]   [12]
694        let csr = CsrMatrix::from_dense(&[vec![1.0, 2.0], vec![0.0, 3.0]]);
695        let result = csr.dot_vec(&[3.0, 4.0]);
696        assert_eq!(result, vec![11.0, 12.0]);
697    }
698
699    #[test]
700    fn test_dot_vec_csc() {
701        let dense = vec![vec![1.0, 2.0], vec![0.0, 3.0]];
702        let csr = CsrMatrix::from_dense(&dense);
703        let csc = csr.to_csc();
704        let result = csc.dot_vec(&[3.0, 4.0]);
705        assert_eq!(result, vec![11.0, 12.0]);
706    }
707
708    #[test]
709    fn test_sparse_row_iteration() {
710        let csr = CsrMatrix::from_dense(&[vec![0.0, 5.0, 0.0, 7.0]]);
711        let row = csr.row(0);
712        let pairs: Vec<(usize, f64)> = row.iter().collect();
713        assert_eq!(pairs, vec![(1, 5.0), (3, 7.0)]);
714        assert_eq!(row.nnz(), 2);
715    }
716
717    #[test]
718    fn test_sparse_col_iteration() {
719        let csr = CsrMatrix::from_dense(&[vec![1.0, 0.0], vec![0.0, 0.0], vec![3.0, 0.0]]);
720        let csc = csr.to_csc();
721        let col = csc.col(0);
722        let pairs: Vec<(usize, f64)> = col.iter().collect();
723        assert_eq!(pairs, vec![(0, 1.0), (2, 3.0)]);
724        assert_eq!(col.nnz(), 2);
725    }
726
727    #[test]
728    fn test_empty_matrix() {
729        // 0x0
730        let csr = CsrMatrix::from_triplets(&[], &[], &[], 0, 0).unwrap();
731        assert_eq!(csr.n_rows(), 0);
732        assert_eq!(csr.n_cols(), 0);
733        assert_eq!(csr.nnz(), 0);
734        assert_eq!(csr.density(), 0.0);
735
736        // 5x5 with no entries
737        let csr = CsrMatrix::from_triplets(&[], &[], &[], 5, 5).unwrap();
738        assert_eq!(csr.n_rows(), 5);
739        assert_eq!(csr.n_cols(), 5);
740        assert_eq!(csr.nnz(), 0);
741        assert_eq!(csr.density(), 0.0);
742        assert_eq!(csr.get(2, 3), 0.0);
743    }
744
745    #[test]
746    fn test_density() {
747        // 3x3 with 5 entries → 5/9
748        let rows = vec![0, 0, 1, 2, 2];
749        let cols = vec![0, 2, 1, 0, 2];
750        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
751        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
752        assert!((csr.density() - 5.0 / 9.0).abs() < 1e-10);
753    }
754
755    #[test]
756    fn test_large_sparse() {
757        // 1000x1000 with ~0.1% density.
758        let n = 1000;
759        let mut rng = fastrand::Rng::with_seed(42);
760        let target_nnz = (n * n) / 1000; // 0.1%
761
762        let mut rows = Vec::with_capacity(target_nnz);
763        let mut cols = Vec::with_capacity(target_nnz);
764        let mut vals = Vec::with_capacity(target_nnz);
765
766        for _ in 0..target_nnz {
767            rows.push(rng.usize(..n));
768            cols.push(rng.usize(..n));
769            vals.push(rng.f64() * 10.0);
770        }
771
772        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, n, n).unwrap();
773        assert_eq!(csr.n_rows(), n);
774        assert_eq!(csr.n_cols(), n);
775        // nnz may be less than target_nnz due to duplicate merging.
776        assert!(csr.nnz() <= target_nnz);
777        assert!(csr.nnz() > 0);
778        assert!(csr.density() < 0.002);
779
780        // Spot-check round-trip.
781        let csc = csr.to_csc();
782        let csr2 = csc.to_csr();
783        assert_eq!(csr.nnz(), csr2.nnz());
784    }
785
786    #[test]
787    fn test_from_dense_skips_zeros() {
788        let dense = vec![
789            vec![0.0, 0.0, 1.0],
790            vec![0.0, 0.0, 0.0],
791            vec![2.0, 0.0, 0.0],
792        ];
793        let csr = CsrMatrix::from_dense(&dense);
794        assert_eq!(csr.nnz(), 2);
795        assert_eq!(csr.get(0, 2), 1.0);
796        assert_eq!(csr.get(2, 0), 2.0);
797    }
798
799    #[test]
800    fn test_csr_add() {
801        let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 0.0]]);
802        let b = CsrMatrix::from_dense(&[vec![0.0, 4.0, 0.0], vec![5.0, 0.0, 6.0]]);
803        let c = &a + &b;
804        assert_eq!(
805            c.to_dense(),
806            vec![vec![1.0, 4.0, 2.0], vec![5.0, 3.0, 6.0],]
807        );
808    }
809
810    #[test]
811    fn test_csr_scalar_mul() {
812        let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 0.0]]);
813        let b = &a * 2.0;
814        assert_eq!(
815            b.to_dense(),
816            vec![vec![2.0, 0.0, 4.0], vec![0.0, 6.0, 0.0],]
817        );
818    }
819
820    #[test]
821    fn test_csc_from_triplets() {
822        let rows = vec![0, 1, 2];
823        let cols = vec![0, 1, 2];
824        let vals = vec![1.0, 2.0, 3.0];
825        let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
826        assert_eq!(csc.n_rows(), 3);
827        assert_eq!(csc.n_cols(), 3);
828        assert_eq!(csc.nnz(), 3);
829        assert_eq!(csc.get(0, 0), 1.0);
830        assert_eq!(csc.get(1, 1), 2.0);
831        assert_eq!(csc.get(2, 2), 3.0);
832        assert_eq!(csc.get(0, 1), 0.0);
833    }
834
835    #[test]
836    fn test_csc_from_dense() {
837        // Column-major: cols[j][i]
838        let cols = vec![
839            vec![1.0, 0.0, 4.0], // column 0
840            vec![0.0, 3.0, 0.0], // column 1
841            vec![2.0, 0.0, 5.0], // column 2
842        ];
843        let csc = CscMatrix::from_dense(&cols);
844        assert_eq!(csc.n_rows(), 3);
845        assert_eq!(csc.n_cols(), 3);
846        assert_eq!(csc.nnz(), 5);
847        assert_eq!(csc.get(0, 0), 1.0);
848        assert_eq!(csc.get(2, 0), 4.0);
849        assert_eq!(csc.get(1, 1), 3.0);
850        assert_eq!(csc.get(0, 2), 2.0);
851        assert_eq!(csc.get(2, 2), 5.0);
852    }
853
854    #[test]
855    fn test_sparse_row_dot() {
856        let csr = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
857        let row = csr.row(0);
858        assert!((row.dot(&[1.0, 10.0, 100.0]) - 320.0).abs() < 1e-10);
859    }
860
861    #[test]
862    fn test_csr_add_cancellation() {
863        // When elements cancel to zero, they should be dropped.
864        let a = CsrMatrix::from_dense(&[vec![1.0, 2.0]]);
865        let b = CsrMatrix::from_dense(&[vec![-1.0, -2.0]]);
866        let c = &a + &b;
867        assert_eq!(c.nnz(), 0);
868        assert_eq!(c.to_dense(), vec![vec![0.0, 0.0]]);
869    }
870
871    #[test]
872    fn test_from_triplets_cross_row_dedup() {
873        // Bug: rows ending with the same column as the next row's start
874        // were incorrectly merged across row boundaries.
875        // Row 0: col 2 = 1.0
876        // Row 1: col 2 = 3.0
877        // These must NOT merge.
878        let rows = vec![0, 1];
879        let cols = vec![2, 2];
880        let vals = vec![1.0, 3.0];
881
882        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
883        assert_eq!(csr.nnz(), 2);
884        assert_eq!(csr.get(0, 2), 1.0);
885        assert_eq!(csr.get(1, 2), 3.0);
886    }
887
888    #[test]
889    fn test_from_triplets_intra_row_dedup() {
890        // Duplicate entries within the same row should still be summed.
891        let rows = vec![0, 0, 1, 1];
892        let cols = vec![1, 1, 2, 2];
893        let vals = vec![1.0, 2.0, 3.0, 4.0];
894
895        let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
896        assert_eq!(csr.nnz(), 2);
897        assert_eq!(csr.get(0, 1), 3.0); // 1.0 + 2.0
898        assert_eq!(csr.get(1, 2), 7.0); // 3.0 + 4.0
899    }
900
901    #[test]
902    fn test_csc_from_triplets_cross_row_dedup() {
903        // Same bug via the CscMatrix path (CSR → transpose).
904        let rows = vec![0, 1];
905        let cols = vec![2, 2];
906        let vals = vec![1.0, 3.0];
907
908        let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
909        assert_eq!(csc.nnz(), 2);
910        assert_eq!(csc.get(0, 2), 1.0);
911        assert_eq!(csc.get(1, 2), 3.0);
912    }
913
914    #[test]
915    fn test_csc_from_triplets_roundtrip_with_dupes() {
916        // Build CSC with duplicate entries, convert to CSR and back.
917        let rows = vec![0, 0, 1, 2, 2];
918        let cols = vec![0, 0, 1, 0, 2]; // row 0, col 0 has duplicates
919        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
920
921        let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
922        assert_eq!(csc.get(0, 0), 3.0); // 1.0 + 2.0
923        assert_eq!(csc.get(1, 1), 3.0);
924        assert_eq!(csc.get(2, 0), 4.0);
925        assert_eq!(csc.get(2, 2), 5.0);
926
927        // Round-trip.
928        let csr = csc.to_csr();
929        let csc2 = csr.to_csc();
930        assert_eq!(csc.to_dense(), csc2.to_dense());
931    }
932
933    #[test]
934    fn test_sparse_row_accessors() {
935        let csr = CsrMatrix::from_dense(&[vec![0.0, 5.0, 0.0, 7.0]]);
936        let row = csr.row(0);
937        assert_eq!(row.indices(), &[1, 3]);
938        assert_eq!(row.values(), &[5.0, 7.0]);
939    }
940}